Selaa lähdekoodia

working compute shaders in dx12

Nicolas Cannasse 1 vuosi sitten
vanhempi
commit
24b2d20ccf

+ 4 - 0
h3d/Buffer.hx

@@ -13,6 +13,10 @@ enum BufferFlag {
 		Used for shader input buffer
 		Used for shader input buffer
 	**/
 	**/
 	UniformBuffer;
 	UniformBuffer;
+	/**
+		Can be written
+	**/
+	ReadWriteBuffer;
 }
 }
 
 
 @:allow(h3d.impl.MemoryManager)
 @:allow(h3d.impl.MemoryManager)

+ 127 - 48
h3d/impl/DX12Driver.hx

@@ -97,6 +97,7 @@ class ShaderRegisters {
 	public var samplers : Int;
 	public var samplers : Int;
 	public var texturesCount : Int;
 	public var texturesCount : Int;
 	public var textures2DCount : Int;
 	public var textures2DCount : Int;
+	public var bufferTypes : Array<hxsl.Ast.BufferKind>;
 	public function new() {
 	public function new() {
 	}
 	}
 }
 }
@@ -111,6 +112,8 @@ class CompiledShader {
 	public var inputLayout : hl.CArray<InputElementDesc>;
 	public var inputLayout : hl.CArray<InputElementDesc>;
 	public var inputCount : Int;
 	public var inputCount : Int;
 	public var shader : hxsl.RuntimeShader;
 	public var shader : hxsl.RuntimeShader;
+	public var isCompute : Bool;
+	public var computePipeline : ComputePipelineState;
 	public function new() {
 	public function new() {
 	}
 	}
 }
 }
@@ -133,6 +136,7 @@ class CompiledShader {
 	@:packed public var bufferSRV(default,null) : BufferSRV;
 	@:packed public var bufferSRV(default,null) : BufferSRV;
 	@:packed public var samplerDesc(default,null) : SamplerDesc;
 	@:packed public var samplerDesc(default,null) : SamplerDesc;
 	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
 	@:packed public var cbvDesc(default,null) : ConstantBufferViewDesc;
+	@:packed public var uavDesc(default,null) : UAVBufferViewDesc;
 	@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;
 	@:packed public var rtvDesc(default,null) : RenderTargetViewDesc;
 
 
 	public var pass : h3d.mat.Pass;
 	public var pass : h3d.mat.Pass;
@@ -156,6 +160,7 @@ class CompiledShader {
 		samplerDesc.comparisonFunc = NEVER;
 		samplerDesc.comparisonFunc = NEVER;
 		samplerDesc.maxLod = 1e30;
 		samplerDesc.maxLod = 1e30;
 		descriptors2 = new hl.NativeArray(2);
 		descriptors2 = new hl.NativeArray(2);
+		uavDesc.viewDimension = BUFFER;
 		barrier.subResource = -1; // all
 		barrier.subResource = -1; // all
 	}
 	}
 
 
