浏览代码

added compute shader support (dx12 only)

Nicolas Cannasse 1 年之前
父节点
当前提交
3b35db388f
共有 3 个文件被更改,包括 158 次插入44 次删除
  1. 123 44
      h3d/impl/DX12Driver.hx
  2. 31 0
      h3d/scene/RenderContext.hx
  3. 4 0
      h3d/scene/Renderer.hx

+ 123 - 44
h3d/impl/DX12Driver.hx

@@ -98,6 +98,7 @@ class ShaderRegisters {
 	public var samplers : Int;
 	public var texturesCount : Int;
 	public var textures2DCount : Int;
+	public var bufferTypes : Array<hxsl.Ast.BufferKind>;
 	public function new() {
 	}
 }
@@ -112,6 +113,8 @@ class CompiledShader {
 	public var inputLayout : hl.CArray<InputElementDesc>;
 	public var inputCount : Int;
 	public var shader : hxsl.RuntimeShader;
+	public var isCompute : Bool;
+	public var computePipeline : ComputePipelineState;
 	public function new() {
 	}
 }
@@ -134,6 +137,7 @@ class CompiledShader {
 	@:packed public var bufferSRV(default,null) : BufferSRV;
 	@:packed public var samplerDesc(default,null) : SamplerDesc;
 	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
+	@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
 	@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;
 
 	public var pass : h3d.mat.Pass;
@@ -157,6 +161,7 @@ class CompiledShader {
 		samplerDesc.comparisonFunc = NEVER;
 		samplerDesc.maxLod = 1e30;
 		descriptors2 = new hl.NativeArray(2);
+		uavDesc.viewDimension = BUFFER;
 		barrier.subResource = -1; // all
 	}
 
@@ -337,7 +342,7 @@ class DX12Driver extends h3d.impl.Driver {
 	public static var INITIAL_RT_COUNT = 1024;
 	public static var BUFFER_COUNT = 2;
 	public static var DEVICE_NAME = null;
-	public static var DEBUG = false;
+	public static var DEBUG = false; // requires dxil.dll when set to true
 
 	public function new() {
 		window = @:privateAccess dx.Window.windows[0];
@@ -908,6 +913,8 @@ class DX12Driver extends h3d.impl.Driver {
 	override function getNativeShaderCode( shader : hxsl.RuntimeShader ) {
 		var out = new hxsl.HlslOut();
 		var vsSource = out.run(shader.vertex.data);
+		if( shader.mode == Compute )
+			return vsSource;
 		var out = new hxsl.HlslOut();
 		var psSource = out.run(shader.fragment.data);
 		return vsSource+"\n\n\n\n"+psSource;
@@ -988,14 +995,14 @@ class DX12Driver extends h3d.impl.Driver {
 			return range;
 		}
 
-		function allocConsts(size,vis,useCBV) {
+		function allocConsts(size,vis,type) {
 			var reg = regCount++;
 			if( size == 0 ) return -1;
 
-			if( useCBV ) {
+			if( type != null ) {
 				var pid = paramsCount;
 				var r = allocDescTable(vis);
-				r.rangeType = CBV;
+				r.rangeType = type;
 				r.numDescriptors = 1;
 				r.baseShaderRegister = reg;
 				r.registerSpace = 0;
@@ -1013,14 +1020,30 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
-			var vis = sh.kind == Vertex ? VERTEX : PIXEL;
+			var vis = switch( sh.kind ) {
+			case Vertex: VERTEX;
+			case Fragment: PIXEL;
+			default: ALL;
+			}
 			var regs = new ShaderRegisters();
-			regs.globals = allocConsts(sh.globalsSize, vis, false);
-			regs.params = allocConsts(sh.paramsSize, vis, sh.kind == Vertex ? vertexParamsCBV : fragmentParamsCBV);
+			regs.globals = allocConsts(sh.globalsSize, vis, null);
+			regs.params = allocConsts(sh.paramsSize, vis, (sh.kind == Fragment ? fragmentParamsCBV : vertexParamsCBV) ? CBV : null);
 			if( sh.bufferCount > 0 ) {
 				regs.buffers = paramsCount;
-				for( i in 0...sh.bufferCount )
-					allocConsts(1, vis, true);
+				regs.bufferTypes = [];
+				var p = sh.buffers;
+				while( p != null ) {
+					var kind = switch( p.type ) {
+					case TBuffer(_,_,kind): kind;
+					default: throw "assert";
+					}
+					regs.bufferTypes.push(kind);
+					allocConsts(1, vis, switch( kind ) {
+					case Uniform: CBV;
+					case RW: UAV;
+					});
+					p = p.next;
+				}
 			}
 			if( sh.texturesCount > 0 ) {
 				regs.texturesCount = sh.texturesCount;
@@ -1064,7 +1087,7 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 
 		var totalVertex = calcSize(shader.vertex);
-		var totalFragment = calcSize(shader.fragment);
+		var totalFragment = shader.mode == Compute ? 0 : calcSize(shader.fragment);
 		var total = totalVertex + totalFragment;
 
 		if( total > 64 ) {
@@ -1086,22 +1109,25 @@ class DX12Driver extends h3d.impl.Driver {
 				throw "Too many globals";
 		}
 
-		var vertexRegisters = allocParams(shader.vertex);
-		var fragmentRegStart = regCount;
-		var fragmentRegisters = allocParams(shader.fragment);
-
+		var regs = [];
+		for( s in shader.getShaders() )
+			regs.push({ start : regCount, registers : allocParams(s) });
 		if( paramsCount > allocatedParams )
 			throw "ASSERT : Too many parameters";
 
 		var sign = new RootSignatureDesc();
-		sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
+		if( shader.mode == Compute ) {
+			sign.flags.set(DENY_PIXEL_SHADER_ROOT_ACCESS);
+			sign.flags.set(DENY_VERTEX_SHADER_ROOT_ACCESS);
+		} else
+			sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
 		sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
 		sign.numParameters = paramsCount;
 		sign.parameters = cast params;
 
-		return { sign : sign, fragmentRegStart : fragmentRegStart, vertexRegisters : vertexRegisters, fragmentRegisters : fragmentRegisters, params : params, paramsCount : paramsCount, texDescs : texDescs };
+		return { sign : sign, registers : regs, params : params, paramsCount : paramsCount, texDescs : texDescs };
 	}
 
 	function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {
@@ -1109,16 +1135,31 @@ class DX12Driver extends h3d.impl.Driver {
 		var res = computeRootSignature(shader);
 
 		var c = new CompiledShader();
-		c.vertexRegisters = res.vertexRegisters;
-		c.fragmentRegisters = res.fragmentRegisters;
 
 		var rootStr = stringifyRootSignature(res.sign, "ROOT_SIGNATURE", res.params, res.paramsCount);
-		var vs = compileSource(shader.vertex, "vs_6_0", 0, rootStr);
-		var ps = compileSource(shader.fragment, "ps_6_0", res.fragmentRegStart, rootStr);
+		var vs = shader.mode == Compute ? null : compileSource(shader.vertex, "vs_6_0", 0, rootStr);
+		var ps = shader.mode == Compute ? null : compileSource(shader.fragment, "ps_6_0", res.registers[1].start, rootStr);
+		var cs = shader.mode == Compute ? compileSource(shader.compute, "cs_6_0", 0, rootStr) : null;
 
 		var signSize = 0;
 		var signBytes = Driver.serializeRootSignature(res.sign, 1, signSize);
 		var sign = new RootSignature(signBytes,signSize);
+		c.rootSignature = sign;
+		c.shader = shader;
+
+		if( shader.mode == Compute ) {
+			c.isCompute = true;
+			var desc = new ComputePipelineStateDesc();
+			desc.rootSignature = sign;
+			desc.cs.shaderBytecode = cs;
+			desc.cs.bytecodeLength = cs.length;
+			c.computePipeline = Driver.createComputePipelineState(desc);
+			c.vertexRegisters = res.registers[0].registers;
+			return c;
+		}
+
+		c.vertexRegisters = res.registers[0].registers;
+		c.fragmentRegisters = res.registers[1].registers;
 
 		var inputs = [];
 		for( v in shader.vertex.data.vars )
@@ -1169,10 +1210,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 		c.format = hxd.BufferFormat.make(format);
 		c.pipeline = p;
-		c.rootSignature = sign;
 		c.inputLayout = inputLayout;
 		c.inputCount = inputs.length;
-		c.shader = shader;
 
 		for( i in 0...inputs.length )
 			inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
@@ -1187,7 +1226,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 	// ----- BUFFERS
 
-	function allocGPU( size : Int, heapType, state ) {
+	function allocGPU( size : Int, heapType, state, uav=false ) {
 		var desc = new ResourceDesc();
 		var flags = new haxe.EnumFlags();
 		desc.dimension = BUFFER;
@@ -1197,6 +1236,7 @@ class DX12Driver extends h3d.impl.Driver {
 		desc.mipLevels = 1;
 		desc.sampleDesc.count = 1;
 		desc.layout = ROW_MAJOR;
+		if( uav ) desc.flags.set(ALLOW_UNORDERED_ACCESS);
 		tmp.heap.type = heapType;
 		return Driver.createCommittedResource(tmp.heap, flags, desc, state, null);
 	}
@@ -1204,9 +1244,9 @@ class DX12Driver extends h3d.impl.Driver {
 	override function allocBuffer( m : h3d.Buffer ) : GPUBuffer {
 		var buf = new VertexBufferData();
 		var size = m.getMemSize();
-		var bufSize = m.flags.has(UniformBuffer) ? calcCBVSize(size) : size;
+		var bufSize = m.flags.has(UniformBuffer) || m.flags.has(ReadWriteBuffer) ? calcCBVSize(size) : size;
 		buf.state = COPY_DEST;
-		buf.res = allocGPU(bufSize, DEFAULT, COMMON);
+		buf.res = allocGPU(bufSize, DEFAULT, COMMON,  m.flags.has(ReadWriteBuffer));
 		if( m.flags.has(UniformBuffer) ) {
 			// no view
 		} else if( m.flags.has(IndexBuffer) ) {
@@ -1475,7 +1515,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 	override function uploadShaderBuffers(buffers:h3d.shader.Buffers, which:h3d.shader.Buffers.BufferKind) {
 		uploadBuffers(buffers, buffers.vertex, which, currentShader.shader.vertex, currentShader.vertexRegisters);
-		uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
+		if( !currentShader.isCompute )
+			uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
 	}
 
 	function calcCBVSize( dataSize : Int ) {
@@ -1534,13 +1575,22 @@ class DX12Driver extends h3d.impl.Driver {
 					desc.bufferLocation = cbv.getGpuVirtualAddress();
 					desc.sizeInBytes = calcCBVSize(dataSize);
 					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
-				} else
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+				} else if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
+				else
 					frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
 			}
 		case Globals:
-			if( shader.globalsSize > 0 )
-				frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			if( shader.globalsSize > 0 ) {
+				if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+				else
+					frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			}
 		case Textures:
 			if( regs.texturesCount > 0 ) {
 				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
@@ -1600,7 +1650,7 @@ class DX12Driver extends h3d.impl.Driver {
 					var state = if ( t.isDepth() )
 						DEPTH_READ;
 					else if ( shader.kind == Fragment )
-						PIXEL_SHADER_RESOURCE
+						PIXEL_SHADER_RESOURCE;
 					else
 						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
@@ -1621,8 +1671,13 @@ class DX12Driver extends h3d.impl.Driver {
 					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
 				}
 
-				frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
-				frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				if( currentShader.isCompute ) {
+					frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				} else {
+					frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				}
 			}
 		case Buffers:
 			if( shader.bufferCount > 0 ) {
@@ -1630,15 +1685,29 @@ class DX12Driver extends h3d.impl.Driver {
 					var srv = frame.shaderResourceViews.alloc(1);
 					var b = buf.buffers[i];
 					var cbv = b.vbuf;
-					if( cbv.view != null )
-						throw "Buffer was allocated without UniformBuffer flag";
-					transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
-					var desc = tmp.cbvDesc;
-					desc.bufferLocation = cbv.res.getGpuVirtualAddress();
-					desc.sizeInBytes = cbv.size;
-					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
-				}
+					switch( regs.bufferTypes[i] ) {
+					case Uniform:
+						if( cbv.view != null )
+							throw "Buffer was allocated without UniformBuffer flag";
+						transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
+						var desc = tmp.cbvDesc;
+						desc.bufferLocation = cbv.res.getGpuVirtualAddress();
+						desc.sizeInBytes = cbv.size;
+						Driver.createConstantBufferView(desc, srv);
+					case RW:
+						if( !b.flags.has(ReadWriteBuffer) )
+							throw "Buffer was allocated without ReadWriteBuffer flag";
+						transition(cbv, UNORDERED_ACCESS);
+						var desc = tmp.uavDesc;
+						desc.numElements = b.vertices;
+						desc.structureSizeInBytes = b.format.strideBytes;
+						Driver.createUnorderedAccessView(cbv.res, null, desc, srv);
+					}
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+			}
 			}
 		}
 	}
@@ -1652,8 +1721,14 @@ class DX12Driver extends h3d.impl.Driver {
 		if( currentShader == sh )
 			return false;
 		currentShader = sh;
-		needPipelineFlush = true;
-		frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		if( sh.isCompute ) {
+			needPipelineFlush = false;
+			frame.commandList.setComputeRootSignature(currentShader.rootSignature);
+			frame.commandList.setPipelineState(currentShader.computePipeline);
+		} else {
+			needPipelineFlush = true;
+			frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		}
 		return true;
 	}
 
@@ -2013,6 +2088,10 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 	}
 
+	override function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		frame.commandList.dispatch(x,y,z);
+	}
+
 }
 
 #end

+ 31 - 0
h3d/scene/RenderContext.hx

@@ -41,6 +41,7 @@ class RenderContext extends h3d.impl.RenderContext {
 
 	var allocPool : h3d.pass.PassObject;
 	var allocFirst : h3d.pass.PassObject;
+	var computeLink = new hxsl.ShaderList(null,null);
 	var cachedShaderList : Array<hxsl.ShaderList>;
 	var cachedPassObjects : Array<Renderer.PassObjects>;
 	var cachedPos : Int;
@@ -145,6 +146,36 @@ class RenderContext extends h3d.impl.RenderContext {
 		return sl;
 	}
 
+	public function computeDispatch( shader : hxsl.Shader, x = 1, y = 1, z = 1 ) {
+
+		var prev = h3d.impl.RenderContext.get();
+		if( prev != this )
+			start();
+
+		// compile shader
+		globals.resetChannels();
+		shader.updateConstants(globals);
+		computeLink.s = shader;
+		var rt = hxsl.Cache.get().link(computeLink, Compute);
+		// upload buffers
+		engine.driver.selectShader(rt);
+		var buf = shaderBuffers;
+		buf.grow(rt);
+		fillGlobals(buf, rt);
+		engine.uploadShaderBuffers(buf, Globals);
+		fillParams(buf, rt, computeLink);
+		engine.uploadShaderBuffers(buf, Params);
+		engine.uploadShaderBuffers(buf, Textures);
+		engine.uploadShaderBuffers(buf, Buffers);
+		engine.driver.computeDispatch(x,y,z);
+		computeLink.s = null;
+
+		if( prev != this ) {
+			done();
+			if( prev != null ) prev.setCurrent();
+		}
+	}
+
 	public function emitLight( l : Light ) {
 		l.next = lights;
 		lights = l;

+ 4 - 0
h3d/scene/Renderer.hx

@@ -192,4 +192,8 @@ class Renderer extends hxd.impl.AnyProps {
 			passObjects.set(p.name, null);
 	}
 
+	public function computeDispatch( shader, x = 1, y = 1, z = 1 ) {
+		ctx.computeDispatch(shader, x, y, z);
+	}
+
 }