Răsfoiți Sursa

merged refactors required for compute shaders

Nicolas Cannasse 1 an în urmă
părinte
comite
e80ba03d19

+ 2 - 1
all.hxml

@@ -3,7 +3,7 @@
 -lib hxbit
 --macro include('h3d')
 --macro include('h2d',true,['h2d.domkit'])
---macro include('hxsl',true,['hxsl.Macros','hxsl.CacheFile','hxsl.CacheFileBuilder','hxsl.Checker'])
+--macro include('hxsl',true,['hxsl.Macros','hxsl.CacheFileBuilder','hxsl.Checker'])
 --macro include('hxd',true,['hxd.res.FileTree','hxd.Res','hxd.impl.BitsBuilder','hxd.fmt.pak.Build','hxd.snd.openal','hxd.res.Config'])
 --no-output
 -D apicheck
@@ -18,6 +18,7 @@
 -lib hlsdl
 -lib hlopenal
 -xml heaps_hl.xml
+hxsl.CacheFileBuilder
 
 --next
 

+ 4 - 0
h3d/Buffer.hx

@@ -13,6 +13,10 @@ enum BufferFlag {
 		Used for shader input buffer
 	**/
 	UniformBuffer;
+	/**
+		Can be written
+	**/
+	ReadWriteBuffer;
 }
 
 @:allow(h3d.impl.MemoryManager)

+ 7 - 7
h3d/impl/DX12Driver.hx