@@ -341,7 +346,7 @@ class DX12Driver extends h3d.impl.Driver {
 	public static var INITIAL_RT_COUNT = 1024;
 	public static var INITIAL_RT_COUNT = 1024;
 	public static var BUFFER_COUNT = 2;
 	public static var BUFFER_COUNT = 2;
 	public static var DEVICE_NAME = null;
 	public static var DEVICE_NAME = null;
-	public static var DEBUG = false;
+	public static var DEBUG = false; // requires dxil.dll when set to true
 
 
 	public function new() {
 	public function new() {
 		window = @:privateAccess dx.Window.windows[0];
 		window = @:privateAccess dx.Window.windows[0];
@@ -875,7 +880,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 
 
-	function getBinaryPayload( vertex : Bool, code : String ) {
+	function getBinaryPayload( code : String ) {
 		var bin = code.indexOf("//BIN=");
 		var bin = code.indexOf("//BIN=");
 		if( bin >= 0 ) {
 		if( bin >= 0 ) {
 			var end = code.indexOf("#", bin);
 			var end = code.indexOf("#", bin);
@@ -895,7 +900,7 @@ class DX12Driver extends h3d.impl.Driver {
 			sh.code = out.run(sh.data);
 			sh.code = out.run(sh.data);
 			sh.code = rootStr + sh.code;
 			sh.code = rootStr + sh.code;
 		}
 		}
-		var bytes = getBinaryPayload(sh.vertex, sh.code);
+		var bytes = getBinaryPayload(sh.code);
 		if ( bytes == null ) {
 		if ( bytes == null ) {
 			return compiler.compile(sh.code, profile, args);
 			return compiler.compile(sh.code, profile, args);
 		}
 		}
@@ -905,6 +910,8 @@ class DX12Driver extends h3d.impl.Driver {
 	override function getNativeShaderCode( shader : hxsl.RuntimeShader ) {
 	override function getNativeShaderCode( shader : hxsl.RuntimeShader ) {
 		var out = new hxsl.HlslOut();
 		var out = new hxsl.HlslOut();
 		var vsSource = out.run(shader.vertex.data);
 		var vsSource = out.run(shader.vertex.data);
+		if( shader.mode == Compute )
+			return vsSource;
 		var out = new hxsl.HlslOut();
 		var out = new hxsl.HlslOut();
 		var psSource = out.run(shader.fragment.data);
 		var psSource = out.run(shader.fragment.data);
 		return vsSource+"\n\n\n\n"+psSource;
 		return vsSource+"\n\n\n\n"+psSource;
@@ -985,14 +992,14 @@ class DX12Driver extends h3d.impl.Driver {
 			return range;
 			return range;
 		}
 		}
 
 
-		function allocConsts(size,vis,useCBV) {
+		function allocConsts(size,vis,type) {
 			var reg = regCount++;
 			var reg = regCount++;
 			if( size == 0 ) return -1;
 			if( size == 0 ) return -1;
 
 
-			if( useCBV ) {
+			if( type != null ) {
 				var pid = paramsCount;
 				var pid = paramsCount;
 				var r = allocDescTable(vis);
 				var r = allocDescTable(vis);
-				r.rangeType = CBV;
+				r.rangeType = type;
 				r.numDescriptors = 1;
 				r.numDescriptors = 1;
 				r.baseShaderRegister = reg;
 				r.baseShaderRegister = reg;
 				r.registerSpace = 0;
 				r.registerSpace = 0;
@@ -1010,14 +1017,30 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 
 
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
-			var vis = sh.vertex ? VERTEX : PIXEL;
+			var vis = switch( sh.kind ) {
+			case Vertex: VERTEX;
+			case Fragment: PIXEL;
+			default: ALL;
+			}
 			var regs = new ShaderRegisters();
 			var regs = new ShaderRegisters();
-			regs.globals = allocConsts(sh.globalsSize, vis, false);
-			regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
+			regs.globals = allocConsts(sh.globalsSize, vis, null);
+			regs.params = allocConsts(sh.paramsSize, vis, (sh.kind == Fragment ? fragmentParamsCBV : vertexParamsCBV) ? CBV : null);
 			if( sh.bufferCount > 0 ) {
 			if( sh.bufferCount > 0 ) {
 				regs.buffers = paramsCount;
 				regs.buffers = paramsCount;
-				for( i in 0...sh.bufferCount )
-					allocConsts(1, vis, true);
+				regs.bufferTypes = [];
+				var p = sh.buffers;
+				while( p != null ) {
+					var kind = switch( p.type ) {
+					case TBuffer(_,_,kind): kind;
+					default: throw "assert";
+					}
+					regs.bufferTypes.push(kind);
+					allocConsts(1, vis, switch( kind ) {
+					case Uniform: CBV;
+					case RW: UAV;
+					});
+					p = p.next;
+				}
 			}
 			}
 			if( sh.texturesCount > 0 ) {
 			if( sh.texturesCount > 0 ) {
 				regs.texturesCount = sh.texturesCount;
 				regs.texturesCount = sh.texturesCount;
@@ -1061,7 +1084,7 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 		}
 
 
 		var totalVertex = calcSize(shader.vertex);
 		var totalVertex = calcSize(shader.vertex);
-		var totalFragment = calcSize(shader.fragment);
+		var totalFragment = shader.mode == Compute ? 0 : calcSize(shader.fragment);
 		var total = totalVertex + totalFragment;
 		var total = totalVertex + totalFragment;
 
 
 		if( total > 64 ) {
 		if( total > 64 ) {
@@ -1083,22 +1106,25 @@ class DX12Driver extends h3d.impl.Driver {
 				throw "Too many globals";
 				throw "Too many globals";
 		}
 		}
 
 
-		var vertexRegisters = allocParams(shader.vertex);
-		var fragmentRegStart = regCount;
-		var fragmentRegisters = allocParams(shader.fragment);
-
+		var regs = [];
+		for( s in shader.getShaders() )
+			regs.push({ start : regCount, registers : allocParams(s) });
 		if( paramsCount > allocatedParams )
 		if( paramsCount > allocatedParams )
 			throw "ASSERT : Too many parameters";
 			throw "ASSERT : Too many parameters";
 
 
 		var sign = new RootSignatureDesc();
 		var sign = new RootSignatureDesc();
-		sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
+		if( shader.mode == Compute ) {
+			sign.flags.set(DENY_PIXEL_SHADER_ROOT_ACCESS);
+			sign.flags.set(DENY_VERTEX_SHADER_ROOT_ACCESS);
+		} else
+			sign.flags.set(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
 		sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_HULL_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_DOMAIN_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
 		sign.flags.set(DENY_GEOMETRY_SHADER_ROOT_ACCESS);
 		sign.numParameters = paramsCount;
 		sign.numParameters = paramsCount;
 		sign.parameters = cast params;
 		sign.parameters = cast params;
 
 
-		return { sign : sign, fragmentRegStart : fragmentRegStart, vertexRegisters : vertexRegisters, fragmentRegisters : fragmentRegisters, params : params, paramsCount : paramsCount, texDescs : texDescs };
+		return { sign : sign, registers : regs, params : params, paramsCount : paramsCount, texDescs : texDescs };
 	}
 	}
 
 
 	function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {
 	function compileShader( shader : hxsl.RuntimeShader ) : CompiledShader {
@@ -1106,16 +1132,31 @@ class DX12Driver extends h3d.impl.Driver {
 		var res = computeRootSignature(shader);
 		var res = computeRootSignature(shader);
 
 
 		var c = new CompiledShader();
 		var c = new CompiledShader();
-		c.vertexRegisters = res.vertexRegisters;
-		c.fragmentRegisters = res.fragmentRegisters;
 
 
 		var rootStr = stringifyRootSignature(res.sign, "ROOT_SIGNATURE", res.params, res.paramsCount);
 		var rootStr = stringifyRootSignature(res.sign, "ROOT_SIGNATURE", res.params, res.paramsCount);
-		var vs = compileSource(shader.vertex, "vs_6_0", 0, rootStr);
-		var ps = compileSource(shader.fragment, "ps_6_0", res.fragmentRegStart, rootStr);
+		var vs = shader.mode == Compute ? null : compileSource(shader.vertex, "vs_6_0", 0, rootStr);
+		var ps = shader.mode == Compute ? null : compileSource(shader.fragment, "ps_6_0", res.registers[1].start, rootStr);
+		var cs = shader.mode == Compute ? compileSource(shader.compute, "cs_6_0", 0, rootStr) : null;
 
 
 		var signSize = 0;
 		var signSize = 0;
 		var signBytes = Driver.serializeRootSignature(res.sign, 1, signSize);
 		var signBytes = Driver.serializeRootSignature(res.sign, 1, signSize);
 		var sign = new RootSignature(signBytes,signSize);
 		var sign = new RootSignature(signBytes,signSize);
+		c.rootSignature = sign;
+		c.shader = shader;
+
+		if( shader.mode == Compute ) {
+			c.isCompute = true;
+			var desc = new ComputePipelineStateDesc();
+			desc.rootSignature = sign;
+			desc.cs.shaderBytecode = cs;
+			desc.cs.bytecodeLength = cs.length;
+			c.computePipeline = Driver.createComputePipelineState(desc);
+			c.vertexRegisters = res.registers[0].registers;
+			return c;
+		}
+
+		c.vertexRegisters = res.registers[0].registers;
+		c.fragmentRegisters = res.registers[1].registers;
 
 
 		var inputs = [];
 		var inputs = [];
 		for( v in shader.vertex.data.vars )
 		for( v in shader.vertex.data.vars )
@@ -1166,10 +1207,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 		c.format = hxd.BufferFormat.make(format);
 		c.format = hxd.BufferFormat.make(format);
 		c.pipeline = p;
 		c.pipeline = p;
-		c.rootSignature = sign;
 		c.inputLayout = inputLayout;
 		c.inputLayout = inputLayout;
 		c.inputCount = inputs.length;
 		c.inputCount = inputs.length;
-		c.shader = shader;
 
 
 		for( i in 0...inputs.length )
 		for( i in 0...inputs.length )
 			inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
 			inputLayout[i].alignedByteOffset = 1; // will trigger error if not set in makePipeline()
@@ -1184,7 +1223,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 	// ----- BUFFERS
 	// ----- BUFFERS
 
 
-	function allocGPU( size : Int, heapType, state ) {
+	function allocGPU( size : Int, heapType, state, uav=false ) {
 		var desc = new ResourceDesc();
 		var desc = new ResourceDesc();
 		var flags = new haxe.EnumFlags();
 		var flags = new haxe.EnumFlags();
 		desc.dimension = BUFFER;
 		desc.dimension = BUFFER;
@@ -1194,6 +1233,7 @@ class DX12Driver extends h3d.impl.Driver {
 		desc.mipLevels = 1;
 		desc.mipLevels = 1;
 		desc.sampleDesc.count = 1;
 		desc.sampleDesc.count = 1;
 		desc.layout = ROW_MAJOR;
 		desc.layout = ROW_MAJOR;
+		if( uav ) desc.flags.set(ALLOW_UNORDERED_ACCESS);
 		tmp.heap.type = heapType;
 		tmp.heap.type = heapType;
 		return Driver.createCommittedResource(tmp.heap, flags, desc, state, null);
 		return Driver.createCommittedResource(tmp.heap, flags, desc, state, null);
 	}
 	}
@@ -1201,9 +1241,9 @@ class DX12Driver extends h3d.impl.Driver {
 	override function allocBuffer( m : h3d.Buffer ) : GPUBuffer {
 	override function allocBuffer( m : h3d.Buffer ) : GPUBuffer {
 		var buf = new VertexBufferData();
 		var buf = new VertexBufferData();
 		var size = m.getMemSize();
 		var size = m.getMemSize();
-		var bufSize = m.flags.has(UniformBuffer) ? calcCBVSize(size) : size;
+		var bufSize = m.flags.has(UniformBuffer) || m.flags.has(ReadWriteBuffer) ? calcCBVSize(size) : size;
 		buf.state = COPY_DEST;
 		buf.state = COPY_DEST;
-		buf.res = allocGPU(bufSize, DEFAULT, COMMON);
+		buf.res = allocGPU(bufSize, DEFAULT, COMMON, m.flags.has(ReadWriteBuffer));
 		if( !m.flags.has(UniformBuffer) ) {
 		if( !m.flags.has(UniformBuffer) ) {
 			var view = new VertexBufferView();
 			var view = new VertexBufferView();
 			view.bufferLocation = buf.res.getGpuVirtualAddress();
 			view.bufferLocation = buf.res.getGpuVirtualAddress();
@@ -1488,7 +1528,8 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 	override function uploadShaderBuffers(buffers:h3d.shader.Buffers, which:h3d.shader.Buffers.BufferKind) {
 	override function uploadShaderBuffers(buffers:h3d.shader.Buffers, which:h3d.shader.Buffers.BufferKind) {
 		uploadBuffers(buffers, buffers.vertex, which, currentShader.shader.vertex, currentShader.vertexRegisters);
 		uploadBuffers(buffers, buffers.vertex, which, currentShader.shader.vertex, currentShader.vertexRegisters);
-		uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
+		if( !currentShader.isCompute )
+			uploadBuffers(buffers, buffers.fragment, which, currentShader.shader.fragment, currentShader.fragmentRegisters);
 	}
 	}
 
 
 	function calcCBVSize( dataSize : Int ) {
 	function calcCBVSize( dataSize : Int ) {
@@ -1547,13 +1588,22 @@ class DX12Driver extends h3d.impl.Driver {
 					desc.bufferLocation = cbv.getGpuVirtualAddress();
 					desc.bufferLocation = cbv.getGpuVirtualAddress();
 					desc.sizeInBytes = calcCBVSize(dataSize);
 					desc.sizeInBytes = calcCBVSize(dataSize);
 					Driver.createConstantBufferView(desc, srv);
 					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
-				} else
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.params & 0xFF, frame.shaderResourceViews.toGPU(srv));
+				} else if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
+				else
 					frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
 					frame.commandList.setGraphicsRoot32BitConstants(regs.params, dataSize >> 2, data, 0);
 			}
 			}
 		case Globals:
 		case Globals:
-			if( shader.globalsSize > 0 )
-				frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			if( shader.globalsSize > 0 ) {
+				if( currentShader.isCompute )
+					frame.commandList.setComputeRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+				else
+					frame.commandList.setGraphicsRoot32BitConstants(regs.globals, shader.globalsSize << 2, hl.Bytes.getArray(buf.globals.toData()), 0);
+			}
 		case Textures:
 		case Textures:
 			if( regs.texturesCount > 0 ) {
 			if( regs.texturesCount > 0 ) {
 				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
 				var srv = frame.shaderResourceViews.alloc(regs.texturesCount);
@@ -1612,10 +1662,10 @@ class DX12Driver extends h3d.impl.Driver {
 					t.lastFrame = frameCount;
 					t.lastFrame = frameCount;
 					var state = if ( t.isDepth() )
 					var state = if ( t.isDepth() )
 						DEPTH_READ;
 						DEPTH_READ;
-					else if ( shader.vertex )
-						NON_PIXEL_SHADER_RESOURCE;
-					else
+					else if ( shader.kind == Fragment )
 						PIXEL_SHADER_RESOURCE;
 						PIXEL_SHADER_RESOURCE;
+					else
+						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
 					transition(t.t, state);
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 
 
@@ -1634,8 +1684,13 @@ class DX12Driver extends h3d.impl.Driver {
 					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
 					Driver.createSampler(desc, sampler.offset(i * frame.samplerViews.stride));
 				}
 				}
 
 
-				frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
-				frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				if( currentShader.isCompute ) {
+					frame.commandList.setComputeRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setComputeRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				} else {
+					frame.commandList.setGraphicsRootDescriptorTable(regs.textures, frame.shaderResourceViews.toGPU(srv));
+					frame.commandList.setGraphicsRootDescriptorTable(regs.samplers, frame.samplerViews.toGPU(sampler));
+				}
 			}
 			}
 		case Buffers:
 		case Buffers:
 			if( shader.bufferCount > 0 ) {
 			if( shader.bufferCount > 0 ) {
@@ -1643,15 +1698,29 @@ class DX12Driver extends h3d.impl.Driver {
 					var srv = frame.shaderResourceViews.alloc(1);
 					var srv = frame.shaderResourceViews.alloc(1);
 					var b = buf.buffers[i];
 					var b = buf.buffers[i];
 					var cbv = b.vbuf;
 					var cbv = b.vbuf;
-					if( cbv.view != null )
-						throw "Buffer was allocated without UniformBuffer flag";
-					transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
-					var desc = tmp.cbvDesc;
-					desc.bufferLocation = cbv.res.getGpuVirtualAddress();
-					desc.sizeInBytes = cbv.size;
-					Driver.createConstantBufferView(desc, srv);
-					frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
-				}
+					switch( regs.bufferTypes[i] ) {
+					case Uniform:
+						if( cbv.view != null )
+							throw "Buffer was allocated without UniformBuffer flag";
+						transition(cbv, VERTEX_AND_CONSTANT_BUFFER);
+						var desc = tmp.cbvDesc;
+						desc.bufferLocation = cbv.res.getGpuVirtualAddress();
+						desc.sizeInBytes = cbv.size;
+						Driver.createConstantBufferView(desc, srv);
+					case RW:
+						if( !b.flags.has(ReadWriteBuffer) )
+							throw "Buffer was allocated without ReadWriteBuffer flag";
+						transition(cbv, UNORDERED_ACCESS);
+						var desc = tmp.uavDesc;
+						desc.numElements = b.vertices;
+						desc.structureSizeInBytes = b.format.strideBytes;
+						Driver.createUnorderedAccessView(cbv.res, null, desc, srv);
+					}
+					if( currentShader.isCompute )
+						frame.commandList.setComputeRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+					else
+						frame.commandList.setGraphicsRootDescriptorTable(regs.buffers + i, frame.shaderResourceViews.toGPU(srv));
+			}
 			}
 			}
 		}
 		}
 	}
 	}
@@ -1665,8 +1734,14 @@ class DX12Driver extends h3d.impl.Driver {
 		if( currentShader == sh )
 		if( currentShader == sh )
 			return false;
 			return false;
 		currentShader = sh;
 		currentShader = sh;
-		needPipelineFlush = true;
-		frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		if( sh.isCompute ) {
+			needPipelineFlush = false;
+			frame.commandList.setComputeRootSignature(currentShader.rootSignature);
+			frame.commandList.setPipelineState(currentShader.computePipeline);
+		} else {
+			needPipelineFlush = true;
+			frame.commandList.setGraphicsRootSignature(currentShader.rootSignature);
+		}
 		return true;
 		return true;
 	}
 	}
 
 
@@ -2026,6 +2101,10 @@ class DX12Driver extends h3d.impl.Driver {
 		}
 		}
 	}
 	}
 
 
+	override function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		frame.commandList.dispatch(x,y,z);
+	}
+
 }
 }
 
 
 #end
 #end

+ 6 - 0
h3d/impl/Driver.hx

@@ -317,4 +317,10 @@ class Driver {
 		return 0.;
 		return 0.;
 	}
 	}
 
 
+	// --- COMPUTE
+
+	public function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		throw "Not implemented";
+	}
+
 }
 }

+ 2 - 2
h3d/pass/Default.hx

@@ -43,7 +43,7 @@ class Default extends Base {
 		var o = @:privateAccess new h3d.pass.PassObject();
 		var o = @:privateAccess new h3d.pass.PassObject();
 		o.pass = p;
 		o.pass = p;
 		setupShaders(new h3d.pass.PassList(o));
 		setupShaders(new h3d.pass.PassList(o));
-		return manager.compileShaders(o.shaders, p.batchMode);
+		return manager.compileShaders(o.shaders, p.batchMode ? Batch : Default);
 	}
 	}
 
 
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
@@ -68,7 +68,7 @@ class Default extends Base {
 				}
 				}
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 			}
 			}
