Преглед изворни кода

Shadergraph better variable handling

Clement Espeute пре 2 година
родитељ
комит
aca523416b

+ 4 - 3
hrt/shgraph/ShaderGlobalInput.hx

@@ -16,11 +16,12 @@ class ShaderGlobalInput extends ShaderNode {
 										{ parent: null, id: 0, kind: Global, name: "global.modelView", type: TMat4 },
 										{ parent: null, id: 0, kind: Global, name: "global.modelViewInverse", type: TMat4 } ];
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var inVar : TVar = globalInputs[variableIdx];
-		var output : TVar = {name: "output", id:1, type: inVar.type, kind: Local, qualifiers: []};
+		var inVar : TVar = Reflect.copy(globalInputs[variableIdx]);
+		inVar.id = getNewIdFn();
+		var output : TVar = {name: "output", id: getNewIdFn(), type: inVar.type, kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
 
 		return {expr: finalExpr, inVars: [], outVars:[{v: output, internal: false}], externVars: [inVar], inits: []};

+ 84 - 29
hrt/shgraph/ShaderGraph.hx

@@ -371,7 +371,7 @@ class Graph {
 
 		// Patch I/O if name have changed
 		if (!outputs.exists(outputName)) {
-			var def = output.instance.getShaderDef(domain);
+			var def = output.instance.getShaderDef(domain, () -> 0);
 			if(edge.outputId != null && def.outVars.length > edge.outputId) {
 				outputName = def.outVars[edge.outputId].v.name;
 			}
@@ -381,7 +381,7 @@ class Graph {
 		}
 
 		if (!inputs.exists(inputName)) {
-			var def = node.instance.getShaderDef(domain);
+			var def = node.instance.getShaderDef(domain,  () -> 0);
 			if (edge.inputId != null && def.inVars.length > edge.inputId) {
 				inputName = def.inVars[edge.inputId].v.name;
 			}
@@ -448,24 +448,37 @@ class Graph {
 
 
 
-	public function generate2(?specificOutput: ShaderNode) : ShaderNodeDef {
+	public function generate2(?specificOutput: ShaderNode, ?getNewVarId: () -> Int) : ShaderNodeDef {
 
-		var varIdCount = 0;
-		var getNewVarId = function()
-		{
-			return varIdCount++;
-		};
+		if (getNewVarId == null) {
+			var varIdCount = 0;
+			getNewVarId = function()
+			{
+				return hxsl.Tools.allocVarId();
+			};
+		}
 
 		inline function getNewVarName(node: Node, id: Int) : String {
 			return '_sg_${(node.type).split(".").pop()}_var_$id';
 		}
 
 		var nodeOutputs : Map<Node, Map<String, TVar>> = [];
+		var nodeDef : Map<Node, ShaderGraph.ShaderNodeDef> = [];
+
+		function getDef(node: Node) : ShaderGraph.ShaderNodeDef {
+			var def = nodeDef.get(node);
+			if (def != null)
+				return def;
+			def = node.instance.getShaderDef(domain, getNewVarId);
+			nodeDef.set(node, def);
+			return def;
+		}
+
 		function getOutputs(node: Node) : Map<String, TVar> {
 			if (!nodeOutputs.exists(node)) {
 				var outputs : Map<String, TVar> = [];
 
-				var def = node.instance.getShaderDef(domain);
+				var def = getDef(node);
 				for (output in def.outVars) {
 					if (output.internal)
 						continue;
@@ -481,7 +494,9 @@ class Graph {
 			return nodeOutputs.get(node);
 		}
 
-		// Recursively replace the to tvar with from tvar in the given expression
+
+
+		// Recursively replace the "what" tvar with "with" tvar in the given expression
 		function replaceVar(expr: TExpr, what: TVar, with: TExpr) : TExpr {
 			if(!what.type.equals(with.t))
 				throw "type missmatch " + what.type + " != " + with.t;
@@ -516,10 +531,29 @@ class Graph {
 			nodeHasOutputs.set(connection.from, true);
 		}
 
-		var graphInputVars  = [];
+		var graphInputVars = [];
 		var graphOutputVars  = [];
 		var externs : Array<TVar> = [];
 
+		var outsideVars : Map<String, TVar> = [];
+		function getOutsideVar(name: String, original: TVar, isInput: Bool) : TVar {
+			var v : TVar = outsideVars.get(name);
+			if (v == null) {
+				v = Reflect.copy(original);
+				v.id = getNewVarId();
+				v.name = name;
+				outsideVars.set(name, v);
+			}
+			if (isInput) {
+				graphInputVars.pushUnique({v: v, internal: false, defVal: null});
+			}
+			else {
+				graphOutputVars.pushUnique({v: v, internal: false});
+			}
+
+			return v;
+		}
+
 		var nodeToExplore : Array<Node> = [];
 
 		for (node => hasOutputs in nodeHasOutputs) {
@@ -615,7 +649,7 @@ class Graph {
 			var outputs = getOutputs(currentNode);
 
 			{
-				var def = currentNode.instance.getShaderDef(domain);
+				var def = getDef(currentNode);
 				var expr = def.expr;
 
 				var outputDecls : Array<TVar> = [];
@@ -632,19 +666,37 @@ class Graph {
 						replacement = convertToType(nodeVar.v.type,  {e: TVar(outputVar), p:pos, t: outputVar.type});
 					}
 					else {
-						var shParam = Std.downcast(currentNode.instance, ShaderParam);
-						if (shParam != null) {
-							var outVar = outputs["output"];
-							var id = getNewVarId();
-							outVar.id = id;
-							outVar.name = nodeVar.v.name;
-							outVar.type = nodeVar.v.type;
-							outVar.kind = Param;
-							outVar.qualifiers = [];
-							graphInputVars.push({v: outVar, internal: false});
-							var param = getParameter(shParam.parameterId);
-							inits.push({variable: outVar, value: param.defaultValue});
-							continue;
+						if (nodeVar.internal) {
+							if (nodeVar.v.type == TSampler2D) {
+								// Rewrite output var to be the sampler directly because we can't assign
+								// a sampler to a temporary variable
+								var outVar = outputs["output"];
+								outVar.id = nodeVar.v.id;
+								outVar.name = nodeVar.v.name;
+								outVar.type = nodeVar.v.type;
+								outVar.qualifiers = nodeVar.v.qualifiers;
+								outVar.parent = nodeVar.v.parent;
+								outVar.kind = nodeVar.v.kind;
+
+								expr = null;
+
+								graphInputVars.pushUnique({v: outVar, internal: false, defVal: null});
+
+								var shParam = Std.downcast(currentNode.instance, ShaderParam);
+								var param = getParameter(shParam.parameterId);
+								inits.push({variable: outVar, value: param.defaultValue});
+
+								continue;
+							}
+
+							var inVar = getOutsideVar(nodeVar.v.name, nodeVar.v, true);
+
+							var shParam = Std.downcast(currentNode.instance, ShaderParam);
+							if (shParam != null) {
+								var param = getParameter(shParam.parameterId);
+								inits.push({variable: inVar, value: param.defaultValue});
+							}
+							replacement = {e: TVar(inVar), p: pos, t:nodeVar.v.type};
 						}
 						else {
 							// default parameter if no connection
@@ -663,7 +715,7 @@ class Graph {
 											id: getNewVarId(),
 											name: name,
 											type: nodeVar.v.type,
-											kind: Input
+											kind: Local
 										};
 										externs.push(tvar);
 									}
@@ -686,7 +738,10 @@ class Graph {
 						continue;
 					}
 					if (outputVar == null) {
-						graphOutputVars.push({v: nodeVar.v, internal: false});
+						var v = getOutsideVar(nodeVar.v.name, nodeVar.v, false);
+						expr = replaceVar(expr, nodeVar.v, {e: TVar(v), p:pos, t: nodeVar.v.type});
+
+						//graphOutputVars.push({v: nodeVar.v, internal: false});
 					} else {
 						expr = replaceVar(expr, nodeVar.v, {e: TVar(outputVar), p:pos, t: nodeVar.v.type});
 						outputDecls.push(outputVar);
@@ -777,7 +832,7 @@ class Graph {
 		var edgesJson : Array<Edge> = [];
 		for (n in nodes) {
 			for (inputName => connection in n.instance.connections) {
-				var def = n.instance.getShaderDef(domain);
+				var def = n.instance.getShaderDef(domain, () -> 0);
 				var inputId = null;
 				for (i => inVar in def.inVars) {
 					if (inVar.v.name == inputName) {
@@ -786,7 +841,7 @@ class Graph {
 					}
 				}
 
-				var def = connection.from.instance.getShaderDef(domain);
+				var def = connection.from.instance.getShaderDef(domain, () -> 0);
 				var outputId = null;
 				for (i => outVar in def.outVars) {
 					if (outVar.v.name == connection.fromName) {

+ 5 - 4
hrt/shgraph/ShaderInput.hx

@@ -19,15 +19,16 @@ class ShaderInput extends ShaderNode {
 	// 	return null;
 	// }
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var inVar : TVar = variable;
-		var output : TVar = {name: "output", id:1, type: this.variable.type, kind: Local, qualifiers: []};
+		var inVar : TVar = Reflect.copy(variable);
+		inVar.id = getNewIdFn();
+		var output : TVar = {name: "output", id: getNewIdFn(), type: this.variable.type, kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
 
 
-		return {expr: finalExpr, inVars: [], outVars:[{v:output, internal: false}], externVars: [inVar], inits: []};
+		return {expr: finalExpr, inVars: [{v:inVar, internal: true}], outVars:[{v:output, internal: false}], externVars: [], inits: []};
 	}
 
 	public static var availableInputs : Array<TVar> = [

+ 3 - 3
hrt/shgraph/ShaderNode.hx

@@ -23,7 +23,7 @@ class ShaderNode {
 					}];
 
 
-	public function getShaderDef(domain: ShaderGraph.Domain) : ShaderGraph.ShaderNodeDef {
+	public function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ) : ShaderGraph.ShaderNodeDef {
 		throw "getShaderDef is not defined for class " + Type.getClassName(Type.getClass(this));
 		return {expr: null, inVars: [], outVars: [], inits: [], externVars: []};
 	}
@@ -37,7 +37,7 @@ class ShaderNode {
 
 	// TODO(ces) : caching
 	public function getOutputs2(domain: ShaderGraph.Domain) : Map<String, TVar> {
-		var def = getShaderDef(domain);
+		var def = getShaderDef(domain, () -> 0);
 		var map : Map<String, TVar> = [];
 		for (tvar in def.outVars) {
 			if (!tvar.internal)
@@ -48,7 +48,7 @@ class ShaderNode {
 
 	// TODO(ces) : caching
 	public function getInputs2(domain: ShaderGraph.Domain) : Map<String, {v: TVar, ?def: hrt.shgraph.ShaderGraph.ShaderDefInput}> {
-		var def = getShaderDef(domain);
+		var def = getShaderDef(domain, () -> 0);
 		var map : Map<String, {v: TVar, ?def: hrt.shgraph.ShaderGraph.ShaderDefInput}> = [];
 		for (i => tvar in def.inVars) {
 			if (!tvar.internal) {

+ 25 - 3
hrt/shgraph/ShaderNodeHxsl.hx

@@ -1,21 +1,43 @@
 package hrt.shgraph;
 
+import hxsl.Ast.TExpr;
+using hxsl.Ast;
+
 @:autoBuild(hrt.shgraph.Macros.buildNode())
 class ShaderNodeHxsl extends ShaderNode {
 
 	static var nodeCache : Map<String, ShaderGraph.ShaderNodeDef> = [];
 
-	override public function getShaderDef(domain: ShaderGraph.Domain) : ShaderGraph.ShaderNodeDef {
+	override public function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ) : ShaderGraph.ShaderNodeDef {
 		var cl = Type.getClass(this);
 		var className = Type.getClassName(cl);
-		var def = nodeCache.get(className);
+		var def = null;//nodeCache.get(className);
 		if (def == null) {
-
 			var unser = new hxsl.Serializer();
 			var toUnser = (cl:Dynamic).SRC;
 			if (toUnser == null) throw "Node " + className + " has no SRC";
 			var data = @:privateAccess unser.unserialize(toUnser);
 			var expr = data.funs[0].expr;
+
+			var idToNewId : Map<Int, Int> = [];
+
+			function patchExprId(expr: TExpr) : TExpr {
+				switch (expr.e) {
+					case TVar(v):
+						var newId = idToNewId.get(v.id);
+						if (newId == null) {
+							newId = getNewIdFn();
+							idToNewId.set(v.id, newId);
+						}
+						v.id = newId;
+						return expr;
+					default:
+						return expr.map(patchExprId);
+				}
+			}
+
+			patchExprId(expr);
+
 			var inVars = [];
 			var outVars = [];
 			var externVars = [];

+ 4 - 3
hrt/shgraph/ShaderOutput.hx

@@ -1,3 +1,4 @@
+
 package hrt.shgraph;
 
 using hxsl.Ast;
@@ -13,11 +14,11 @@ class ShaderOutput extends ShaderNode {
 	var components = [X, Y, Z, W];
 
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var inVar : TVar = {name: "input", id:0, type: this.variable.type, kind: Param, qualifiers: []};
-		var output : TVar = {name: variable.name, id:1, type: this.variable.type, kind: Local, qualifiers: []};
+		var inVar : TVar = {name: "input", id: getNewIdFn(), type: this.variable.type, kind: Param, qualifiers: []};
+		var output : TVar = {name: variable.name, id: getNewIdFn(), type: this.variable.type, kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
 
 		//var param = getParameter(inputNode.parameterId);

+ 5 - 6
hrt/shgraph/ShaderParam.hx

@@ -12,19 +12,18 @@ class ShaderParam extends ShaderNode {
 	@prop() public var perInstance : Bool;
 
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
 		var qual = [];
 		if (this.variable.type == TSampler2D) {
 			qual.push(Sampler(this.variable.name));
 		}
-		var inVar : TVar = {name: this.variable.name, id:0, type: this.variable.type, kind: Param, qualifiers: qual};
-		var output : TVar = {name: "output", id:1, type: this.variable.type, kind: Local, qualifiers: []};
-		//var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
+		var inVar : TVar = {name: this.variable.name, id: getNewIdFn(), type: this.variable.type, kind: Param, qualifiers: qual};
+		var output : TVar = {name: "output", id: getNewIdFn(), type: this.variable.type, kind: Local, qualifiers: []};
+		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
 
-
-		return {expr: null, inVars: [{v:inVar, internal: true}], outVars:[{v:output, internal: false}], externVars: [], inits: []};
+		return {expr: finalExpr, inVars: [{v:inVar, internal: true}], outVars:[{v:output, internal: false}], externVars: [], inits: []};
 	}
 
 	public var variable : TVar;

+ 2 - 2
hrt/shgraph/nodes/BoolConst.hx

@@ -12,10 +12,10 @@ class BoolConst extends ShaderConst {
 
 	@prop() var value : Bool = true;
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var output : TVar = {name: "output", id:1, type: TBool, kind: Local, qualifiers: []};
+		var output : TVar = {name: "output", id: getNewIdFn(), type: TBool, kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TConst(CBool(value)), p: pos, t: output.type}), p: pos, t: output.type};
 
 		return {expr: finalExpr, inVars: [], outVars:[{v: output, internal: false}], externVars: [], inits: []};

+ 2 - 2
hrt/shgraph/nodes/Color.hx

@@ -26,10 +26,10 @@ class Color extends ShaderConst {
 	// 	};
 	// }
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var output : TVar = {name: "output", id:1, type: TVec(4, VFloat), kind: Local, qualifiers: []};
+		var output : TVar = {name: "output", id:getNewIdFn(), type: TVec(4, VFloat), kind: Local, qualifiers: []};
 		var finalExpr : TExpr =
 		{ e: TBinop(OpAssign, {
 				e: TVar(output),

+ 2 - 2
hrt/shgraph/nodes/FloatConst.hx

@@ -9,10 +9,10 @@ using hxsl.Ast;
 @noheader()
 class FloatConst extends ShaderConst {
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var output : TVar = {name: "output", id:1, type: TFloat, kind: Local, qualifiers: []};
+		var output : TVar = {name: "output", id: getNewIdFn(), type: TFloat, kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TConst(CFloat(value)), p: pos, t: output.type}), p: pos, t: output.type};
 
 		return {expr: finalExpr, inVars: [], outVars:[{v: output, internal: false}], externVars: [], inits: []};

+ 3 - 3
hrt/shgraph/nodes/Preview.hx

@@ -32,11 +32,11 @@ class AlphaPreview extends hxsl.Shader {
 @noheader()
 class Preview extends ShaderNode {
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var pos : Position = {file: "", min: 0, max: 0};
 
-		var inVar : TVar = {name: "input", id:0, type: TVec(4, VFloat), kind: Param, qualifiers: []};
-		var output : TVar = {name: "pixelColor", id:1, type: TVec(4, VFloat), kind: Local, qualifiers: []};
+		var inVar : TVar = {name: "input", id: getNewIdFn(), type: TVec(4, VFloat), kind: Param, qualifiers: []};
+		var output : TVar = {name: "pixelColor", id: getNewIdFn(), type: TVec(4, VFloat), kind: Local, qualifiers: []};
 		var finalExpr : TExpr = {e: TBinop(OpAssign, {e:TVar(output), p:pos, t:output.type}, {e: TVar(inVar), p: pos, t: output.type}), p: pos, t: output.type};
 
 		//var param = getParameter(inputNode.parameterId);

+ 1 - 1
hrt/shgraph/nodes/Sampler.hx

@@ -9,7 +9,7 @@ class Sampler extends ShaderNodeHxsl {
 
 	static var SRC = {
 		@sginput var texture : Sampler2D;
-		@sginput(uv) var uv : Vec2;
+		@sginput var uv : Vec2;
 		@sgoutput var RGBA : Vec4;
 		@sgoutput var RGB : Vec3;
 		@sgoutput var A : Float;

+ 2 - 2
hrt/shgraph/nodes/SubGraph.hx

@@ -10,9 +10,9 @@ class SubGraph extends ShaderNode {
 
 	@prop() public var pathShaderGraph : String;
 
-	override public function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override public function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		var shader = new ShaderGraph(pathShaderGraph);
-		var gen = shader.getGraph(domain).generate2();
+		var gen = shader.getGraph(domain).generate2(getNewIdFn);
 
 		// for (tvar in gen.externVars) {
 		// 	if (tvar.qualifiers != null) {

+ 1 - 1
hrt/shgraph/nodes/Text.hx

@@ -12,7 +12,7 @@ class Text extends ShaderNode {
 
 	@prop() var text : String = "";
 
-	override function getShaderDef(domain: ShaderGraph.Domain):hrt.shgraph.ShaderGraph.ShaderNodeDef {
+	override function getShaderDef(domain: ShaderGraph.Domain, getNewIdFn : () -> Int ):hrt.shgraph.ShaderGraph.ShaderNodeDef {
 		return {expr: null, inVars: [], outVars: [], inits: [], externVars: []};
 	}