Browse Source

fixed binding groups for params/globals

Nicolas Cannasse 1 năm trước cách đây
mục cha
commit
bb83ba2b91
2 tập tin đã thay đổi với 42 bổ sung30 xóa
  1. 35 25
      h3d/impl/WebGpuDriver.hx
  2. 7 5
      hxsl/WgslOut.hx

+ 35 - 25
h3d/impl/WebGpuDriver.hx

@@ -7,8 +7,10 @@ import h3d.mat.Pass;
 class WebGpuSubShader {
 	public var kind : GPUShaderStage;
 	public var module : GPUShaderModule;
-	public var groups : Array<GPUBindGroupLayout>;
 	public var paramsBufferSize : Int = 0;
+	public var globalsBufferSize : Int = 0;
+	public var paramsGroup : Int = 0;
+	public var globalsGroup : Int = 0;
 	public function new() {
 	}
 }
@@ -20,6 +22,7 @@ class WebGpuShader {
 	public var layout : GPUPipelineLayout;
 	public var inputCount : Int;
 	public var pipelines : PipelineCache<GPURenderPipeline> = new PipelineCache();
+	public var groups : Array<GPUBindGroupLayout>;
 	public function new() {
 	}
 }
@@ -296,20 +299,17 @@ class WebGpuDriver extends h3d.impl.Driver {
 		buf.unmap();
 	}
 
-
-	function compile( shader : hxsl.RuntimeShader.RuntimeShaderData, kind ) {
+	function compile( parent : WebGpuShader, shader : hxsl.RuntimeShader.RuntimeShaderData, kind ) {
 		var comp = new hxsl.WgslOut();
-		var source = comp.run(shader.data);
 		var sh = new WebGpuSubShader();
 		sh.kind = kind;
-		sh.module = device.createShaderModule({ code : source });
-		sh.groups = [];
 		for( v in shader.data.vars ) {
 			switch( v.kind ) {
-			case Param:
+			case Param, Global:
 				var size = hxsl.Ast.Tools.size(v.type) * 4;
+				var index = parent.groups.length;
 				var g = device.createBindGroupLayout({ entries : [{
-					binding: 0,
+					binding: index,
 					visibility : kind,
 					buffer : {
 						type : Uniform,
@@ -317,11 +317,23 @@ class WebGpuDriver extends h3d.impl.Driver {
 						minBindingSize: size,
 					}
 				}]});
-				sh.paramsBufferSize = size;
-				sh.groups.push(g);
+				switch( v.kind ) {
+				case Param:
+					sh.paramsBufferSize = size;
+					sh.paramsGroup = index;
+					comp.paramsBinding = comp.paramsGroup = index;
+				case Global:
+					sh.globalsBufferSize = size;
+					sh.globalsGroup = index;
+					comp.globalsBinding = comp.globalsGroup = index;
+				default:
+				}
+				parent.groups.push(g);
 			default:
 			}
 		}
+		var source = comp.run(shader.data);
+		sh.module = device.createShaderModule({ code : source });
 		return sh;
 	}
 
@@ -336,10 +348,11 @@ class WebGpuDriver extends h3d.impl.Driver {
 			default:
 			}
 		}
+		sh.groups = [];
 		sh.format = hxd.BufferFormat.make(format);
-		sh.vertex = compile(shader.vertex,VERTEX);
-		sh.fragment = compile(shader.fragment,FRAGMENT);
-		sh.layout = device.createPipelineLayout({ bindGroupLayouts: sh.vertex.groups.concat(sh.fragment.groups) });
+		sh.vertex = compile(sh,shader.vertex,VERTEX);
+		sh.fragment = compile(sh,shader.fragment,FRAGMENT);
+		sh.layout = device.createPipelineLayout({ bindGroupLayouts: sh.groups });
 		sh.inputCount = format.length;
 		return sh;
 	}
@@ -382,32 +395,29 @@ class WebGpuDriver extends h3d.impl.Driver {
 
 	function _uploadShaderBuffers(buffers:h3d.shader.Buffers.ShaderBuffers, which:h3d.shader.Buffers.BufferKind, sh:WebGpuSubShader) {
 		switch( which ) {
-		case Globals:
-			if( buffers.globals.length == 0 )
-				return;
-			throw "TODO";
-		case Params:
-			if( sh.paramsBufferSize == 0 )
+		case Globals, Params:
+			var size = which == Params ? sh.paramsBufferSize : sh.globalsBufferSize;
+			if( size == 0 )
 				return;
-			var flags = new haxe.EnumFlags();
 			var buffer = device.createBuffer({
-				size : sh.paramsBufferSize,
+				size : size,
 				usage : UNIFORM,
 				mappedAtCreation : true,
 			});
 			var map = buffer.getMappedRange();
-			new js.lib.Float32Array(map).set(cast buffers.params);
+			new js.lib.Float32Array(map).set(cast (which == Params ? buffers.params : buffers.globals));
 			buffer.unmap();
+			var index = which == Params ? sh.paramsGroup : sh.globalsGroup;
 			var group = device.createBindGroup({
-				layout : sh.groups[0],
+				layout : currentShader.groups[index],
 				entries: [{
-					binding: 0,
+					binding: index,
 					resource: {
 						buffer : buffer,
 					}
 				}],
 			});
-			renderPass.setBindGroup(0, group);
+			renderPass.setBindGroup(index, group);
 		case Textures:
 			if( buffers.tex.length == 0 )
 				return;

+ 7 - 5
hxsl/WgslOut.hx

@@ -26,6 +26,10 @@ class WgslOut {
 	var allNames : Map<String, Int>;
 	var hasVarying : Bool;
 	public var varNames : Map<Int,String>;
+	public var paramsGroup : Int = 0;
+	public var paramsBinding : Int = 0;
+	public var globalsGroup : Int = 0;
+	public var globalsBinding : Int = 0;
 
 	var varAccess : Map<Int,String>;
 
@@ -454,14 +458,12 @@ class WgslOut {
 		if( !found )
 			return;
 
-		add('struct _globals {\n');
 		for( v in s.vars )
 			if( v.kind == Global ) {
-				add("\t");
+				add('@group(${globalsGroup}) @binding(${globalsBinding}) var<uniform> ');
 				addVar(v);
-				add(",\n");
+				add(";\n");
 			}
-		add("};\n\n");
 	}
 
 	function initParams( s : ShaderData ) {
@@ -491,7 +493,7 @@ class WgslOut {
 						continue;
 					}
 				}
-				add("@group(0) @binding(0) var<uniform> ");
+				add('@group($paramsGroup) @binding($paramsBinding) var<uniform> ');
 				addVar(v);
 				add(";\n");
 			}