Sfoglia il codice sorgente

working compute shaders in dx12

Nicolas Cannasse 1 anno fa
parent
commit
24b2d20ccf

+ 4 - 0
h3d/Buffer.hx

@@ -13,6 +13,10 @@ enum BufferFlag {
 		Used for shader input buffer
 	**/
 	UniformBuffer;
+	/**
+		Can be written
+	**/
+	ReadWriteBuffer;
 }
 
 @:allow(h3d.impl.MemoryManager)

+ 127 - 48
h3d/impl/DX12Driver.hx

@@ -97,6 +97,7 @@ class ShaderRegisters {
 	public var samplers : Int;
 	public var texturesCount : Int;
 	public var textures2DCount : Int;
+	public var bufferTypes : Array<hxsl.Ast.BufferKind>;
 	public function new() {
 	}
 }
@@ -111,6 +112,8 @@ class CompiledShader {
 	public var inputLayout : hl.CArray<InputElementDesc>;
 	public var inputCount : Int;
 	public var shader : hxsl.RuntimeShader;
+	public var isCompute : Bool;
+	public var computePipeline : ComputePipelineState;
 	public function new() {
 	}
 }
@@ -133,6 +136,7 @@ class CompiledShader {
 	@:packed public var bufferSRV(default,null) : BufferSRV;
 	@:packed public var samplerDesc(default,null) : SamplerDesc;
 	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
+	@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
 	@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;
 
 	public var pass : h3d.mat.Pass;
@@ -156,6 +160,7 @@ class CompiledShader {
 		samplerDesc.comparisonFunc = NEVER;
 		samplerDesc.maxLod = 1e30;
 		descriptors2 = new hl.NativeArray(2);
+		uavDesc.viewDimension = BUFFER;
 		barrier.subResource = -1; // all
 	}
 
@@ -341,7 +346,7 @@ class DX12Driver extends h3d.impl.Driver {
 	public static var INITIAL_RT_COUNT = 1024;
 	public static var BUFFER_COUNT = 2;
 	public static var DEVICE_NAME = null;
-	public static var DEBUG = false;
+	public static var DEBUG = false; // requires dxil.dll when set to true
 
 	public function new() {
 		window = @:privateAccess dx.Window.windows[0];
@@ -875,7 +880,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 
-	function getBinaryPayload( vertex : Bool, code : String ) {
+	function getBinaryPayload( code : String ) {
 		var bin = code.indexOf("//BIN=");
 		if( bin >= 0 ) {
 			var end = code.indexOf("#", bin);
@@ -895,7 +900,7 @@ class DX12Driver extends h3d.impl.Driver {
 			sh.code = out.run(sh.data);
 			sh.code = rootStr + sh.code;
 		}
-		var bytes = getBinaryPayload(sh.vertex, sh.code);
+		var bytes = getBinaryPayload(sh.code);
 		if ( bytes == null ) {
 			return compiler.compile(sh.code, profile, args);
 		}
@@ -905,6 +910,8 @@ class DX12Driver extends h3d.impl.Driver {
 	override function getNativeShaderCode( shader : hxsl.RuntimeShader ) {
 		var out = new hxsl.HlslOut();
 		var vsSource = out.run(shader.vertex.data);
+		if( shader.mode == Compute )
+			return vsSource;
 		var out = new hxsl.HlslOut();
 		var psSource = out.run(shader.fragment.data);
 		return vsSource+"\n\n\n\n"+psSource;
@@ -985,14 +992,14 @@ class DX12Driver extends h3d.impl.Driver {
 			return range;
 		}
 
-		function allocConsts(size,vis,useCBV) {
+		function allocConsts(size,vis,type) {
 			var reg = regCount++;
 			if( size == 0 ) return -1;
 
-			if( useCBV ) {
+			if( type != null ) {
 				var pid = paramsCount;
 				var r = allocDescTable(vis);
-				r.rangeType = CBV;
+				r.rangeType = type;
 				r.numDescriptors = 1;
 				r.baseShaderRegister = reg;
 				r.registerSpace = 0;
@@ -1010,14 +1017,30 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
-			var vis = sh.vertex ? VERTEX : PIXEL;
+			var vis = switch( sh.kind ) {
+			case Vertex: VERTEX;
+			case Fragment: PIXEL;
+			default: ALL;
+			}
 			var regs = new ShaderRegisters();
-			regs.globals = allocConsts(sh.globalsSize, vis, false);
-			regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
+			regs.globals = allocConsts(sh.globalsSize, vis, null);
+			regs.params = allocConsts(sh.paramsSize, vis, (sh.kind == Fragment ? fragmentParamsCBV : vertexParamsCBV) ? CBV : null);
 			if( sh.bufferCount > 0 ) {
 				regs.buffers = paramsCount;
-				for( i in 0...sh.bufferCount )
-					allocConsts(1, vis, true);
+				regs.bufferTypes = [];
+				var p = sh.buffers;
+				while( p != null ) {
+					var kind = switch( p.type ) {
+					case TBuffer(_,_,kind): kind;
+					default: throw "assert";
+					}
+					regs.bufferTypes.push(kind);
+					allocConsts(1, vis, switch( kind ) {
+					case Uniform: CBV;
+					case RW: UAV;
+					});
+					p = p.next;
+				}
 			}
 			if( sh.texturesCount > 0 ) {
 				regs.texturesCount = sh.texturesCount;
@@ -1061,7 +1084,7 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 
 		var totalVertex = calcSize(shader.vertex);
-		var totalFragment = calcSize(shader.fragment);
+		var totalFragment = shader.mode == Compute ? 0 : calcSize(shader.fragment);
 		var total = totalVertex + totalFragment;
 
 		if( total > 64 ) {
@@ -1083,22 +1106,25 @@ class DX12Driver extends h3d.impl.Driver {
 				throw "Too many globals";
 		}
 
-		var vertexRegisters = allocParams(shader.vertex);
-		var fragmentRegStart = regCount;
-		var fragmentRegisters = allocParams(shader.fragment);
-
+		var regs = [];
+		for( s in shader.getShaders() )
+			regs.push({ start : regCount, registers : allocParams(s) });
 		if( paramsCount > allocatedParams )
 			throw "ASSERT : Too many parameters";
 
 		var sign = new RootSignatureDesc();
-		sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
+		if( shader.mode == Compute ) {
+			sign.flags.set(DENY_PIXEL_SHADER_ROOT_ACCESS);
+			sign.flags.set(DENY_VERTEX_SHADER_ROOT_ACCESS);
+		} else
+			sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
 		sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
 		sign.numParameters = paramsCount;
 		sign.parameters = cast params;
 
-		return { sign : sign, fragmentRegStart : fragmentRegStart, vertexRegisters : vertexRegisters, fragmentRegisters : fragmentRegisters, params : params, paramsCount : paramsCount, texDescs : texDescs };
+		return { sign : sign, registers : regs, params : params, paramsCount : paramsCount, texDescs : texDescs };
 	}
 
 	function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {
@@ -1106,16 +1132,31 @@ class DX12Driver extends h3d.impl.Driver {
 		var res = computeRootSignature(shader);
 
 		var c = new CompiledShader();
-		c.vertexRegisters = res.vertexRegisters;
-		c.fragmentRegisters = res.fragmentRegisters;
 
 		var rootStr = stringifyRootSignature(res.sign, "ROOT_SIGNATURE", res.params, res.paramsCount);
-		var vs = compileSource(shader.vertex, "vs_6_0", 0, rootStr);
-		var ps = compileSource(shader.fragment, "ps_6_0", res.fragmentRegStart, rootStr);
+		var vs = shader.mode == Compute ? null : compileSource(shader.vertex, "vs_6_0", 0, rootStr);
+		var ps = shader.mode == Compute ? null : compileSource(shader.fragment, "ps_6_0", res.registers[1].start, rootStr);
+		var cs = shader.mode == Compute ? compileSource(shader.compute, "cs_6_0", 0, rootStr) : null;
 
 		var signSize = 0;
 		var signBytes = Driver.serializeRootSignature(res.sign, 1, signSize);
 		var sign = new RootSignature(signBytes,signSize);
+		c.rootSignature = sign;
+		c.shader = shader;
+
+		if( shader.mode == Compute ) {
+			c.isCompute = true;
+			var desc = new ComputePipelineStateDesc();
+			desc.rootSignature = sign;
+			desc.cs.shaderBytecode = cs;
+			desc.cs.bytecodeLength = cs.length;
+			c.computePipeline = Driver.createComputePipelineState(desc);
+			c.vertexRegisters = res.registers[0].registers;
+			return c;
+		}
+
+		c.vertexRegisters = res.registers[0].registers;
+		c.fragmentRegisters = res.registers[1].registers;
 
 		var inputs = [];
 		for( v in shader.vertex.data.vars )
@@ -1166,10 +1207,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 		c.format = hxd.BufferFormat.make(format);
 		c.pipeline = p;
-		c.rootSignature = sign;
 		c.inputLayout = inputLayout;
 		c.inputCount = inputs.length;
-		c.shader = shader;
 
 		for( i in 0...inputs.length )
 			inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
@@ -1184,7 +1223,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 	// ----- BUFFERS
 
-	function allocGPU( size : Int, heapType, state ) {
+	function allocGPU( size : Int, heapType, state, uav=false ) {
 		var desc = new ResourceDesc();
 		var flags = new haxe.EnumFlags();
 		desc.dimension = BUFFER;
@@ -1194,6 +1233,7 @@ class DX12Driver extends h3d.impl.Driver {
 		desc.mipLevels = 1;
 		desc.sampleDesc.count = 1;
 		desc.layout = ROW_MAJOR;
+		if( uav ) desc.flags.set(ALLOW_UNORDERED_ACCESS);
 		tmp.heap.type = heapType;
 		return Driver.createCommittedResource(tmp.heap, flags, desc, state, null);
 	}
@@ -1201,9 +1241,9 @@ class DX12Driver extends h3d.impl.Driver {
 	override function allocBuffer( m : h3d.Buffer ) : GPUBuffer {
 		var buf = new VertexBufferData();
 		var size = m.getMemSize();
-		var bufSize = m.flags.has(UniformBuffer) ? calcCBVSize(size) : size;
+		var bufSize = m.flags.has(UniformBuffer) || m.flags.has(ReadWriteBuffer) ? calcCBVSize(size) : size;
 		buf.state = COPY_DEST;
-		buf.res = allocGPU(bufSize, DEFAULT, COMMON);
+		buf.res = allocGPU(bufSize, DEFAULT, COMMON, m.flags.has(ReadWriteBuffer));
 		if( !m.flags.has(UniformBuffer) ) {
 			var view = new VertexBufferView();
 			view.bufferLocation = buf.res.getGpuVirtualAddress();
@@ -1488,7 +1528,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 	override function uploadShaderBuffers(buffers:h3d.shader.Buffers, which:h3d.shader.Buffers.BufferKind) {
 		uploadBuffers(buffers, buffers.vertex, which, currentShader.shader.vertex, currentShader.vertexRegisters);
-		uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
+		if( !currentShader.isCompute )
+			uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
 	}
 
 	function calcCBVSize( dataSize : Int ) {
@@ -1547,13 +1588,22 @@ class DX12Driver extends h3d.impl.Driver {
 					desc.bufferLocation = cbv.getGpuVirtualAddress();
 					desc.sizeInBytes = calcCBVSize(dataSize);
 					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
-				} else
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+				} else if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
+				else
 					frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
 			}
 		case Globals:
-			if( shader.globalsSize > 0 )
-				frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			if( shader.globalsSize > 0 ) {
+				if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+				else
+					frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			}
 		case Textures:
 			if( regs.texturesCount > 0 ) {
 				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
@@ -1612,10 +1662,10 @@ class DX12Driver extends h3d.impl.Driver {
 					t.lastFrame = frameCount;
 					var state = if ( t.isDepth() )
 						DEPTH_READ;
-					else if ( shader.vertex )
-						NON_PIXEL_SHADER_RESOURCE;
-					else
+					else if ( shader.kind == Fragment )
 						PIXEL_SHADER_RESOURCE;
+					else
+						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 
@@ -1634,8 +1684,13 @@ class DX12Driver extends h3d.impl.Driver {
 					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
 				}
 
-				frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
-				frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				if( currentShader.isCompute ) {
+					frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				} else {
+					frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				}
 			}
 		case Buffers:
 			if( shader.bufferCount > 0 ) {
@@ -1643,15 +1698,29 @@ class DX12Driver extends h3d.impl.Driver {
 					var srv = frame.shaderResourceViews.alloc(1);
 					var b = buf.buffers[i];
 					var cbv = b.vbuf;
-					if( cbv.view != null )
-						throw "Buffer was allocated without UniformBuffer flag";
-					transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
-					var desc = tmp.cbvDesc;
-					desc.bufferLocation = cbv.res.getGpuVirtualAddress();
-					desc.sizeInBytes = cbv.size;
-					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
-				}
+					switch( regs.bufferTypes[i] ) {
+					case Uniform:
+						if( cbv.view != null )
+							throw "Buffer was allocated without UniformBuffer flag";
+						transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
+						var desc = tmp.cbvDesc;
+						desc.bufferLocation = cbv.res.getGpuVirtualAddress();
+						desc.sizeInBytes = cbv.size;
+						Driver.createConstantBufferView(desc, srv);
+					case RW:
+						if( !b.flags.has(ReadWriteBuffer) )
+							throw "Buffer was allocated without ReadWriteBuffer flag";
+						transition(cbv, UNORDERED_ACCESS);
+						var desc = tmp.uavDesc;
+						desc.numElements = b.vertices;
+						desc.structureSizeInBytes = b.format.strideBytes;
+						Driver.createUnorderedAccessView(cbv.res, null, desc, srv);
+					}
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+			}
 			}
 		}
 	}
@@ -1665,8 +1734,14 @@ class DX12Driver extends h3d.impl.Driver {
 		if( currentShader == sh )
 			return false;
 		currentShader = sh;
-		needPipelineFlush = true;
-		frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		if( sh.isCompute ) {
+			needPipelineFlush = false;
+			frame.commandList.setComputeRootSignature(currentShader.rootSignature);
+			frame.commandList.setPipelineState(currentShader.computePipeline);
+		} else {
+			needPipelineFlush = true;
+			frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		}
 		return true;
 	}
 
@@ -2026,6 +2101,10 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 	}
 
+	override function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		frame.commandList.dispatch(x,y,z);
+	}
+
 }
 
 #end

+ 6 - 0
h3d/impl/Driver.hx

@@ -317,4 +317,10 @@ class Driver {
 		return 0.;
 	}
 
+	// --- COMPUTE
+
+	public function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		throw "Not implemented";
+	}
+
 }

+ 2 - 2
h3d/pass/Default.hx

@@ -43,7 +43,7 @@ class Default extends Base {
 		var o = @:privateAccess new h3d.pass.PassObject();
 		o.pass = p;
 		setupShaders(new h3d.pass.PassList(o));
-		return manager.compileShaders(o.shaders, p.batchMode);
+		return manager.compileShaders(o.shaders, p.batchMode ? Batch : Default);
 	}
 
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
@@ -68,7 +68,7 @@ class Default extends Base {
 				}
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 			}
-			p.shader = manager.compileShaders(shaders, p.pass.batchMode);
+			p.shader = manager.compileShaders(shaders, p.pass.batchMode ? Batch : Default);
 			p.shaders = shaders;
 			var t = p.shader.fragment.textures;
 			if( t == null || t.type.match(TArray(_)) )

+ 38 - 13
h3d/pass/ShaderManager.hx

@@ -7,6 +7,8 @@ class ShaderManager {
 	public var globals : hxsl.Globals;
 	var shaderCache : hxsl.Cache;
 	var currentOutput : hxsl.ShaderList;
+	var currentCompute : hxsl.ShaderList;
+	var computeBuffers : h3d.shader.Buffers;
 
 	public function new(?output:Array<hxsl.Output>) {
 		shaderCache = hxsl.Cache.get();
@@ -192,20 +194,14 @@ class ShaderManager {
 			var ptr = getPtr(buf.globals);
 			while( g != null ) {
 				var v = globals.fastGet(g.gid);
-				if( v == null ) {
-					if( g.path == "__consts__" ) {
-						fillRec(s.consts, g.type, ptr, g.pos);
-						g = g.next;
-						continue;
-					}
+				if( v == null )
 					throw "Missing global value " + g.path;
-				}
 				fillRec(v, g.type, ptr, g.pos);
 				g = g.next;
 			}
 		}
 		fill(buf.vertex, s.vertex);
-		fill(buf.fragment, s.fragment);
+		if( s.fragment != null ) fill(buf.fragment, s.fragment);
 	}
 
 	public function fillParams( buf : h3d.shader.Buffers, s : hxsl.RuntimeShader, shaders : hxsl.ShaderList ) {
@@ -273,16 +269,45 @@ class ShaderManager {
 			}
 		}
 		fill(buf.vertex, s.vertex);
-		fill(buf.fragment, s.fragment);
+		if( s.fragment != null ) fill(buf.fragment, s.fragment);
 	}
 
-	public function compileShaders( shaders : hxsl.ShaderList, batchMode : Bool = false ) {
+	public function compileShaders( shaders : hxsl.ShaderList, mode : hxsl.Linker.LinkMode = Default ) {
 		globals.resetChannels();
 		for( s in shaders ) s.updateConstants(globals);
-		currentOutput.next = shaders;
-		var s = shaderCache.link(currentOutput, batchMode);
-		currentOutput.next = null;
+		var s;
+		if( mode == Compute )
+			s = shaderCache.link(shaders, mode);
+		else {
+			currentOutput.next = shaders;
+			s = shaderCache.link(currentOutput, mode);
+			currentOutput.next = null;
+		}
 		return s;
 	}
 
+	public function computeDispatch( shader : hxsl.Shader, x = 1, y = 1, z = 1 ) {
+		if( currentCompute == null )
+			currentCompute = new hxsl.ShaderList(null);
+		shader.updateConstants(globals);
+		currentCompute.s = shader;
+		var rt = shaderCache.link(currentCompute, Compute);
+		var bufs = computeBuffers;
+		if( bufs == null ) {
+			bufs = new h3d.shader.Buffers(rt);
+			computeBuffers = bufs;
+		}
+		bufs.grow(rt);
+		fillParams(bufs, rt, currentCompute);
+		fillGlobals(bufs, rt);
+		var e = h3d.Engine.getCurrent();
+		e.driver.selectShader(rt);
+		e.driver.uploadShaderBuffers(bufs, Params);
+		e.driver.uploadShaderBuffers(bufs, Globals);
+		e.driver.uploadShaderBuffers(bufs, Buffers);
+		e.driver.uploadShaderBuffers(bufs, Textures);
+		e.driver.computeDispatch(x,y,z);
+		currentCompute.s = null;
+	}
+
 }

+ 1 - 1
h3d/scene/MeshBatch.hx

@@ -121,7 +121,7 @@ class MeshBatch extends MultiMaterial {
 
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var shaders = p.getShadersRec();
-				var rt = manager.compileShaders(shaders, false);
+				var rt = manager.compileShaders(shaders, Default);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 
 				var b = new BatchData();

+ 4 - 2
h3d/shader/Buffers.hx

@@ -43,12 +43,14 @@ class Buffers {
 
 	public function new( s : hxsl.RuntimeShader ) {
 		vertex = new ShaderBuffers(s.vertex);
-		fragment = new ShaderBuffers(s.fragment);
+		if( s.fragment != null )
+			fragment = new ShaderBuffers(s.fragment);
 	}
 
 	public inline function grow( s : hxsl.RuntimeShader ) {
 		vertex.grow(s.vertex);
-		fragment.grow(s.fragment);
+		if( s.fragment != null )
+			fragment.grow(s.fragment);
 	}
 }
 

+ 18 - 3
hxsl/Ast.hx

@@ -1,5 +1,10 @@
 package hxsl;
 
+enum BufferKind {
+	Uniform;
+	RW;
+}
+
 enum Type {
 	TVoid;
 	TInt;
@@ -17,7 +22,7 @@ enum Type {
 	TStruct( vl : Array<TVar> );
 	TFun( variants : Array<FunType> );
 	TArray( t : Type, size : SizeDecl );
-	TBuffer( t : Type, size : SizeDecl );
+	TBuffer( t : Type, size : SizeDecl, kind : BufferKind );
 	TChannel( size : Int );
 	TMat2;
 }
@@ -187,6 +192,7 @@ enum FunctionKind {
 	Fragment;
 	Init;
 	Helper;
+	Main;
 }
 
 enum TGlobal {
@@ -280,6 +286,8 @@ enum TGlobal {
 	IntBitsToFloat;
 	UintBitsToFloat;
 	RoundEven;
+	// compute
+	SetLayout;
 }
 
 enum Component {
@@ -418,7 +426,12 @@ class Tools {
 			prefix + "Vec" + size;
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
-		case TBuffer(t, s): "buffer "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
+		case TBuffer(t, s, k):
+			var prefix = switch( k ) {
+			case Uniform: "buffer";
+			case RW: "rwbuffer";
+			};
+			prefix+" "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TBytes(n): "Bytes" + n;
 		default: t.getName().substr(1);
 		}
@@ -457,6 +470,8 @@ class Tools {
 			return hasSideEffect(e) || hasSideEffect(index);
 		case TConst(_), TVar(_), TGlobal(_):
 			return false;
+		case TCall({ e : TGlobal(SetLayout) },_):
+			return true;
 		case TCall(e, pl):
 			if( !e.e.match(TGlobal(_)) )
 				return true;
@@ -545,7 +560,7 @@ class Tools {
 		case TMat3x4: 12;
 		case TBytes(s): s;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
-		case TArray(t, SConst(v)), TBuffer(t, SConst(v)): size(t) * v;
+		case TArray(t, SConst(v)), TBuffer(t, SConst(v),_): size(t) * v;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		}
 	}

+ 57 - 37
hxsl/Cache.hx

@@ -1,6 +1,7 @@
 package hxsl;
 using hxsl.Ast;
 import hxsl.RuntimeShader;
+import hxsl.Linker.LinkMode;
 
 class BatchInstanceParams {
 
@@ -195,7 +196,7 @@ class Cache {
 	}
 
 	@:noDebug
-	public function link( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	public function link( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var c = linkCache;
 		for( s in shaders ) {
 			var i = @:privateAccess s.instance;
@@ -207,11 +208,11 @@ class Cache {
 			c = cs;
 		}
 		if( c.linked == null )
-			c.linked = compileRuntimeShader(shaders, batchMode);
+			c.linked = compileRuntimeShader(shaders, mode);
 		return c.linked;
 	}
 
-	function compileRuntimeShader( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	function compileRuntimeShader( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var shaderDatas = [];
 		var index = 0;
 		for( s in shaders ) {
@@ -262,14 +263,14 @@ class Cache {
 		//TRACE = shaderId == 0;
 		#end
 
-		var linker = new hxsl.Linker(batchMode);
+		var linker = new hxsl.Linker(mode);
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			throw e;
 		}
 
-		if( batchMode ) {
+		if( mode == Batch ) {
 			function checkRec( v : TVar ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
@@ -302,7 +303,7 @@ class Cache {
 
 		var prev = s;
 		var splitter = new hxsl.Splitter();
-		var s = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
+		var sl = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
 
 		// params tracking
 		var paramVars = new Map();
@@ -319,41 +320,42 @@ class Cache {
 
 
 		#if debug
-		Printer.check(s.vertex,[prev]);
-		Printer.check(s.fragment,[prev]);
+		for( s in sl )
+			Printer.check(s,[prev]);
 		#end
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- SPLIT ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
-		var prev = s;
-		var s = new hxsl.Dce().dce(s.vertex, s.fragment);
+		var prev = sl;
+		var sl = new hxsl.Dce().dce(sl);
 
 		#if debug
-		Printer.check(s.vertex,[prev.vertex]);
-		Printer.check(s.fragment,[prev.fragment]);
+		for( i => s in sl )
+			Printer.check(s,[prev[i]]);
 		#end
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- DCE ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
-		var r = buildRuntimeShader(s.vertex, s.fragment, paramVars);
+		var r = buildRuntimeShader(sl, paramVars);
+		r.mode = mode;
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- FLATTEN ----\n\n");
-			dbg.writeString(Printer.shaderToString(r.vertex.data, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(r.fragment.data,DEBUG_IDS)+"\n\n");
+			for( s in r.getShaders() )
+				dbg.writeString(Printer.shaderToString(s.data, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
@@ -366,9 +368,7 @@ class Cache {
 
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
-		r.signature = haxe.crypto.Md5.encode(Printer.shaderToString(r.vertex.data) + Printer.shaderToString(r.fragment.data));
-		r.batchMode = batchMode;
-
+		r.signature = haxe.crypto.Md5.encode([for( s in r.getShaders() ) Printer.shaderToString(s.data)].join(""));
 		var r2 = byID.get(r.signature);
 		if( r2 != null )
 			r.id = r2.id; // same id but different variable mapping
@@ -385,19 +385,33 @@ class Cache {
 		return r;
 	}
 
-	function buildRuntimeShader( vertex : ShaderData, fragment : ShaderData, paramVars ) {
+	function buildRuntimeShader( shaders : Array<ShaderData>, paramVars ) {
 		var r = new RuntimeShader();
-		r.vertex = flattenShader(vertex, Vertex, paramVars);
-		r.vertex.vertex = true;
-		r.fragment = flattenShader(fragment, Fragment, paramVars);
 		r.globals = new Map();
-		initGlobals(r, r.vertex);
-		initGlobals(r, r.fragment);
-
-		#if debug
-		Printer.check(r.vertex.data,[vertex]);
-		Printer.check(r.fragment.data,[fragment]);
-		#end
+		for( s in shaders ) {
+			var kind = switch( s.name ) {
+			case "vertex": Vertex;
+			case "fragment": Fragment;
+			case "main": Main;
+			default: throw "assert";
+			}
+			var fl = flattenShader(s, kind, paramVars);
+			fl.kind = kind;
+			switch( kind ) {
+			case Vertex:
+				r.vertex = fl;
+			case Fragment:
+				r.fragment = fl;
+			case Main:
+				r.compute = fl;
+			default:
+				throw "assert";
+			}
+			initGlobals(r, fl);
+			#if debug
+			Printer.check(fl,[vertexs]);
+			#end
+		}
 		return r;
 	}
 
@@ -429,7 +443,6 @@ class Cache {
 		data = hl.Api.compact(data, null, 0, null);
 		#end
 		var textures = [];
-		c.consts = flat.consts;
 		c.texturesCount = 0;
 		for( g in flat.allocData.keys() ) {
 			var alloc = flat.allocData.get(g);
@@ -468,8 +481,15 @@ class Cache {
 					c.params = out[0];
 					c.paramsSize = size;
 				case TArray(TBuffer(_), _):
-					c.buffers = out[0];
-					c.bufferCount = out.length;
+					if( c.buffers == null ) {
+						c.buffers = out[0];
+						c.bufferCount = out.length;
+					} else {
+						var p = c.buffers;
+						while( p.next != null ) p = p.next;
+						p.next = out[0];
+						c.bufferCount += out.length;
+					}
 				default: throw "assert";
 				}
 			case Global:
@@ -575,7 +595,7 @@ class Cache {
 		inputOffset.qualifiers = [PerInstance(1)];
 
 		var vcount = declVar("Batch_Count",TInt,Param);
-		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount)),Param);
+		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount),Uniform),Param);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };

+ 12 - 3
hxsl/Checker.hx

@@ -188,6 +188,8 @@ class Checker {
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 			case IntBitsToFloat, UintBitsToFloat:
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
+			case SetLayout:
+				[{ args : [{ name : "x", type : TInt },{ name : "y", type : TInt },{ name : "z", type : TInt }], ret : TVoid }];
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 				null;
 			}
@@ -246,6 +248,7 @@ class Checker {
 			var kind = switch( f.name ) {
 			case "vertex":  Vertex;
 			case "fragment": Fragment;
+			case "main": Main;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			}
 			if( args.length != 0 && kind != Helper )
@@ -337,6 +340,8 @@ class Checker {
 			switch( v.kind ) {
 			case Local, Var, Output:
 				return;
+			case Param if( v.type.match(TBuffer(_,_,RW)) ):
+				return;
 			default:
 			}
 		case TSwiz(e, _):
@@ -631,7 +636,7 @@ class Checker {
 			default: unify(e2.t, TInt, e2.p);
 			}
 			switch( e1.t ) {
-			case TArray(t, size), TBuffer(t,size):
+			case TArray(t, size), TBuffer(t,size,_):
 				switch( [size, e2.e] ) {
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 					error("Indexing outside array bounds", e.pos);
@@ -849,7 +854,7 @@ class Checker {
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 			}
 			return parent.type;
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			switch( t ) {
 			case TArray(_):
 				error("Multidimentional arrays are not allowed", pos);
@@ -894,7 +899,11 @@ class Checker {
 				SVar(v2);
 			}
 			t = makeVarType(t,parent,pos);
-			return vt.match(TArray(_)) ? TArray(t, s) : TBuffer(t,s);
+			return switch( vt ) {
+			case TArray(_): TArray(t, s);
+			case TBuffer(_,_,kind): TBuffer(t,s,kind);
+			default: throw "assert";
+			}
 		default:
 			return vt;
 		}

+ 24 - 30
hxsl/Dce.hx

@@ -34,30 +34,27 @@ class Dce {
 		#end
 	}
 
-	public function dce( vertex : ShaderData, fragment : ShaderData ) {
+	public function dce( shaders : Array<ShaderData> ) {
 		// collect vars dependencies
 		used = new Map();
 		channelVars = [];
 
 		var inputs = [];
-		for( v in vertex.vars ) {
-			var i = get(v);
-			if( v.kind == Input )
-				inputs.push(i);
-			if( v.kind == Output )
-				i.keep = true;
-		}
-		for( v in fragment.vars ) {
-			var i = get(v);
-			if( v.kind == Output )
-				i.keep = true;
+		for( s in shaders ) {
+			for( v in s.vars ) {
+				var i = get(v);
+				if( v.kind == Input )
+					inputs.push(i);
+				if( v.kind == Output || v.type.match(TBuffer(_,_,RW)) )
+					i.keep = true;
+			}
 		}
 
 		// collect dependencies
-		for( f in vertex.funs )
-			check(f.expr, [], []);
-		for( f in fragment.funs )
-			check(f.expr, [], []);
+		for( s in shaders ) {
+			for( f in s.funs )
+				check(f.expr, [], []);
+		}
 
 		var outExprs = [];
 		while( true ) {
@@ -74,10 +71,10 @@ class Dce {
 				markRec(v);
 
 			outExprs = [];
-			for( f in vertex.funs )
-				outExprs.push(mapExpr(f.expr, false));
-			for( f in fragment.funs )
-				outExprs.push(mapExpr(f.expr, false));
+			for( s in shaders ) {
+				for( f in s.funs )
+					outExprs.push(mapExpr(f.expr, false));
+			}
 
 			// post add conditional branches
 			markAsKeep = false;
@@ -86,22 +83,19 @@ class Dce {
 			if( !markAsKeep ) break;
 		}
 
-		for( f in vertex.funs )
-			f.expr = outExprs.shift();
-		for( f in fragment.funs )
-			f.expr = outExprs.shift();
+		for( s in shaders ) {
+			for( f in s.funs )
+				f.expr = outExprs.shift();
+		}
 
 		for( v in used ) {
 			if( v.used ) continue;
 			if( v.v.kind == VarKind.Input) continue;
-			vertex.vars.remove(v.v);
-			fragment.vars.remove(v.v);
+			for( s in shaders )
+				s.vars.remove(v.v);
 		}
 
-		return {
-			fragment : fragment,
-			vertex : vertex,
-		}
+		return shaders.copy();
 	}
 
 	function get( v : TVar ) {

+ 12 - 4
hxsl/Eval.hx

@@ -52,18 +52,26 @@ class Eval {
 		switch( v2.type ) {
 		case TStruct(vl):
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
-		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs)):
+		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs), _):
 			var c = constants.get(vs.id);
 			if( c != null )
 				switch( c ) {
 				case TConst(CInt(v)):
-					v2.type = v2.type.match(TArray(_)) ? TArray(t, SConst(v)) : TBuffer(t, SConst(v));
+					v2.type = switch( v2.type ) {
+					case TArray(_): TArray(t, SConst(v));
+					case TBuffer(_,_,kind): TBuffer(t, SConst(v), kind);
+					default: throw "assert";
+					};
 				default:
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 				}
 			else {
 				var vs2 = mapVar(vs);
-				v2.type = v2.type.match(TArray(_)) ? TArray(t, SVar(vs2)) : TBuffer(t, SVar(vs2));
+				v2.type = switch( v2.type ) {
+				case TArray(_): TArray(t, SVar(vs2));
+				case TBuffer(_,_,kind): TBuffer(t, SVar(vs2), kind);
+				default: throw "assert";
+				}
 			}
 		default:
 		}
@@ -81,7 +89,7 @@ class Eval {
 			return false;
 		case TArray(t, _):
 			return checkSamplerRec(t);
-		case TBuffer(_, size):
+		case TBuffer(_, _, _):
 			return true;
 		default:
 		}

+ 9 - 120
hxsl/Flatten.hx

@@ -26,8 +26,6 @@ class Flatten {
 	var params : Array<TVar>;
 	var outVars : Array<TVar>;
 	var varMap : Map<TVar,Alloc>;
-	var econsts : TExpr;
-	public var consts : Array<Float>;
 	public var allocData : Map< TVar, Array<Alloc> >;
 
 	public function new() {
@@ -44,6 +42,7 @@ class Flatten {
 		var prefix = switch( kind ) {
 		case Vertex: "vertex";
 		case Fragment: "fragment";
+		case Main: "compute";
 		default: throw "assert";
 		}
 		pack(prefix + "Globals", Global, globals, VFloat);
@@ -52,7 +51,8 @@ class Flatten {
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
-		packBuffers(allVars);
+		packBuffers("buffers", allVars, Uniform);
+		packBuffers("rwbuffers", allVars, RW);
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		return {
 			name : s.name,
@@ -104,119 +104,6 @@ class Flatten {
 		return optimize(e);
 	}
 
-	function mapConsts( e : TExpr ) : TExpr {
-		switch( e.e ) {
-		case TArray(ea, eindex = { e : TConst(CInt(_)) } ):
-			return { e : TArray(mapConsts(ea), eindex), t : e.t, p : e.p };
-		case TBinop(OpMult, _, { t : TMat3x4 } ):
-			allocConst(1, e.p); // pre-alloc
-		case TArray(ea, eindex):
-			switch( ea.t ) {
-			case TArray(t, _):
-				var stride = varSize(t, VFloat) >> 2;
-				allocConst(stride, e.p); // pre-alloc
-			default:
-			}
-		case TConst(c):
-			switch( c ) {
-			case CFloat(v):
-				return allocConst(v, e.p);
-			case CInt(v):
-				return allocConst(v, e.p);
-			default:
-				return e;
-			}
-		case TGlobal(g):
-			switch( g ) {
-			case Pack:
-				allocConsts([1, 255, 255 * 255, 255 * 255 * 255], e.p);
-				allocConsts([1/255, 1/255, 1/255, 0], e.p);
-			case Unpack:
-				allocConsts([1, 1 / 255, 1 / (255 * 255), 1 / (255 * 255 * 255)], e.p);
-			case Radians:
-				allocConst(Math.PI / 180, e.p);
-			case Degrees:
-				allocConst(180 / Math.PI, e.p);
-			case Log:
-				allocConst(0.6931471805599453, e.p);
-			case Exp:
-				allocConst(1.4426950408889634, e.p);
-			case Mix:
-				allocConst(1, e.p);
-			case UnpackNormal:
-				allocConst(0.5, e.p);
-			case PackNormal:
-				allocConst(1, e.p);
-				allocConst(0.5, e.p);
-			case ScreenToUv:
-				allocConsts([0.5,0.5], e.p);
-				allocConsts([0.5,-0.5], e.p);
-			case UvToScreen:
-				allocConsts([2,-2], e.p);
-				allocConsts([-1,1], e.p);
-			case Smoothstep:
-				allocConst(2.0, e.p);
-				allocConst(3.0, e.p);
-			default:
-			}
-		case TCall( { e : TGlobal(Vec4) }, [ { e : TVar( { kind : Global | Param | Input | Var } ), t : TVec(3, VFloat) }, { e : TConst(CInt(1)) } ]):
-			// allow var expansion without relying on a constant
-			return e;
-		default:
-		}
-		return e.map(mapConsts);
-	}
-
-	function allocConst( v : Float, p ) : TExpr {
-		var index = consts.indexOf(v);
-		if( index < 0 ) {
-			index = consts.length;
-			consts.push(v);
-		}
-		return { e : TArray(econsts, { e : TConst(CInt(index)), t : TInt, p : p } ), t : TFloat, p : p };
-	}
-
-	function allocConsts( va : Array<Float>, p ) : TExpr {
-		var pad = (va.length - 1) & 3;
-		var index = -1;
-		for( i in 0...consts.length - (va.length - 1) ) {
-			if( (i >> 2) != (i + pad) >> 2 ) continue;
-			var found = true;
-			for( j in 0...va.length )
-				if( consts[i + j] != va[j] ) {
-					found = false;
-					break;
-				}
-			if( found ) {
-				index = i;
-				break;
-			}
-		}
-		if( index < 0 ) {
-			// pad
-			while( consts.length >> 2 != (consts.length + pad) >> 2 )
-				consts.push(0);
-			index = consts.length;
-			for( v in va )
-				consts.push(v);
-		}
-		inline function get(i) : TExpr {
-			return { e : TArray(econsts, { e : TConst(CInt(index+i)), t : TInt, p : p } ), t : TFloat, p : p };
-		}
-		switch( va.length ) {
-		case 1:
-			return get(0);
-		case 2:
-			return { e : TCall( { e : TGlobal(Vec2), t : TVoid, p : p }, [get(0), get(1)]), t : TVec(2, VFloat), p : p };
-		case 3:
-			return { e : TCall( { e : TGlobal(Vec3), t : TVoid, p : p }, [get(0), get(1), get(2)]), t : TVec(3, VFloat), p : p };
-		case 4:
-			return { e : TCall( { e : TGlobal(Vec4), t : TVoid, p : p }, [get(0), get(1), get(3), get(4)]), t : TVec(4, VFloat), p : p };
-		default:
-			throw "assert";
-		}
-	}
-
 	inline function mkInt(v:Int,pos) {
 		return { e : TConst(CInt(v)), t : TInt, p : pos };
 	}
@@ -374,22 +261,24 @@ class Flatten {
 		return alloc;
 	}
 
-	function packBuffers( vars : Array<TVar> ) {
+	function packBuffers( name : String, vars : Array<TVar>, kind ) {
 		var alloc = new Array<Alloc>();
 		var g : TVar = {
 			id : Tools.allocVarId(),
-			name : "buffers",
+			name : name,
 			type : TVoid,
 			kind : Param,
 		};
 		for( v in vars )
-			if( v.type.match(TBuffer(_)) ) {
+			switch( v.type ) {
+			case TBuffer(_,_,k) if( kind == k ):
 				var a = new Alloc(g, null, alloc.length, 1);
 				a.v = v;
 				alloc.push(a);
 				outVars.push(v);
+			default:
 			}
-		g.type = TArray(TBuffer(TVoid,SConst(0)),SConst(alloc.length));
+		g.type = TArray(TBuffer(TVoid,SConst(0),kind),SConst(alloc.length));
 		allocData.set(g, alloc);
 	}
 

+ 3 - 2
hxsl/GlslOut.hx

@@ -181,12 +181,13 @@ class GlslOut {
 			case SConst(n): add(n);
 			}
 			add("]");
-		case TBuffer(t, size):
+		case TBuffer(t, size, kind):
+			if( kind != Uniform ) throw "TODO";
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add(" { ");
 			v.type = TArray(t,size);
 			addVar(v);
-			v.type = TBuffer(t,size);
+			v.type = TBuffer(t,size, kind);
 			add("; }");
 		default:
 			addType(v.type);

+ 53 - 21
hxsl/HlslOut.hx

@@ -101,13 +101,19 @@ class HlslOut {
 	var exprValues : Array<String>;
 	var locals : Map<Int,TVar>;
 	var decls : Array<String>;
-	var isVertex : Bool;
+	var kind : FunctionKind;
 	var allNames : Map<String, Int>;
 	var samplers : Map<Int, Array<Int>>;
+	var computeLayout = [1,1,1];
 	public var varNames : Map<Int,String>;
 	public var baseRegister : Int = 0;
 
 	var varAccess : Map<Int,String>;
+	var isVertex(get,never) : Bool;
+	var isCompute(get,never) : Bool;
+
+	inline function get_isCompute() return kind == Main;
+	inline function get_isVertex() return kind == Vertex;
 
 	public function new() {
 		varNames = new Map();
@@ -175,7 +181,7 @@ class HlslOut {
 			add(" }");
 		case TFun(_):
 			add("function");
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addType(t);
 			add("[");
 			switch( size ) {
@@ -201,7 +207,7 @@ class HlslOut {
 
 	function addVar( v : TVar ) {
 		switch( v.type ) {
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addVar({
 				id : v.id,
 				name : v.name,
@@ -298,6 +304,8 @@ class HlslOut {
 			var acc = varAccess.get(v.id);
 			if( acc != null ) add(acc);
 			ident(v);
+		case TCall({ e : TGlobal(SetLayout) },_):
+			// ignore
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 			addValue(args[0], tabs);
 			switch( g ) {
@@ -687,6 +695,8 @@ class HlslOut {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 		switch( e.e )  {
 		case TGlobal(g): m.set(g,true);
+		case TCall({ e : TGlobal(SetLayout) }, [{ e : TConst(CInt(x)) }, { e : TConst(CInt(y)) }, { e : TConst(CInt(z)) }]):
+			computeLayout = [x,y,z];
 		default: e.iter(collectGlobals.bind(m));
 		}
 	}
@@ -709,7 +719,7 @@ class HlslOut {
 			collectGlobals(foundGlobals, f.expr);
 
 		add("struct s_input {\n");
-		if( !isVertex )
+		if( kind == Fragment )
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 		for( v in s.vars )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
@@ -722,14 +732,16 @@ class HlslOut {
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 		add("};\n\n");
 
-		add("struct s_output {\n");
-		for( v in s.vars )
-			if( v.kind == Output )
-				declVar("_out.", v);
-		for( v in s.vars )
-			if( v.kind == Var && isVertex )
-				declVar("_out.", v);
-		add("};\n\n");
+		if( !isCompute ) {
+			add("struct s_output {\n");
+			for( v in s.vars )
+				if( v.kind == Output )
+					declVar("_out.", v);
+			for( v in s.vars )
+				if( v.kind == Var && isVertex )
+					declVar("_out.", v);
+			add("};\n\n");
+		}
 	}
 
 	function initGlobals( s : ShaderData ) {
@@ -770,10 +782,25 @@ class HlslOut {
 
 		var bufCount = 0;
 		for( b in buffers ) {
-			add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
-			addVar(b);
-			add("; };\n");
-			bufCount++;
+			switch( b.type ) {
+			case TBuffer(t, size, kind):
+				switch( kind ) {
+				case Uniform:
+					add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
+					addVar(b);
+					add("; };\n");
+					bufCount++;
+				case RW:
+					add('RWStructuredBuffer<');
+					addType(t);
+					add('> ');
+					ident(b);
+					add(' : register(u${bufCount+baseRegister+2});');
+					bufCount++;
+				}
+			default:
+				throw "assert";
+			}
 		}
 		if( bufCount > 0 ) add("\n");
 
@@ -795,7 +822,8 @@ class HlslOut {
 
 	function initStatics( s : ShaderData ) {
 		add(STATIC + "s_input _in;\n");
-		add(STATIC + "s_output _out;\n");
+		if( !isCompute )
+			add(STATIC + "s_output _out;\n");
 
 		add("\n");
 		for( v in s.vars )
@@ -808,7 +836,11 @@ class HlslOut {
 	}
 
 	function emitMain( expr : TExpr ) {
-		add("s_output main( s_input __in ) {\n");
+		if( isCompute )
+			add('[numthreads(${computeLayout[0]},${computeLayout[1]},${computeLayout[2]})] void ');
+		else
+			add('s_output ');
+		add("main( s_input __in ) {\n");
 		add("\t_in = __in;\n");
 		switch( expr.e ) {
 		case TBlock(el):
@@ -820,7 +852,8 @@ class HlslOut {
 		default:
 			addExpr(expr, "");
 		}
-		add("\treturn _out;\n");
+		if( !isCompute )
+			add("\treturn _out;\n");
 		add("}");
 	}
 
@@ -848,8 +881,7 @@ class HlslOut {
 
 		if( s.funs.length != 1 ) throw "assert";
 		var f = s.funs[0];
-		isVertex = f.kind == Vertex;
-
+		kind = f.kind;
 		varAccess = new Map();
 		samplers = new Map();
 		initVars(s);

+ 19 - 7
hxsl/Linker.hx

@@ -1,6 +1,12 @@
 package hxsl;
 using hxsl.Ast;
 
+enum LinkMode {
+	Default;
+	Batch;
+	Compute;
+}
+
 private class AllocatedVar {
 	public var id : Int;
 	public var v : TVar;
@@ -28,6 +34,7 @@ private class ShaderInfos {
 	public var vertex : Null<Bool>;
 	public var onStack : Bool;
 	public var hasDiscard : Bool;
+	public var isCompute : Bool;
 	public var marked : Null<Bool>;
 	public function new(n, v) {
 		this.name = n;
@@ -49,12 +56,12 @@ class Linker {
 	var varIdMap : Map<Int,Int>;
 	var locals : Map<Int,Bool>;
 	var curInstance : Int;
-	var batchMode : Bool;
+	var mode : LinkMode;
 	var isBatchShader : Bool;
 	var debugDepth = 0;
 
-	public function new(batchMode=false) {
-		this.batchMode = batchMode;
+	public function new(mode) {
+		this.mode = mode;
 	}
 
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
@@ -348,7 +355,7 @@ class Linker {
 		curInstance = 0;
 		var outVars = [];
 		for( s in shadersData ) {
-			isBatchShader = batchMode && StringTools.startsWith(s.name,"batchShader_");
+			isBatchShader = mode == Batch && StringTools.startsWith(s.name,"batchShader_");
 			for( v in s.vars ) {
 				var v2 = allocVar(v, null, s.name);
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
@@ -375,8 +382,13 @@ class Linker {
 				if( v.kind == null ) throw "assert";
 				switch( v.kind ) {
 				case Vertex, Fragment:
+					if( mode == Compute )
+						throw "Unexpected "+v.kind.getName().toLowerCase()+"() function in compute shader";
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
-
+				case Main:
+					if( mode != Compute )
+						throw "Unexpected main() outside compute shader";
+					addShader(s.name, true, f.expr, priority).isCompute = true;
 				case Init:
 					var prio : Array<Int>;
 					var status : Null<Bool> = switch( f.ref.name ) {
@@ -408,7 +420,7 @@ class Linker {
 
 		// force shaders containing discard to be included
 		for( s in shaders )
-			if( s.hasDiscard ) {
+			if( s.hasDiscard || s.isCompute ) {
 				initDependencies(s);
 				entry.deps.set(s, true);
 			}
@@ -505,7 +517,7 @@ class Linker {
 				expr : expr,
 			};
 		}
-		var funs = [
+		var funs = mode == Compute ? [build(Main,"main",v)] : [
 			build(Vertex, "vertex", v),
 			build(Fragment, "fragment", f),
 		];

+ 7 - 2
hxsl/MacroParser.hx

@@ -115,7 +115,7 @@ class MacroParser {
 			case "Channel3": return TChannel(3);
 			case "Channel4": return TChannel(4);
 			}
-		case TPath( { pack : [], name : name = ("Array"|"Buffer"), sub : null, params : [t, size] } ):
+		case TPath( { pack : [], name : name = ("Array"|"Buffer"|"RWBuffer"), sub : null, params : [t, size] } ):
 			var t = switch( t ) {
 			case TPType(t): parseType(t, pos);
 			default: null;
@@ -129,7 +129,12 @@ class MacroParser {
 			default: null;
 			}
 			if( t != null && size != null )
-				return name == "Array" ? TArray(t, size) : TBuffer(t,size);
+				return switch( name ) {
+				case "Array": TArray(t, size);
+				case "Buffer": TBuffer(t,size,Uniform);
+				case "RWBuffer": TBuffer(t,size,RW);
+				default: throw "assert";
+				}
 		case TAnonymous(fl):
 			return TStruct([for( f in fl ) {
 				switch( f.kind ) {

+ 9 - 3
hxsl/RuntimeShader.hx

@@ -44,7 +44,7 @@ class AllocGlobal {
 }
 
 class RuntimeShaderData {
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var data : Ast.ShaderData;
 	public var code : String;
 	public var params : AllocParam;
@@ -55,7 +55,6 @@ class RuntimeShaderData {
 	public var texturesCount : Int;
 	public var buffers : AllocParam;
 	public var bufferCount : Int;
-	public var consts : Array<Float>;
 	public function new() {
 	}
 }
@@ -76,14 +75,18 @@ class RuntimeShader {
 	public var id : Int;
 	public var vertex : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
+	public var compute(get,set) : RuntimeShaderData;
 	public var globals : Map<Int,Bool>;
 
+	inline function get_compute() return vertex;
+	inline function set_compute(v) return vertex = v;
+
 	/**
 		Signature of the resulting HxSL code.
 		Several shaders with the different specification might still get the same resulting signature.
 	**/
 	public var signature : String;
-	public var batchMode : Bool;
+	public var mode : hxsl.Linker.LinkMode;
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 
 	public function new() {
@@ -94,5 +97,8 @@ class RuntimeShader {
 		return globals.exists(gid);
 	}
 
+	public function getShaders() {
+		return mode == Compute ? [compute] : [vertex, fragment];
+	}
 
 }

+ 15 - 2
hxsl/Serializer.hx

@@ -86,12 +86,19 @@ class Serializer {
 				writeArr(vl,writeVar);
 		case TFun(variants):
 			// not serialized
-		case TArray(t, size), TBuffer(t, size):
+		case TArray(t, size), TBuffer(t, size, Uniform):
 			writeType(t);
 			switch (size) {
 			case SConst(v): out.addByte(0); writeVarInt(v);
 			case SVar(v): writeVar(v);
 			}
+		case TBuffer(t, size, kind):
+			out.addByte(kind.getIndex() + 0x80);
+			writeType(t);
+			switch (size) {
+				case SConst(v): out.addByte(0); writeVarInt(v);
+				case SVar(v): writeVar(v);
+				}
 		case TChannel(size):
 			out.addByte(size);
 		case TVoid, TInt, TBool, TFloat, TString, TMat2, TMat3, TMat4, TMat3x4, TSampler2D, TSampler2DArray, TSamplerCube:
@@ -136,9 +143,15 @@ class Serializer {
 			var v = readVar();
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 		case 16:
+			var tag = input.readByte();
+			var kind = Uniform;
+			if( tag & 0x80 == 0 )
+				input.position--;
+			else
+				kind = BufferKind.createByIndex(tag & 0x7F);
 			var t = readType();
 			var v = readVar();
-			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v));
+			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v), kind);
 		case 17:
 			TChannel(input.readByte());
 		case 18: TMat2;

+ 31 - 21
hxsl/Splitter.hx

@@ -24,17 +24,19 @@ class Splitter {
 	public function new() {
 	}
 
-	public function split( s : ShaderData ) : { vertex : ShaderData, fragment : ShaderData } {
+	public function split( s : ShaderData ) : Array<ShaderData> {
 		var vfun = null, vvars = new Map();
 		var ffun = null, fvars = new Map();
+		var isCompute = false;
 		varNames = new Map();
 		varMap = new Map();
 		for( f in s.funs )
 			switch( f.kind ) {
-			case Vertex:
+			case Vertex, Main:
 				vars = vvars;
 				vfun = f;
 				checkExpr(f.expr);
+				if( f.kind == Main ) isCompute = true;
 			case Fragment:
 				vars = fvars;
 				ffun = f;
@@ -144,20 +146,22 @@ class Splitter {
 		for( v in fvars )
 			checkVar(v, false, vvars, ffun.expr.p);
 
-		ffun = {
-			ret : ffun.ret,
-			ref : ffun.ref,
-			kind : ffun.kind,
-			args : ffun.args,
-			expr : mapVars(ffun.expr),
-		};
-		switch( ffun.expr.e ) {
-		case TBlock(el):
-			for( e in finits )
-				el.unshift(e);
-		default:
-			finits.push(ffun.expr);
-			ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+		if( ffun != null ) {
+			ffun = {
+				ret : ffun.ret,
+				ref : ffun.ref,
+				kind : ffun.kind,
+				args : ffun.args,
+				expr : mapVars(ffun.expr),
+			};
+			switch( ffun.expr.e ) {
+			case TBlock(el):
+				for( e in finits )
+					el.unshift(e);
+			default:
+				finits.push(ffun.expr);
+				ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+			}
 		}
 
 		var vvars = [for( v in vvars ) if( !v.local ) v];
@@ -167,18 +171,24 @@ class Splitter {
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 
-		return {
-			vertex : {
+		return isCompute ? [
+			{
+				name : "main",
+				vars : [for( v in vvars ) v.v],
+				funs : [vfun],
+			}
+		] : [
+			{
 				name : "vertex",
 				vars : [for( v in vvars ) v.v],
 				funs : [vfun],
 			},
-			fragment : {
+			{
 				name : "fragment",
 				vars : [for( v in fvars ) v.v],
 				funs : [ffun],
-			},
-		};
+			}
+		];
 	}
 
 	function addExpr( f : TFunction, e : TExpr ) {