Răsfoiți Sursa

WebGPUComputePipelines: Use WebGPUProgrammableStage. (#21757)

Michael Herzog 4 ani în urmă
părinte
comite
31d1746a17

+ 14 - 19
examples/jsm/renderers/webgpu/WebGPUComputePipelines.js

@@ -1,3 +1,5 @@
+import WebGPUProgrammableStage from './WebGPUProgrammableStage.js';
+
 class WebGPUComputePipelines {
 
 	constructor( device, glslang ) {
@@ -6,7 +8,7 @@ class WebGPUComputePipelines {
 		this.glslang = glslang;
 
 		this.pipelines = new WeakMap();
-		this.shaderModules = {
+		this.stages = {
 			compute: new WeakMap()
 		};
 
@@ -16,38 +18,31 @@ class WebGPUComputePipelines {
 
 		let pipeline = this.pipelines.get( param );
 
+		// @TODO: Reuse compute pipeline if possible
+
 		if ( pipeline === undefined ) {
 
 			const device = this.device;
+			const glslang = this.glslang;
+
 			const shader = {
 				computeShader: param.shader
 			};
 
-			// shader modules
-
-			const glslang = this.glslang;
-
-			let moduleCompute = this.shaderModules.compute.get( shader );
+			// programmable stage
 
-			if ( moduleCompute === undefined ) {
+			let stageCompute = this.stages.compute.get( shader );
 
-				const byteCodeCompute = glslang.compileGLSL( shader.computeShader, 'compute' );
+			if ( stageCompute === undefined ) {
 
-				moduleCompute = device.createShaderModule( { code: byteCodeCompute } );
+ 				stageCompute = new WebGPUProgrammableStage( device, glslang, shader.computeShader, 'compute' );
 
-				this.shaderModules.compute.set( shader, moduleCompute );
+				this.stages.compute.set( shader, stageCompute );
 
 			}
 
-			//
-
-			const compute = {
-				module: moduleCompute,
-				entryPoint: 'main'
-			};
-
 			pipeline = device.createComputePipeline( {
-				compute: compute
+				compute: stageCompute.stage
 			} );
 
 			this.pipelines.set( param, pipeline );
@@ -61,7 +56,7 @@ class WebGPUComputePipelines {
 	dispose() {
 
 		this.pipelines = new WeakMap();
-		this.shaderModules = {
+		this.stages = {
 			compute: new WeakMap()
 		};
 

+ 3 - 3
examples/jsm/renderers/webgpu/WebGPURenderPipelines.js

@@ -61,9 +61,6 @@ class WebGPURenderPipelines {
 
 			}
 
-			stageVertex.usedTimes ++;
-			stageFragment.usedTimes ++;
-
 			// determine render pipeline
 
 			currentPipeline = this._acquirePipeline( stageVertex, stageFragment, object, nodeBuilder );
@@ -83,7 +80,10 @@ class WebGPURenderPipelines {
 			if ( materialPipelines.has( currentPipeline ) === false ) {
 
 				materialPipelines.add( currentPipeline );
+
 				currentPipeline.usedTimes ++;
+				stageVertex.usedTimes ++;
+				stageFragment.usedTimes ++;
 
 			}