Browse Source

per frame srv/samplers, fixed shaders params/regs, support for CBV for large params, impl selectBuffer()

Nicolas Cannasse 3 years ago
parent
commit
2ea557b407
1 changed files with 196 additions and 92 deletions
  1. 196 92
      h3d/impl/DX12Driver.hx

+ 196 - 92
h3d/impl/DX12Driver.hx

@@ -18,6 +18,8 @@ class DxFrame {
 	public var commandList : CommandList;
 	public var fenceValue : Int64;
 	public var toRelease : Array<Resource> = [];
+	public var shaderResourceViews : ManagedHeap;
+	public var samplerViews : ManagedHeap;
 	public function new() {
 	}
 }
@@ -31,10 +33,11 @@ class CachedPipeline {
 }
 
 class ShaderRegisters {
-	public var globalsReg : Int;
-	public var paramsReg : Int;
-	public var texturesParam : Int;
-	public var samplersParam : Int;
+	public var globals : Int;
+	public var params : Int;
+	public var textures : Int;
+	public var samplers : Int;
+	public var texturesCount : Int;
 	public function new() {
 	}
 }
@@ -48,6 +51,7 @@ class CompiledShader {
 	public var pipelines : Map<Int,hl.NativeArray<CachedPipeline>> = new Map();
 	public var rootSignature : RootSignature;
 	public var inputLayout : hl.CArray<InputElementDesc>;
+	public var inputOffsets : Array<Int>;
 	public var shader : hxsl.RuntimeShader;
 	public function new() {
 	}
@@ -64,6 +68,10 @@ class CompiledShader {
 	@:packed public var clearValue(default,null) : ClearValue;
 	@:packed public var viewport(default,null) : Viewport;
 	@:packed public var rect(default,null) : Rect;
+	@:packed public var tex2DSRV(default,null) : Tex2DSRV;
+	@:packed public var bufferSRV(default,null) : BufferSRV;
+	@:packed public var samplerDesc(default,null) : SamplerDesc;
+	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
 
 	public var pass : h3d.mat.Pass;
 
@@ -73,6 +81,13 @@ class CompiledShader {
 		vertexViews = hl.CArray.alloc(VertexBufferView, 16);
 		pass = new h3d.mat.Pass("default");
 		pass.stencil = new h3d.mat.Stencil();
+		tex2DSRV.dimension = TEXTURE2D;
+		tex2DSRV.shader4ComponentMapping = ShaderComponentMapping.DEFAULT;
+		bufferSRV.dimension = BUFFER;
+		bufferSRV.flags = RAW;
+		bufferSRV.shader4ComponentMapping = ShaderComponentMapping.DEFAULT;
+		samplerDesc.comparisonFunc = NEVER;
+		samplerDesc.maxLod = 1e30;
 	}
 
 }
@@ -90,9 +105,9 @@ class ManagedHeapFL {
 
 class ManagedHeap {
 
-	var heap : DescriptorHeap;
+	public var stride(default,null) : Int;
 	var size : Int;
-	var stride : Int;
+	var heap : DescriptorHeap;
 	var address : Address;
 	var cpuToGpu : Int64;
 	var free : ManagedHeapFL;
@@ -183,9 +198,6 @@ class DX12Driver extends h3d.impl.Driver {
 
 	var renderTargetViews : ManagedHeap;
 	var depthStenciViews : ManagedHeap;
-	var shaderResourceViews : ManagedHeap;
-	var samplerViews : ManagedHeap;
-	var heaps : Array<ManagedHeap> = [];
 
 	var currentFrame : Int;
 	var fenceValue : Int64 = 0;
@@ -219,6 +231,8 @@ class DX12Driver extends h3d.impl.Driver {
 			f.allocator = new CommandAllocator(DIRECT);
 			f.commandList = new CommandList(DIRECT, f.allocator, null);
 			f.commandList.close();
+			f.shaderResourceViews = new ManagedHeap(CBV_SRV_UAV, 1024);
+			f.samplerViews = new ManagedHeap(SAMPLER, 1024);
 			frames.push(f);
 		}
 		fence = new Fence(0, NONE);
@@ -227,14 +241,6 @@ class DX12Driver extends h3d.impl.Driver {
 
 		renderTargetViews = new ManagedHeap(RTV);
 		depthStenciViews = new ManagedHeap(DSV);
-		shaderResourceViews = new ManagedHeap(CBV_SRV_UAV);
-		samplerViews = new ManagedHeap(SAMPLER);
-		heaps = [
-			renderTargetViews,
-			depthStenciViews,
-			shaderResourceViews,
-			samplerViews,
-		];
 
 		compiler = new ShaderCompiler();
 		resize(window.width, window.height);
@@ -269,12 +275,12 @@ class DX12Driver extends h3d.impl.Driver {
 		frame.commandList.omSetRenderTargets(1, tmp.renderTargets, true, tmp.depthStencils);
 
 		var arr = new hl.NativeArray(2);
-		arr[0] = @:privateAccess shaderResourceViews.heap;
-		arr[1] = @:privateAccess samplerViews.heap;
+		arr[0] = @:privateAccess frame.shaderResourceViews.heap;
+		arr[1] = @:privateAccess frame.samplerViews.heap;
 		frame.commandList.setDescriptorHeaps(arr);
 
-		shaderResourceViews.reset();
-		samplerViews.reset();
+		frame.shaderResourceViews.reset();
+		frame.samplerViews.reset();
 	}
 
 	inline function unsafeCastTo<T,R>( v : T, c : Class<R> ) : R {
@@ -321,8 +327,9 @@ class DX12Driver extends h3d.impl.Driver {
 
 		Driver.resize(width, height, BUFFER_COUNT, R8G8B8A8_UNORM);
 
-		for( h in heaps )
-			h.reset();
+		// TODO : use circular buffer instead
+		renderTargetViews.reset();
+		depthStenciViews.reset();
 
 		for( i => f in frames ) {
 			f.backBuffer = Driver.getBackBuffer(i);
@@ -406,66 +413,107 @@ class DX12Driver extends h3d.impl.Driver {
 
 	function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {
 
-		var params = hl.CArray.alloc(RootParameterConstants,6);
+		var params = hl.CArray.alloc(RootParameterConstants,8);
 		var paramsCount = 0, regCount = 0;
 		var texDescs = [];
+		var vertexParamsCBV = false;
+		var fragmentParamsCBV = false;
 		var c = new CompiledShader();
 
-		function allocConsts(size,vis) {
-			if( size == 0 ) return;
-			var p = params[paramsCount++];
-			p.parameterType = CONSTANTS;
-			p.shaderRegister = regCount++;
+		function allocDescTable(vis) {
+			var p = unsafeCastTo(params[paramsCount++], RootParameterDescriptorTable);
+			p.parameterType = DESCRIPTOR_TABLE;
+			p.numDescriptorRanges = 1;
+			var range = new DescriptorRange();
+			texDescs.push(range);
+			p.descriptorRanges = range;
 			p.shaderVisibility = vis;
-			p.num32BitValues = size << 2;
+			return range;
 		}
 
-		function allocDescTable(size,vis) {
-			var p = unsafeCastTo(params[paramsCount++], RootParameterDescriptorTable);
-			p.parameterType = DESCRIPTOR_TABLE;
-			p.numDescriptorRanges = size;
-			var descs = hl.CArray.alloc(DescriptorRange, size);
-			texDescs.push(descs);
-			p.descriptorRanges = descs[0];
+		function allocConsts(size,vis,useCBV) {
+			var reg = regCount++;
+			if( size == 0 ) return -1;
+
+			if( useCBV ) {
+				var pid = paramsCount;
+				var r = allocDescTable(vis);
+				r.rangeType = CBV;
+				r.numDescriptors = 1;
+				r.baseShaderRegister = reg;
+				r.registerSpace = 0;
+				return pid | 0x100;
+			}
+
+			var pid = paramsCount++;
+			var p = params[pid];
+			p.parameterType = CONSTANTS;
+			p.shaderRegister = reg;
 			p.shaderVisibility = vis;
-			return descs;
+			p.num32BitValues = size << 2;
+			return pid;
 		}
 
+
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
 			var vis = sh.vertex ? VERTEX : PIXEL;
 			var regs = new ShaderRegisters();
-			regs.globalsReg = regCount;
-			allocConsts(sh.globalsSize, vis);
-			regs.paramsReg = regCount;
-			allocConsts(sh.paramsSize, vis);
+			regs.globals = allocConsts(sh.globalsSize, vis, false);
+			regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
 			if( sh.texturesCount > 0 ) {
-				regs.texturesParam = paramsCount;
-				var descs = allocDescTable(sh.texturesCount, vis);
-				for( i in 0...sh.texturesCount ) {
-					var d = descs[i];
-					d.rangeType = SRV;
-					d.baseShaderRegister = 0;
-					d.registerSpace = 0;
-					d.numDescriptors = 1;
-				}
-				regs.samplersParam = paramsCount;
-				var descs = allocDescTable(sh.texturesCount, vis);
-				for( i in 0...sh.texturesCount ) {
-					var d = descs[i];
-					d.rangeType = SAMPLER;
-					d.baseShaderRegister = 0;
-					d.registerSpace = 0;
-					d.numDescriptors = 1;
-				}
+				regs.texturesCount = sh.texturesCount;
+				regs.textures = paramsCount;
+
+				var r = allocDescTable(vis);
+				r.rangeType = SRV;
+				r.baseShaderRegister = 0;
+				r.registerSpace = 0;
+				r.numDescriptors = sh.texturesCount;
+
+				regs.samplers = paramsCount;
+				var r = allocDescTable(vis);
+				r.rangeType = SAMPLER;
+				r.baseShaderRegister = 0;
+				r.registerSpace = 0;
+				r.numDescriptors = sh.texturesCount;
 			}
 			return regs;
 		}
 
+		function calcSize( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
+			var s = (sh.globalsSize + sh.paramsSize) << 2;
+			if( sh.texturesCount > 0 ) s += 2;
+			return s;
+		}
+
+		var totalVertex = calcSize(shader.vertex);
+		var totalFragment = calcSize(shader.fragment);
+		var total = totalVertex + totalFragment;
+
+		if( total > 64 ) {
+			var withoutVP = total - (shader.vertex.paramsSize << 2);
+			var withoutFP = total - (shader.fragment.paramsSize << 2);
+			if( total > 64 && (withoutVP < 64 || withoutFP > 64) ) {
+				vertexParamsCBV = true;
+				total -= (shader.vertex.paramsSize << 2);
+			}
+			if( total > 64 ) {
+				fragmentParamsCBV = true;
+				total -= (shader.fragment.paramsSize << 2);
+			}
+			if( total > 64 )
+				throw "Too many globals";
+		}
+
 		c.vertexRegisters = allocParams(shader.vertex);
+		var fragmentRegStart = regCount;
 		c.fragmentRegisters = allocParams(shader.fragment);
 
+		if( paramsCount > params.length )
+			throw "ASSERT : Too many parameters";
+
 		var vs = compileSource(shader.vertex, "vs_6_0", 0);
-		var ps = compileSource(shader.fragment, "ps_6_0", c.fragmentRegisters.globalsReg);
+		var ps = compileSource(shader.fragment, "ps_6_0", fragmentRegStart);
 
 		var inputs = [];
 		for( v in shader.vertex.data.vars )
@@ -477,6 +525,9 @@ class DX12Driver extends h3d.impl.Driver {
 
 		var sign = new RootSignatureDesc();
 		sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
+		sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
+		sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
+		sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
 		sign.numParameters = paramsCount;
 		sign.parameters = params[0];
 
@@ -485,9 +536,12 @@ class DX12Driver extends h3d.impl.Driver {
 		var sign = new RootSignature(signBytes,signSize);
 
 		var inputLayout = hl.CArray.alloc(InputElementDesc, inputs.length);
+		var inputOffsets = [];
+		var offset = 0;
 		for( i => v in inputs ) {
 			var d = inputLayout[i];
 			var perInst = 0;
+			inputOffsets.push(offset);
 			if( v.qualifiers != null )
 				for( q in v.qualifiers )
 					switch( q ) {
@@ -496,11 +550,11 @@ class DX12Driver extends h3d.impl.Driver {
 					}
 			d.semanticName = @:privateAccess hxsl.HlslOut.semanticName(v.name).toUtf8();
 			d.format = switch( v.type ) {
-				case TFloat: R32_FLOAT;
-				case TVec(2, VFloat): R32G32_FLOAT;
-				case TVec(3, VFloat): R32G32B32_FLOAT;
-				case TVec(4, VFloat): R32G32B32A32_FLOAT;
-				case TBytes(4): R8G8B8A8_UINT;
+				case TFloat: offset++; R32_FLOAT;
+				case TVec(2, VFloat): offset += 2; R32G32_FLOAT;
+				case TVec(3, VFloat): offset += 3;R32G32B32_FLOAT;
+				case TVec(4, VFloat): offset += 4;R32G32B32A32_FLOAT;
+				case TBytes(4): offset++; R8G8B8A8_UINT;
 				default:
 					throw "Unsupported input type " + hxsl.Ast.Tools.toString(v.type);
 			};
@@ -536,7 +590,9 @@ class DX12Driver extends h3d.impl.Driver {
 		c.rootSignature = sign;
 		c.inputLayout = inputLayout;
 		c.inputCount = inputs.length;
+		c.inputOffsets = inputOffsets;
 		c.shader = shader;
+
 		for( i in 0...inputs.length )
 			inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
 	   return c;
@@ -733,43 +789,77 @@ class DX12Driver extends h3d.impl.Driver {
 		uploadBuffers(buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
 	}
 
+	function calcCBVSize( dataSize : Int ) {
+		// the view must be a mult of 256
+		var sz = dataSize & ~0xFF;
+		if( sz != dataSize ) sz += 0x100;
+		return sz;
+ 	}
+
+	function allocDynamicCBV( data : hl.Bytes, dataSize : Int ) {
+		var tmpBuf = allocBuffer(calcCBVSize(dataSize), UPLOAD, GENERIC_READ);
+		var ptr = tmpBuf.map(0, null);
+		ptr.blit(0, data, 0, dataSize);
+		tmpBuf.unmap(0,null);
+		frame.toRelease.push(tmpBuf);
+		return tmpBuf;
+	}
+
 	function uploadBuffers( buf : h3d.shader.Buffers.ShaderBuffers, which:h3d.shader.Buffers.BufferKind, shader : hxsl.RuntimeShader.RuntimeShaderData, regs : ShaderRegisters ) {
 		switch( which ) {
 		case Params:
-			frame.commandList.setGraphicsRoot32BitConstants(regs.paramsReg, shader.paramsSize << 2,  hl.Bytes.getArray(buf.params.toData()), 0);
+			if( shader.paramsSize > 0 ) {
+				var data = hl.Bytes.getArray(buf.params.toData());
+				var dataSize = shader.paramsSize << 4;
+				if( regs.params & 0x100 != 0 ) {
+					// update CBV
+					var srv = frame.shaderResourceViews.alloc(1);
+					var cbv = allocDynamicCBV(data,dataSize);
+					var desc = tmp.cbvDesc;
+					desc.bufferLocation = cbv.getGpuVirtualAddress();
+					desc.sizeInBytes = calcCBVSize(dataSize);
+					Driver.createConstantBufferView(desc, srv);
+					frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+				} else
+					frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
+			}
 		case Globals:
-			frame.commandList.setGraphicsRoot32BitConstants(regs.globalsReg, shader.globalsSize << 2,  hl.Bytes.getArray(buf.globals.toData()), 0);
+			if( shader.globalsSize > 0 )
+				frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
 		case Textures:
-			if( buf.tex.length > 0 ) {
-				var address = shaderResourceViews.alloc(buf.tex.length);
-
-				var t = buf.tex[0];
-				var desc = new Text2DSRV();
-				desc.mipLevels = 1;
-				desc.format = t.t.format;
-				desc.shader4ComponentMapping = ShaderComponentMapping.DEFAULT;
-				Driver.createShaderResourceView(t.t.res, desc, address);
-
-
-				var sampler = samplerViews.alloc(1);
-				var desc = new SamplerDesc();
-				desc.addressU = desc.addressV = desc.addressW = CLAMP;
-				desc.comparisonFunc = NEVER;
-				desc.maxLod = 3.402823466e+38;
-				Driver.createSampler(desc, sampler);
-
-				frame.commandList.setGraphicsRootDescriptorTable(regs.texturesParam, shaderResourceViews.toGPU(address));
-				frame.commandList.setGraphicsRootDescriptorTable(regs.samplersParam, samplerViews.toGPU(sampler));
+			if( regs.texturesCount > 0 ) {
+				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
+				var sampler = frame.samplerViews.alloc(regs.texturesCount);
+				for( i in 0...regs.texturesCount ) {
+					var t = buf.tex[i];
+					var desc = tmp.tex2DSRV;
+					desc.mipLevels = t.mipLevels;
+					desc.format = t.t.format;
+					Driver.createShaderResourceView(t.t.res, desc, srv.offset(i * frame.shaderResourceViews.stride));
+
+					var desc = tmp.samplerDesc;
+					desc.filter = switch( [t.filter, t.mipMap] ) {
+					case [Nearest, None|Nearest]: MIN_MAG_MIP_POINT;
+					case [Nearest, Linear]: MIN_MAG_POINT_MIP_LINEAR;
+					case [Linear, None|Nearest]: MIN_MAG_LINEAR_MIP_POINT;
+					case [Linear, Linear]: MIN_MAG_MIP_LINEAR;
+					}
+					desc.addressU = desc.addressV = desc.addressW = switch( t.wrap ) {
+					case Clamp: CLAMP;
+					case Repeat: WRAP;
+					}
+					desc.mipLODBias = t.lodBias;
+					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
+				}
+
+				frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+				frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
 			}
 		case Buffers:
 			if( buf.buffers != null && buf.buffers.length > 0 ) throw "TODO";
 		}
 	}
 
-	override function selectBuffer(buffer:Buffer) {
-		throw "TODO";
-	}
-
 	override function selectShader( shader : hxsl.RuntimeShader ) {
 		var sh = compiledShaders.get(shader.id);
 		if( sh == null ) {
@@ -798,6 +888,20 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 	}
 
+	override function selectBuffer(buffer:Buffer) {
+		var views = tmp.vertexViews;
+		var bview = @:privateAccess buffer.buffer.vbuf.view;
+		for( i in 0...currentShader.inputCount ) {
+			var v = views[i];
+			v.bufferLocation = bview.bufferLocation;
+			v.sizeInBytes = bview.sizeInBytes;
+			v.strideInBytes = bview.strideInBytes;
+			pipelineSignature.setUI8(PSIGN_BUF_OFFSETS + i, currentShader.inputOffsets[i]);
+		}
+		needPipelineFlush = true;
+		frame.commandList.iaSetVertexBuffers(0, currentShader.inputCount, views[0]);
+	}
+
 	override function selectMultiBuffers(buffers:h3d.Buffer.BufferOffset) {
 		var views = tmp.vertexViews;
 		var bufferCount = 0;