-			p.shader = manager.compileShaders(shaders, p.pass.batchMode);
+			p.shader = manager.compileShaders(shaders, p.pass.batchMode ? Batch : Default);
 			p.shaders = shaders;
 			p.shaders = shaders;
 			var t = p.shader.fragment.textures;
 			var t = p.shader.fragment.textures;
 			if( t == null || t.type.match(TArray(_)) )
 			if( t == null || t.type.match(TArray(_)) )

+ 38 - 13
h3d/pass/ShaderManager.hx

@@ -7,6 +7,8 @@ class ShaderManager {
 	public var globals : hxsl.Globals;
 	public var globals : hxsl.Globals;
 	var shaderCache : hxsl.Cache;
 	var shaderCache : hxsl.Cache;
 	var currentOutput : hxsl.ShaderList;
 	var currentOutput : hxsl.ShaderList;
+	var currentCompute : hxsl.ShaderList;
+	var computeBuffers : h3d.shader.Buffers;
 
 
 	public function new(?output:Array<hxsl.Output>) {
 	public function new(?output:Array<hxsl.Output>) {
 		shaderCache = hxsl.Cache.get();
 		shaderCache = hxsl.Cache.get();
@@ -192,20 +194,14 @@ class ShaderManager {
 			var ptr = getPtr(buf.globals);
 			var ptr = getPtr(buf.globals);
 			while( g != null ) {
 			while( g != null ) {
 				var v = globals.fastGet(g.gid);
 				var v = globals.fastGet(g.gid);
-				if( v == null ) {
-					if( g.path == "__consts__" ) {
-						fillRec(s.consts, g.type, ptr, g.pos);
-						g = g.next;
-						continue;
-					}
+				if( v == null )
 					throw "Missing global value " + g.path;
 					throw "Missing global value " + g.path;
-				}
 				fillRec(v, g.type, ptr, g.pos);
 				fillRec(v, g.type, ptr, g.pos);
 				g = g.next;
 				g = g.next;
 			}
 			}
 		}
 		}
 		fill(buf.vertex, s.vertex);
 		fill(buf.vertex, s.vertex);
-		fill(buf.fragment, s.fragment);
+		if( s.fragment != null ) fill(buf.fragment, s.fragment);
 	}
 	}
 
 
 	public function fillParams( buf : h3d.shader.Buffers, s : hxsl.RuntimeShader, shaders : hxsl.ShaderList ) {
 	public function fillParams( buf : h3d.shader.Buffers, s : hxsl.RuntimeShader, shaders : hxsl.ShaderList ) {
@@ -273,16 +269,45 @@ class ShaderManager {
 			}
 			}
 		}
 		}
 		fill(buf.vertex, s.vertex);
 		fill(buf.vertex, s.vertex);
-		fill(buf.fragment, s.fragment);
+		if( s.fragment != null ) fill(buf.fragment, s.fragment);
 	}
 	}
 
 
-	public function compileShaders( shaders : hxsl.ShaderList, batchMode : Bool = false ) {
+	public function compileShaders( shaders : hxsl.ShaderList, mode : hxsl.Linker.LinkMode = Default ) {
 		globals.resetChannels();
 		globals.resetChannels();
 		for( s in shaders ) s.updateConstants(globals);
 		for( s in shaders ) s.updateConstants(globals);
-		currentOutput.next = shaders;
-		var s = shaderCache.link(currentOutput, batchMode);
-		currentOutput.next = null;
+		var s;
+		if( mode == Compute )
+			s = shaderCache.link(shaders, mode);
+		else {
+			currentOutput.next = shaders;
+			s = shaderCache.link(currentOutput, mode);
+			currentOutput.next = null;
+		}
 		return s;
 		return s;
 	}
 	}
 
 
+	public function computeDispatch( shader : hxsl.Shader, x = 1, y = 1, z = 1 ) {
+		if( currentCompute == null )
+			currentCompute = new hxsl.ShaderList(null);
+		shader.updateConstants(globals);
+		currentCompute.s = shader;
+		var rt = shaderCache.link(currentCompute, Compute);
+		var bufs = computeBuffers;
+		if( bufs == null ) {
+			bufs = new h3d.shader.Buffers(rt);
+			computeBuffers = bufs;
+		}
+		bufs.grow(rt);
+		fillParams(bufs, rt, currentCompute);
+		fillGlobals(bufs, rt);
+		var e = h3d.Engine.getCurrent();
+		e.driver.selectShader(rt);
+		e.driver.uploadShaderBuffers(bufs, Params);
+		e.driver.uploadShaderBuffers(bufs, Globals);
+		e.driver.uploadShaderBuffers(bufs, Buffers);
+		e.driver.uploadShaderBuffers(bufs, Textures);
+		e.driver.computeDispatch(x,y,z);
+		currentCompute.s = null;
+	}
+
 }
 }

+ 1 - 1
h3d/scene/MeshBatch.hx

