2
0
Эх сурвалжийг харах

merged refactors required for compute shaders

Nicolas Cannasse 1 жил өмнө
parent
commit
e80ba03d19

+ 2 - 1
all.hxml

@@ -3,7 +3,7 @@
 -lib hxbit
 -lib hxbit
 --macro include('h3d')
 --macro include('h3d')
 --macro include('h2d',true,['h2d.domkit'])
 --macro include('h2d',true,['h2d.domkit'])
---macro include('hxsl',true,['hxsl.Macros','hxsl.CacheFile','hxsl.CacheFileBuilder','hxsl.Checker'])
+--macro include('hxsl',true,['hxsl.Macros','hxsl.CacheFileBuilder','hxsl.Checker'])
 --macro include('hxd',true,['hxd.res.FileTree','hxd.Res','hxd.impl.BitsBuilder','hxd.fmt.pak.Build','hxd.snd.openal','hxd.res.Config'])
 --macro include('hxd',true,['hxd.res.FileTree','hxd.Res','hxd.impl.BitsBuilder','hxd.fmt.pak.Build','hxd.snd.openal','hxd.res.Config'])
 --no-output
 --no-output
 -D apicheck
 -D apicheck
@@ -18,6 +18,7 @@
 -lib hlsdl
 -lib hlsdl
 -lib hlopenal
 -lib hlopenal
 -xml heaps_hl.xml
 -xml heaps_hl.xml
+hxsl.CacheFileBuilder
 
 
 --next
 --next
 
 

+ 4 - 0
h3d/Buffer.hx

@@ -13,6 +13,10 @@ enum BufferFlag {
 		Used for shader input buffer
 		Used for shader input buffer
 	**/
 	**/
 	UniformBuffer;
 	UniformBuffer;
+	/**
+		Can be written
+	**/
+	ReadWriteBuffer;
 }
 }
 
 
 @:allow(h3d.impl.MemoryManager)
 @:allow(h3d.impl.MemoryManager)

+ 7 - 7
h3d/impl/DX12Driver.hx

