Explorar o código

TSL: Introduce ShaderCallNode & `tslFn` improvements (#26824)

* Add ShaderCallNode

* cleanup

* Use tslFn as default
sunag hai 1 ano
pai
achega
9d2d7ebc78

+ 2 - 2
examples/jsm/nodes/procedural/CheckerNode.js

@@ -25,9 +25,9 @@ class CheckerNode extends TempNode {
 
 
 	}
 	}
 
 
-	generate( builder ) {
+	construct() {
 
 
-		return checkerShaderNode( { uv: this.uvNode } ).build( builder );
+		return checkerShaderNode( { uv: this.uvNode } );
 
 
 	}
 	}
 
 

+ 56 - 18
examples/jsm/nodes/shadernode/ShaderNode.js

@@ -196,29 +196,33 @@ const ShaderNodeImmutable = function ( NodeClass, ...params ) {
 
 
 };
 };
 
 
-class ShaderNodeInternal extends Node {
+class ShaderCallNodeInternal extends Node {
 
 
-	constructor( jsFunc ) {
+	constructor( shaderNode, inputNodes ) {
 
 
 		super();
 		super();
 
 
-		this._jsFunc = jsFunc;
+		this.shaderNode = shaderNode;
+		this.inputNodes = inputNodes;
 
 
 	}
 	}
 
 
-	call( inputs, stack, builder ) {
+	getNodeType( builder ) {
 
 
-		inputs = nodeObjects( inputs );
+		const { outputNode } = builder.getNodeProperties( this );
 
 
-		return nodeObject( this._jsFunc( inputs, stack, builder ) );
+		return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
 
 
 	}
 	}
 
 
-	getNodeType( builder ) {
+	call( builder ) {
 
 
-		const { outputNode } = builder.getNodeProperties( this );
+		const { shaderNode, inputNodes } = this;
 
 
-		return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
+		const jsFunc = shaderNode.jsFunc;
+		const outputNode = inputNodes !== null ? jsFunc( nodeObjects( inputNodes ), builder.stack, builder ) : jsFunc( builder.stack, builder );
+
+		return nodeObject( outputNode );
 
 
 	}
 	}
 
 
@@ -226,12 +230,52 @@ class ShaderNodeInternal extends Node {
 
 
 		builder.addStack();
 		builder.addStack();
 
 
-		builder.stack.outputNode = nodeObject( this._jsFunc( builder.stack, builder ) );
+		builder.stack.outputNode = this.call( builder );
 
 
 		return builder.removeStack();
 		return builder.removeStack();
 
 
 	}
 	}
 
 
+	generate( builder, output ) {
+
+		const { outputNode } = builder.getNodeProperties( this );
+
+		if ( outputNode === null ) {
+
+			// TSL: It's recommended to use `tslFn` in construct() pass.
+
+			return this.call( builder ).build( builder, output );
+
+		}
+
+		return super.generate( builder, output );
+
+	}
+
+}
+
+class ShaderNodeInternal extends Node {
+
+	constructor( jsFunc ) {
+
+		super();
+
+		this.jsFunc = jsFunc;
+
+	}
+
+	call( inputs = null ) {
+
+		return nodeObject( new ShaderCallNodeInternal( this, inputs ) );
+
+	}
+
+	construct() {
+
+		return this.call();
+
+	}
+
 }
 }
 
 
 const bools = [ false, true ];
 const bools = [ false, true ];