@@ -121,7 +121,7 @@ class MeshBatch extends MultiMaterial {
 
 
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var shaders = p.getShadersRec();
 				var shaders = p.getShadersRec();
-				var rt = manager.compileShaders(shaders, false);
+				var rt = manager.compileShaders(shaders, Default);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 
 
 				var b = new BatchData();
 				var b = new BatchData();

+ 4 - 2
h3d/shader/Buffers.hx

@@ -43,12 +43,14 @@ class Buffers {
 
 
 	public function new( s : hxsl.RuntimeShader ) {
 	public function new( s : hxsl.RuntimeShader ) {
 		vertex = new ShaderBuffers(s.vertex);
 		vertex = new ShaderBuffers(s.vertex);
-		fragment = new ShaderBuffers(s.fragment);
+		if( s.fragment != null )
+			fragment = new ShaderBuffers(s.fragment);
 	}
 	}
 
 
 	public inline function grow( s : hxsl.RuntimeShader ) {
 	public inline function grow( s : hxsl.RuntimeShader ) {
 		vertex.grow(s.vertex);
 		vertex.grow(s.vertex);
-		fragment.grow(s.fragment);
+		if( s.fragment != null )
+			fragment.grow(s.fragment);
 	}
 	}
 }
 }
 
 

+ 18 - 3
hxsl/Ast.hx

@@ -1,5 +1,10 @@
 package hxsl;
 package hxsl;
 
 
+enum BufferKind {
+	Uniform;
+	RW;
+}
+
 enum Type {
 enum Type {
 	TVoid;
 	TVoid;
 	TInt;
 	TInt;
@@ -17,7 +22,7 @@ enum Type {
 	TStruct( vl : Array<TVar> );
 	TStruct( vl : Array<TVar> );
 	TFun( variants : Array<FunType> );
 	TFun( variants : Array<FunType> );
 	TArray( t : Type, size : SizeDecl );
 	TArray( t : Type, size : SizeDecl );
-	TBuffer( t : Type, size : SizeDecl );
+	TBuffer( t : Type, size : SizeDecl, kind : BufferKind );
 	TChannel( size : Int );
 	TChannel( size : Int );
 	TMat2;
 	TMat2;
 }
 }
@@ -187,6 +192,7 @@ enum FunctionKind {
 	Fragment;
 	Fragment;
 	Init;
 	Init;
 	Helper;
 	Helper;
+	Main;
 }
 }
 
 
 enum TGlobal {
 enum TGlobal {
@@ -280,6 +286,8 @@ enum TGlobal {
 	IntBitsToFloat;
 	IntBitsToFloat;
 	UintBitsToFloat;
 	UintBitsToFloat;
 	RoundEven;
 	RoundEven;
+	// compute
+	SetLayout;
 }
 }
 
 
 enum Component {
 enum Component {
@@ -418,7 +426,12 @@ class Tools {
 			prefix + "Vec" + size;
 			prefix + "Vec" + size;
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
-		case TBuffer(t, s): "buffer "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
+		case TBuffer(t, s, k):
+			var prefix = switch( k ) {
+			case Uniform: "buffer";
+			case RW: "rwbuffer";
+			};
+			prefix+" "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TBytes(n): "Bytes" + n;
 		case TBytes(n): "Bytes" + n;
 		default: t.getName().substr(1);
 		default: t.getName().substr(1);
 		}
 		}
@@ -457,6 +470,8 @@ class Tools {
 			return hasSideEffect(e) || hasSideEffect(index);
 			return hasSideEffect(e) || hasSideEffect(index);
 		case TConst(_), TVar(_), TGlobal(_):
 		case TConst(_), TVar(_), TGlobal(_):
 			return false;
 			return false;
+		case TCall({ e : TGlobal(SetLayout) },_):
+			return true;
 		case TCall(e, pl):
 		case TCall(e, pl):
 			if( !e.e.match(TGlobal(_)) )
 			if( !e.e.match(TGlobal(_)) )
 				return true;
 				return true;
@@ -545,7 +560,7 @@ class Tools {
 		case TMat3x4: 12;
 		case TMat3x4: 12;
 		case TBytes(s): s;
 		case TBytes(s): s;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
-		case TArray(t, SConst(v)), TBuffer(t, SConst(v)): size(t) * v;
+		case TArray(t, SConst(v)), TBuffer(t, SConst(v),_): size(t) * v;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		}
 		}
 	}
 	}

+ 57 - 37
hxsl/Cache.hx

@@ -1,6 +1,7 @@
 package hxsl;
 package hxsl;
 using hxsl.Ast;
 using hxsl.Ast;
 import hxsl.RuntimeShader;
 import hxsl.RuntimeShader;