@@ -875,7 +875,7 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 	static var VERTEX_FORMATS = [null,null,R32G32_FLOAT,R32G32B32_FLOAT,R32G32B32A32_FLOAT];
 
 
-	function getBinaryPayload( vertex : Bool, code : String ) {
+	function getBinaryPayload( code : String ) {
 		var bin = code.indexOf("//BIN=");
 		var bin = code.indexOf("//BIN=");
 		if( bin >= 0 ) {
 		if( bin >= 0 ) {
 			var end = code.indexOf("#", bin);
 			var end = code.indexOf("#", bin);
@@ -895,7 +895,7 @@ class DX12Driver extends h3d.impl.Driver {
 			sh.code = out.run(sh.data);
 			sh.code = out.run(sh.data);
 			sh.code = rootStr + sh.code;
 			sh.code = rootStr + sh.code;
 		}
 		}
-		var bytes = getBinaryPayload(sh.vertex, sh.code);
+		var bytes = getBinaryPayload(sh.code);
 		if ( bytes == null ) {
 		if ( bytes == null ) {
 			return compiler.compile(sh.code, profile, args);
 			return compiler.compile(sh.code, profile, args);
 		}
 		}
@@ -1010,10 +1010,10 @@ class DX12Driver extends h3d.impl.Driver {
 
 
 
 
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
 		function allocParams( sh : hxsl.RuntimeShader.RuntimeShaderData ) {
-			var vis = sh.vertex ? VERTEX : PIXEL;
+			var vis = sh.kind == Vertex ? VERTEX : PIXEL;
 			var regs = new ShaderRegisters();
 			var regs = new ShaderRegisters();
 			regs.globals = allocConsts(sh.globalsSize, vis, false);
 			regs.globals = allocConsts(sh.globalsSize, vis, false);
-			regs.params = allocConsts(sh.paramsSize, vis, sh.vertex ? vertexParamsCBV : fragmentParamsCBV);
+			regs.params = allocConsts(sh.paramsSize, vis, sh.kind == Vertex ? vertexParamsCBV : fragmentParamsCBV);
 			if( sh.bufferCount > 0 ) {
 			if( sh.bufferCount > 0 ) {
 				regs.buffers = paramsCount;
 				regs.buffers = paramsCount;
 				for( i in 0...sh.bufferCount )
 				for( i in 0...sh.bufferCount )
@@ -1612,10 +1612,10 @@ class DX12Driver extends h3d.impl.Driver {
 					t.lastFrame = frameCount;
 					t.lastFrame = frameCount;
 					var state = if ( t.isDepth() )
 					var state = if ( t.isDepth() )
 						DEPTH_READ;
 						DEPTH_READ;
-					else if ( shader.vertex )
-						NON_PIXEL_SHADER_RESOURCE;
+					else if ( shader.kind == Fragment )
+						PIXEL_SHADER_RESOURCE
 					else
 					else
-						PIXEL_SHADER_RESOURCE;
+						NON_PIXEL_SHADER_RESOURCE;
 					transition(t.t, state);
 					transition(t.t, state);
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 					Driver.createShaderResourceView(t.t.res, tdesc, srv.offset(i * frame.shaderResourceViews.stride));
 
 

+ 3 - 3
h3d/impl/DirectXDriver.hx

@@ -861,9 +861,9 @@ class DirectXDriver extends h3d.impl.Driver {
 			shader.data.funs = null;
 			shader.data.funs = null;
 			#end
 			#end
 		}
 		}
-		var bytes = getBinaryPayload(shader.vertex, shader.code);
+		var bytes = getBinaryPayload(shader.kind == Vertex, shader.code);
 		if( bytes == null ) {
 		if( bytes == null ) {
-			bytes = try dx.Driver.compileShader(shader.code, "", "main", (shader.vertex?"vs_":"ps_") + shaderVersion, OptimizationLevel3) catch( err : String ) {
+			bytes = try dx.Driver.compileShader(shader.code, "", "main", (shader.kind==Vertex?"vs_":"ps_") + shaderVersion, OptimizationLevel3) catch( err : String ) {
 				err = ~/^\(([0-9]+),([0-9]+)-([0-9]+)\)/gm.map(err, function(r) {
 				err = ~/^\(([0-9]+),([0-9]+)-([0-9]+)\)/gm.map(err, function(r) {
 					var line = Std.parseInt(r.matched(1));
 					var line = Std.parseInt(r.matched(1));
 					var char = Std.parseInt(r.matched(2));
 					var char = Std.parseInt(r.matched(2));
@@ -877,7 +877,7 @@ class DirectXDriver extends h3d.impl.Driver {
 		}
 		}
 		if( compileOnly )
 		if( compileOnly )
 			return { s : null, bytes : bytes };
 			return { s : null, bytes : bytes };
-		var s = shader.vertex ? Driver.createVertexShader(bytes) : Driver.createPixelShader(bytes);
+		var s = shader.kind == Vertex ? Driver.createVertexShader(bytes) : Driver.createPixelShader(bytes);
 		if( s == null ) {
 		if( s == null ) {
 			if( hasDeviceError ) return null;
 			if( hasDeviceError ) return null;
 			throw "Failed to create shader\n" + shader.code;
 			throw "Failed to create shader\n" + shader.code;

+ 6 - 0
h3d/impl/Driver.hx

@@ -317,4 +317,10 @@ class Driver {
 		return 0.;
 		return 0.;
 	}
 	}
 
 
+	// --- COMPUTE
+
+	public function computeDispatch( x : Int = 1, y : Int = 1, z : Int = 1 ) {
+		throw "Not implemented";
+	}
+
 }
 }

+ 10 - 10
h3d/impl/GlDriver.hx

@@ -41,15 +41,15 @@ private typedef ShaderCompiler = hxsl.GlslOut;
 
 
 private class CompiledShader {
 private class CompiledShader {
 	public var s : GLShader;
 	public var s : GLShader;
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var globals : Uniform;
 	public var globals : Uniform;
 	public var params : Uniform;
 	public var params : Uniform;
 	public var textures : Array<{ u : Uniform, t : hxsl.Ast.Type, mode : Int }>;
 	public var textures : Array<{ u : Uniform, t : hxsl.Ast.Type, mode : Int }>;
 	public var buffers : Array<Int>;
 	public var buffers : Array<Int>;
 	public var shader : hxsl.RuntimeShader.RuntimeShaderData;
 	public var shader : hxsl.RuntimeShader.RuntimeShaderData;
-	public function new(s,vertex,shader) {
+	public function new(s,kind,shader) {
 		this.s = s;
 		this.s = s;
-		this.vertex = vertex;
+		this.kind = kind;
 		this.shader = shader;
 		this.shader = shader;
 	}
 	}
 }
 }
@@ -275,7 +275,7 @@ class GlDriver extends Driver {
 	}
 	}
 
 
 	function compileShader( glout : ShaderCompiler, shader : hxsl.RuntimeShader.RuntimeShaderData ) {
 	function compileShader( glout : ShaderCompiler, shader : hxsl.RuntimeShader.RuntimeShaderData ) {
-		var type = shader.vertex ? GL.VERTEX_SHADER : GL.FRAGMENT_SHADER;
+		var type = shader.kind == Vertex ? GL.VERTEX_SHADER : GL.FRAGMENT_SHADER;
 		var s = gl.createShader(type);
 		var s = gl.createShader(type);
 		if( shader.code == null ){
 		if( shader.code == null ){
 			shader.code = glout.run(shader.data);
 			shader.code = glout.run(shader.data);
@@ -296,11 +296,11 @@ class GlDriver extends Driver {
 				codeLines[i] = (i+1) + "\t" + codeLines[i];
 				codeLines[i] = (i+1) + "\t" + codeLines[i];
 			throw "An error occurred compiling the shaders: " + log + line+"\n\n"+codeLines.join("\n");
 			throw "An error occurred compiling the shaders: " + log + line+"\n\n"+codeLines.join("\n");
 		}
 		}
-		return new CompiledShader(s, shader.vertex, shader);
+		return new CompiledShader(s, shader.kind, shader);
 	}
 	}
 
 
 	function initShader( p : CompiledProgram, s : CompiledShader, shader : hxsl.RuntimeShader.RuntimeShaderData, rt : hxsl.RuntimeShader ) {
 	function initShader( p : CompiledProgram, s : CompiledShader, shader : hxsl.RuntimeShader.RuntimeShaderData, rt : hxsl.RuntimeShader ) {
-		var prefix = s.vertex ? "vertex" : "fragment";
+		var prefix = s.kind == Vertex ? "vertex" : "fragment";
 		s.globals = gl.getUniformLocation(p.p, prefix + "Globals");
 		s.globals = gl.getUniformLocation(p.p, prefix + "Globals");
 		s.params = gl.getUniformLocation(p.p, prefix + "Params");
 		s.params = gl.getUniformLocation(p.p, prefix + "Params");
 		s.textures = [];
 		s.textures = [];
@@ -346,9 +346,9 @@ class GlDriver extends Driver {
 			t = t.next;
 			t = t.next;
 		}
 		}
 		if( shader.bufferCount > 0 ) {
 		if( shader.bufferCount > 0 ) {
-			s.buffers = [for( i in 0...shader.bufferCount ) gl.getUniformBlockIndex(p.p,(shader.vertex?"vertex_":"")+"uniform_buffer"+i)];
+			s.buffers = [for( i in 0...shader.bufferCount ) gl.getUniformBlockIndex(p.p,(shader.kind==Vertex?"vertex_":"")+"uniform_buffer"+i)];
 			var start = 0;
 			var start = 0;
-			if( !s.vertex ) start = rt.vertex.bufferCount;
+			if( s.kind == Fragment ) start = rt.vertex.bufferCount;
 			for( i in 0...shader.bufferCount )
 			for( i in 0...shader.bufferCount )
 				gl.uniformBlockBinding(p.p,s.buffers[i],i + start);
 				gl.uniformBlockBinding(p.p,s.buffers[i],i + start);
 		}
 		}
@@ -505,7 +505,7 @@ class GlDriver extends Driver {
 		case Buffers:
 		case Buffers:
 			if( s.buffers != null ) {
 			if( s.buffers != null ) {
 				var start = 0;
 				var start = 0;
-				if( !s.vertex && curShader.vertex.buffers != null )
+				if( s.kind == Fragment && curShader.vertex.buffers != null )
 					start = curShader.vertex.buffers.length;
 					start = curShader.vertex.buffers.length;
 				for( i in 0...s.buffers.length )
 				for( i in 0...s.buffers.length )
 					gl.bindBufferBase(GL.UNIFORM_BUFFER, i + start, buf.buffers[i].vbuf);
 					gl.bindBufferBase(GL.UNIFORM_BUFFER, i + start, buf.buffers[i].vbuf);
@@ -544,7 +544,7 @@ class GlDriver extends Driver {
 
 
 				if( pt.u == null ) continue;
 				if( pt.u == null ) continue;
 
 
-				var idx = s.vertex ? i : curShader.vertex.textures.length + i;
+				var idx = s.kind == Fragment ? curShader.vertex.textures.length + i : i;
 				if( boundTextures[idx] != t.t ) {
 				if( boundTextures[idx] != t.t ) {
 					boundTextures[idx] = t.t;
 					boundTextures[idx] = t.t;
 
 

+ 2 - 2
h3d/pass/Default.hx

@@ -43,7 +43,7 @@ class Default extends Base {
 		var o = @:privateAccess new h3d.pass.PassObject();
 		var o = @:privateAccess new h3d.pass.PassObject();
 		o.pass = p;
 		o.pass = p;
 		setupShaders(new h3d.pass.PassList(o));
 		setupShaders(new h3d.pass.PassList(o));
-		return manager.compileShaders(o.shaders, p.batchMode);
+		return manager.compileShaders(o.shaders, p.batchMode ? Batch : Default);
 	}
 	}
 
 
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
 	function processShaders( p : h3d.pass.PassObject, shaders : hxsl.ShaderList ) {
@@ -68,7 +68,7 @@ class Default extends Base {
 				}
 				}
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 				shaders = ctx.lightSystem.computeLight(p.obj, shaders);
 			}
 			}
-			p.shader = manager.compileShaders(shaders, p.pass.batchMode);
+			p.shader = manager.compileShaders(shaders, p.pass.batchMode ? Batch : Default);
 			p.shaders = shaders;
 			p.shaders = shaders;
 			var t = p.shader.fragment.textures;
 			var t = p.shader.fragment.textures;
 			if( t == null || t.type.match(TArray(_)) )
 			if( t == null || t.type.match(TArray(_)) )

+ 2 - 2
h3d/pass/ShaderManager.hx

@@ -270,11 +270,11 @@ class ShaderManager {
 		fill(buf.fragment, s.fragment);
 		fill(buf.fragment, s.fragment);
 	}
 	}
 
 
-	public function compileShaders( shaders : hxsl.ShaderList, batchMode : Bool = false ) {
+	public function compileShaders( shaders : hxsl.ShaderList, mode : hxsl.Linker.LinkMode = Default ) {
 		globals.resetChannels();
 		globals.resetChannels();
 		for( s in shaders ) s.updateConstants(globals);
 		for( s in shaders ) s.updateConstants(globals);
 		currentOutput.next = shaders;
 		currentOutput.next = shaders;
-		var s = shaderCache.link(currentOutput, batchMode);
+		var s = shaderCache.link(currentOutput, mode);
 		currentOutput.next = null;
 		currentOutput.next = null;
 		return s;
 		return s;
 	}
 	}

+ 1 - 1
h3d/scene/MeshBatch.hx

@@ -121,7 +121,7 @@ class MeshBatch extends MultiMaterial {
 
 
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var manager = cast(ctx,h3d.pass.Default).manager;
 				var shaders = p.getShadersRec();
 				var shaders = p.getShadersRec();
-				var rt = manager.compileShaders(shaders, false);
+				var rt = manager.compileShaders(shaders, Default);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 				var shader = manager.shaderCache.makeBatchShader(rt, shaders, instancedParams);
 
 
 				var b = new BatchData();
 				var b = new BatchData();

+ 18 - 3
hxsl/Ast.hx

@@ -1,5 +1,10 @@
 package hxsl;
 package hxsl;
 
 
+enum BufferKind {
+	Uniform;
+	RW;
+}
+
 enum Type {
 enum Type {
 	TVoid;
 	TVoid;
 	TInt;
 	TInt;
@@ -17,7 +22,7 @@ enum Type {
 	TStruct( vl : Array<TVar> );
 	TStruct( vl : Array<TVar> );
 	TFun( variants : Array<FunType> );
 	TFun( variants : Array<FunType> );
 	TArray( t : Type, size : SizeDecl );
 	TArray( t : Type, size : SizeDecl );
-	TBuffer( t : Type, size : SizeDecl );
+	TBuffer( t : Type, size : SizeDecl, kind : BufferKind );
 	TChannel( size : Int );
 	TChannel( size : Int );
 	TMat2;
 	TMat2;
 }
 }
@@ -187,6 +192,7 @@ enum FunctionKind {
 	Fragment;
 	Fragment;
 	Init;
 	Init;
 	Helper;
 	Helper;
+	Main;
 }
 }
 
 
 enum TGlobal {
 enum TGlobal {
@@ -280,6 +286,8 @@ enum TGlobal {
 	IntBitsToFloat;
 	IntBitsToFloat;
 	UintBitsToFloat;
 	UintBitsToFloat;
 	RoundEven;
 	RoundEven;
+	// compute
+	SetLayout;
 }
 }
 
 
 enum Component {
 enum Component {
@@ -418,7 +426,12 @@ class Tools {
 			prefix + "Vec" + size;
 			prefix + "Vec" + size;
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TStruct(vl):"{" + [for( v in vl ) v.name + " : " + toString(v.type)].join(",") + "}";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TArray(t, s): toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
-		case TBuffer(t, s): "buffer "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
+		case TBuffer(t, s, k):
+			var prefix = switch( k ) {
+			case Uniform: "buffer";
+			case RW: "rwbuffer";
+			};
+			prefix+" "+toString(t) + "[" + (switch( s ) { case SConst(i): "" + i; case SVar(v): v.name; } ) + "]";
 		case TBytes(n): "Bytes" + n;
 		case TBytes(n): "Bytes" + n;
 		default: t.getName().substr(1);
 		default: t.getName().substr(1);
 		}
 		}
@@ -457,6 +470,8 @@ class Tools {
 			return hasSideEffect(e) || hasSideEffect(index);
 			return hasSideEffect(e) || hasSideEffect(index);
 		case TConst(_), TVar(_), TGlobal(_):
 		case TConst(_), TVar(_), TGlobal(_):
 			return false;
 			return false;
+		case TCall({ e : TGlobal(SetLayout) },_):
+			return true;
 		case TCall(e, pl):
 		case TCall(e, pl):
 			if( !e.e.match(TGlobal(_)) )
 			if( !e.e.match(TGlobal(_)) )
 				return true;
 				return true;
@@ -545,7 +560,7 @@ class Tools {
 		case TMat3x4: 12;
 		case TMat3x4: 12;
 		case TBytes(s): s;
 		case TBytes(s): s;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
 		case TBool, TString, TSampler2D, TSampler2DArray, TSamplerCube, TFun(_): 0;
-		case TArray(t, SConst(v)), TBuffer(t, SConst(v)): size(t) * v;
+		case TArray(t, SConst(v)), TBuffer(t, SConst(v),_): size(t) * v;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		case TArray(_, SVar(_)), TBuffer(_): 0;
 		}
 		}
 	}
 	}

+ 57 - 36
hxsl/Cache.hx

@@ -1,6 +1,7 @@
 package hxsl;
 package hxsl;
 using hxsl.Ast;
 using hxsl.Ast;
 import hxsl.RuntimeShader;
 import hxsl.RuntimeShader;
+import hxsl.Linker.LinkMode;
 
 
 class BatchInstanceParams {
 class BatchInstanceParams {
 
 
@@ -195,7 +196,7 @@ class Cache {
 	}
 	}
 
 
 	@:noDebug
 	@:noDebug
-	public function link( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	public function link( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var c = linkCache;
 		var c = linkCache;
 		for( s in shaders ) {
 		for( s in shaders ) {
 			var i = @:privateAccess s.instance;
 			var i = @:privateAccess s.instance;
@@ -207,11 +208,11 @@ class Cache {
 			c = cs;
 			c = cs;
 		}
 		}
 		if( c.linked == null )
 		if( c.linked == null )
-			c.linked = compileRuntimeShader(shaders, batchMode);
+			c.linked = compileRuntimeShader(shaders, mode);
 		return c.linked;
 		return c.linked;
 	}
 	}
 
 
-	function compileRuntimeShader( shaders : hxsl.ShaderList, batchMode : Bool ) {
+	function compileRuntimeShader( shaders : hxsl.ShaderList, mode : LinkMode ) {
 		var shaderDatas = [];
 		var shaderDatas = [];
 		var index = 0;
 		var index = 0;
 		for( s in shaders ) {
 		for( s in shaders ) {
@@ -262,14 +263,14 @@ class Cache {
 		//TRACE = shaderId == 0;
 		//TRACE = shaderId == 0;
 		#end
 		#end
 
 
-		var linker = new hxsl.Linker(batchMode);
+		var linker = new hxsl.Linker(mode);
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 		var s = try linker.link([for( s in shaderDatas ) s.inst.shader]) catch( e : Error ) {
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			var shaders = [for( s in shaderDatas ) Printer.shaderToString(s.inst.shader)];
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			e.msg += "\n\nin\n\n" + shaders.join("\n-----\n");
 			throw e;
 			throw e;
 		}
 		}
 
 
-		if( batchMode ) {
+		if( mode == Batch ) {
 			function checkRec( v : TVar ) {
 			function checkRec( v : TVar ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 				if( v.qualifiers != null && v.qualifiers.indexOf(PerObject) >= 0 ) {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
 					if( v.qualifiers.length == 1 ) v.qualifiers = null else {
@@ -302,7 +303,7 @@ class Cache {
 
 
 		var prev = s;
 		var prev = s;
 		var splitter = new hxsl.Splitter();
 		var splitter = new hxsl.Splitter();
-		var s = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
+		var sl = try splitter.split(s) catch( e : Error ) { e.msg += "\n\nin\n\n"+Printer.shaderToString(s); throw e; };
 
 
 		// params tracking
 		// params tracking
 		var paramVars = new Map();
 		var paramVars = new Map();
@@ -319,41 +320,42 @@ class Cache {
 
 
 
 
 		#if debug
 		#if debug
-		Printer.check(s.vertex,[prev]);
-		Printer.check(s.fragment,[prev]);
+		for( s in sl )
+			Printer.check(s,[prev]);
 		#end
 		#end
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- SPLIT ----\n\n");
 			dbg.writeString("----- SPLIT ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
-		var prev = s;
-		var s = new hxsl.Dce().dce(s.vertex, s.fragment);
+		var prev = sl;
+		var sl = new hxsl.Dce().dce(sl);
 
 
 		#if debug
 		#if debug
-		Printer.check(s.vertex,[prev.vertex]);
-		Printer.check(s.fragment,[prev.fragment]);
+		for( i => s in sl )
+			Printer.check(s,[prev[i]]);
 		#end
 		#end
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- DCE ----\n\n");
 			dbg.writeString("----- DCE ----\n\n");
-			dbg.writeString(Printer.shaderToString(s.vertex, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(s.fragment, DEBUG_IDS) + "\n\n");
+			for( s in sl )
+				dbg.writeString(Printer.shaderToString(s, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
-		var r = buildRuntimeShader(s.vertex, s.fragment, paramVars);
+		var r = buildRuntimeShader(sl, paramVars);
+		r.mode = mode;
 
 
 		#if shader_debug_dump
 		#if shader_debug_dump
 		if( dbg != null ) {
 		if( dbg != null ) {
 			dbg.writeString("----- FLATTEN ----\n\n");
 			dbg.writeString("----- FLATTEN ----\n\n");
-			dbg.writeString(Printer.shaderToString(r.vertex.data, DEBUG_IDS) + "\n\n");
-			dbg.writeString(Printer.shaderToString(r.fragment.data,DEBUG_IDS)+"\n\n");
+			for( s in r.getShaders() )
+				dbg.writeString(Printer.shaderToString(s.data, DEBUG_IDS) + "\n\n");
 		}
 		}
 		#end
 		#end
 
 
@@ -366,9 +368,7 @@ class Cache {
 
 
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		var signParts = [for( i in r.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
 		r.spec.signature = haxe.crypto.Md5.encode(signParts.join(":"));
-		r.signature = haxe.crypto.Md5.encode(Printer.shaderToString(r.vertex.data) + Printer.shaderToString(r.fragment.data));
-		r.batchMode = batchMode;
-
+		r.signature = haxe.crypto.Md5.encode([for( s in r.getShaders() ) Printer.shaderToString(s.data)].join(""));
 		var r2 = byID.get(r.signature);
 		var r2 = byID.get(r.signature);
 		if( r2 != null )
 		if( r2 != null )
 			r.id = r2.id; // same id but different variable mapping
 			r.id = r2.id; // same id but different variable mapping
@@ -385,19 +385,33 @@ class Cache {
 		return r;
 		return r;
 	}
 	}
 
 
-	function buildRuntimeShader( vertex : ShaderData, fragment : ShaderData, paramVars ) {
+	function buildRuntimeShader( shaders : Array<ShaderData>, paramVars ) {
 		var r = new RuntimeShader();
 		var r = new RuntimeShader();
-		r.vertex = flattenShader(vertex, Vertex, paramVars);
-		r.vertex.vertex = true;
-		r.fragment = flattenShader(fragment, Fragment, paramVars);
 		r.globals = new Map();
 		r.globals = new Map();
-		initGlobals(r, r.vertex);
-		initGlobals(r, r.fragment);
-
-		#if debug
-		Printer.check(r.vertex.data,[vertex]);
-		Printer.check(r.fragment.data,[fragment]);
-		#end
+		for( s in shaders ) {
+			var kind = switch( s.name ) {
+			case "vertex": Vertex;
+			case "fragment": Fragment;
+			case "main": Main;
+			default: throw "assert";
+			}
+			var fl = flattenShader(s, kind, paramVars);
+			fl.kind = kind;
+			switch( kind ) {
+			case Vertex:
+				r.vertex = fl;
+			case Fragment:
+				r.fragment = fl;
+			case Main:
+				r.compute = fl;
+			default:
+				throw "assert";
+			}
+			initGlobals(r, fl);
+			#if debug
+			Printer.check(fl.data,[s]);
+			#end
+		}
 		return r;
 		return r;
 	}
 	}
 
 
@@ -467,8 +481,15 @@ class Cache {
 					c.params = out[0];
 					c.params = out[0];
 					c.paramsSize = size;
 					c.paramsSize = size;
 				case TArray(TBuffer(_), _):
 				case TArray(TBuffer(_), _):
-					c.buffers = out[0];
-					c.bufferCount = out.length;
+					if( c.buffers == null ) {
+						c.buffers = out[0];
+						c.bufferCount = out.length;
+					} else {
+						var p = c.buffers;
+						while( p.next != null ) p = p.next;
+						p.next = out[0];
+						c.bufferCount += out.length;
+					}
 				default: throw "assert";
 				default: throw "assert";
 				}
 				}
 			case Global:
 			case Global:
@@ -574,7 +595,7 @@ class Cache {
 		inputOffset.qualifiers = [PerInstance(1)];
 		inputOffset.qualifiers = [PerInstance(1)];
 
 
 		var vcount = declVar("Batch_Count",TInt,Param);
 		var vcount = declVar("Batch_Count",TInt,Param);
-		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount)),Param);
+		var vbuffer = declVar("Batch_Buffer",TBuffer(TVec(4,VFloat),SVar(vcount),Uniform),Param);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var voffset = declVar("Batch_Offset", TInt, Local);
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var ebuffer = { e : TVar(vbuffer), p : pos, t : vbuffer.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };
 		var eoffset = { e : TVar(voffset), p : pos, t : voffset.type };

+ 17 - 12
hxsl/CacheFile.hx

@@ -1,6 +1,8 @@
 package hxsl;
 package hxsl;
 import hxsl.Ast.Tools;
 import hxsl.Ast.Tools;
 
 
+#if (sys || nodejs)
+
 private class NullShader extends hxsl.Shader {
 private class NullShader extends hxsl.Shader {
 	static var SRC = {
 	static var SRC = {
 		var output : {
 		var output : {
@@ -103,7 +105,7 @@ class CacheFile extends Cache {
 		} else if( !allowCompile )
 		} else if( !allowCompile )
 			throw "Missing " + file;
 			throw "Missing " + file;
 		if( linkCache.linked == null ) {
 		if( linkCache.linked == null ) {
-			var rt = link(makeDefaultShader(), false);
+			var rt = link(makeDefaultShader(), Default);
 			linkCache.linked = rt;
 			linkCache.linked = rt;
 			if( rt.vertex.code == null || rt.fragment.code == null ) {
 			if( rt.vertex.code == null || rt.fragment.code == null ) {
 				wait.push(rt);
 				wait.push(rt);
@@ -269,7 +271,7 @@ class CacheFile extends Cache {
 
 
 			for( r in runtimes ) {
 			for( r in runtimes ) {
 				var shaderList = null;
 				var shaderList = null;
-				var batchMode = false;
+				var mode : hxsl.Linker.LinkMode = Default;
 				r.inst.reverse();
 				r.inst.reverse();
 				for( i in r.inst ) {
 				for( i in r.inst ) {
 					var s = Type.createEmptyInstance(hxsl.Shader);
 					var s = Type.createEmptyInstance(hxsl.Shader);
@@ -282,7 +284,7 @@ class CacheFile extends Cache {
 							}
 							}
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							i.shader = { version : null, shader : sh.shader };
 							i.shader = { version : null, shader : sh.shader };
-							batchMode = true;
+							mode = Batch;
 						}
 						}
 						s.constBits = i.bits;
 						s.constBits = i.bits;
 						s.shader = i.shader.shader;
 						s.shader = i.shader.shader;
@@ -293,7 +295,7 @@ class CacheFile extends Cache {
 				}
 				}
 				if( r == null ) continue;
 				if( r == null ) continue;
 				//log("Recompile "+[for( s in shaderList ) shaderName(s)]);
 				//log("Recompile "+[for( s in shaderList ) shaderName(s)]);
-				var rt = link(shaderList, batchMode); // will compile + update linkMap
+				var rt = link(shaderList, mode); // will compile + update linkMap
 				if( rt.spec.signature != r.specSign ) {
 				if( rt.spec.signature != r.specSign ) {
 					var signParts = [for( i in rt.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 					var signParts = [for( i in rt.spec.instances ) i.shader.data.name+"_" + i.bits + "_" + i.index];
 					throw "assert";
 					throw "assert";
@@ -320,6 +322,7 @@ class CacheFile extends Cache {
 				if( spec == null )
 				if( spec == null )
 					continue;
 					continue;
 
 
+				r.mode = Default;
 				r.signature = spec.signature;
 				r.signature = spec.signature;
 				var shaderList = null;
 				var shaderList = null;
 				spec.inst.reverse();
 				spec.inst.reverse();
@@ -334,7 +337,7 @@ class CacheFile extends Cache {
 							}
 							}
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							var sh = makeBatchShader(rt.rt, rt.shaders.next, i.batch.params);
 							i.shader = { version : null, shader : sh.shader };
 							i.shader = { version : null, shader : sh.shader };
-							r.batchMode = true;
+							r.mode = Batch;
 						}
 						}
 						// pseudo instance
 						// pseudo instance
 						var scache = i.shader.shader.instanceCache;
 						var scache = i.shader.shader.instanceCache;
@@ -586,14 +589,14 @@ class CacheFile extends Cache {
 
 
 	function cleanRuntimeData(r:hxsl.RuntimeShader.RuntimeShaderData) {
 	function cleanRuntimeData(r:hxsl.RuntimeShader.RuntimeShaderData) {
 		var rc = new hxsl.RuntimeShader.RuntimeShaderData();
 		var rc = new hxsl.RuntimeShader.RuntimeShaderData();
-		rc.vertex = r.vertex;
+		rc.kind = r.kind;
 		rc.data = {
 		rc.data = {
 			name : null,
 			name : null,
 			vars : [],
 			vars : [],
 			funs : null,
 			funs : null,
 		};
 		};
 		for( v in r.data.vars )
 		for( v in r.data.vars )
-			if( v.kind == (r.vertex ? Input : Output) ) {
+			if( v.kind == (r.kind == Vertex ? Input : Output) ) {
 				rc.data.vars.push({
 				rc.data.vars.push({
 					id : v.id,
 					id : v.id,
 					name : v.name,
 					name : v.name,
@@ -646,8 +649,8 @@ class CacheFile extends Cache {
 	}
 	}
 
 
 	function sortBySpec( r1 : RuntimeShader, r2 : RuntimeShader ) {
 	function sortBySpec( r1 : RuntimeShader, r2 : RuntimeShader ) {
-		if( r1.batchMode != r2.batchMode )
-			return r1.batchMode ? 1 : -1;
+		if( r1.mode != r2.mode )
+			return r1.mode.getIndex() - r2.mode.getIndex();
 		var minLen = hxd.Math.imin(r1.spec.instances.length, r2.spec.instances.length);
 		var minLen = hxd.Math.imin(r1.spec.instances.length, r2.spec.instances.length);
 		for( i in 0...minLen ) {
 		for( i in 0...minLen ) {
 			var i1 = r1.spec.instances[i];
 			var i1 = r1.spec.instances[i];
@@ -706,7 +709,7 @@ class CacheFile extends Cache {
 
 
 	public dynamic function onMissingShader(shaders:hxsl.ShaderList) {
 	public dynamic function onMissingShader(shaders:hxsl.ShaderList) {
 		log("Missing shader " + [for( s in shaders ) shaderName(s)]);
 		log("Missing shader " + [for( s in shaders ) shaderName(s)]);
-		return link(null, false); // default fallback
+		return link(null, Default); // default fallback
 	}
 	}
 
 
 	public dynamic function onNewShader(r:RuntimeShader) {
 	public dynamic function onNewShader(r:RuntimeShader) {
@@ -738,7 +741,7 @@ class CacheFile extends Cache {
 		for( i in s.spec.instances ) {
 		for( i in s.spec.instances ) {
 			var inst = shaders.get(i.shader.data.name);
 			var inst = shaders.get(i.shader.data.name);
 			if( inst == null ) {
 			if( inst == null ) {
-				if( s.batchMode && StringTools.startsWith(i.shader.data.name,"batchShader_") )
+				if( s.mode == Batch && StringTools.startsWith(i.shader.data.name,"batchShader_") )
 					continue;
 					continue;
 				var version = getShaderVersion(i.shader);
 				var version = getShaderVersion(i.shader);
 				inst = { shader : i.shader, version : version };
 				inst = { shader : i.shader, version : version };
@@ -779,4 +782,6 @@ class CacheFile extends Cache {
 		}
 		}
 	}
 	}
 
 
-}
+}
+
+#end

+ 4 - 5
hxsl/CacheFileBuilder.hx

@@ -137,7 +137,7 @@ class CacheFileBuilder {
 			throw "DirectX compilation requires -lib hldx without -D dx12";
 			throw "DirectX compilation requires -lib hldx without -D dx12";
 			#end
 			#end
 		case OpenGL:
 		case OpenGL:
-			if( rd.vertex ) {
+			if( rd.kind == Vertex ) {
 				// both vertex and fragment needs to be compiled with the same GlslOut !
 				// both vertex and fragment needs to be compiled with the same GlslOut !
 				glout = new GlslOut();
 				glout = new GlslOut();
 				glout.version = 150;
 				glout.version = 150;
@@ -172,7 +172,7 @@ class CacheFileBuilder {
 			var tmpSrc = tmpFile + ".hlsl";
 			var tmpSrc = tmpFile + ".hlsl";
 			var tmpOut = tmpFile + ".sb";
 			var tmpOut = tmpFile + ".sb";
 			sys.io.File.saveContent(tmpSrc, code);
 			sys.io.File.saveContent(tmpSrc, code);
-			var args = ["-T", (rd.vertex ? "vs_" : "ps_") + dxShaderVersion,"-O3","-Fo", tmpOut, tmpSrc];
+			var args = ["-T", (rd.kind == Vertex ? "vs_" : "ps_") + dxShaderVersion,"-O3","-Fo", tmpOut, tmpSrc];
 			var p = new sys.io.Process(Sys.getEnv("XboxOneXDKLatest")+ "xdk\\FXC\\amd64\\fxc.exe", args);
 			var p = new sys.io.Process(Sys.getEnv("XboxOneXDKLatest")+ "xdk\\FXC\\amd64\\fxc.exe", args);
 			var error = p.stderr.readAll().toString();
 			var error = p.stderr.readAll().toString();
 			var ecode = p.exitCode();
 			var ecode = p.exitCode();
@@ -216,13 +216,12 @@ class CacheFileBuilder {
 			throw "-lib hldx and -D dx12 are required to generate binaries for XBoxSeries";
 			throw "-lib hldx and -D dx12 are required to generate binaries for XBoxSeries";
 			#end
 			#end
 		case NX:
 		case NX:
-			if( rd.vertex )
+			if( rd.kind == Vertex )
 				glout = new hxsl.NXGlslOut();
 				glout = new hxsl.NXGlslOut();
 			return { code : glout.run(rd.data), bytes : null };
 			return { code : glout.run(rd.data), bytes : null };
 		case NXBinaries:
 		case NXBinaries:
-			if( rd.vertex )
+			if( rd.kind == Vertex ) {
 				glout = new hxsl.NXGlslOut();
 				glout = new hxsl.NXGlslOut();
-			if ( rd.vertex ) {
 				vertexOut = glout.run(rd.data);
 				vertexOut = glout.run(rd.data);
 				return { code : vertexOut, bytes : null }; // binary is in fragment.code
 				return { code : vertexOut, bytes : null }; // binary is in fragment.code
 			}
 			}

+ 12 - 3
hxsl/Checker.hx

@@ -188,6 +188,8 @@ class Checker {
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 				[for( i => t in genType ) { args : [ { name: "x", type: t } ], ret: genIType[i] }];
 			case IntBitsToFloat, UintBitsToFloat:
 			case IntBitsToFloat, UintBitsToFloat:
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
 				[for( i => t in genType ) { args : [ { name: "x", type: genIType[i] } ], ret: t }];
+			case SetLayout:
+				[{ args : [{ name : "x", type : TInt },{ name : "y", type : TInt },{ name : "z", type : TInt }], ret : TVoid }];
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 			case VertexID, InstanceID, FragCoord, FrontFacing:
 				null;
 				null;
 			}
 			}
@@ -246,6 +248,7 @@ class Checker {
 			var kind = switch( f.name ) {
 			var kind = switch( f.name ) {
 			case "vertex":  Vertex;
 			case "vertex":  Vertex;
 			case "fragment": Fragment;
 			case "fragment": Fragment;
+			case "main": Main;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			default: StringTools.startsWith(f.name,"__init__") ? Init : Helper;
 			}
 			}
 			if( args.length != 0 && kind != Helper )
 			if( args.length != 0 && kind != Helper )
@@ -337,6 +340,8 @@ class Checker {
 			switch( v.kind ) {
 			switch( v.kind ) {
 			case Local, Var, Output:
 			case Local, Var, Output:
 				return;
 				return;
+			case Param if( v.type.match(TBuffer(_,_,RW)) ):
+				return;
 			default:
 			default:
 			}
 			}
 		case TSwiz(e, _):
 		case TSwiz(e, _):
@@ -631,7 +636,7 @@ class Checker {
 			default: unify(e2.t, TInt, e2.p);
 			default: unify(e2.t, TInt, e2.p);
 			}
 			}
 			switch( e1.t ) {
 			switch( e1.t ) {
-			case TArray(t, size), TBuffer(t,size):
+			case TArray(t, size), TBuffer(t,size,_):
 				switch( [size, e2.e] ) {
 				switch( [size, e2.e] ) {
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 				case [SConst(v), TConst(CInt(i))] if( i >= v ):
 					error("Indexing outside array bounds", e.pos);
 					error("Indexing outside array bounds", e.pos);
@@ -849,7 +854,7 @@ class Checker {
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 				vl[i] = makeVar( { type : v.type, qualifiers : v.qualifiers, name : v.name, kind : v.kind, expr : null }, pos, parent);
 			}
 			}
 			return parent.type;
 			return parent.type;
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			switch( t ) {
 			switch( t ) {
 			case TArray(_):
 			case TArray(_):
 				error("Multidimentional arrays are not allowed", pos);
 				error("Multidimentional arrays are not allowed", pos);
@@ -894,7 +899,11 @@ class Checker {
 				SVar(v2);
 				SVar(v2);
 			}
 			}
 			t = makeVarType(t,parent,pos);
 			t = makeVarType(t,parent,pos);
-			return vt.match(TArray(_)) ? TArray(t, s) : TBuffer(t,s);
+			return switch( vt ) {
+			case TArray(_): TArray(t, s);
+			case TBuffer(_,_,kind): TBuffer(t,s,kind);
+			default: throw "assert";
+			}
 		default:
 		default:
 			return vt;
 			return vt;
 		}
 		}

+ 24 - 30
hxsl/Dce.hx

@@ -34,30 +34,27 @@ class Dce {
 		#end
 		#end
 	}
 	}
 
 
-	public function dce( vertex : ShaderData, fragment : ShaderData ) {
+	public function dce( shaders : Array<ShaderData> ) {
 		// collect vars dependencies
 		// collect vars dependencies
 		used = new Map();
 		used = new Map();
 		channelVars = [];
 		channelVars = [];
 
 
 		var inputs = [];
 		var inputs = [];
-		for( v in vertex.vars ) {
-			var i = get(v);
-			if( v.kind == Input )
-				inputs.push(i);
-			if( v.kind == Output )
-				i.keep = true;
-		}
-		for( v in fragment.vars ) {
-			var i = get(v);
-			if( v.kind == Output )
-				i.keep = true;
+		for( s in shaders ) {
+			for( v in s.vars ) {
+				var i = get(v);
+				if( v.kind == Input )
+					inputs.push(i);
+				if( v.kind == Output || v.type.match(TBuffer(_,_,RW)) )
+					i.keep = true;
+			}
 		}
 		}
 
 
 		// collect dependencies
 		// collect dependencies
-		for( f in vertex.funs )
-			check(f.expr, [], []);
-		for( f in fragment.funs )
-			check(f.expr, [], []);
+		for( s in shaders ) {
+			for( f in s.funs )
+				check(f.expr, [], []);
+		}
 
 
 		var outExprs = [];
 		var outExprs = [];
 		while( true ) {
 		while( true ) {
@@ -74,10 +71,10 @@ class Dce {
 				markRec(v);
 				markRec(v);
 
 
 			outExprs = [];
 			outExprs = [];
-			for( f in vertex.funs )
-				outExprs.push(mapExpr(f.expr, false));
-			for( f in fragment.funs )
-				outExprs.push(mapExpr(f.expr, false));
+			for( s in shaders ) {
+				for( f in s.funs )
+					outExprs.push(mapExpr(f.expr, false));
+			}
 
 
 			// post add conditional branches
 			// post add conditional branches
 			markAsKeep = false;
 			markAsKeep = false;
@@ -86,22 +83,19 @@ class Dce {
 			if( !markAsKeep ) break;
 			if( !markAsKeep ) break;
 		}
 		}
 
 
-		for( f in vertex.funs )
-			f.expr = outExprs.shift();
-		for( f in fragment.funs )
-			f.expr = outExprs.shift();
+		for( s in shaders ) {
+			for( f in s.funs )
+				f.expr = outExprs.shift();
+		}
 
 
 		for( v in used ) {
 		for( v in used ) {
 			if( v.used ) continue;
 			if( v.used ) continue;
 			if( v.v.kind == VarKind.Input) continue;
 			if( v.v.kind == VarKind.Input) continue;
-			vertex.vars.remove(v.v);
-			fragment.vars.remove(v.v);
+			for( s in shaders )
+				s.vars.remove(v.v);
 		}
 		}
 
 
-		return {
-			fragment : fragment,
-			vertex : vertex,
-		}
+		return shaders.copy();
 	}
 	}
 
 
 	function get( v : TVar ) {
 	function get( v : TVar ) {

+ 12 - 4
hxsl/Eval.hx

@@ -52,18 +52,26 @@ class Eval {
 		switch( v2.type ) {
 		switch( v2.type ) {
 		case TStruct(vl):
 		case TStruct(vl):
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
 			v2.type = TStruct([for( v in vl ) mapVar(v)]);
-		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs)):
+		case TArray(t, SVar(vs)), TBuffer(t, SVar(vs), _):
 			var c = constants.get(vs.id);
 			var c = constants.get(vs.id);
 			if( c != null )
 			if( c != null )
 				switch( c ) {
 				switch( c ) {
 				case TConst(CInt(v)):
 				case TConst(CInt(v)):
-					v2.type = v2.type.match(TArray(_)) ? TArray(t, SConst(v)) : TBuffer(t, SConst(v));
+					v2.type = switch( v2.type ) {
+					case TArray(_): TArray(t, SConst(v));
+					case TBuffer(_,_,kind): TBuffer(t, SConst(v), kind);
+					default: throw "assert";
+					};
 				default:
 				default:
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 					Error.t("Integer value expected for array size constant " + vs.name, null);
 				}
 				}
 			else {
 			else {
 				var vs2 = mapVar(vs);
 				var vs2 = mapVar(vs);
-				v2.type = v2.type.match(TArray(_)) ? TArray(t, SVar(vs2)) : TBuffer(t, SVar(vs2));
+				v2.type = switch( v2.type ) {
+				case TArray(_): TArray(t, SVar(vs2));
+				case TBuffer(_,_,kind): TBuffer(t, SVar(vs2), kind);
+				default: throw "assert";
+				}
 			}
 			}
 		default:
 		default:
 		}
 		}
@@ -81,7 +89,7 @@ class Eval {
 			return false;
 			return false;
 		case TArray(t, _):
 		case TArray(t, _):
 			return checkSamplerRec(t);
 			return checkSamplerRec(t);
-		case TBuffer(_, size):
+		case TBuffer(_):
 			return true;
 			return true;
 		default:
 		default:
 		}
 		}

+ 9 - 5
hxsl/Flatten.hx

@@ -42,6 +42,7 @@ class Flatten {
 		var prefix = switch( kind ) {
 		var prefix = switch( kind ) {
 		case Vertex: "vertex";
 		case Vertex: "vertex";
 		case Fragment: "fragment";
 		case Fragment: "fragment";
+		case Main: "compute";
 		default: throw "assert";
 		default: throw "assert";
 		}
 		}
 		pack(prefix + "Globals", Global, globals, VFloat);
 		pack(prefix + "Globals", Global, globals, VFloat);
@@ -50,7 +51,8 @@ class Flatten {
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 		var textures = packTextures(prefix + "Textures", allVars, TSampler2D)
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesCube", allVars, TSamplerCube))
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
 			.concat(packTextures(prefix+"TexturesArray", allVars, TSampler2DArray));
-		packBuffers(allVars);
+		packBuffers("buffers", allVars, Uniform);
+		packBuffers("rwbuffers", allVars, RW);
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		var funs = [for( f in s.funs ) mapFun(f, mapExpr)];
 		return {
 		return {
 			name : s.name,
 			name : s.name,
@@ -259,22 +261,24 @@ class Flatten {
 		return alloc;
 		return alloc;
 	}
 	}
 
 
-	function packBuffers( vars : Array<TVar> ) {
+	function packBuffers( name : String, vars : Array<TVar>, kind ) {
 		var alloc = new Array<Alloc>();
 		var alloc = new Array<Alloc>();
 		var g : TVar = {
 		var g : TVar = {
 			id : Tools.allocVarId(),
 			id : Tools.allocVarId(),
-			name : "buffers",
+			name : name,
 			type : TVoid,
 			type : TVoid,
 			kind : Param,
 			kind : Param,
 		};
 		};
 		for( v in vars )
 		for( v in vars )
-			if( v.type.match(TBuffer(_)) ) {
+			switch( v.type ) {
+			case TBuffer(_,_,k) if( kind == k ):
 				var a = new Alloc(g, null, alloc.length, 1);
 				var a = new Alloc(g, null, alloc.length, 1);
 				a.v = v;
 				a.v = v;
 				alloc.push(a);
 				alloc.push(a);
 				outVars.push(v);
 				outVars.push(v);
+			default:
 			}
 			}
-		g.type = TArray(TBuffer(TVoid,SConst(0)),SConst(alloc.length));
+		g.type = TArray(TBuffer(TVoid,SConst(0),kind),SConst(alloc.length));
 		allocData.set(g, alloc);
 		allocData.set(g, alloc);
 	}
 	}
 
 

+ 3 - 2
hxsl/GlslOut.hx

@@ -181,12 +181,13 @@ class GlslOut {
 			case SConst(n): add(n);
 			case SConst(n): add(n);
 			}
 			}
 			add("]");
 			add("]");
-		case TBuffer(t, size):
+		case TBuffer(t, size, kind):
+			if( kind != Uniform ) throw "TODO";
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add(" { ");
 			add(" { ");
 			v.type = TArray(t,size);
 			v.type = TArray(t,size);
 			addVar(v);
 			addVar(v);
-			v.type = TBuffer(t,size);
+			v.type = TBuffer(t,size,kind);
 			add("; }");
 			add("; }");
 		default:
 		default:
 			addType(v.type);
 			addType(v.type);

+ 53 - 21
hxsl/HlslOut.hx

@@ -101,13 +101,19 @@ class HlslOut {
 	var exprValues : Array<String>;
 	var exprValues : Array<String>;
 	var locals : Map<Int,TVar>;
 	var locals : Map<Int,TVar>;
 	var decls : Array<String>;
 	var decls : Array<String>;
-	var isVertex : Bool;
+	var kind : FunctionKind;
 	var allNames : Map<String, Int>;
 	var allNames : Map<String, Int>;
 	var samplers : Map<Int, Array<Int>>;
 	var samplers : Map<Int, Array<Int>>;
+	var computeLayout = [1,1,1];
 	public var varNames : Map<Int,String>;
 	public var varNames : Map<Int,String>;
 	public var baseRegister : Int = 0;
 	public var baseRegister : Int = 0;
 
 
 	var varAccess : Map<Int,String>;
 	var varAccess : Map<Int,String>;
+	var isVertex(get,never) : Bool;
+	var isCompute(get,never) : Bool;
+
+	inline function get_isCompute() return kind == Main;
+	inline function get_isVertex() return kind == Vertex;
 
 
 	public function new() {
 	public function new() {
 		varNames = new Map();
 		varNames = new Map();
@@ -175,7 +181,7 @@ class HlslOut {
 			add(" }");
 			add(" }");
 		case TFun(_):
 		case TFun(_):
 			add("function");
 			add("function");
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addType(t);
 			addType(t);
 			add("[");
 			add("[");
 			switch( size ) {
 			switch( size ) {
@@ -201,7 +207,7 @@ class HlslOut {
 
 
 	function addVar( v : TVar ) {
 	function addVar( v : TVar ) {
 		switch( v.type ) {
 		switch( v.type ) {
-		case TArray(t, size), TBuffer(t,size):
+		case TArray(t, size), TBuffer(t,size,_):
 			addVar({
 			addVar({
 				id : v.id,
 				id : v.id,
 				name : v.name,
 				name : v.name,
@@ -298,6 +304,8 @@ class HlslOut {
 			var acc = varAccess.get(v.id);
 			var acc = varAccess.get(v.id);
 			if( acc != null ) add(acc);
 			if( acc != null ) add(acc);
 			ident(v);
 			ident(v);
+		case TCall({ e : TGlobal(SetLayout) },_):
+			// ignore
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 		case TCall({ e : TGlobal(g = (Texture | TextureLod)) }, args):
 			addValue(args[0], tabs);
 			addValue(args[0], tabs);
 			switch( g ) {
 			switch( g ) {
@@ -687,6 +695,8 @@ class HlslOut {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
 		switch( e.e )  {
 		switch( e.e )  {
 		case TGlobal(g): m.set(g,true);
 		case TGlobal(g): m.set(g,true);
+		case TCall({ e : TGlobal(SetLayout) }, [{ e : TConst(CInt(x)) }, { e : TConst(CInt(y)) }, { e : TConst(CInt(z)) }]):
+			computeLayout = [x,y,z];
 		default: e.iter(collectGlobals.bind(m));
 		default: e.iter(collectGlobals.bind(m));
 		}
 		}
 	}
 	}
@@ -709,7 +719,7 @@ class HlslOut {
 			collectGlobals(foundGlobals, f.expr);
 			collectGlobals(foundGlobals, f.expr);
 
 
 		add("struct s_input {\n");
 		add("struct s_input {\n");
-		if( !isVertex )
+		if( kind == Fragment )
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 			add("\tfloat4 __pos__ : "+SV_POSITION+";\n");
 		for( v in s.vars )
 		for( v in s.vars )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
 			if( v.kind == Input || (v.kind == Var && !isVertex) )
@@ -722,14 +732,16 @@ class HlslOut {
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 			add("\tbool isFrontFace : "+SV_IsFrontFace+";\n");
 		add("};\n\n");
 		add("};\n\n");
 
 
-		add("struct s_output {\n");
-		for( v in s.vars )
-			if( v.kind == Output )
-				declVar("_out.", v);
-		for( v in s.vars )
-			if( v.kind == Var && isVertex )
-				declVar("_out.", v);
-		add("};\n\n");
+		if( !isCompute ) {
+			add("struct s_output {\n");
+			for( v in s.vars )
+				if( v.kind == Output )
+					declVar("_out.", v);
+			for( v in s.vars )
+				if( v.kind == Var && isVertex )
+					declVar("_out.", v);
+			add("};\n\n");
+		}
 	}
 	}
 
 
 	function initGlobals( s : ShaderData ) {
 	function initGlobals( s : ShaderData ) {
@@ -770,10 +782,25 @@ class HlslOut {
 
 
 		var bufCount = 0;
 		var bufCount = 0;
 		for( b in buffers ) {
 		for( b in buffers ) {
-			add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
-			addVar(b);
-			add("; };\n");
-			bufCount++;
+			switch( b.type ) {
+			case TBuffer(t, size, kind):
+				switch( kind ) {
+				case Uniform:
+					add('cbuffer _buffer$bufCount : register(b${bufCount+baseRegister+2}) { ');
+					addVar(b);
+					add("; };\n");
+					bufCount++;
+				case RW:
+					add('RWStructuredBuffer<');
+					addType(t);
+					add('> ');
+					ident(b);
+					add(' : register(u${bufCount+baseRegister+2});');
+					bufCount++;
+				}
+			default:
+				throw "assert";
+			}
 		}
 		}
 		if( bufCount > 0 ) add("\n");
 		if( bufCount > 0 ) add("\n");
 
 
@@ -795,7 +822,8 @@ class HlslOut {
 
 
 	function initStatics( s : ShaderData ) {
 	function initStatics( s : ShaderData ) {
 		add(STATIC + "s_input _in;\n");
 		add(STATIC + "s_input _in;\n");
-		add(STATIC + "s_output _out;\n");
+		if( !isCompute )
+			add(STATIC + "s_output _out;\n");
 
 
 		add("\n");
 		add("\n");
 		for( v in s.vars )
 		for( v in s.vars )
@@ -808,7 +836,11 @@ class HlslOut {
 	}
 	}
 
 
 	function emitMain( expr : TExpr ) {
 	function emitMain( expr : TExpr ) {
-		add("s_output main( s_input __in ) {\n");
+		if( isCompute )
+			add('[numthreads(${computeLayout[0]},${computeLayout[1]},${computeLayout[2]})] void ');
+		else
+			add('s_output ');
+		add("main( s_input __in ) {\n");
 		add("\t_in = __in;\n");
 		add("\t_in = __in;\n");
 		switch( expr.e ) {
 		switch( expr.e ) {
 		case TBlock(el):
 		case TBlock(el):
@@ -820,7 +852,8 @@ class HlslOut {
 		default:
 		default:
 			addExpr(expr, "");
 			addExpr(expr, "");
 		}
 		}
-		add("\treturn _out;\n");
+		if( !isCompute )
+			add("\treturn _out;\n");
 		add("}");
 		add("}");
 	}
 	}
 
 
@@ -848,8 +881,7 @@ class HlslOut {
 
 
 		if( s.funs.length != 1 ) throw "assert";
 		if( s.funs.length != 1 ) throw "assert";
 		var f = s.funs[0];
 		var f = s.funs[0];
-		isVertex = f.kind == Vertex;
-
+		kind = f.kind;
 		varAccess = new Map();
 		varAccess = new Map();
 		samplers = new Map();
 		samplers = new Map();
 		initVars(s);
 		initVars(s);

+ 19 - 7
hxsl/Linker.hx

@@ -1,6 +1,12 @@
 package hxsl;
 package hxsl;
 using hxsl.Ast;
 using hxsl.Ast;
 
 
+enum LinkMode {
+	Default;
+	Batch;
+	Compute;
+}
+
 private class AllocatedVar {
 private class AllocatedVar {
 	public var id : Int;
 	public var id : Int;
 	public var v : TVar;
 	public var v : TVar;
@@ -28,6 +34,7 @@ private class ShaderInfos {
 	public var vertex : Null<Bool>;
 	public var vertex : Null<Bool>;
 	public var onStack : Bool;
 	public var onStack : Bool;
 	public var hasDiscard : Bool;
 	public var hasDiscard : Bool;
+	public var isCompute : Bool;
 	public var marked : Null<Bool>;
 	public var marked : Null<Bool>;
 	public function new(n, v) {
 	public function new(n, v) {
 		this.name = n;
 		this.name = n;
@@ -49,12 +56,12 @@ class Linker {
 	var varIdMap : Map<Int,Int>;
 	var varIdMap : Map<Int,Int>;
 	var locals : Map<Int,Bool>;
 	var locals : Map<Int,Bool>;
 	var curInstance : Int;
 	var curInstance : Int;
-	var batchMode : Bool;
+	var mode : LinkMode;
 	var isBatchShader : Bool;
 	var isBatchShader : Bool;
 	var debugDepth = 0;
 	var debugDepth = 0;
 
 
-	public function new(batchMode=false) {
-		this.batchMode = batchMode;
+	public function new(mode) {
+		this.mode = mode;
 	}
 	}
 
 
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
 	inline function debug( msg : String, ?pos : haxe.PosInfos ) {
@@ -348,7 +355,7 @@ class Linker {
 		curInstance = 0;
 		curInstance = 0;
 		var outVars = [];
 		var outVars = [];
 		for( s in shadersData ) {
 		for( s in shadersData ) {
-			isBatchShader = batchMode && StringTools.startsWith(s.name,"batchShader_");
+			isBatchShader = mode == Batch && StringTools.startsWith(s.name,"batchShader_");
 			for( v in s.vars ) {
 			for( v in s.vars ) {
 				var v2 = allocVar(v, null, s.name);
 				var v2 = allocVar(v, null, s.name);
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
 				if( isBatchShader && v2.v.kind == Param && !StringTools.startsWith(v2.path,"Batch_") )
@@ -375,8 +382,13 @@ class Linker {
 				if( v.kind == null ) throw "assert";
 				if( v.kind == null ) throw "assert";
 				switch( v.kind ) {
 				switch( v.kind ) {
 				case Vertex, Fragment:
 				case Vertex, Fragment:
+					if( mode == Compute )
+						throw "Unexpected "+v.kind.getName().toLowerCase()+"() function in compute shader";
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
 					addShader(s.name + "." + (v.kind == Vertex ? "vertex" : "fragment"), v.kind == Vertex, f.expr, priority);
-
+				case Main:
+					if( mode != Compute )
+						throw "Unexpected main() outside compute shader";
+					addShader(s.name, true, f.expr, priority).isCompute = true;
 				case Init:
 				case Init:
 					var prio : Array<Int>;
 					var prio : Array<Int>;
 					var status : Null<Bool> = switch( f.ref.name ) {
 					var status : Null<Bool> = switch( f.ref.name ) {
@@ -408,7 +420,7 @@ class Linker {
 
 
 		// force shaders containing discard to be included
 		// force shaders containing discard to be included
 		for( s in shaders )
 		for( s in shaders )
-			if( s.hasDiscard ) {
+			if( s.hasDiscard || s.isCompute ) {
 				initDependencies(s);
 				initDependencies(s);
 				entry.deps.set(s, true);
 				entry.deps.set(s, true);
 			}
 			}
@@ -505,7 +517,7 @@ class Linker {
 				expr : expr,
 				expr : expr,
 			};
 			};
 		}
 		}
-		var funs = [
+		var funs = mode == Compute ? [build(Main,"main",v)] : [
 			build(Vertex, "vertex", v),
 			build(Vertex, "vertex", v),
 			build(Fragment, "fragment", f),
 			build(Fragment, "fragment", f),
 		];
 		];

+ 7 - 2
hxsl/MacroParser.hx

@@ -115,7 +115,7 @@ class MacroParser {
 			case "Channel3": return TChannel(3);
 			case "Channel3": return TChannel(3);
 			case "Channel4": return TChannel(4);
 			case "Channel4": return TChannel(4);
 			}
 			}
-		case TPath( { pack : [], name : name = ("Array"|"Buffer"), sub : null, params : [t, size] } ):
+		case TPath( { pack : [], name : name = ("Array"|"Buffer"|"RWBuffer"), sub : null, params : [t, size] } ):
 			var t = switch( t ) {
 			var t = switch( t ) {
 			case TPType(t): parseType(t, pos);
 			case TPType(t): parseType(t, pos);
 			default: null;
 			default: null;
@@ -129,7 +129,12 @@ class MacroParser {
 			default: null;
 			default: null;
 			}
 			}
 			if( t != null && size != null )
 			if( t != null && size != null )
-				return name == "Array" ? TArray(t, size) : TBuffer(t,size);
+				return switch( name ) {
+				case "Array": TArray(t, size);
+				case "Buffer": TBuffer(t,size,Uniform);
+				case "RWBuffer": TBuffer(t,size,RW);
+				default: throw "assert";
+				}
 		case TAnonymous(fl):
 		case TAnonymous(fl):
 			return TStruct([for( f in fl ) {
 			return TStruct([for( f in fl ) {
 				switch( f.kind ) {
 				switch( f.kind ) {

+ 9 - 2
hxsl/RuntimeShader.hx

@@ -44,7 +44,7 @@ class AllocGlobal {
 }
 }
 
 
 class RuntimeShaderData {
 class RuntimeShaderData {
-	public var vertex : Bool;
+	public var kind : hxsl.Ast.FunctionKind;
 	public var data : Ast.ShaderData;
 	public var data : Ast.ShaderData;
 	public var code : String;
 	public var code : String;
 	public var params : AllocParam;
 	public var params : AllocParam;
@@ -75,14 +75,18 @@ class RuntimeShader {
 	public var id : Int;
 	public var id : Int;
 	public var vertex : RuntimeShaderData;
 	public var vertex : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
 	public var fragment : RuntimeShaderData;
+	public var compute(get,set) : RuntimeShaderData;
 	public var globals : Map<Int,Bool>;
 	public var globals : Map<Int,Bool>;
 
 
+	inline function get_compute() return vertex;
+	inline function set_compute(v) return vertex = v;
+
 	/**
 	/**
 		Signature of the resulting HxSL code.
 		Signature of the resulting HxSL code.
 		Several shaders with the different specification might still get the same resulting signature.
 		Several shaders with the different specification might still get the same resulting signature.
 	**/
 	**/
 	public var signature : String;
 	public var signature : String;
-	public var batchMode : Bool;
+	public var mode : hxsl.Linker.LinkMode;
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 	public var spec : { instances : Array<ShaderInstanceDesc>, signature : String };
 
 
 	public function new() {
 	public function new() {
@@ -93,5 +97,8 @@ class RuntimeShader {
 		return globals.exists(gid);
 		return globals.exists(gid);
 	}
 	}
 
 
+	public function getShaders() {
+		return mode == Compute ? [compute] : [vertex, fragment];
+	}
 
 
 }
 }

+ 15 - 2
hxsl/Serializer.hx

@@ -86,7 +86,14 @@ class Serializer {
 				writeArr(vl,writeVar);
 				writeArr(vl,writeVar);
 		case TFun(variants):
 		case TFun(variants):
 			// not serialized
 			// not serialized
-		case TArray(t, size), TBuffer(t, size):
+		case TArray(t, size), TBuffer(t, size, Uniform):
+			writeType(t);
+			switch (size) {
+			case SConst(v): out.addByte(0); writeVarInt(v);
+			case SVar(v): writeVar(v);
+			}
+		case TBuffer(t, size, kind):
+			out.addByte(kind.getIndex() + 0x80);
 			writeType(t);
 			writeType(t);
 			switch (size) {
 			switch (size) {
 			case SConst(v): out.addByte(0); writeVarInt(v);
 			case SConst(v): out.addByte(0); writeVarInt(v);
@@ -136,9 +143,15 @@ class Serializer {
 			var v = readVar();
 			var v = readVar();
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 			TArray(t, v == null ? SConst(readVarInt()) : SVar(v));
 		case 16:
 		case 16:
+			var tag = input.readByte();
+			var kind = Uniform;
+			if( tag & 0x80 == 0 )
+				input.position--;
+			else
+				kind = BufferKind.createByIndex(tag & 0x7F);
 			var t = readType();
 			var t = readType();
 			var v = readVar();
 			var v = readVar();
-			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v));
+			TBuffer(t, v == null ? SConst(readVarInt()) : SVar(v), kind);
 		case 17:
 		case 17:
 			TChannel(input.readByte());
 			TChannel(input.readByte());
 		case 18: TMat2;
 		case 18: TMat2;

+ 31 - 21
hxsl/Splitter.hx

@@ -24,17 +24,19 @@ class Splitter {
 	public function new() {
 	public function new() {
 	}
 	}
 
 
-	public function split( s : ShaderData ) : { vertex : ShaderData, fragment : ShaderData } {
+	public function split( s : ShaderData ) : Array<ShaderData> {
 		var vfun = null, vvars = new Map();
 		var vfun = null, vvars = new Map();
 		var ffun = null, fvars = new Map();
 		var ffun = null, fvars = new Map();
+		var isCompute = false;
 		varNames = new Map();
 		varNames = new Map();
 		varMap = new Map();
 		varMap = new Map();
 		for( f in s.funs )
 		for( f in s.funs )
 			switch( f.kind ) {
 			switch( f.kind ) {
-			case Vertex:
+			case Vertex, Main:
 				vars = vvars;
 				vars = vvars;
 				vfun = f;
 				vfun = f;
 				checkExpr(f.expr);
 				checkExpr(f.expr);
+				if( f.kind == Main ) isCompute = true;
 			case Fragment:
 			case Fragment:
 				vars = fvars;
 				vars = fvars;
 				ffun = f;
 				ffun = f;
@@ -144,20 +146,22 @@ class Splitter {
 		for( v in fvars )
 		for( v in fvars )
 			checkVar(v, false, vvars, ffun.expr.p);
 			checkVar(v, false, vvars, ffun.expr.p);
 
 
-		ffun = {
-			ret : ffun.ret,
-			ref : ffun.ref,
-			kind : ffun.kind,
-			args : ffun.args,
-			expr : mapVars(ffun.expr),
-		};
-		switch( ffun.expr.e ) {
-		case TBlock(el):
-			for( e in finits )
-				el.unshift(e);
-		default:
-			finits.push(ffun.expr);
-			ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+		if( ffun != null ) {
+			ffun = {
+				ret : ffun.ret,
+				ref : ffun.ref,
+				kind : ffun.kind,
+				args : ffun.args,
+				expr : mapVars(ffun.expr),
+			};
+			switch( ffun.expr.e ) {
+			case TBlock(el):
+				for( e in finits )
+					el.unshift(e);
+			default:
+				finits.push(ffun.expr);
+				ffun.expr = { e : TBlock(finits), t : TVoid, p : ffun.expr.p };
+			}
 		}
 		}
 
 
 		var vvars = [for( v in vvars ) if( !v.local ) v];
 		var vvars = [for( v in vvars ) if( !v.local ) v];
@@ -167,18 +171,24 @@ class Splitter {
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		vvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 		fvars.sort(function(v1, v2) return getId(v1) - getId(v2));
 
 
-		return {
-			vertex : {
+		return isCompute ? [
+			{
+				name : "main",
+				vars : [for( v in vvars ) v.v],
+				funs : [vfun],
+			}
+		] : [
+			{
 				name : "vertex",
 				name : "vertex",
 				vars : [for( v in vvars ) v.v],
 				vars : [for( v in vvars ) v.v],
 				funs : [vfun],
 				funs : [vfun],
 			},
 			},
-			fragment : {
+			{
 				name : "fragment",
 				name : "fragment",
 				vars : [for( v in fvars ) v.v],
 				vars : [for( v in fvars ) v.v],
 				funs : [ffun],
 				funs : [ffun],
-			},
-		};
+			}
+		];
 	}
 	}
 
 
 	function addExpr( f : TFunction, e : TExpr ) {
 	function addExpr( f : TFunction, e : TExpr ) {