Browse Source

fixed uav texture register / desc table

Nicolas Cannasse 1 year ago
parent
commit
5ca7b092a3
2 changed files with 94 additions and 42 deletions
  1. 87 35
      h3d/impl/DX12Driver.hx
  2. 7 7
      hxsl/HlslOut.hx

+ 87 - 35
h3d/impl/DX12Driver.hx

@@ -137,8 +137,9 @@ class CompiledShader {
 	@:packed public var bufferSRV(default,null) : BufferSRV;
 	@:packed public var samplerDesc(default,null) : SamplerDesc;
 	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
-	@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
 	@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;
+	@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
+	@:packed public var wtexDesc(default,null) : UAVTextureViewDesc;
 
 	public var pass : h3d.mat.Pass;
 
@@ -899,12 +900,12 @@ class DX12Driver extends h3d.impl.Driver {
 		var args = [];
 		var out = new hxsl.HlslOut();
 		out.baseRegister = baseRegister;
-		if ( sh.code == null ) {
+		if( sh.code == null ) {
 			sh.code = out.run(sh.data);
 			sh.code = rootStr + sh.code;
 		}
 		var bytes = getBinaryPayload(sh.code);
-		if ( bytes == null ) {
+		if( bytes == null ) {
 			return compiler.compile(sh.code, profile, args);
 		}
 		return bytes;
@@ -937,7 +938,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 		for ( i in 0...paramsCount ) {
 			var param = params[i];
-			var vis = 'SHADER_VISIBILITY_${param.shaderVisibility == VERTEX ? "VERTEX" : "PIXEL"}';
+			var vis = "SHADER_VISIBILITY_"+switch( param.shaderVisibility ) { case VERTEX: "VERTEX"; case PIXEL: "PIXEL"; default: "ALL"; };
 			if ( param.parameterType == CONSTANTS ) {
 				var shaderRegister = param.shaderRegister;
 				s += 'RootConstants(num32BitConstants=${param.num32BitValues},b${shaderRegister}, visibility=${vis}),';
@@ -956,7 +957,8 @@ class DX12Driver extends h3d.impl.Driver {
 						var baseShaderRegister = descRange.baseShaderRegister;
 						s += 'DescriptorTable(Sampler(s${baseShaderRegister}, space=${descRange.registerSpace}, numDescriptors = ${descRange.numDescriptors}), visibility = ${vis}),';
 					case UAV:
-						throw "Not supported";
+						var reg = descRange.baseShaderRegister;
+						s += 'UAV(u${reg}, visibility = ${vis})';
 					}
 				} catch ( e : Dynamic ) {
 					continue;
@@ -1028,8 +1030,8 @@ class DX12Driver extends h3d.impl.Driver {
 			var regs = new ShaderRegisters();
 			regs.globals = allocConsts(sh.globalsSize, vis, null);
 			regs.params = allocConsts(sh.paramsSize, vis, (sh.kind == Fragment ? fragmentParamsCBV : vertexParamsCBV) ? CBV : null);
+			regs.buffers = paramsCount;
 			if( sh.bufferCount > 0 ) {
-				regs.buffers = paramsCount;
 				regs.bufferTypes = [];
 				var p = sh.buffers;
 				while( p != null ) {
@@ -1046,33 +1048,41 @@ class DX12Driver extends h3d.impl.Driver {
 				}
 			}
 			if( sh.texturesCount > 0 ) {
-				regs.texturesCount = sh.texturesCount;
-				regs.textures = paramsCount;
+				regs.texturesCount = 0;
 				regs.texturesTypes = [];
 
-				var p = sh.textures;
-				while( p != null ) {
-					switch( p.type ) {
-					case TArray( t = TSampler(_) | TRWTexture(_) , SConst(n) ):
+				var p = sh.data.vars;
+				for( v in sh.data.vars ) {
+					switch( v.type ) {
+					case TArray(t = TSampler(_) | TRWTexture(_), SConst(n)):
 						for( i in 0...n )
 							regs.texturesTypes.push(t);
+						if( t.match(TSampler(_)) )
+							regs.texturesCount += n;
+						else {
+							for( i in 0...n )
+								allocConsts(1, vis, UAV);
+						}
 					default:
 					}
-					p = p.next;
 				}
 
-				var r = allocDescTable(vis);
-				r.rangeType = SRV;
-				r.baseShaderRegister = 0;
-				r.registerSpace = 0;
-				r.numDescriptors = sh.texturesCount;
-
-				regs.samplers = paramsCount;
-				var r = allocDescTable(vis);
-				r.rangeType = SAMPLER;
-				r.baseShaderRegister = 0;
-				r.registerSpace = 0;
-				r.numDescriptors = sh.texturesCount;
+				if( regs.texturesCount > 0 ) {
+					regs.textures = paramsCount;
+
+					var r = allocDescTable(vis);
+					r.rangeType = SRV;
+					r.baseShaderRegister = 0;
+					r.registerSpace = 0;
+					r.numDescriptors = regs.texturesCount;
+
+					regs.samplers = paramsCount;
+					var r = allocDescTable(vis);
+					r.rangeType = SAMPLER;
+					r.baseShaderRegister = 0;
+					r.registerSpace = 0;
+					r.numDescriptors = regs.texturesCount;
+				}
 			}
 			return regs;
 		}
@@ -1404,6 +1414,8 @@ class DX12Driver extends h3d.impl.Driver {
 			clear.color.a = color.a;
 			td.color = color;
 		}
+		if( t.flags.has(Writable) )
+			desc.flags.set(ALLOW_UNORDERED_ACCESS);
 
 		td.state = isRT ? RENDER_TARGET : COPY_DEST;
 		td.res = Driver.createCommittedResource(tmp.heap, flags, desc, isRT ? RENDER_TARGET : COMMON, clear);
@@ -1595,10 +1607,12 @@ class DX12Driver extends h3d.impl.Driver {
 					frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
 			}
 		case Textures:
-			if( regs.texturesCount > 0 ) {
+			if( shader.texturesCount > 0 ) {
 				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
 				var sampler = frame.samplerViews.alloc(regs.texturesCount);
-				for( i in 0...regs.texturesCount ) {
+				var regIndex = regs.buffers + shader.bufferCount;
+				var outIndex = 0;
+				for( i in 0...shader.texturesCount ) {
 					var t = buf.tex[i];
 					var pt = regs.texturesTypes[i];
 					if( t == null || t.isDisposed() ) {
@@ -1629,6 +1643,41 @@ class DX12Driver extends h3d.impl.Driver {
 						}
 					}
 
+					switch( pt ) {
+					case TRWTexture(dim,arr,chans):
+						var tdim : hxsl.Ast.TexDimension = t.flags.has(Cube) ? TCube : T2D;
+						var fmt;
+						if( (arr != t.flags.has(IsArray)) || dim != tdim )
+							throw "Texture format does not match: "+t+"["+t.format+"] should be "+hxsl.Ast.Tools.toString(pt);
+						var srv = frame.shaderResourceViews.alloc(1);
+						if( !t.flags.has(Writable) )
+							throw "Texture was allocated without Writable flag";
+						transition(t.t, UNORDERED_ACCESS);
+						var desc = tmp.wtexDesc;
+						desc.format = cast getTextureFormat(t);
+						desc.viewDimension = switch( [dim,arr] ) {
+						case [T1D, false]: TEXTURE1D;
+						case [T2D, false]: TEXTURE2D;
+						case [T3D, false]: TEXTURE3D;
+						case [T1D, true]: TEXTURE1DARRAY;
+						case [T2D, true]: TEXTURE2DARRAY;
+						default: throw "Unsupported RWTexture "+t;
+						}
+						desc.mipSlice = 0;
+						desc.planeSlice = 0;
+						if( arr ) {
+							desc.firstArraySlice = 0;
+							desc.arraySize = 1;
+						}
+						Driver.createUnorderedAccessView(t.t.res, null, desc, srv);
+						if( currentShader.isCompute )
+							frame.commandList.setComputeRootDescriptorTable(regIndex++, frame.shaderResourceViews.toGPU(srv));
+						else
+							frame.commandList.setGraphicsRootDescriptorTable(regIndex++, frame.shaderResourceViews.toGPU(srv));
+						continue;
+					default:
+					}
+
 					var tdesc : ShaderResourceViewDesc;
 					if( t.flags.has(Cube) ) {
 						var desc = tmp.texCubeSRV;
@@ -1660,7 +1709,7 @@ class DX12Driver extends h3d.impl.Driver {
 					else
 						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
-					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
+					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(outIndex * frame.shaderResourceViews.stride));
 
 					var desc = tmp.samplerDesc;
 					desc.filter = switch( [t.filter, t.mipMap] ) {
@@ -1674,15 +1723,18 @@ class DX12Driver extends h3d.impl.Driver {
 					case Repeat: WRAP;
 					}
 					desc.mipLODBias = t.lodBias;
-					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
+					Driver.createSampler(desc, sampler.offset(outIndex * frame.samplerViews.stride));
+					outIndex++;
 				}
 
-				if( currentShader.isCompute ) {
-					frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
-					frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
-				} else {
-					frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
-					frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				if( regs.texturesCount > 0 ) {
+					if( currentShader.isCompute ) {
+						frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+						frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+					} else {
+						frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+						frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+					}
 				}
 			}
 		case Buffers:

+ 7 - 7
hxsl/HlslOut.hx

@@ -809,12 +809,13 @@ class HlslOut {
 	function initParams( s : ShaderData ) {
 		var textures = [];
 		var buffers = [];
+		var uavs = [];
 		add('cbuffer _params : register(b${baseRegister+1}) {\n');
 		for( v in s.vars )
 			if( v.kind == Param ) {
 				switch( v.type ) {
 				case TArray(TRWTexture(_), _):
-					buffers.push(v);
+					uavs.push(v);
 					continue;
 				case TArray(t, _) if( t.isTexture() ):
 					textures.push(v);
@@ -834,20 +835,19 @@ class HlslOut {
 			}
 		add("};\n\n");
 
-		var bufCount = 0;
-		for( b in buffers ) {
+		var regCount = baseRegister + 2;
+		for( b in buffers.concat(uavs) ) {
 			switch( b.type ) {
 			case TBuffer(t, size, Uniform):
-				add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
+				add('cbuffer _buffer$regCount : register(b${regCount++}) { ');
 				addVar(b);
 				add("; };\n");
 			default:
 				addVar(b);
-				add(' : register(u${bufCount+baseRegister+2});\n');
+				add(' : register(u${regCount++});\n');
 			}
-			bufCount++;
 		}
-		if( bufCount > 0 ) add("\n");
+		if( buffers.length + uavs.length > 0 ) add("\n");
 
 		var ctx = new Samplers();
 		var texCount = 0;