@@ -349,15 +393,9 @@ export const shader = ( jsFunc ) => { // @deprecated, r154
 
 
 export const tslFn = ( jsFunc ) => {
 export const tslFn = ( jsFunc ) => {
 
 
-	let shaderNode = null;
+	const shaderNode = new ShaderNode( jsFunc );
 
 
-	return ( ...params ) => {
-
-		if ( shaderNode === null ) shaderNode = new ShaderNode( jsFunc );
-
-		return shaderNode.call( ...params );
-
-	};
+	return ( inputs ) => shaderNode.call( inputs );
 
 
 };
 };
 
 

+ 8 - 8
examples/jsm/nodes/utils/LoopNode.js

@@ -1,7 +1,7 @@
 import Node, { addNodeClass } from '../core/Node.js';
 import Node, { addNodeClass } from '../core/Node.js';
 import { expression } from '../code/ExpressionNode.js';
 import { expression } from '../code/ExpressionNode.js';
 import { bypass } from '../core/BypassNode.js';
 import { bypass } from '../core/BypassNode.js';
-import { context as contextNode } from '../core/ContextNode.js';
+import { context } from '../core/ContextNode.js';
 import { addNodeElement, nodeObject, nodeArray } from '../shadernode/ShaderNode.js';
 import { addNodeElement, nodeObject, nodeArray } from '../shadernode/ShaderNode.js';
 
 
 class LoopNode extends Node {
 class LoopNode extends Node {
@@ -65,13 +65,11 @@ class LoopNode extends Node {
 
 
 		const properties = this.getProperties( builder );
 		const properties = this.getProperties( builder );
 
 
-		const context = { tempWrite: false };
+		const contextData = { tempWrite: false };
 
 
 		const params = this.params;
 		const params = this.params;
 		const stackNode = properties.stackNode;
 		const stackNode = properties.stackNode;
 
 
-		const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';
-
 		for ( let i = 0, l = params.length - 1; i < l; i ++ ) {
 		for ( let i = 0, l = params.length - 1; i < l; i ++ ) {
 
 
 			const param = params[ i ];
 			const param = params[ i ];
@@ -82,7 +80,7 @@ class LoopNode extends Node {
 			if ( param.isNode ) {
 			if ( param.isNode ) {
 
 
 				start = '0';
 				start = '0';
-				end = param.generate( builder, 'int' );
+				end = param.build( builder, 'int' );
 				direction = 'forward';
 				direction = 'forward';
 
 
 			} else {
 			} else {
@@ -92,10 +90,10 @@ class LoopNode extends Node {
 				direction = param.direction;
 				direction = param.direction;
 
 
 				if ( typeof start === 'number' ) start = start.toString();
 				if ( typeof start === 'number' ) start = start.toString();
-				else if ( start && start.isNode ) start = start.generate( builder, 'int' );
+				else if ( start && start.isNode ) start = start.build( builder, 'int' );
 
 
 				if ( typeof end === 'number' ) end = end.toString();
 				if ( typeof end === 'number' ) end = end.toString();
-				else if ( end && end.isNode ) end = end.generate( builder, 'int' );
+				else if ( end && end.isNode ) end = end.build( builder, 'int' );
 
 
 				if ( start !== undefined && end === undefined ) {
 				if ( start !== undefined && end === undefined ) {
 
 
@@ -159,7 +157,9 @@ class LoopNode extends Node {
 
 
 		}
 		}
 
 
-		const stackSnippet = contextNode( stackNode, context ).build( builder, 'void' );
+		const stackSnippet = context( stackNode, contextData ).build( builder, 'void' );
+
+		const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';
 
 
 		builder.removeFlowTab().addFlowCode( '\n' + builder.tab + stackSnippet );
 		builder.removeFlowTab().addFlowCode( '\n' + builder.tab + stackSnippet );
 
 

+ 3 - 4
examples/webgpu_audio_processing.html

@@ -31,7 +31,7 @@
 		<script type="module">
 		<script type="module">
 
 
 			import * as THREE from 'three';
 			import * as THREE from 'three';
-			import { ShaderNode, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
+			import { tslFn, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
 
 
 			import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
 			import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
 
 
@@ -136,11 +136,10 @@
 
 
 				// compute (shader-node)
 				// compute (shader-node)
 
 
-				const computeShaderNode = new ShaderNode( ( stack ) => {
+				const computeShaderFn = tslFn( ( stack ) => {
 
 
 					const index = float( instanceIndex );
 					const index = float( instanceIndex );
 
 
-
 					// pitch
 					// pitch
 
 
 					const time = index.mul( pitch );
 					const time = index.mul( pitch );
@@ -171,7 +170,7 @@
 
 
 				// compute
 				// compute
 
 
-				computeNode = computeShaderNode.compute( waveBuffer.length );
+				computeNode = computeShaderFn().compute( waveBuffer.length );
 
 
 
 
 				// gui
 				// gui

+ 5 - 5
examples/webgpu_compute.html

@@ -26,7 +26,7 @@
 		<script type="module">
 		<script type="module">
 
 
 			import * as THREE from 'three';
 			import * as THREE from 'three';
-			import { ShaderNode, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
+			import { tslFn, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
 
 
 			import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
 			import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
 
 
@@ -74,7 +74,7 @@
 
 
 				// create function
 				// create function
 
 
-				const computeShaderNode = new ShaderNode( ( stack ) => {
+				const computeShaderFn = tslFn( ( stack ) => {
 
 
 					const particle = particleBufferNode.element( instanceIndex );
 					const particle = particleBufferNode.element( instanceIndex );
 					const velocity = velocityBufferNode.element( instanceIndex );
 					const velocity = velocityBufferNode.element( instanceIndex );
@@ -98,10 +98,10 @@
 
 
 				// compute
 				// compute
 
 
-				computeNode = computeShaderNode.compute( particleNum );
+				computeNode = computeShaderFn().compute( particleNum );
 				computeNode.onInit = ( { renderer } ) => {
 				computeNode.onInit = ( { renderer } ) => {
 
 
-					const precomputeShaderNode = new ShaderNode( ( stack ) => {
+					const precomputeShaderNode = tslFn( ( stack ) => {
 
 
 						const particleIndex = float( instanceIndex );
 						const particleIndex = float( instanceIndex );
 
 
@@ -117,7 +117,7 @@
 
 
 					} );
 					} );
 
 
-					renderer.compute( precomputeShaderNode.compute( particleNum ) );
+					renderer.compute( precomputeShaderNode().compute( particleNum ) );
 
 
 				};
 				};
 
 

+ 7 - 7
examples/webgpu_compute_particles.html

@@ -26,7 +26,7 @@
 		<script type="module">
 		<script type="module">
 
 
 			import * as THREE from 'three';
 			import * as THREE from 'three';
-			import { ShaderNode, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
+			import { tslFn, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
 
 
 			import WebGPU from 'three/addons/capabilities/WebGPU.js';
 			import WebGPU from 'three/addons/capabilities/WebGPU.js';
 			import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
 			import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
@@ -83,7 +83,7 @@
 
 
 				// compute
 				// compute
 
 
-				const computeInit = new ShaderNode( ( stack ) => {
+				const computeInit = tslFn( ( stack ) => {
 
 
 					const position = positionBuffer.element( instanceIndex );
 					const position = positionBuffer.element( instanceIndex );
 					const color = colorBuffer.element( instanceIndex );
 					const color = colorBuffer.element( instanceIndex );
@@ -98,11 +98,11 @@
 
 
 					stack.assign( color, vec3( randX, randY, randZ ) );
 					stack.assign( color, vec3( randX, randY, randZ ) );
 
 
-				} ).compute( particleCount );
+				} )().compute( particleCount );
 
 
 				//
 				//
 
 
-				const computeUpdate = new ShaderNode( ( stack ) => {
+				const computeUpdate = tslFn( ( stack ) => {
 
 
 					const position = positionBuffer.element( instanceIndex );
 					const position = positionBuffer.element( instanceIndex );
 					const velocity = velocityBuffer.element( instanceIndex );
 					const velocity = velocityBuffer.element( instanceIndex );
@@ -128,7 +128,7 @@
 
 
 				} );
 				} );
 
 
-				computeParticles = computeUpdate.compute( particleCount );
+				computeParticles = computeUpdate().compute( particleCount );
 
 
 				// create nodes
 				// create nodes
 
 
@@ -179,7 +179,7 @@
 
 
 				// click event
 				// click event
 
 
-				const computeHit = new ShaderNode( ( stack ) => {
+				const computeHit = tslFn( ( stack ) => {
 
 
 					const position = positionBuffer.element( instanceIndex );
 					const position = positionBuffer.element( instanceIndex );
 					const velocity = velocityBuffer.element( instanceIndex );
 					const velocity = velocityBuffer.element( instanceIndex );
@@ -193,7 +193,7 @@
 
 
 					stack.assign( velocity, velocity.add( direction.mul( relativePower ) ) );
 					stack.assign( velocity, velocity.add( direction.mul( relativePower ) ) );
 
 
-				} ).compute( particleCount );
+				} )().compute( particleCount );
 
 
 				//
 				//
 
 

+ 1 - 1
examples/webgpu_materials.html

@@ -29,7 +29,7 @@
 			import * as THREE from 'three';
 			import * as THREE from 'three';
 			import * as Nodes from 'three/nodes';
 			import * as Nodes from 'three/nodes';
 
 
-			import { tslFn, wgslFn, attribute, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
+			import { tslFn, wgslFn, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
 
 
 			import WebGPU from 'three/addons/capabilities/WebGPU.js';
 			import WebGPU from 'three/addons/capabilities/WebGPU.js';
 			import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
 			import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';