+import hxsl.Linker.LinkMode;
 
 
 class BatchInstanceParams {
 class BatchInstanceParams {
 
 
@@ -195,7 +196,7 @@ class Cache {
 	}
 	}
 
 
 	@:noDebug
 	@:noDebug
-	public function link( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	public function link( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var c = linkCache;
 		var c = linkCache;
 		for( s in shaders ) {
 		for( s in shaders ) {
 			var i = @:privateAccess s.instance;
 			var i = @:privateAccess s.instance;
@@ -207,11 +208,11 @@ class Cache {
 			c = cs;
 			c = cs;
 		}
 		}
 		if( c.linked == null )
 		if( c.linked == null )
-			c.linked = compileRuntimeShader(shaders, batchMode);
+			c.linked = compileRuntimeShader(shaders, mode);
 		return c.linked;
 		return c.linked;
 	}
 	}
 
 
-	function compileRuntimeShader( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	function compileRuntimeShader( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var shaderDatas = [];
 		var shaderDatas = [];
 		var index = 0;
 		var index = 0;
 		for( s in shaders ) {
 		for( s in shaders ) {
@@ -262,14 +263,14 @@ class Cache {
 		//TRACE = shaderId == 0;
 		//TRACE = shaderId == 0;
 		#end
 		#end
 
 
-		var linker = new hxsl.Linker(batchMode);
+		var linker = new hxsl.Linker(mode);
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			throw e;
 			throw e;
 		}
 		}
 
 
-		if( batchMode ) {
+		if( mode == Batch ) {
 			function checkRec( v : TVar ) {
 			function checkRec( v : TVar ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
@@ -302,7 +303,7 @@ class Cache {
 
 
 		var prev = s;
 		var prev = s;
 		var splitter = new hxsl.Splitter();
 		var splitter = new hxsl.Splitter();
-		var s = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
+		var sl = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
 
 
 		// params tracking
 		// params tracking
 		var paramVars = new Map();
 		var paramVars = new Map();
@@ -319,41 +320,42 @@ class Cache {
 
 
 
 
 		#if debug
 		#if debug
-		Printer.check(s.vertex,[prev]);
-		Printer.check(s.fragment,[prev]);
+		for( s in sl )
+			Printer.check(s,[prev]);
 		#end
 		#end
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- SPLIT ----\n\n");
 			dbg.writeString("----- SPLIT ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
-		var prev = s;
-		var s = new hxsl.Dce().dce(s.vertex, s.fragment);
+		var prev = sl;
+		var sl = new hxsl.Dce().dce(sl);
 
 
 		#if debug
 		#if debug
-		Printer.check(s.vertex,[prev.vertex]);
-		Printer.check(s.fragment,[prev.fragment]);
+		for( i => s in sl )
+			Printer.check(s,[prev[i]]);
 		#end
 		#end
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- DCE ----\n\n");
 			dbg.writeString("----- DCE ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
-		var r = buildRuntimeShader(s.vertex, s.fragment, paramVars);
+		var r = buildRuntimeShader(sl, paramVars);
+		r.mode = mode;
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- FLATTEN ----\n\n");
 			dbg.writeString("----- FLATTEN ----\n\n");
-			dbg.writeString(Printer.shaderToString(r.vertex.data, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(r.fragment.data,DEBUG_IDS)+"\n\n");
+			for( s in r.getShaders() )
+				dbg.writeString(Printer.shaderToString(s.data, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
@@ -366,9 +368,7 @@ class Cache {
 
 
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
-		r.signature = haxe.crypto.Md5.encode(Printer.shaderToString(r.vertex.data) + Printer.shaderToString(r.fragment.data));
-		r.batchMode = batchMode;
-
+		r.signature = haxe.crypto.Md5.encode([for( s in r.getShaders() ) Printer.shaderToString(s.data)].join(""));
 		var r2 = byID.get(r.signature);
 		var r2 = byID.get(r.signature);
 		if( r2 != null )
 		if( r2 != null )
 			r.id = r2.id; // same id but different variable mapping
 			r.id = r2.id; // same id but different variable mapping
@@ -385,19 +385,33 @@ class Cache {
 		return r;
 		return r;
 	}
 	}
 
 
-	function buildRuntimeShader( vertex : ShaderData, fragment : ShaderData, paramVars ) {
+	function buildRuntimeShader( shaders : Array<ShaderData>, paramVars ) {
 		var r = new RuntimeShader();
 		var r = new RuntimeShader();
-		r.vertex = flattenShader(vertex, Vertex, paramVars);
-		r.vertex.vertex = true;
-		r.fragment = flattenShader(fragment, Fragment, paramVars);
 		r.globals = new Map();
 		r.globals = new Map();
-		initGlobals(r, r.vertex);
-		initGlobals(r, r.fragment);
-
-		#if debug
-		Printer.check(r.vertex.data,[vertex]);
-		Printer.check(r.fragment.data,[fragment]);
-		#end
+		for( s in shaders ) {
+			var kind = switch( s.name ) {
+			case "vertex": Vertex;
+			case "fragment": Fragment;
+			case "main": Main;
+			default: throw "assert";
+			}
+			var fl = flattenShader(s, kind, paramVars);
+			fl.kind = kind;
+			switch( kind ) {
+			case Vertex:
+				r.vertex = fl;
+			case Fragment:
+				r.fragment = fl;
+			case Main:
+				r.compute = fl;
+			default:
+				throw "assert";
+			}
+			initGlobals(r, fl);
+			#if debug
+			Printer.check(fl,[vertexs]);
+			#end
+		}
 		return r;
 		return r;
 	}
 	}
 
 
@@ -429,7 +443,6 @@ class Cache {
 		data = hl.Api.compact(data, null, 0, null);
 		data = hl.Api.compact(data, null, 0, null);
 		#end
 		#end
 		var textures = [];
 		var textures = [];
-		c.consts = flat.consts;
 		c.texturesCount = 0;
 		c.texturesCount = 0;
 		for( g in flat.allocData.keys() ) {
 		for( g in flat.allocData.keys() ) {
 			var alloc = flat.allocData.get(g);
 			var alloc = flat.allocData.get(g);
@@ -468,8 +481,15 @@ class Cache {
 					c.params = out[0];
 					c.params = out[0];
 					c.paramsSize = size;
 					c.paramsSize = size;
 				case TArray(TBuffer(_), _):
 				case TArray(TBuffer(_), _):
-					c.buffers = out[0];
-					c.bufferCount = out.length;
+					if( c.buffers == null ) {
+						c.buffers = out[0];
+						c.bufferCount = out.length;
+					} else {
+						var p = c.buffers;
+						while( p.next != null ) p = p.next;
+						p.next = out[0];
+						c.bufferCount += out.length;
+					}
 				default: throw "assert";
 				default: throw "assert";
 				}
 				}
 			case Global:
 			case Global:
@@ -575,7 +595,7 @@ class Cache {
 		inputOffset.qualifiers = [PerInstance(1)];
 		inputOffset.qualifiers = [PerInstance(1)];
 
 
 		var vcount = declVar("Batch_Count",TInt,Param);
 		var vcount = declVar("Batch_Count",TInt,Param);
-		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount)),Param);
+		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount),Uniform),Param);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };

+ 12 - 3
hxsl/Checker.hx

@@ -188,6 +188,8 @@ class Checker {
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 			case IntBitsToFloat, UintBitsToFloat:
 			case IntBitsToFloat, UintBitsToFloat:
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
+			case SetLayout:
+				[{ args : [{ name : "x", type : TInt },{ name : "y", type : TInt },{ name : "z", type : TInt }], ret : TVoid }];
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 				null;
 				null;
 			}
 			}
@@ -246,6 +248,7 @@ class Checker {
 			var kind = switch( f.name ) {
 			var kind = switch( f.name ) {
 			case "vertex":  Vertex;
 			case "vertex":  Vertex;
 			case "fragment": Fragment;
 			case "fragment": Fragment;
+			case "main": Main;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			}
 			}
 			if( args.length != 0 && kind != Helper )
 			if( args.length != 0 && kind != Helper )
@@ -337,6 +340,8 @@ class Checker {
 			switch( v.kind ) {
 			switch( v.kind ) {
 			case Local, Var, Output:
 			case Local, Var, Output:
 				return;
 				return;
+			case Param if( v.type.match(TBuffer(_,_,RW)) ):
+				return;
 			default:
 			default:
 			}
 			}
 		case TSwiz(e, _):
 		case TSwiz(e, _):
@@ -631,7 +636,7 @@ class Checker {
 			default: unify(e2.t, TInt, e2.p);
 			default: unify(e2.t, TInt, e2.p);
 			}
 			}
 			switch( e1.t ) {
 			switch( e1.t ) {
-			case TArray(t, size), TBuffer(t,size):
+			case TArray(t, size), TBuffer(t,size,_):
 				switch( [size, e2.e] ) {
 				switch( [size, e2.e] ) {
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 					error("Indexing outside array bounds", e.pos);
 					error("Indexing outside array bounds", e.pos);
@@ -849,7 +854,7 @@ class Checker {
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 			}
 			}
 			return parent.type;
 			return parent.type;
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			switch( t ) {
 			switch( t ) {
 			case TArray(_):
 			case TArray(_):
 				error("Multidimentional arrays are not allowed", pos);
 				error("Multidimentional arrays are not allowed", pos);
@@ -894,7 +899,11 @@ class Checker {
 				SVar(v2);
 				SVar(v2);
 			}
 			}
 			t = makeVarType(t,parent,pos);
 			t = makeVarType(t,parent,pos);
-			return vt.match(TArray(_)) ? TArray(t, s) : TBuffer(t,s);
+			return switch( vt ) {
+			case TArray(_): TArray(t, s);
+			case TBuffer(_,_,kind): TBuffer(t,s,kind);
+			default: throw "assert";
+			}
 		default:
 		default:
 			return vt;
 			return vt;
 		}
 		}

+ 24 - 30
hxsl/Dce.hx

@@ -34,30 +34,27 @@ class Dce {
 		#end
 		#end
 	}
 	}
 
 
-	public function dce( vertex : ShaderData, fragment : ShaderData ) {
+	public function dce( shaders : Array<ShaderData> ) {
 		// collect vars dependencies
 		// collect vars dependencies
 		used = new Map();
 		used = new Map();
 		channelVars = [];
 		channelVars = [];
 
 
 		var inputs = [];
 		var inputs = [];
-		for( v in vertex.vars ) {
-			var i = get(v);
-			if( v.kind == Input )
-				inputs.push(i);
-			if( v.kind == Output )
-				i.keep = true;
-		}
-		for( v in fragment.vars ) {
-			var i = get(v);
-			if( v.kind == Output )
-				i.keep = true;
+		for( s in shaders ) {
+			for( v in s.vars ) {
+				var i = get(v);
+				if( v.kind == Input )
+					inputs.push(i);
+				if( v.kind == Output || v.type.match(TBuffer(_,_,RW)) )
+					i.keep = true;
+			}
 		}
 		}
 
 
 		// collect dependencies
 		// collect dependencies
-		for( f in vertex.funs )
-			check(f.expr, [], []);
-		for( f in fragment.funs )
-			check(f.expr, [], []);
+		for( s in shaders ) {
+			for( f in s.funs )
+				check(f.expr, [], []);
+		}
 
 
 		var outExprs = [];
 		var outExprs = [];
 		while( true ) {
 		while( true ) {
@@ -74,10 +71,10 @@ class Dce {
 				markRec(v);
 				markRec(v);
 
 
 			outExprs = [];
 			outExprs = [];
-			for( f in vertex.funs )
-				outExprs.push(mapExpr(f.expr, false));
-			for( f in fragment.funs )
-				outExprs.push(mapExpr(f.expr, false));
+			for( s in shaders ) {
+				for( f in s.funs )
+					outExprs.push(mapExpr(f.expr, false));
+			}
 
 
 			// post add conditional branches
 			// post add conditional branches
 			markAsKeep = false;
 			markAsKeep = false;
@@ -86,22 +83,19 @@ class Dce {
 			if( !markAsKeep ) break;
 			if( !markAsKeep ) break;
 		}
 		}
 
 
-		for( f in vertex.funs )
-			f.expr = outExprs.shift();
-		for( f in fragment.funs )
-			f.expr = outExprs.shift();
+		for( s in shaders ) {
+			for( f in s.funs )
+				f.expr = outExprs.shift();
+		}
 
 
 		for( v in used ) {
 		for( v in used ) {
 			if( v.used ) continue;
 			if( v.used ) continue;
 			if( v.v.kind == VarKind.Input) continue;
 			if( v.v.kind == VarKind.Input) continue;
-			vertex.vars.remove(v.v);
-			fragment.vars.remove(v.v);
+			for( s in shaders )
+				s.vars.remove(v.v);
 		}
 		}
 
 
-		return {
-			fragment : fragment,
-			vertex : vertex,
-		}
+		return shaders.copy();
 	}
 	}
 
 
 	function get( v : TVar ) {
 	function get( v : TVar ) {

+ 12 - 4
hxsl/Eval.hx

@@ -52,18 +52,26 @@ class Eval {
 		switch( v2.type ) {
 		switch( v2.type ) {
 		case TStruct(vl):
 		case TStruct(vl):
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
-		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs)):
+		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs), _):
 			var c = constants.get(vs.id);
 			var c = constants.get(vs.id);
 			if( c != null )
 			if( c != null )
 				switch( c ) {
 				switch( c ) {
 				case TConst(CInt(v)):
 				case TConst(CInt(v)):
-					v2.type = v2.type.match(TArray(_)) ? TArray(t, SConst(v)) : TBuffer(t, SConst(v));
+					v2.type = switch( v2.type ) {
+					case TArray(_): TArray(t, SConst(v));
+					case TBuffer(_,_,kind): TBuffer(t, SConst(v), kind);
+					default: throw "assert";
+					};
 				default:
 				default:
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 				}
 				}
 			else {
 			else {
 				var vs2 = mapVar(vs);
 				var vs2 = mapVar(vs);
-				v2.type = v2.type.match(TArray(_)) ? TArray(t, SVar(vs2)) : TBuffer(t, SVar(vs2));
+				v2.type = switch( v2.type ) {
+				case TArray(_): TArray(t, SVar(vs2));
+				case TBuffer(_,_,kind): TBuffer(t, SVar(vs2), kind);
+				default: throw "assert";
+				}
 			}
 			}
 		default:
 		default:
 		}
 		}
@@ -81,7 +89,7 @@ class Eval {
 			return false;
 			return false;
 		case TArray(t, _):
 		case TArray(t, _):
 			return checkSamplerRec(t);
 			return checkSamplerRec(t);
-		case TBuffer(_, size):
+		case TBuffer(_, _, _):
 			return true;
 			return true;
 		default:
 		default:
 		}
 		}

+ 9 - 120
hxsl/Flatten.hx

@@ -26,8 +26,6 @@ class Flatten {
 	var params : Array<TVar>;
 	var params : Array<TVar>;
 	var outVars : Array<TVar>;
 	var outVars : Array<TVar>;
 	var varMap : Map<TVar,Alloc>;
 	var varMap : Map<TVar,Alloc>;
-	var econsts : TExpr;
-	public var consts : Array<Float>;
 	public var allocData : Map< TVar, Array<Alloc> >;
 	public var allocData : Map< TVar, Array<Alloc> >;
 
 
 	public function new() {
 	public function new() {
@@ -44,6 +42,7 @@ class Flatten {
 		var prefix = switch( kind ) {
 		var prefix = switch( kind ) {
 		case Vertex: "vertex";
 		case Vertex: "vertex";
 		case Fragment: "fragment";
 		case Fragment: "fragment";
+		case Main: "compute";
 		default: throw "assert";
 		default: throw "assert";
 		}
 		}
 		pack(prefix + "Globals", Global, globals, VFloat);
 		pack(prefix + "Globals", Global, globals, VFloat);
@@ -52,7 +51,8 @@ class Flatten {
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
-		packBuffers(allVars);
+		packBuffers("buffers", allVars, Uniform);
+		packBuffers("rwbuffers", allVars, RW);
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		return {
 		return {
 			name : s.name,
 			name : s.name,
@@ -104,119 +104,6 @@ class Flatten {
 		return optimize(e);
 		return optimize(e);
 	}
 	}
 
 
-	function mapConsts( e : TExpr ) : TExpr {
-		switch( e.e ) {
-		case TArray(ea, eindex = { e : TConst(CInt(_)) } ):
-			return { e : TArray(mapConsts(ea), eindex), t : e.t, p : e.p };
-		case TBinop(OpMult, _, { t : TMat3x4 } ):
-			allocConst(1, e.p); // pre-alloc
-		case TArray(ea, eindex):
-			switch( ea.t ) {
-			case TArray(t, _):
-				var stride = varSize(t, VFloat) >> 2;
-				allocConst(stride, e.p); // pre-alloc
-			default:
-			}
-		case TConst(c):
-			switch( c ) {
-			case CFloat(v):
-				return allocConst(v, e.p);
-			case CInt(v):
-				return allocConst(v, e.p);
-			default:
-				return e;
-			}
-		case TGlobal(g):
-			switch( g ) {
-			case Pack:
-				allocConsts([1, 255, 255 * 255, 255 * 255 * 255], e.p);
-				allocConsts([1/255, 1/255, 1/255, 0], e.p);
-			case Unpack:
-				allocConsts([1, 1 / 255, 1 / (255 * 255), 1 / (255 * 255 * 255)], e.p);
-			case Radians:
-				allocConst(Math.PI / 180, e.p);
-			case Degrees:
-				allocConst(180 / Math.PI, e.p);
-			case Log:
-				allocConst(0.6931471805599453, e.p);
-			case Exp:
-				allocConst(1.4426950408889634, e.p);
-			case Mix:
-				allocConst(1, e.p);
-			case UnpackNormal:
-				allocConst(0.5, e.p);
-			case PackNormal:
-				allocConst(1, e.p);
-				allocConst(0.5, e.p);
-			case ScreenToUv:
-				allocConsts([0.5,0.5], e.p);
-				allocConsts([0.5,-0.5], e.p);
-			case UvToScreen:
-				allocConsts([2,-2], e.p);
-				allocConsts([-1,1], e.p);
-			case Smoothstep:
-				allocConst(2.0, e.p);
-				allocConst(3.0, e.p);
-			default:
-			}
-		case TCall( { e : TGlobal(Vec4) }, [ { e : TVar( { kind : Global | Param | Input | Var } ), t : TVec(3, VFloat) }, { e : TConst(CInt(1)) } ]):
-			// allow var expansion without relying on a constant
-			return e;
-		default:
-		}
-		return e.map(mapConsts);
-	}
-
-	function allocConst( v : Float, p ) : TExpr {
-		var index = consts.indexOf(v);
-		if( index < 0 ) {
-			index = consts.length;
-			consts.push(v);
-		}
-		return { e : TArray(econsts, { e : TConst(CInt(index)), t : TInt, p : p } ), t : TFloat, p : p };
-	}
-
-	function allocConsts( va : Array<Float>, p ) : TExpr {
-		var pad = (va.length - 1) & 3;
-		var index = -1;
-		for( i in 0...consts.length - (va.length - 1) ) {
-			if( (i >> 2) != (i + pad) >> 2 ) continue;
-			var found = true;
-			for( j in 0...va.length )
-				if( consts[i + j] != va[j] ) {
-					found = false;
-					break;
-				}
-			if( found ) {
-				index = i;
-				break;
-			}
-		}
-		if( index < 0 ) {
-			// pad
-			while( consts.length >> 2 != (consts.length + pad) >> 2 )
-				consts.push(0);
-			index = consts.length;
-			for( v in va )
-				consts.push(v);
-		}
-		inline function get(i) : TExpr {
-			return { e : TArray(econsts, { e : TConst(CInt(index+i)), t : TInt, p : p } ), t : TFloat, p : p };
-		}
-		switch( va.length ) {
-		case 1:
-			return get(0);
-		case 2:
-			return { e : TCall( { e : TGlobal(Vec2), t : TVoid, p : p }, [get(0), get(1)]), t : TVec(2, VFloat), p : p };
-		case 3:
-			return { e : TCall( { e : TGlobal(Vec3), t : TVoid, p : p }, [get(0), get(1), get(2)]), t : TVec(3, VFloat), p : p };
-		case 4:
-			return { e : TCall( { e : TGlobal(Vec4), t : TVoid, p : p }, [get(0), get(1), get(3), get(4)]), t : TVec(4, VFloat), p : p };
-		default:
-			throw "assert";
-		}
-	}
-
 	inline function mkInt(v:Int,pos) {
 	inline function mkInt(v:Int,pos) {
 		return { e : TConst(CInt(v)), t : TInt, p : pos };
 		return { e : TConst(CInt(v)), t : TInt, p : pos };
 	}
 	}
@@ -374,22 +261,24 @@ class Flatten {
 		return alloc;
 		return alloc;
 	}
 	}
 
 
-	function packBuffers( vars : Array<TVar> ) {
+	function packBuffers( name : String, vars : Array<TVar>, kind ) {
 		var alloc = new Array<Alloc>();
 		var alloc = new Array<Alloc>();
 		var g : TVar = {
 		var g : TVar = {
 			id : Tools.allocVarId(),
 			id : Tools.allocVarId(),
-			name : "buffers",
+			name : name,
 			type : TVoid,
 			type : TVoid,
 			kind : Param,
 			kind : Param,
 		};
 		};
 		for( v in vars )
 		for( v in vars )
-			if( v.type.match(TBuffer(_)) ) {
+			switch( v.type ) {
+			case TBuffer(_,_,k) if( kind == k ):
 				var a = new Alloc(g, null, alloc.length, 1);
 				var a = new Alloc(g, null, alloc.length, 1);
 				a.v = v;
 				a.v = v;
 				alloc.push(a);
 				alloc.push(a);
 				outVars.push(v);
 				outVars.push(v);
+			default:
 			}
 			}
-		g.type = TArray(TBuffer(TVoid,SConst(0)),SConst(alloc.length));
+		g.type = TArray(TBuffer(TVoid,SConst(0),kind),SConst(alloc.length));
 		allocData.set(g, alloc);
 		allocData.set(g, alloc);
 	}
 	}
 
 

+ 3 - 2
hxsl/GlslOut.hx

@@ -181,12 +181,13 @@ class GlslOut {
 			case SConst(n): add(n);
 			case SConst(n): add(n);
 			}
 			}
 			add("]");
 			add("]");
-		case TBuffer(t, size):
+		case TBuffer(t, size, kind):
+			if( kind != Uniform ) throw "TODO";
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add(" { ");
 			add(" { ");
 			v.type = TArray(t,size);
 			v.type = TArray(t,size);
 			addVar(v);
 			addVar(v);
-			v.type = TBuffer(t,size);
+			v.type = TBuffer(t,size, kind);
 			add("; }");
 			add("; }");
 		default:
 		default:
 			addType(v.type);
 			addType(v.type);

+ 53 - 21
hxsl/HlslOut.hx

@@ -101,13 +101,19 @@ class HlslOut {
 	var exprValues : Array<String>;
 	var exprValues : Array<String>;
 	var locals : Map<Int,TVar>;
 	var locals : Map<Int,TVar>;
 	var decls : Array<String>;
 	var decls : Array<String>;
-	var isVertex : Bool;
+	var kind : FunctionKind;
 	var allNames : Map<String, Int>;
 	var allNames : Map<String, Int>;
 	var samplers : Map<Int, Array<Int>>;
 	var samplers : Map<Int, Array<Int>>;
+	var computeLayout = [1,1,1];
 	public var varNames : Map<Int,String>;
 	public var varNames : Map<Int,String>;
 	public var baseRegister : Int = 0;
 	public var baseRegister : Int = 0;
 
 
 	var varAccess : Map<Int,String>;
 	var varAccess : Map<Int,String>;
+	var isVertex(get,never) : Bool;
+	var isCompute(get,never) : Bool;
+
+	inline function get_isCompute() return kind == Main;
+	inline function get_isVertex() return kind == Vertex;
 
 
 	public function new() {
 	public function new() {
 		varNames = new Map();
 		varNames = new Map();
@@ -175,7 +181,7 @@ class HlslOut {
 			add(" }");
 			add(" }");
 		case TFun(_):
 		case TFun(_):
 			add("function");
 			add("function");
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addType(t);
 			addType(t);
 			add("[");
 			add("[");
 			switch( size ) {
 			switch( size ) {
@@ -201,7 +207,7 @@ class HlslOut {
 
 
 	function addVar( v : TVar ) {
 	function addVar( v : TVar ) {
 		switch( v.type ) {
 		switch( v.type ) {
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addVar({
 			addVar({
 				id : v.id,
 				id : v.id,
 				name : v.name,
 				name : v.name,
@@ -298,6 +304,8 @@ class HlslOut {
 			var acc = varAccess.get(v.id);
 			var acc = varAccess.get(v.id);
 			if( acc != null ) add(acc);
 			if( acc != null ) add(acc);
 			ident(v);
 			ident(v);
+		case TCall({ e : TGlobal(SetLayout) },_):
+			// ignore
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 			addValue(args[0], tabs);
 			addValue(args[0], tabs);
 			switch( g ) {
 			switch( g ) {
@@ -687,6 +695,8 @@ class HlslOut {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 		switch( e.e )  {
 		switch( e.e )  {
 		case TGlobal(g): m.set(g,true);
 		case TGlobal(g): m.set(g,true);
+		case TCall({ e : TGlobal(SetLayout) }, [{ e : TConst(CInt(x)) }, { e : TConst(CInt(y)) }, { e : TConst(CInt(z)) }]):
+			computeLayout = [x,y,z];
 		default: e.iter(collectGlobals.bind(m));
 		default: e.iter(collectGlobals.bind(m));
 		}
 		}
 	}
 	}
@@ -709,7 +719,7 @@ class HlslOut {
 			collectGlobals(foundGlobals, f.expr);
 			collectGlobals(foundGlobals, f.expr);
 
 
 		add("struct s_input {\n");
 		add("struct s_input {\n");
-		if( !isVertex )
+		if( kind == Fragment )
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 		for( v in s.vars )
 		for( v in s.vars )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
@@ -722,14 +732,16 @@ class HlslOut {
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 		add("};\n\n");
 		add("};\n\n");
 
 
-		add("struct s_output {\n");
-		for( v in s.vars )
-			if( v.kind == Output )
-				declVar("_out.", v);
-		for( v in s.vars )
-			if( v.kind == Var && isVertex )
-				declVar("_out.", v);
-		add("};\n\n");
+		if( !isCompute ) {
+			add("struct s_output {\n");
+			for( v in s.vars )
+				if( v.kind == Output )
+					declVar("_out.", v);
+			for( v in s.vars )
+				if( v.kind == Var && isVertex )
+					declVar("_out.", v);
+			add("};\n\n");
+		}
 	}
 	}
 
 
 	function initGlobals( s : ShaderData ) {
 	function initGlobals( s : ShaderData ) {
@@ -770,10 +782,25 @@ class HlslOut {
 
 
 		var bufCount = 0;
 		var bufCount = 0;
 		for( b in buffers ) {
 		for( b in buffers ) {
-			add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
-			addVar(b);
-			add("; };\n");
-			bufCount++;
+			switch( b.type ) {
+			case TBuffer(t, size, kind):
+				switch( kind ) {
+				case Uniform:
+					add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
+					addVar(b);
+					add("; };\n");
+					bufCount++;
+				case RW:
+					add('RWStructuredBuffer<');
+					addType(t);
+					add('> ');
+					ident(b);
+					add(' : register(u${bufCount+baseRegister+2});');
+					bufCount++;
+				}
+			default:
+				throw "assert";
+			}
 		}
 		}
 		if( bufCount > 0 ) add("\n");
 		if( bufCount > 0 ) add("\n");
 
 
@@ -795,7 +822,8 @@ class HlslOut {
 
 
 	function initStatics( s : ShaderData ) {
 	function initStatics( s : ShaderData ) {
 		add(STATIC + "s_input _in;\n");
 		add(STATIC + "s_input _in;\n");
-		add(STATIC + "s_output _out;\n");
+		if( !isCompute )
+			add(STATIC + "s_output _out;\n");
 
 
 		add("\n");
 		add("\n");
 		for( v in s.vars )
 		for( v in s.vars )
@@ -808,7 +836,11 @@ class HlslOut {
 	}
 	}
 
 
 	function emitMain( expr : TExpr ) {
 	function emitMain( expr : TExpr ) {
-		add("s_output main( s_input __in ) {\n");
+		if( isCompute )
+			add('[numthreads(${computeLayout[0]},${computeLayout[1]},${computeLayout[2]})] void ');
+		else
+			add('s_output ');
+		add("main( s_input __in ) {\n");
 		add("\t_in = __in;\n");
 		add("\t_in = __in;\n");
 		switch( expr.e ) {
 		switch( expr.e ) {
 		case TBlock(el):
 		case TBlock(el):
@@ -820,7 +852,8 @@ class HlslOut {
 		default:
 		default:
 			addExpr(expr, "");
 			addExpr(expr, "");
 		}
 		}
-		add("\treturn _out;\n");
+		if( !isCompute )
+			add("\treturn _out;\n");
 		add("}");
 		add("}");
 	}
 	}
 
 
@@ -848,8 +881,7 @@ class HlslOut {
 
 
 		if( s.funs.length != 1 ) throw "assert";
 		if( s.funs.length != 1 ) throw "assert";
 		var f = s.funs[0];
 		var f = s.funs[0];
-		isVertex = f.kind == Vertex;
-
+		kind = f.kind;
 		varAccess = new Map();
 		varAccess = new Map();
 		samplers = new Map();
 		samplers = new Map();
 		initVars(s);
 		initVars(s);

+ 19 - 7
hxsl/Linker.hx

@@ -1,6 +1,12 @@
 package hxsl;
 package hxsl;
 using hxsl.Ast;
 using hxsl.Ast;
 
 
+enum LinkMode {
+	Default;
+	Batch;
+	Compute;
+}
+
 private class AllocatedVar {
 private class AllocatedVar {
 	public var id : Int;
 	public var id : Int;
 	public var v : TVar;
 	public var v : TVar;
@@ -28,6 +34,7 @@ private class ShaderInfos {
 	public var vertex : Null<Bool>;
 	public var vertex : Null<Bool>;
 	public var onStack : Bool;
 	public var onStack : Bool;
 	public var hasDiscard : Bool;
 	public var hasDiscard : Bool;
+	public var isCompute : Bool;
 	public var marked : Null<Bool>;
 	public var marked : Null<Bool>;
 	public function new(n, v) {
 	public function new(n, v) {
 		this.name = n;
 		this.name = n;
@@ -49,12 +56,12 @@ class Linker {
 	var varIdMap : Map<Int,Int>;
 	var varIdMap : Map<Int,Int>;
 	var locals : Map<Int,Bool>;
 	var locals : Map<Int,Bool>;
 	var curInstance : Int;
 	var curInstance : Int;
-	var batchMode : Bool;
+	var mode : LinkMode;
 	var isBatchShader : Bool;
 	var isBatchShader : Bool;
 	var debugDepth = 0;
 	var debugDepth = 0;
 
 
-	public function new(batchMode=false) {
-		this.batchMode = batchMode;
+	public function new(mode) {
+		this.mode = mode;
 	}
 	}
 
 
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
@@ -348,7 +355,7 @@ class Linker {
 		curInstance = 0;
 		curInstance = 0;
 		var outVars = [];
 		var outVars = [];
 		for( s in shadersData ) {
 		for( s in shadersData ) {
-			isBatchShader = batchMode && StringTools.startsWith(s.name,"batchShader_");
+			isBatchShader = mode == Batch && StringTools.startsWith(s.name,"batchShader_");
 			for( v in s.vars ) {
 			for( v in s.vars ) {
 				var v2 = allocVar(v, null, s.name);
 				var v2 = allocVar(v, null, s.name);
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
@@ -375,8 +382,13 @@ class Linker {
 				if( v.kind == null ) throw "assert";
 				if( v.kind == null ) throw "assert";
 				switch( v.kind ) {
 				switch( v.kind ) {
 				case Vertex, Fragment:
 				case Vertex, Fragment:
+					if( mode == Compute )
+						throw "Unexpected "+v.kind.getName().toLowerCase()+"() function in compute shader";
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
-
+				case Main:
+					if( mode != Compute )
+						throw "Unexpected main() outside compute shader";
+					addShader(s.name, true, f.expr, priority).isCompute = true;
 				case Init:
 				case Init:
 					var prio : Array<Int>;
 					var prio : Array<Int>;
 					var status : Null<Bool> = switch( f.ref.name ) {
 					var status : Null<Bool> = switch( f.ref.name ) {
@@ -408,7 +420,7 @@ class Linker {
 
 
 		// force shaders containing discard to be included
 		// force shaders containing discard to be included
 		for( s in shaders )
 		for( s in shaders )
-			if( s.hasDiscard ) {
+			if( s.hasDiscard || s.isCompute ) {
 				initDependencies(s);
 				initDependencies(s);
 				entry.deps.set(s, true);
 				entry.deps.set(s, true);
 			}
 			}
@@ -505,7 +517,7 @@ class Linker {
 				expr : expr,
 				expr : expr,
 			};
 			};
 		}
 		}
-		var funs = [
+		var funs = mode == Compute ? [build(Main,"main",v)] : [
 			build(Vertex, "vertex", v),
 			build(Vertex, "vertex", v),
 			build(Fragment, "fragment", f),
 			build(Fragment, "fragment", f),
 		];
 		];

+ 7 - 2
hxsl/MacroParser.hx

@@ -115,7 +115,7 @@ class MacroParser {
 			case "Channel3": return TChannel(3);
 			case "Channel3": return TChannel(3);
 			case "Channel4": return TChannel(4);
 			case "Channel4": return TChannel(4);
 			}
 			}
-		case TPath( { pack : [], name : name = ("Array"|"Buffer"), sub : null, params : [t, size] } ):
+		case TPath( { pack : [], name : name = ("Array"|"Buffer"|"RWBuffer"), sub : null, params : [t, size] } ):
 			var t = switch( t ) {
 			var t = switch( t ) {
 			case TPType(t): parseType(t, pos);
 			case TPType(t): parseType(t, pos);
 			default: null;
 			default: null;
@@ -129,7 +129,12 @@ class MacroParser {
 			default: null;
 			default: null;
 			}
 			}
 			if( t != null && size != null )
 			if( t != null && size != null )
-				return name == "Array" ? TArray(t, size) : TBuffer(t,size);
+				return switch( name ) {
+				case "Array": TArray(t, size);
+				case "Buffer": TBuffer(t,size,Uniform);
+				case "RWBuffer": TBuffer(t,size,RW);
+				default: throw "assert";
+				}
 		case TAnonymous(fl):
 		case TAnonymous(fl):
 			return TStruct([for( f in fl ) {
 			return TStruct([for( f in fl ) {
 				switch( f.kind ) {
 				switch( f.kind ) {

+ 9 - 3
hxsl/RuntimeShader.hx

@@ -44,7 +44,7 @@ class AllocGlobal {
 }
 }
 
 
 class RuntimeShaderData {
 class RuntimeShaderData {
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var data : Ast.ShaderData;
 	public var data : Ast.ShaderData;
 	public var code : String;
 	public var code : String;
 	public var params : AllocParam;
 	public var params : AllocParam;
@@ -55,7 +55,6 @@ class RuntimeShaderData {
 	public var texturesCount : Int;
 	public var texturesCount : Int;
 	public var buffers : AllocParam;
 	public var buffers : AllocParam;
 	public var bufferCount : Int;
 	public var bufferCount : Int;
-	public var consts : Array<Float>;
 	public function new() {
 	public function new() {
 	}
 	}
 }
 }
@@ -76,14 +75,18 @@ class RuntimeShader {
 	public var id : Int;
 	public var id : Int;
 	public var vertex : RuntimeShaderData;
 	public var vertex : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
+	public var compute(get,set) : RuntimeShaderData;
 	public var globals : Map<Int,Bool>;
 	public var globals : Map<Int,Bool>;
 
 
+	inline function get_compute() return vertex;
+	inline function set_compute(v) return vertex = v;
+
 	/**
 	/**
 		Signature of the resulting HxSL code.
 		Signature of the resulting HxSL code.
 		Several shaders with the different specification might still get the same resulting signature.
 		Several shaders with the different specification might still get the same resulting signature.
 	**/
 	**/
 	public var signature : String;
 	public var signature : String;
-	public var batchMode : Bool;
+	public var mode : hxsl.Linker.LinkMode;
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 
 
 	public function new() {
 	public function new() {
@@ -94,5 +97,8 @@ class RuntimeShader {
 		return globals.exists(gid);
 		return globals.exists(gid);
 	}
 	}
 
 
+	public function getShaders() {
+		return mode == Compute ? [compute] : [vertex, fragment];
+	}
 
 
 }
 }

+ 15 - 2
hxsl/Serializer.hx

@@ -86,12 +86,19 @@ class Serializer {
 				writeArr(vl,writeVar);
 				writeArr(vl,writeVar);
 		case TFun(variants):
 		case TFun(variants):
 			// not serialized
 			// not serialized
-		case TArray(t, size), TBuffer(t, size):
+		case TArray(t, size), TBuffer(t, size, Uniform):
 			writeType(t);
 			writeType(t);
 			switch (size) {
 			switch (size) {
 			case SConst(v): out.addByte(0); writeVarInt(v);
 			case SConst(v): out.addByte(0); writeVarInt(v);
 			case SVar(v): writeVar(v);
 			case SVar(v): writeVar(v);
 			}
 			}
+		case TBuffer(t, size, kind):
+			out.addByte(kind.getIndex() + 0x80);
+			writeType(t);
+			switch (size) {
+				case SConst(v): out.addByte(0); writeVarInt(v);
+				case SVar(v): writeVar(v);
+				}
 		case TChannel(size):
 		case TChannel(size):
 			out.addByte(size);
 			out.addByte(size);
 		case TVoid, TInt, TBool, TFloat, TString, TMat2, TMat3, TMat4, TMat3x4, TSampler2D, TSampler2DArray, TSamplerCube:
 		case TVoid, TInt, TBool, TFloat, TString, TMat2, TMat3, TMat4, TMat3x4, TSampler2D, TSampler2DArray, TSamplerCube:
@@ -136,9 +143,15 @@ class Serializer {
 			var v = readVar();
 			var v = readVar();
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 		case 16:
 		case 16:
+			var tag = input.readByte();
+			var kind = Uniform;
+			if( tag & 0x80 == 0 )
+				input.position--;
+			else
+				kind = BufferKind.createByIndex(tag & 0x7F);
 			var t = readType();
 			var t = readType();
 			var v = readVar();
 			var v = readVar();
-			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v));
+			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v), kind);
 		case 17:
 		case 17:
 			TChannel(input.readByte());
 			TChannel(input.readByte());
 		case 18: TMat2;
 		case 18: TMat2;

+ 31 - 21
hxsl/Splitter.hx

@@ -24,17 +24,19 @@ class Splitter {
 	public function new() {
 	public function new() {
 	}
 	}
 
 
-	public function split( s : ShaderData ) : { vertex : ShaderData, fragment : ShaderData } {
+	public function split( s : ShaderData ) : Array<ShaderData> {
 		var vfun = null, vvars = new Map();
 		var vfun = null, vvars = new Map();
 		var ffun = null, fvars = new Map();
 		var ffun = null, fvars = new Map();
+		var isCompute = false;
 		varNames = new Map();
 		varNames = new Map();
 		varMap = new Map();
 		varMap = new Map();
 		for( f in s.funs )
 		for( f in s.funs )
 			switch( f.kind ) {
 			switch( f.kind ) {
-			case Vertex:
+			case Vertex, Main:
 				vars = vvars;
 				vars = vvars;
 				vfun = f;
 				vfun = f;
 				checkExpr(f.expr);
 				checkExpr(f.expr);
+				if( f.kind == Main ) isCompute = true;
 			case Fragment:
 			case Fragment:
 				vars = fvars;
 				vars = fvars;
 				ffun = f;
 				ffun = f;
@@ -144,20 +146,22 @@ class Splitter {
 		for( v in fvars )
 		for( v in fvars )
 			checkVar(v, false, vvars, ffun.expr.p);
 			checkVar(v, false, vvars, ffun.expr.p);
 
 
-		ffun = {
-			ret : ffun.ret,
-			ref : ffun.ref,
-			kind : ffun.kind,
-			args : ffun.args,
-			expr : mapVars(ffun.expr),
-		};
-		switch( ffun.expr.e ) {
-		case TBlock(el):
-			for( e in finits )
-				el.unshift(e);
-		default:
-			finits.push(ffun.expr);
-			ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+		if( ffun != null ) {
+			ffun = {
+				ret : ffun.ret,
+				ref : ffun.ref,
+				kind : ffun.kind,
+				args : ffun.args,
+				expr : mapVars(ffun.expr),
+			};
+			switch( ffun.expr.e ) {
+			case TBlock(el):
+				for( e in finits )
+					el.unshift(e);
+			default:
+				finits.push(ffun.expr);
+				ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+			}
 		}
 		}
 
 
 		var vvars = [for( v in vvars ) if( !v.local ) v];
 		var vvars = [for( v in vvars ) if( !v.local ) v];
@@ -167,18 +171,24 @@ class Splitter {
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 
 
-		return {
-			vertex : {
+		return isCompute ? [
+			{
+				name : "main",
+				vars : [for( v in vvars ) v.v],
+				funs : [vfun],
+			}
+		] : [
+			{
 				name : "vertex",
 				name : "vertex",
 				vars : [for( v in vvars ) v.v],
 				vars : [for( v in vvars ) v.v],
 				funs : [vfun],
 				funs : [vfun],
 			},
 			},
-			fragment : {
+			{
 				name : "fragment",
 				name : "fragment",
 				vars : [for( v in fvars ) v.v],
 				vars : [for( v in fvars ) v.v],
 				funs : [ffun],
 				funs : [ffun],
-			},
-		};
+			}
+		];
 	}
 	}
 
 
 	function addExpr( f : TFunction, e : TExpr ) {
 	function addExpr( f : TFunction, e : TExpr ) {