2
0
Эх сурвалжийг харах

added rw buffer and compute shader support, upgraded glsl keyword list

Nicolas Cannasse 1 жил өмнө
parent
commit
4370575fe9
2 өөрчлөгдсөн 155 нэмэгдсэн , 26 устгасан
  1. 74 12
      h3d/impl/GlDriver.hx
  2. 81 14
      hxsl/GlslOut.hx

+ 74 - 12
h3d/impl/GlDriver.hx

@@ -46,6 +46,7 @@ private class CompiledShader {
 	public var params : Uniform;
 	public var textures : Array<{ u : Uniform, t : hxsl.Ast.Type, mode : Int }>;
 	public var buffers : Array<Int>;
+	public var bufferTypes : Array<hxsl.Ast.BufferKind>;
 	public var shader : hxsl.RuntimeShader.RuntimeShaderData;
 	public function new(s,kind,shader) {
 		this.s = s;
@@ -217,6 +218,12 @@ class GlDriver extends Driver {
 		gl.pixelStorei(GL.UNPACK_ALIGNMENT, 1);
 	}
 
+	#if hlsdl
+	public static function enableComputeShaders() {
+		sdl.Sdl.setGLVersion(4, 3);
+	}
+	#end
+
 	override function setRenderFlag( r : RenderFlag, value : Int ) {
 		switch( r ) {
 		case CameraHandness:
@@ -259,6 +266,8 @@ class GlDriver extends Driver {
 		inline function compile(sh) {
 			return makeCompiler().run(sh);
 		}
+		if( shader.mode == Compute )
+			return "// compute:\n" + compile(shader.compute.data);
 		return "// vertex:\n" + compile(shader.vertex.data) + "// fragment:\n" + compile(shader.fragment.data);
 	}
 
@@ -275,7 +284,12 @@ class GlDriver extends Driver {
 	}
 
 	function compileShader( glout : ShaderCompiler, shader : hxsl.RuntimeShader.RuntimeShaderData ) {
-		var type = shader.kind == Vertex ? GL.VERTEX_SHADER : GL.FRAGMENT_SHADER;
+		var type = switch( shader.kind ) {
+		case Vertex: GL.VERTEX_SHADER;
+		case Fragment: GL.FRAGMENT_SHADER;
+		case Main: GL.COMPUTE_SHADER;
+		default: throw "assert";
+		};
 		var s = gl.createShader(type);
 		if( shader.code == null ){
 			shader.code = glout.run(shader.data);
@@ -300,7 +314,11 @@ class GlDriver extends Driver {
 	}
 
 	function initShader( p : CompiledProgram, s : CompiledShader, shader : hxsl.RuntimeShader.RuntimeShaderData, rt : hxsl.RuntimeShader ) {
-		var prefix = s.kind == Vertex ? "vertex" : "fragment";
+		var prefix = switch( s.kind ) {
+		case Vertex: "vertex";
+		case Fragment: "fragment";
+		default: "compute";
+		}
 		s.globals = gl.getUniformLocation(p.p, prefix + "Globals");
 		s.params = gl.getUniformLocation(p.p, prefix + "Params");
 		s.textures = [];
@@ -346,11 +364,41 @@ class GlDriver extends Driver {
 			t = t.next;
 		}
 		if( shader.bufferCount > 0 ) {
-			s.buffers = [for( i in 0...shader.bufferCount ) gl.getUniformBlockIndex(p.p,(shader.kind==Vertex?"vertex_":"")+"uniform_buffer"+i)];
+			s.bufferTypes = [];
+			var bp = s.shader.buffers;
+			while( bp != null ) {
+				var kind = switch( bp.type ) {
+				case TBuffer(_,_,kind): kind;
+				default: throw "assert";
+				}
+				s.bufferTypes.push(kind);
+				bp = bp.next;
+			}
+			s.buffers = [for( i in 0...shader.bufferCount ) {
+				switch( s.bufferTypes[i] ) {
+				case RW:
+					#if js
+					throw "RW buffer not supported in WebGL";
+					#elseif (hl_ver < version("1.15.0"))
+					throw "RW buffer support requires -D hl-ver=1.15.0";
+					#else
+					gl.getProgramResourceIndex(p.p,GL.SHADER_STORAGE_BLOCK, "rw_uniform_buffer"+i);
+					#end
+				case Uniform:
+					gl.getUniformBlockIndex(p.p,(shader.kind==Vertex?"vertex_":"")+"uniform_buffer"+i);
+				}
+			}];
 			var start = 0;
 			if( s.kind == Fragment ) start = rt.vertex.bufferCount;
 			for( i in 0...shader.bufferCount )
-				gl.uniformBlockBinding(p.p,s.buffers[i],i + start);
+				switch( s.bufferTypes[i] ) {
+				case Uniform:
+					gl.uniformBlockBinding(p.p,s.buffers[i],i + start);
+				case RW:
+					#if (hl_ver >= version("1.15.0"))
+					gl.shaderStorageBlockBinding(p.p,s.buffers[i], i + start);
+					#end
+				}
 		}
 	}
 
@@ -360,11 +408,12 @@ class GlDriver extends Driver {
 			p = new CompiledProgram();
 			var glout = makeCompiler();
 			p.vertex = compileShader(glout,shader.vertex);
-			p.fragment = compileShader(glout,shader.fragment);
+			if( shader.fragment != null )
+				p.fragment = compileShader(glout,shader.fragment);
 
 			p.p = gl.createProgram();
 			#if ((hlsdl || usegl) && !hlmesa)
-			if( glES == null ) {
+			if( glES == null && shader.fragment != null ) {
 				var outCount = 0;
 				for( v in shader.fragment.data.vars )
 					switch( v.kind ) {
@@ -375,7 +424,8 @@ class GlDriver extends Driver {
 			}
 			#end
 			gl.attachShader(p.p, p.vertex.s);
-			gl.attachShader(p.p, p.fragment.s);
+			if( p.fragment != null )
+				gl.attachShader(p.p, p.fragment.s);
 			var log = null;
 			try {
 				gl.linkProgram(p.p);
@@ -385,7 +435,8 @@ class GlDriver extends Driver {
 				throw "Shader linkage error: "+Std.string(e)+" ("+getDriverName(false)+")";
 			}
 			gl.deleteShader(p.vertex.s);
-			gl.deleteShader(p.fragment.s);
+			if( p.fragment != null )
+				gl.deleteShader(p.fragment.s);
 			if( log != null ) {
 				#if js
 				gl.deleteProgram(p.p);
@@ -399,11 +450,12 @@ class GlDriver extends Driver {
 					return selectShader(shader);
 				}
 				#end
-				throw "Program linkage failure: "+log+"\nVertex=\n"+shader.vertex.code+"\n\nFragment=\n"+shader.fragment.code;
+				throw "Program linkage failure: "+log+"\nVertex=\n"+shader.vertex.code+(shader.fragment == null ? "" : "\n\nFragment=\n"+shader.fragment.code);
 			}
 			firstShader = false;
 			initShader(p, p.vertex, shader.vertex, shader);
-			initShader(p, p.fragment, shader.fragment, shader);
+			if( p.fragment != null )
+				initShader(p, p.fragment, shader.fragment, shader);
 			p.attribs = [];
 			p.hasAttribIndex = 0;
 			var format : Array<hxd.BufferFormat.BufferInput> = [];
@@ -479,7 +531,8 @@ class GlDriver extends Driver {
 
 	override function uploadShaderBuffers( buf : h3d.shader.Buffers, which : h3d.shader.Buffers.BufferKind ) {
 		uploadBuffer(buf, curShader.vertex, buf.vertex, which);
-		uploadBuffer(buf, curShader.fragment, buf.fragment, which);
+		if( curShader.fragment != null )
+			uploadBuffer(buf, curShader.fragment, buf.fragment, which);
 	}
 
 	function uploadBuffer( buffer : h3d.shader.Buffers, s : CompiledShader, buf : h3d.shader.Buffers.ShaderBuffers, which : h3d.shader.Buffers.BufferKind ) {
@@ -508,7 +561,12 @@ class GlDriver extends Driver {
 				if( s.kind == Fragment && curShader.vertex.buffers != null )
 					start = curShader.vertex.buffers.length;
 				for( i in 0...s.buffers.length )
-					gl.bindBufferBase(GL.UNIFORM_BUFFER, i + start, buf.buffers[i].vbuf);
+					switch( s.bufferTypes[i] ) {
+					case Uniform:
+						gl.bindBufferBase(GL.UNIFORM_BUFFER, i + start, buf.buffers[i].vbuf);
+					case RW:
+						gl.bindBufferBase(0x90D2 /*GL.SHADER STORAGE BUFFER*/, i + start, buf.buffers[i].vbuf);
+					}
 			}
 		case Textures:
 			var tcount = s.textures.length;
@@ -1779,6 +1837,10 @@ class GlDriver extends Driver {
 
 	#if hl
 
+	override function computeDispatch(x:Int = 1, y:Int = 1, z:Int = 1) {
+		GL.dispatchCompute(x,y,z);
+	}
+
 	override function allocQuery(kind:QueryKind) {
 		return { q : GL.createQuery(), kind : kind };
 	}

+ 81 - 14
hxsl/GlslOut.hx

@@ -3,16 +3,46 @@ import hxsl.Ast;
 
 class GlslOut {
 
-	static var KWD_LIST = [
-		"input", "output", "discard", #if js "sample", #end
-		"dvec2", "dvec3", "dvec4", "hvec2", "hvec3", "hvec4", "fvec2", "fvec3", "fvec4",
-		"int", "float", "bool", "long", "short", "double", "half", "fixed", "unsigned", "superp",
-		"lowp", "mediump", "highp", "precision", "invariant", "discard",
-		"struct", "asm", "union", "template", "this", "packed", "goto", "sizeof","namespace",
-		"noline", "volatile", "external", "flat", "input", "output",
-		"out","attribute","const","uniform","varying","inout","void",
-	];
-	static var KWDS = [for( k in KWD_LIST ) k => true];
+	static var KWD_LIST = "attribute const uniform varying buffer shared
+	coherent volatile restrict readonly writeonly
+	atomic_uint
+	layout
+	centroid flat smooth noperspective
+	patch sample
+	break continue do for while switch case default
+	if else
+	subroutine
+	in out inout
+	float double int void bool true false
+	invariant precise
+	discard return
+	mat2 mat3 mat4 dmat2 dmat3 dmat4
+	mat2x2 mat2x3 mat2x4 dmat2x2 dmat2x3 dmat2x4
+	mat3x2 mat3x3 mat3x4 dmat3x2 dmat3x3 dmat3x4
+	mat4x2 mat4x3 mat4x4 dmat4x2 dmat4x3 dmat4x4
+	vec2 vec3 vec4 ivec2 ivec3 ivec4 bvec2 bvec3 bvec4 dvec2 dvec3 dvec4
+	uint uvec2 uvec3 uvec4
+	lowp mediump highp precision
+	image1D iimage1D uimage1D
+	image2D iimage2D uimage2D
+	image3D iimage3D uimage3D
+	struct
+	common partition active
+	asm
+	class union enum typedef template this packed
+	resource
+	goto
+	inline noinline public static extern external interface
+	long short half fixed unsigned superp
+	input output
+	hvec2 hvec3 hvec4 fvec2 fvec3 fvec4
+	sampler3DRect
+	filter
+	sizeof cast
+	namespace using
+	row_major";
+
+	static var KWDS = [for( k in ~/[ \t\r\n]+/g.split(KWD_LIST) ) k => true];
 	static var GLOBALS = {
 		var gl = [];
 		inline function set(g:hxsl.Ast.TGlobal,str:String) {
@@ -55,6 +85,7 @@ class GlslOut {
 	var isVertex : Bool;
 	var allNames : Map<String, Int>;
 	var outIndexes : Map<Int, Int>;
+	var isCompute : Bool;
 
 	var isES(get,never) : Bool;
 	var isES2(get,never) : Bool;
@@ -182,7 +213,11 @@ class GlslOut {
 			}
 			add("]");
 		case TBuffer(t, size, kind):
-			if( kind != Uniform ) throw "TODO";
+			switch( kind ) {
+			case Uniform:
+			case RW:
+				add("rw_");
+			}
 			add((isVertex ? "vertex_" : "") + "uniform_buffer"+(uniformBuffer++));
 			add(" { ");
 			v.type = TArray(t,size);
@@ -426,6 +461,8 @@ class GlslOut {
 			} else {
 				add("/*var*/");
 			}
+		case TCall( { e : TGlobal(SetLayout) }, _):
+			// nothing
 		case TCall( { e : TGlobal(Saturate) }, [e]):
 			add("clamp(");
 			addValue(e, tabs);
@@ -633,9 +670,18 @@ class GlslOut {
 	function initVar( v : TVar ){
 		switch( v.kind ) {
 		case Param, Global:
-			if( v.type.match(TBuffer(_)) )
+			switch( v.type ) {
+			case TBuffer(_, _, kind):
 				add("layout(std140) ");
-			add("uniform ");
+				switch( kind ) {
+				case Uniform:
+					add("uniform ");
+				case RW:
+					add("buffer ");
+				}
+			default:
+				add("uniform ");
+			}
 		case Input:
 			add( isES2 ? "attribute " : "in ");
 		case Var:
@@ -682,7 +728,22 @@ class GlslOut {
 			decl("#extension GL_EXT_draw_buffers : enable");
 	}
 
+	var computeLayout = [1,1,1];
+	function collectGlobals( m : Map<TGlobal,Bool>, e : TExpr ) {
+		switch( e.e )  {
+		case TGlobal(g): m.set(g,true);
+		case TCall({ e : TGlobal(SetLayout) }, [{ e : TConst(CInt(x)) }, { e : TConst(CInt(y)) }, { e : TConst(CInt(z)) }]):
+			computeLayout = [x,y,z];
+		default: hxsl.Tools.iter(e,collectGlobals.bind(m));
+		}
+	}
+
 	public function run( s : ShaderData ) {
+
+		var foundGlobals = new Map();
+		for( f in s.funs )
+			collectGlobals(foundGlobals, f.expr);
+
 		locals = new Map();
 		decls = [];
 		buf = new StringBuf();
@@ -691,14 +752,18 @@ class GlslOut {
 		if( s.funs.length != 1 ) throw "assert";
 		var f = s.funs[0];
 		isVertex = f.kind == Vertex;
+		isCompute = f.kind == Main;
 
-		if (isVertex)
+		if (isVertex || isCompute)
 			decl("precision highp float;");
 		else
 			decl("precision mediump float;");
 
 		initVars(s);
 
+		if( isCompute )
+			decl('layout(local_size_x = ${computeLayout[0]}, local_size_y = ${computeLayout[1]}, local_size_z = ${computeLayout[2]}) in;');
+
 		var tmp = buf;
 		buf = new StringBuf();
 		add("void main(void) {\n");
@@ -740,6 +805,8 @@ class GlslOut {
 
 		if( isES )
 			decl("#version " + (version < 100 ? 100 : version) + (version > 150 ? " es" : ""));
+		else if( isCompute )
+			decl("#version 430");
 		else if( version != null )
 			decl("#version " + (version > 150 ? 150 : version));
 		else