@@ -875,7 +875,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 
-	function getBinaryPayload( vertex : Bool, code : String ) {
+	function getBinaryPayload( code : String ) {
 		var bin = code.indexOf("//BIN=");
 		if( bin >= 0 ) {
 			var end = code.indexOf("#", bin);
@@ -895,7 +895,7 @@ class DX12Driver extends h3d.impl.Driver {
 			sh.code = out.run(sh.data);
 			sh.code = rootStr + sh.code;
 		}
-		var bytes = getBinaryPayload(sh.vertex, sh.code);
+		var bytes = getBinaryPayload(sh.code);
 		if ( bytes == null ) {
 			return compiler.compile(sh.code, profile, args);
 		}
@@ -1010,10 +1010,10 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
-			var vis = sh.vertex ? VERTEX : PIXEL;
+			var vis = sh.kind == Vertex ? VERTEX : PIXEL;
 			var regs = new ShaderRegisters();
 			regs.globals = allocConsts(sh.globalsSize, vis, false);
-			regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
+			regs.params = allocConsts(sh.paramsSize, vis, sh.kind == Vertex ? vertexParamsCBV : fragmentParamsCBV);
 			if( sh.bufferCount > 0 ) {
 				regs.buffers = paramsCount;
 				for( i in 0...sh.bufferCount )
@@ -1612,10 +1612,10 @@ class DX12Driver extends h3d.impl.Driver {
 					t.lastFrame = frameCount;
 					var state = if ( t.isDepth() )
 						DEPTH_READ;
-					else if ( shader.vertex )
-						NON_PIXEL_SHADER_RESOURCE;
+					else if ( shader.kind == Fragment )
+						PIXEL_SHADER_RESOURCE
 					else
-						PIXEL_SHADER_RESOURCE;
+						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 

+ 3 - 3
h3d/impl/DirectXDriver.hx

@@ -861,9 +861,9 @@ class DirectXDriver extends h3d.impl.Driver {
 			shader.data.funs = null;
 			#end
 		}
-		var bytes = getBinaryPayload(shader.vertex, shader.code);
+		var bytes = getBinaryPayload(shader.kind == Vertex, shader.code);
 		if( bytes == null ) {
-			bytes = try dx.Driver.compileShader(shader.code, "", "main", (shader.vertex?"vs_":"ps_") + shaderVersion, OptimizationLevel3) catch( err : String ) {
+			bytes = try dx.Driver.compileShader(shader.code, "", "main", (shader.kind==Vertex?"vs_":"ps_") + shaderVersion, OptimizationLevel3) catch( err : String ) {
 				err = ~/^\(([0-9]+),([0-9]+)-([0-9]+)\)/gm.map(err, function(r) {
 					var line = Std.parseInt(r.matched(1));
 					var char = Std.parseInt(r.matched(2));
@@ -877,7 +877,7 @@ class DirectXDriver extends h3d.impl.Driver {
 		}
 		if( compileOnly )
 			return { s : null, bytes : bytes };
-		var s = shader.vertex ? Driver.createVertexShader(bytes) : Driver.createPixelShader(bytes);
+		var s = shader.kind == Vertex ? Driver.createVertexShader(bytes) : Driver.createPixelShader(bytes);
 		if( s == null ) {
 			if( hasDeviceError ) return null;
 			throw "Failed to create shader\n" + shader.code;

+ 6 - 0
h3d/impl/Driver.hx

@@ -317,4 +317,10 @@ class Driver {
 		return 0.;
 	}
 
+	// --- COMPUTE
+
+	public function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		throw "Not implemented";
+	}
+
 }

+ 10 - 10
h3d/impl/GlDriver.hx

@@ -41,15 +41,15 @@ private typedef ShaderCompiler = hxsl.GlslOut;
 
 private class CompiledShader {
 	public var s : GLShader;
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var globals : Uniform;
 	public var params : Uniform;
 	public var textures : Array<{ u : Uniform, t : hxsl.Ast.Type, mode : Int }>;
 	public var buffers : Array<Int>;
 	public var shader : hxsl.RuntimeShader.RuntimeShaderData;
-	public function new(s,vertex,shader) {
+	public function new(s,kind,shader) {
 		this.s = s;
-		this.vertex = vertex;
+		this.kind = kind;
 		this.shader = shader;
 	}
 }
@@ -275,7 +275,7 @@ class GlDriver extends Driver {
 	}
 
 	function compileShader( glout : ShaderCompiler, shader : hxsl.RuntimeShader.RuntimeShaderData ) {
-		var type = shader.vertex ? GL.VERTEX_SHADER : GL.FRAGMENT_SHADER;
+		var type = shader.kind == Vertex ? GL.VERTEX_SHADER : GL.FRAGMENT_SHADER;
 		var s = gl.createShader(type);
 		if( shader.code == null ){
 			shader.code = glout.run(shader.data);
@@ -296,11 +296,11 @@ class GlDriver extends Driver {
 				codeLines[i] = (i+1) + "\t" + codeLines[i];
 			throw "An error occurred compiling the shaders: " + log + line+"\n\n"+codeLines.join("\n");
 		}
-		return new CompiledShader(s, shader.vertex, shader);
+		return new CompiledShader(s, shader.kind, shader);
 	}
 
 	function initShader( p : CompiledProgram, s : CompiledShader, shader : hxsl.RuntimeShader.RuntimeShaderData, rt : hxsl.RuntimeShader ) {
-		var prefix = s.vertex ? "vertex" : "fragment";
+		var prefix = s.kind == Vertex ? "vertex" : "fragment";
 		s.globals = gl.getUniformLocation(p.p, prefix + "Globals");
 		s.params = gl.getUniformLocation(p.p, prefix + "Params");
 		s.textures = [];
@@ -346,9 +346,9 @@ class GlDriver extends Driver {
 			t = t.next;
 		}
 		if( shader.bufferCount > 0 ) {
-			s.buffers = [for( i in 0...shader.bufferCount ) gl.getUniformBlockIndex(p.p,(shader.vertex?"vertex_":"")+"uniform_buffer"+i)];
+			s.buffers = [for( i in 0...shader.bufferCount ) gl.getUniformBlockIndex(p.p,(shader.kind==Vertex?"vertex_":"")+"uniform_buffer"+i)];
 			var start = 0;
-			if( !s.vertex ) start = rt.vertex.bufferCount;
+			if( s.kind == Fragment ) start = rt.vertex.bufferCount;
 			for( i in 0...shader.bufferCount )
 				gl.uniformBlockBinding(p.p,s.buffers[i],i + start);
 		}
@@ -505,7 +505,7 @@ class GlDriver extends Driver {
 		case Buffers:
 			if( s.buffers != null ) {
 				var start = 0;
-				if( !s.vertex && curShader.vertex.buffers != null )
+				if( s.kind == Fragment && curShader.vertex.buffers != null )
 					start = curShader.vertex.buffers.length;
 				for( i in 0...s.buffers.length )
 					gl.bindBufferBase(GL.UNIFORM_BUFFER, i + start, buf.buffers[i].vbuf);
@@ -544,7 +544,7 @@ class GlDriver extends Driver {
 
 				if( pt.u == null ) continue;
 
-				var idx = s.vertex ? i : curShader.vertex.textures.length + i;
+				var idx = s.kind == Fragment ? curShader.vertex.textures.length + i : i;
 				if( boundTextures[idx] != t.t ) {
 					boundTextures[idx] = t.t;
 

+ 2 - 2
h3d/pass/Default.hx

@@ -43,7 +43,7 @@ class Default extends Base {
 		var o = @:privateAccess new h3d.pass.PassObject();
 		o.pass = p;
 		setupShaders(new h3d.pass.PassList(o));
-		return manager.compileShaders(o.shaders, p.batchMode);
+		return manager.compileShaders(o.shaders, p.batchMode ? Batch : Default);
 	}
 
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
@@ -68,7 +68,7 @@ class Default extends Base {
 				}
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 			}
-			p.shader = manager.compileShaders(shaders, p.pass.batchMode);
+			p.shader = manager.compileShaders(shaders, p.pass.batchMode ? Batch : Default);
 			p.shaders = shaders;
 			var t = p.shader.fragment.textures;
 			if( t == null || t.type.match(TArray(_)) )

+ 2 - 2
h3d/pass/ShaderManager.hx

@@ -270,11 +270,11 @@ class ShaderManager {
 		fill(buf.fragment, s.fragment);
 	}
 
-	public function compileShaders( shaders : hxsl.ShaderList, batchMode : Bool = false ) {
+	public function compileShaders( shaders : hxsl.ShaderList, mode : hxsl.Linker.LinkMode = Default ) {
 		globals.resetChannels();
 		for( s in shaders ) s.updateConstants(globals);
 		currentOutput.next = shaders;
-		var s = shaderCache.link(currentOutput, batchMode);
+		var s = shaderCache.link(currentOutput, mode);
 		currentOutput.next = null;
 		return s;
 	}

+ 1 - 1
h3d/scene/MeshBatch.hx

@@ -121,7 +121,7 @@ class MeshBatch extends MultiMaterial {
 
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var shaders = p.getShadersRec();
-				var rt = manager.compileShaders(shaders, false);
+				var rt = manager.compileShaders(shaders, Default);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 
 				var b = new BatchData();

+ 18 - 3
hxsl/Ast.hx

@@ -1,5 +1,10 @@
 package hxsl;
 
+enum BufferKind {
+	Uniform;
+	RW;
+}
+
 enum Type {
 	TVoid;
 	TInt;
@@ -17,7 +22,7 @@ enum Type {
 	TStruct( vl : Array<TVar> );
 	TFun( variants : Array<FunType> );
 	TArray( t : Type, size : SizeDecl );
-	TBuffer( t : Type, size : SizeDecl );
+	TBuffer( t : Type, size : SizeDecl, kind : BufferKind );
 	TChannel( size : Int );
 	TMat2;
 }
@@ -187,6 +192,7 @@ enum FunctionKind {
 	Fragment;
 	Init;
 	Helper;
+	Main;
 }
 
 enum TGlobal {
@@ -280,6 +286,8 @@ enum TGlobal {
 	IntBitsToFloat;
 	UintBitsToFloat;
 	RoundEven;
+	// compute
+	SetLayout;
 }
 
 enum Component {
@@ -418,7 +426,12 @@ class Tools {
 			prefix + "Vec" + size;
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
-		case TBuffer(t, s): "buffer "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
+		case TBuffer(t, s, k):
+			var prefix = switch( k ) {
+			case Uniform: "buffer";
+			case RW: "rwbuffer";
+			};
+			prefix+" "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TBytes(n): "Bytes" + n;
 		default: t.getName().substr(1);
 		}
@@ -457,6 +470,8 @@ class Tools {
 			return hasSideEffect(e) || hasSideEffect(index);
 		case TConst(_), TVar(_), TGlobal(_):
 			return false;
+		case TCall({ e : TGlobal(SetLayout) },_):
+			return true;
 		case TCall(e, pl):
 			if( !e.e.match(TGlobal(_)) )
 				return true;
@@ -545,7 +560,7 @@ class Tools {
 		case TMat3x4: 12;
 		case TBytes(s): s;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
-		case TArray(t, SConst(v)), TBuffer(t, SConst(v)): size(t) * v;
+		case TArray(t, SConst(v)), TBuffer(t, SConst(v),_): size(t) * v;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		}
 	}

+ 57 - 36
hxsl/Cache.hx

@@ -1,6 +1,7 @@
 package hxsl;
 using hxsl.Ast;
 import hxsl.RuntimeShader;
+import hxsl.Linker.LinkMode;
 
 class BatchInstanceParams {
 
@@ -195,7 +196,7 @@ class Cache {
 	}
 
 	@:noDebug
-	public function link( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	public function link( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var c = linkCache;
 		for( s in shaders ) {
 			var i = @:privateAccess s.instance;
@@ -207,11 +208,11 @@ class Cache {
 			c = cs;
 		}
 		if( c.linked == null )
-			c.linked = compileRuntimeShader(shaders, batchMode);
+			c.linked = compileRuntimeShader(shaders, mode);
 		return c.linked;
 	}
 
-	function compileRuntimeShader( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	function compileRuntimeShader( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var shaderDatas = [];
 		var index = 0;
 		for( s in shaders ) {
@@ -262,14 +263,14 @@ class Cache {
 		//TRACE = shaderId == 0;
 		#end
 
-		var linker = new hxsl.Linker(batchMode);
+		var linker = new hxsl.Linker(mode);
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			throw e;
 		}
 
-		if( batchMode ) {
+		if( mode == Batch ) {
 			function checkRec( v : TVar ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
@@ -302,7 +303,7 @@ class Cache {
 
 		var prev = s;
 		var splitter = new hxsl.Splitter();
-		var s = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
+		var sl = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
 
 		// params tracking
 		var paramVars = new Map();
@@ -319,41 +320,42 @@ class Cache {
 
 
 		#if debug
-		Printer.check(s.vertex,[prev]);
-		Printer.check(s.fragment,[prev]);
+		for( s in sl )
+			Printer.check(s,[prev]);
 		#end
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- SPLIT ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
-		var prev = s;
-		var s = new hxsl.Dce().dce(s.vertex, s.fragment);
+		var prev = sl;
+		var sl = new hxsl.Dce().dce(sl);
 
 		#if debug
-		Printer.check(s.vertex,[prev.vertex]);
-		Printer.check(s.fragment,[prev.fragment]);
+		for( i => s in sl )
+			Printer.check(s,[prev[i]]);
 		#end
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- DCE ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
-		var r = buildRuntimeShader(s.vertex, s.fragment, paramVars);
+		var r = buildRuntimeShader(sl, paramVars);
+		r.mode = mode;
 
 		#if shader_debug_dump
 		if( dbg != null ) {
 			dbg.writeString("----- FLATTEN ----\n\n");
-			dbg.writeString(Printer.shaderToString(r.vertex.data, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(r.fragment.data,DEBUG_IDS)+"\n\n");
+			for( s in r.getShaders() )
+				dbg.writeString(Printer.shaderToString(s.data, DEBUG_IDS) + "\n\n");
 		}
 		#end
 
@@ -366,9 +368,7 @@ class Cache {
 
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
-		r.signature = haxe.crypto.Md5.encode(Printer.shaderToString(r.vertex.data) + Printer.shaderToString(r.fragment.data));
-		r.batchMode = batchMode;
-
+		r.signature = haxe.crypto.Md5.encode([for( s in r.getShaders() ) Printer.shaderToString(s.data)].join(""));
 		var r2 = byID.get(r.signature);
 		if( r2 != null )
 			r.id = r2.id; // same id but different variable mapping
@@ -385,19 +385,33 @@ class Cache {
 		return r;
 	}
 
-	function buildRuntimeShader( vertex : ShaderData, fragment : ShaderData, paramVars ) {
+	function buildRuntimeShader( shaders : Array<ShaderData>, paramVars ) {
 		var r = new RuntimeShader();
-		r.vertex = flattenShader(vertex, Vertex, paramVars);
-		r.vertex.vertex = true;
-		r.fragment = flattenShader(fragment, Fragment, paramVars);
 		r.globals = new Map();
-		initGlobals(r, r.vertex);
-		initGlobals(r, r.fragment);
-
-		#if debug
-		Printer.check(r.vertex.data,[vertex]);
-		Printer.check(r.fragment.data,[fragment]);
-		#end
+		for( s in shaders ) {
+			var kind = switch( s.name ) {
+			case "vertex": Vertex;
+			case "fragment": Fragment;
+			case "main": Main;
+			default: throw "assert";
+			}
+			var fl = flattenShader(s, kind, paramVars);
+			fl.kind = kind;
+			switch( kind ) {
+			case Vertex:
+				r.vertex = fl;
+			case Fragment:
+				r.fragment = fl;
+			case Main:
+				r.compute = fl;
+			default:
+				throw "assert";
+			}
+			initGlobals(r, fl);
+			#if debug
+			Printer.check(fl.data,[s]);
+			#end
+		}
 		return r;
 	}
 
@@ -467,8 +481,15 @@ class Cache {
 					c.params = out[0];
 					c.paramsSize = size;
 				case TArray(TBuffer(_), _):
-					c.buffers = out[0];
-					c.bufferCount = out.length;
+					if( c.buffers == null ) {
+						c.buffers = out[0];
+						c.bufferCount = out.length;
+					} else {
+						var p = c.buffers;
+						while( p.next != null ) p = p.next;
+						p.next = out[0];
+						c.bufferCount += out.length;
+					}
 				default: throw "assert";
 				}
 			case Global:
@@ -574,7 +595,7 @@ class Cache {
 		inputOffset.qualifiers = [PerInstance(1)];
 
 		var vcount = declVar("Batch_Count",TInt,Param);
-		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount)),Param);
+		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount),Uniform),Param);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };

+ 17 - 12
hxsl/CacheFile.hx

@@ -1,6 +1,8 @@
 package hxsl;
 import hxsl.Ast.Tools;
 
+#if (sys || nodejs)
+
 private class NullShader extends hxsl.Shader {
 	static var SRC = {
 		var output : {
@@ -103,7 +105,7 @@ class CacheFile extends Cache {
 		} else if( !allowCompile )
 			throw "Missing " + file;
 		if( linkCache.linked == null ) {
-			var rt = link(makeDefaultShader(), false);
+			var rt = link(makeDefaultShader(), Default);
 			linkCache.linked = rt;
 			if( rt.vertex.code == null || rt.fragment.code == null ) {
 				wait.push(rt);
@@ -269,7 +271,7 @@ class CacheFile extends Cache {
 
 			for( r in runtimes ) {
 				var shaderList = null;
-				var batchMode = false;
+				var mode : hxsl.Linker.LinkMode = Default;
 				r.inst.reverse();
 				for( i in r.inst ) {
 					var s = Type.createEmptyInstance(hxsl.Shader);
@@ -282,7 +284,7 @@ class CacheFile extends Cache {
 							}
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							i.shader = { version : null, shader : sh.shader };
-							batchMode = true;
+							mode = Batch;
 						}
 						s.constBits = i.bits;
 						s.shader = i.shader.shader;
@@ -293,7 +295,7 @@ class CacheFile extends Cache {
 				}
 				if( r == null ) continue;
 				//log("Recompile "+[for( s in shaderList ) shaderName(s)]);
-				var rt = link(shaderList, batchMode); // will compile + update linkMap
+				var rt = link(shaderList, mode); // will compile + update linkMap
 				if( rt.spec.signature != r.specSign ) {
 					var signParts = [for( i in rt.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 					throw "assert";
@@ -320,6 +322,7 @@ class CacheFile extends Cache {
 				if( spec == null )
 					continue;
 
+				r.mode = Default;
 				r.signature = spec.signature;
 				var shaderList = null;
 				spec.inst.reverse();
@@ -334,7 +337,7 @@ class CacheFile extends Cache {
 							}
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							i.shader = { version : null, shader : sh.shader };
-							r.batchMode = true;
+							r.mode = Batch;
 						}
 						// pseudo instance
 						var scache = i.shader.shader.instanceCache;
@@ -586,14 +589,14 @@ class CacheFile extends Cache {
 
 	function cleanRuntimeData(r:hxsl.RuntimeShader.RuntimeShaderData) {
 		var rc = new hxsl.RuntimeShader.RuntimeShaderData();
-		rc.vertex = r.vertex;
+		rc.kind = r.kind;
 		rc.data = {
 			name : null,
 			vars : [],
 			funs : null,
 		};
 		for( v in r.data.vars )
-			if( v.kind == (r.vertex ? Input : Output) ) {
+			if( v.kind == (r.kind == Vertex ? Input : Output) ) {
 				rc.data.vars.push({
 					id : v.id,
 					name : v.name,
@@ -646,8 +649,8 @@ class CacheFile extends Cache {
 	}
 
 	function sortBySpec( r1 : RuntimeShader, r2 : RuntimeShader ) {
-		if( r1.batchMode != r2.batchMode )
-			return r1.batchMode ? 1 : -1;
+		if( r1.mode != r2.mode )
+			return r1.mode.getIndex() - r2.mode.getIndex();
 		var minLen = hxd.Math.imin(r1.spec.instances.length, r2.spec.instances.length);
 		for( i in 0...minLen ) {
 			var i1 = r1.spec.instances[i];
@@ -706,7 +709,7 @@ class CacheFile extends Cache {
 
 	public dynamic function onMissingShader(shaders:hxsl.ShaderList) {
 		log("Missing shader " + [for( s in shaders ) shaderName(s)]);
-		return link(null, false); // default fallback
+		return link(null, Default); // default fallback
 	}
 
 	public dynamic function onNewShader(r:RuntimeShader) {
@@ -738,7 +741,7 @@ class CacheFile extends Cache {
 		for( i in s.spec.instances ) {
 			var inst = shaders.get(i.shader.data.name);
 			if( inst == null ) {
-				if( s.batchMode && StringTools.startsWith(i.shader.data.name,"batchShader_") )
+				if( s.mode == Batch && StringTools.startsWith(i.shader.data.name,"batchShader_") )
 					continue;
 				var version = getShaderVersion(i.shader);
 				inst = { shader : i.shader, version : version };
@@ -779,4 +782,6 @@ class CacheFile extends Cache {
 		}
 	}
 
-}
+}
+
+#end

+ 4 - 5
hxsl/CacheFileBuilder.hx

@@ -137,7 +137,7 @@ class CacheFileBuilder {
 			throw "DirectX compilation requires -lib hldx without -D dx12";
 			#end
 		case OpenGL:
-			if( rd.vertex ) {
+			if( rd.kind == Vertex ) {
 				// both vertex and fragment needs to be compiled with the same GlslOut !
 				glout = new GlslOut();
 				glout.version = 150;
@@ -172,7 +172,7 @@ class CacheFileBuilder {
 			var tmpSrc = tmpFile + ".hlsl";
 			var tmpOut = tmpFile + ".sb";
 			sys.io.File.saveContent(tmpSrc, code);
-			var args = ["-T", (rd.vertex ? "vs_" : "ps_") + dxShaderVersion,"-O3","-Fo", tmpOut, tmpSrc];
+			var args = ["-T", (rd.kind == Vertex ? "vs_" : "ps_") + dxShaderVersion,"-O3","-Fo", tmpOut, tmpSrc];
 			var p = new sys.io.Process(Sys.getEnv("XboxOneXDKLatest")+ "xdk\\FXC\\amd64\\fxc.exe", args);
 			var error = p.stderr.readAll().toString();
 			var ecode = p.exitCode();
@@ -216,13 +216,12 @@ class CacheFileBuilder {
 			throw "-lib hldx and -D dx12 are required to generate binaries for XBoxSeries";
 			#end
 		case NX:
-			if( rd.vertex )
+			if( rd.kind == Vertex )
 				glout = new hxsl.NXGlslOut();
 			return { code : glout.run(rd.data), bytes : null };
 		case NXBinaries:
-			if( rd.vertex )
+			if( rd.kind == Vertex ) {
 				glout = new hxsl.NXGlslOut();
-			if ( rd.vertex ) {
 				vertexOut = glout.run(rd.data);
 				return { code : vertexOut, bytes : null }; // binary is in fragment.code
 			}

+ 12 - 3
hxsl/Checker.hx

@@ -188,6 +188,8 @@ class Checker {
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 			case IntBitsToFloat, UintBitsToFloat:
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
+			case SetLayout:
+				[{ args : [{ name : "x", type : TInt },{ name : "y", type : TInt },{ name : "z", type : TInt }], ret : TVoid }];
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 				null;
 			}
@@ -246,6 +248,7 @@ class Checker {
 			var kind = switch( f.name ) {
 			case "vertex":  Vertex;
 			case "fragment": Fragment;
+			case "main": Main;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			}
 			if( args.length != 0 && kind != Helper )
@@ -337,6 +340,8 @@ class Checker {
 			switch( v.kind ) {
 			case Local, Var, Output:
 				return;
+			case Param if( v.type.match(TBuffer(_,_,RW)) ):
+				return;
 			default:
 			}
 		case TSwiz(e, _):
@@ -631,7 +636,7 @@ class Checker {
 			default: unify(e2.t, TInt, e2.p);
 			}
 			switch( e1.t ) {
-			case TArray(t, size), TBuffer(t,size):
+			case TArray(t, size), TBuffer(t,size,_):
 				switch( [size, e2.e] ) {
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 					error("Indexing outside array bounds", e.pos);
@@ -849,7 +854,7 @@ class Checker {
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 			}
 			return parent.type;
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			switch( t ) {
 			case TArray(_):
 				error("Multidimentional arrays are not allowed", pos);
@@ -894,7 +899,11 @@ class Checker {
 				SVar(v2);
 			}
 			t = makeVarType(t,parent,pos);
-			return vt.match(TArray(_)) ? TArray(t, s) : TBuffer(t,s);
+			return switch( vt ) {
+			case TArray(_): TArray(t, s);
+			case TBuffer(_,_,kind): TBuffer(t,s,kind);
+			default: throw "assert";
+			}
 		default:
 			return vt;
 		}

+ 24 - 30
hxsl/Dce.hx

@@ -34,30 +34,27 @@ class Dce {
 		#end
 	}
 
-	public function dce( vertex : ShaderData, fragment : ShaderData ) {
+	public function dce( shaders : Array<ShaderData> ) {
 		// collect vars dependencies
 		used = new Map();
 		channelVars = [];
 
 		var inputs = [];
-		for( v in vertex.vars ) {
-			var i = get(v);
-			if( v.kind == Input )
-				inputs.push(i);
-			if( v.kind == Output )
-				i.keep = true;
-		}
-		for( v in fragment.vars ) {
-			var i = get(v);
-			if( v.kind == Output )
-				i.keep = true;
+		for( s in shaders ) {
+			for( v in s.vars ) {
+				var i = get(v);
+				if( v.kind == Input )
+					inputs.push(i);
+				if( v.kind == Output || v.type.match(TBuffer(_,_,RW)) )
+					i.keep = true;
+			}
 		}
 
 		// collect dependencies
-		for( f in vertex.funs )
-			check(f.expr, [], []);
-		for( f in fragment.funs )
-			check(f.expr, [], []);
+		for( s in shaders ) {
+			for( f in s.funs )
+				check(f.expr, [], []);
+		}
 
 		var outExprs = [];
 		while( true ) {
@@ -74,10 +71,10 @@ class Dce {
 				markRec(v);
 
 			outExprs = [];
-			for( f in vertex.funs )
-				outExprs.push(mapExpr(f.expr, false));
-			for( f in fragment.funs )
-				outExprs.push(mapExpr(f.expr, false));
+			for( s in shaders ) {
+				for( f in s.funs )
+					outExprs.push(mapExpr(f.expr, false));
+			}
 
 			// post add conditional branches
 			markAsKeep = false;
@@ -86,22 +83,19 @@ class Dce {
 			if( !markAsKeep ) break;
 		}
 
-		for( f in vertex.funs )
-			f.expr = outExprs.shift();
-		for( f in fragment.funs )
-			f.expr = outExprs.shift();
+		for( s in shaders ) {
+			for( f in s.funs )
+				f.expr = outExprs.shift();
+		}
 
 		for( v in used ) {
 			if( v.used ) continue;
 			if( v.v.kind == VarKind.Input) continue;
-			vertex.vars.remove(v.v);
-			fragment.vars.remove(v.v);
+			for( s in shaders )
+				s.vars.remove(v.v);
 		}
 
-		return {
-			fragment : fragment,
-			vertex : vertex,
-		}
+		return shaders.copy();
 	}
 
 	function get( v : TVar ) {

+ 12 - 4
hxsl/Eval.hx

@@ -52,18 +52,26 @@ class Eval {
 		switch( v2.type ) {
 		case TStruct(vl):
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
-		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs)):
+		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs), _):
 			var c = constants.get(vs.id);
 			if( c != null )
 				switch( c ) {
 				case TConst(CInt(v)):
-					v2.type = v2.type.match(TArray(_)) ? TArray(t, SConst(v)) : TBuffer(t, SConst(v));
+					v2.type = switch( v2.type ) {
+					case TArray(_): TArray(t, SConst(v));
+					case TBuffer(_,_,kind): TBuffer(t, SConst(v), kind);
+					default: throw "assert";
+					};
 				default:
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 				}
 			else {
 				var vs2 = mapVar(vs);
-				v2.type = v2.type.match(TArray(_)) ? TArray(t, SVar(vs2)) : TBuffer(t, SVar(vs2));
+				v2.type = switch( v2.type ) {
+				case TArray(_): TArray(t, SVar(vs2));
+				case TBuffer(_,_,kind): TBuffer(t, SVar(vs2), kind);
+				default: throw "assert";
+				}
 			}
 		default:
 		}
@@ -81,7 +89,7 @@ class Eval {
 			return false;
 		case TArray(t, _):
 			return checkSamplerRec(t);
-		case TBuffer(_, size):
+		case TBuffer(_):
 			return true;
 		default:
 		}

+ 9 - 5
hxsl/Flatten.hx

@@ -42,6 +42,7 @@ class Flatten {
 		var prefix = switch( kind ) {
 		case Vertex: "vertex";
 		case Fragment: "fragment";
+		case Main: "compute";
 		default: throw "assert";
 		}
 		pack(prefix + "Globals", Global, globals, VFloat);
@@ -50,7 +51,8 @@ class Flatten {
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
-		packBuffers(allVars);
+		packBuffers("buffers", allVars, Uniform);
+		packBuffers("rwbuffers", allVars, RW);
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		return {
 			name : s.name,
@@ -259,22 +261,24 @@ class Flatten {
 		return alloc;
 	}
 
-	function packBuffers( vars : Array<TVar> ) {
+	function packBuffers( name : String, vars : Array<TVar>, kind ) {
 		var alloc = new Array<Alloc>();
 		var g : TVar = {
 			id : Tools.allocVarId(),
-			name : "buffers",
+			name : name,
 			type : TVoid,
 			kind : Param,
 		};
 		for( v in vars )
-			if( v.type.match(TBuffer(_)) ) {
+			switch( v.type ) {
+			case TBuffer(_,_,k) if( kind == k ):
 				var a = new Alloc(g, null, alloc.length, 1);
 				a.v = v;
 				alloc.push(a);
 				outVars.push(v);
+			default:
 			}
-		g.type = TArray(TBuffer(TVoid,SConst(0)),SConst(alloc.length));
+		g.type = TArray(TBuffer(TVoid,SConst(0),kind),SConst(alloc.length));
 		allocData.set(g, alloc);
 	}
 

+ 3 - 2
hxsl/GlslOut.hx

@@ -181,12 +181,13 @@ class GlslOut {
 			case SConst(n): add(n);
 			}
 			add("]");
-		case TBuffer(t, size):
+		case TBuffer(t, size, kind):
+			if( kind != Uniform ) throw "TODO";
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add(" { ");
 			v.type = TArray(t,size);
 			addVar(v);
-			v.type = TBuffer(t,size);
+			v.type = TBuffer(t,size,kind);
 			add("; }");
 		default:
 			addType(v.type);

+ 53 - 21
hxsl/HlslOut.hx

@@ -101,13 +101,19 @@ class HlslOut {
 	var exprValues : Array<String>;
 	var locals : Map<Int,TVar>;
 	var decls : Array<String>;
-	var isVertex : Bool;
+	var kind : FunctionKind;
 	var allNames : Map<String, Int>;
 	var samplers : Map<Int, Array<Int>>;
+	var computeLayout = [1,1,1];
 	public var varNames : Map<Int,String>;
 	public var baseRegister : Int = 0;
 
 	var varAccess : Map<Int,String>;
+	var isVertex(get,never) : Bool;
+	var isCompute(get,never) : Bool;
+
+	inline function get_isCompute() return kind == Main;
+	inline function get_isVertex() return kind == Vertex;
 
 	public function new() {
 		varNames = new Map();
@@ -175,7 +181,7 @@ class HlslOut {
 			add(" }");
 		case TFun(_):
 			add("function");
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addType(t);
 			add("[");
 			switch( size ) {
@@ -201,7 +207,7 @@ class HlslOut {
 
 	function addVar( v : TVar ) {
 		switch( v.type ) {
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addVar({
 				id : v.id,
 				name : v.name,
@@ -298,6 +304,8 @@ class HlslOut {
 			var acc = varAccess.get(v.id);
 			if( acc != null ) add(acc);
 			ident(v);
+		case TCall({ e : TGlobal(SetLayout) },_):
+			// ignore
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 			addValue(args[0], tabs);
 			switch( g ) {
@@ -687,6 +695,8 @@ class HlslOut {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 		switch( e.e )  {
 		case TGlobal(g): m.set(g,true);
+		case TCall({ e : TGlobal(SetLayout) }, [{ e : TConst(CInt(x)) }, { e : TConst(CInt(y)) }, { e : TConst(CInt(z)) }]):
+			computeLayout = [x,y,z];
 		default: e.iter(collectGlobals.bind(m));
 		}
 	}
@@ -709,7 +719,7 @@ class HlslOut {
 			collectGlobals(foundGlobals, f.expr);
 
 		add("struct s_input {\n");
-		if( !isVertex )
+		if( kind == Fragment )
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 		for( v in s.vars )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
@@ -722,14 +732,16 @@ class HlslOut {
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 		add("};\n\n");
 
-		add("struct s_output {\n");
-		for( v in s.vars )
-			if( v.kind == Output )
-				declVar("_out.", v);
-		for( v in s.vars )
-			if( v.kind == Var && isVertex )
-				declVar("_out.", v);
-		add("};\n\n");
+		if( !isCompute ) {
+			add("struct s_output {\n");
+			for( v in s.vars )
+				if( v.kind == Output )
+					declVar("_out.", v);
+			for( v in s.vars )
+				if( v.kind == Var && isVertex )
+					declVar("_out.", v);
+			add("};\n\n");
+		}
 	}
 
 	function initGlobals( s : ShaderData ) {
@@ -770,10 +782,25 @@ class HlslOut {
 
 		var bufCount = 0;
 		for( b in buffers ) {
-			add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
-			addVar(b);
-			add("; };\n");
-			bufCount++;
+			switch( b.type ) {
+			case TBuffer(t, size, kind):
+				switch( kind ) {
+				case Uniform:
+					add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
+					addVar(b);
+					add("; };\n");
+					bufCount++;
+				case RW:
+					add('RWStructuredBuffer<');
+					addType(t);
+					add('> ');
+					ident(b);
+					add(' : register(u${bufCount+baseRegister+2});');
+					bufCount++;
+				}
+			default:
+				throw "assert";
+			}
 		}
 		if( bufCount > 0 ) add("\n");
 
@@ -795,7 +822,8 @@ class HlslOut {
 
 	function initStatics( s : ShaderData ) {
 		add(STATIC + "s_input _in;\n");
-		add(STATIC + "s_output _out;\n");
+		if( !isCompute )
+			add(STATIC + "s_output _out;\n");
 
 		add("\n");
 		for( v in s.vars )
@@ -808,7 +836,11 @@ class HlslOut {
 	}
 
 	function emitMain( expr : TExpr ) {
-		add("s_output main( s_input __in ) {\n");
+		if( isCompute )
+			add('[numthreads(${computeLayout[0]},${computeLayout[1]},${computeLayout[2]})] void ');
+		else
+			add('s_output ');
+		add("main( s_input __in ) {\n");
 		add("\t_in = __in;\n");
 		switch( expr.e ) {
 		case TBlock(el):
@@ -820,7 +852,8 @@ class HlslOut {
 		default:
 			addExpr(expr, "");
 		}
-		add("\treturn _out;\n");
+		if( !isCompute )
+			add("\treturn _out;\n");
 		add("}");
 	}
 
@@ -848,8 +881,7 @@ class HlslOut {
 
 		if( s.funs.length != 1 ) throw "assert";
 		var f = s.funs[0];
-		isVertex = f.kind == Vertex;
-
+		kind = f.kind;
 		varAccess = new Map();
 		samplers = new Map();
 		initVars(s);

+ 19 - 7
hxsl/Linker.hx

@@ -1,6 +1,12 @@
 package hxsl;
 using hxsl.Ast;
 
+enum LinkMode {
+	Default;
+	Batch;
+	Compute;
+}
+
 private class AllocatedVar {
 	public var id : Int;
 	public var v : TVar;
@@ -28,6 +34,7 @@ private class ShaderInfos {
 	public var vertex : Null<Bool>;
 	public var onStack : Bool;
 	public var hasDiscard : Bool;
+	public var isCompute : Bool;
 	public var marked : Null<Bool>;
 	public function new(n, v) {
 		this.name = n;
@@ -49,12 +56,12 @@ class Linker {
 	var varIdMap : Map<Int,Int>;
 	var locals : Map<Int,Bool>;
 	var curInstance : Int;
-	var batchMode : Bool;
+	var mode : LinkMode;
 	var isBatchShader : Bool;
 	var debugDepth = 0;
 
-	public function new(batchMode=false) {
-		this.batchMode = batchMode;
+	public function new(mode) {
+		this.mode = mode;
 	}
 
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
@@ -348,7 +355,7 @@ class Linker {
 		curInstance = 0;
 		var outVars = [];
 		for( s in shadersData ) {
-			isBatchShader = batchMode && StringTools.startsWith(s.name,"batchShader_");
+			isBatchShader = mode == Batch && StringTools.startsWith(s.name,"batchShader_");
 			for( v in s.vars ) {
 				var v2 = allocVar(v, null, s.name);
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
@@ -375,8 +382,13 @@ class Linker {
 				if( v.kind == null ) throw "assert";
 				switch( v.kind ) {
 				case Vertex, Fragment:
+					if( mode == Compute )
+						throw "Unexpected "+v.kind.getName().toLowerCase()+"() function in compute shader";
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
-
+				case Main:
+					if( mode != Compute )
+						throw "Unexpected main() outside compute shader";
+					addShader(s.name, true, f.expr, priority).isCompute = true;
 				case Init:
 					var prio : Array<Int>;
 					var status : Null<Bool> = switch( f.ref.name ) {
@@ -408,7 +420,7 @@ class Linker {
 
 		// force shaders containing discard to be included
 		for( s in shaders )
-			if( s.hasDiscard ) {
+			if( s.hasDiscard || s.isCompute ) {
 				initDependencies(s);
 				entry.deps.set(s, true);
 			}
@@ -505,7 +517,7 @@ class Linker {
 				expr : expr,
 			};
 		}
-		var funs = [
+		var funs = mode == Compute ? [build(Main,"main",v)] : [
 			build(Vertex, "vertex", v),
 			build(Fragment, "fragment", f),
 		];

+ 7 - 2
hxsl/MacroParser.hx

@@ -115,7 +115,7 @@ class MacroParser {
 			case "Channel3": return TChannel(3);
 			case "Channel4": return TChannel(4);
 			}
-		case TPath( { pack : [], name : name = ("Array"|"Buffer"), sub : null, params : [t, size] } ):
+		case TPath( { pack : [], name : name = ("Array"|"Buffer"|"RWBuffer"), sub : null, params : [t, size] } ):
 			var t = switch( t ) {
 			case TPType(t): parseType(t, pos);
 			default: null;
@@ -129,7 +129,12 @@ class MacroParser {
 			default: null;
 			}
 			if( t != null && size != null )
-				return name == "Array" ? TArray(t, size) : TBuffer(t,size);
+				return switch( name ) {
+				case "Array": TArray(t, size);
+				case "Buffer": TBuffer(t,size,Uniform);
+				case "RWBuffer": TBuffer(t,size,RW);
+				default: throw "assert";
+				}
 		case TAnonymous(fl):
 			return TStruct([for( f in fl ) {
 				switch( f.kind ) {

+ 9 - 2
hxsl/RuntimeShader.hx

@@ -44,7 +44,7 @@ class AllocGlobal {
 }
 
 class RuntimeShaderData {
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var data : Ast.ShaderData;
 	public var code : String;
 	public var params : AllocParam;
@@ -75,14 +75,18 @@ class RuntimeShader {
 	public var id : Int;
 	public var vertex : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
+	public var compute(get,set) : RuntimeShaderData;
 	public var globals : Map<Int,Bool>;
 
+	inline function get_compute() return vertex;
+	inline function set_compute(v) return vertex = v;
+
 	/**
 		Signature of the resulting HxSL code.
 		Several shaders with the different specification might still get the same resulting signature.
 	**/
 	public var signature : String;
-	public var batchMode : Bool;
+	public var mode : hxsl.Linker.LinkMode;
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 
 	public function new() {
@@ -93,5 +97,8 @@ class RuntimeShader {
 		return globals.exists(gid);
 	}
 
+	public function getShaders() {
+		return mode == Compute ? [compute] : [vertex, fragment];
+	}
 
 }

+ 15 - 2
hxsl/Serializer.hx

@@ -86,7 +86,14 @@ class Serializer {
 				writeArr(vl,writeVar);
 		case TFun(variants):
 			// not serialized
-		case TArray(t, size), TBuffer(t, size):
+		case TArray(t, size), TBuffer(t, size, Uniform):
+			writeType(t);
+			switch (size) {
+			case SConst(v): out.addByte(0); writeVarInt(v);
+			case SVar(v): writeVar(v);
+			}
+		case TBuffer(t, size, kind):
+			out.addByte(kind.getIndex() + 0x80);
 			writeType(t);
 			switch (size) {
 			case SConst(v): out.addByte(0); writeVarInt(v);
@@ -136,9 +143,15 @@ class Serializer {
 			var v = readVar();
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 		case 16:
+			var tag = input.readByte();
+			var kind = Uniform;
+			if( tag & 0x80 == 0 )
+				input.position--;
+			else
+				kind = BufferKind.createByIndex(tag & 0x7F);
 			var t = readType();
 			var v = readVar();
-			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v));
+			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v), kind);
 		case 17:
 			TChannel(input.readByte());
 		case 18: TMat2;

+ 31 - 21
hxsl/Splitter.hx

@@ -24,17 +24,19 @@ class Splitter {
 	public function new() {
 	}
 
-	public function split( s : ShaderData ) : { vertex : ShaderData, fragment : ShaderData } {
+	public function split( s : ShaderData ) : Array<ShaderData> {
 		var vfun = null, vvars = new Map();
 		var ffun = null, fvars = new Map();
+		var isCompute = false;
 		varNames = new Map();
 		varMap = new Map();
 		for( f in s.funs )
 			switch( f.kind ) {
-			case Vertex:
+			case Vertex, Main:
 				vars = vvars;
 				vfun = f;
 				checkExpr(f.expr);
+				if( f.kind == Main ) isCompute = true;
 			case Fragment:
 				vars = fvars;
 				ffun = f;
@@ -144,20 +146,22 @@ class Splitter {
 		for( v in fvars )
 			checkVar(v, false, vvars, ffun.expr.p);
 
-		ffun = {
-			ret : ffun.ret,
-			ref : ffun.ref,
-			kind : ffun.kind,
-			args : ffun.args,
-			expr : mapVars(ffun.expr),
-		};
-		switch( ffun.expr.e ) {
-		case TBlock(el):
-			for( e in finits )
-				el.unshift(e);
-		default:
-			finits.push(ffun.expr);
-			ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+		if( ffun != null ) {
+			ffun = {
+				ret : ffun.ret,
+				ref : ffun.ref,
+				kind : ffun.kind,
+				args : ffun.args,
+				expr : mapVars(ffun.expr),
+			};
+			switch( ffun.expr.e ) {
+			case TBlock(el):
+				for( e in finits )
+					el.unshift(e);
+			default:
+				finits.push(ffun.expr);
+				ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+			}
 		}
 
 		var vvars = [for( v in vvars ) if( !v.local ) v];
@@ -167,18 +171,24 @@ class Splitter {
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 
-		return {
-			vertex : {
+		return isCompute ? [
+			{
+				name : "main",
+				vars : [for( v in vvars ) v.v],
+				funs : [vfun],
+			}
+		] : [
+			{
 				name : "vertex",
 				vars : [for( v in vvars ) v.v],
 				funs : [vfun],
 			},
-			fragment : {
+			{
 				name : "fragment",
 				vars : [for( v in fvars ) v.v],
 				funs : [ffun],
-			},
-		};
+			}
+		];
 	}
 
 	function addExpr( f : TFunction, e : TExpr ) {