Sfoglia il codice sorgente

[shgraph] Correct tpying of shader local vars, each node now has a ref to the graph

Clément Espeute 7 mesi fa
parent
commit
fbafeb0689

+ 2 - 2
hide/view/shadereditor/ShaderEditor.hx

@@ -214,6 +214,7 @@ class ShaderEditor extends hide.view.FileView implements GraphInterface.IGraphEd
 				@:privateAccess var id = currentGraph.current_node_id++;
 				inst.id = id;
 				inst.setPos(posCursor);
+				inst.graph = currentGraph;
 
 				graphEditor.opBox(inst, true, graphEditor.currentUndoBuffer);
 				graphEditor.commitUndo();
@@ -254,7 +255,6 @@ class ShaderEditor extends hide.view.FileView implements GraphInterface.IGraphEd
 
 			var inst = new ShaderParam();
 			inst.parameterId = draggedParamId;
-			inst.shaderGraph = shaderGraph;
 			addNode(inst);
 		};
 
@@ -1702,7 +1702,7 @@ class ShaderEditor extends hide.view.FileView implements GraphInterface.IGraphEd
 	}
 
 	public function unserializeNode(data : Dynamic, newId : Bool) : IGraphNode {
-		var node = ShaderNode.createFromDynamic(data, shaderGraph);
+		var node = ShaderNode.createFromDynamic(data, currentGraph);
 		if (newId) {
 			@:privateAccess var newId = currentGraph.current_node_id++;
 			node.setId(newId);

+ 13 - 0
hrt/shgraph/AstTools.hx

@@ -23,6 +23,19 @@ class AstTools {
 		);
 	}
 
+	public static function makeDynamic(type: Type, value: Dynamic) : TExpr {
+		switch(type) {
+			case TInt:
+				return makeInt(value is Int ? cast value : 0);
+			case TFloat:
+				return makeFloat(value is Float ? cast value : 0);
+			case TVec(size, VFloat):
+				return makeVec((value is Array && value.len == size) ? cast value : [for (_ in 0...size) 0.0]);
+			default:
+				throw "unsupported type " + type;
+		}
+	}
+
 	public inline static function makeAssign(to: TExpr, from: TExpr) : TExpr {
 		return makeExpr(TBinop(OpAssign, to, from), to.t);
 	}

+ 15 - 6
hrt/shgraph/NodeGenContext.hx

@@ -9,8 +9,8 @@ import hrt.shgraph.ShaderGraph;
 import hrt.shgraph.ShaderNode;
 
 class NodeGenContextSubGraph extends NodeGenContext {
-	public function new(parentCtx : NodeGenContext) {
-		super(parentCtx?.domain ?? Fragment);
+	public function new(graph: ShaderGraph.Graph, parentCtx : NodeGenContext) {
+		super(graph, parentCtx?.domain ?? Fragment);
 		this.parentCtx = parentCtx;
 	}
 
@@ -77,9 +77,11 @@ class NodeGenContext {
 	// Pour les rares nodes qui ont besoin de differencier entre vertex et fragment
 	public var domain : ShaderGraph.Domain;
 	public var previewDomain: ShaderGraph.Domain = null;
+	public var graph: ShaderGraph.Graph = null;
 
-	public function new(domain: ShaderGraph.Domain) {
+	public function new(graph: ShaderGraph.Graph, domain: ShaderGraph.Domain) {
 		this.domain = domain;
+		this.graph = graph;
 	}
 
 	// For general input/output of the shader graph. Allocate a new global var if name is not found,
@@ -121,8 +123,15 @@ class NodeGenContext {
 		expressions.push(makeAssign(v, expr));
 	}
 
-	public function getLocalTVar(name: String, type: Type) : TVar {
-		return MapUtils.getOrPut(localVars, name, {id: hxsl.Ast.Tools.allocVarId(), name: name, type: type, kind: Local});
+	public function getLocalTVar(id: Int, init: TExpr = null) : TVar {
+		var graphVar = graph.variables[id];
+		var type = ShaderGraph.sgTypeToType(graphVar.type);
+		var variable = MapUtils.getOrPut(localVars, id, {variable: {id: hxsl.Ast.Tools.allocVarId(), name: '_sg_local_$id', type: type, kind: Local}, isInit: false});
+		if (init != null && !variable.isInit) {
+			variable.isInit = true;
+			addExpr(AstTools.makeVarDecl(variable.variable, init));
+		}
+		return variable.variable;
 	}
 
 	function getOrAllocateFromTVar(tvar: TVar) : TVar {
@@ -369,5 +378,5 @@ class NodeGenContext {
 
 	var nodeInputInfo : Array<InputInfo>;
 	var globalVars: Map<String, ShaderGraph.ExternVarDef> = [];
-	var localVars: Map<String, TVar> = [];
+	var localVars: Map<Int, {variable: TVar, isInit: Bool}> = [];
 }

+ 15 - 6
hrt/shgraph/ShaderGraph.hx

@@ -229,7 +229,7 @@ class ShaderGraphGenContext {
 		initNodes();
 		var sortedNodes = sortGraph();
 
-		genContext = genContext ?? new NodeGenContext(graph.domain);
+		genContext = genContext ?? new NodeGenContext(graph, graph.domain);
 		var expressions : Array<TExpr> = [];
 		genContext.expressions = expressions;
 
@@ -269,6 +269,14 @@ class ShaderGraphGenContext {
 			global.paramIndex = p.index;
 		}
 
+		// Default init uninitialized local vars
+		for (id => variable in genContext.localVars) {
+			if (variable.isInit)
+				continue;
+			var initExpr = AstTools.makeVarDecl(variable.variable, AstTools.makeDynamic(variable.variable.type, graph.variables[id].defValue));
+			expressions.unshift(initExpr);
+		}
+
 		return AstTools.makeExpr(TBlock(expressions), TVoid);
 	}
 
@@ -451,13 +459,14 @@ class ShaderGraph extends hrt.prefab.Prefab {
 		};
 
 
-		var nodeGen = new NodeGenContext(Vertex);
+		var nodeGen = new NodeGenContext(null, Vertex);
 		nodeGen.previewDomain = previewDomain;
 
 		for (i => graph in graphs) {
 			if (previewDomain != null && previewDomain != graph.domain)
 				continue;
 			nodeGen.domain = graph.domain;
+			nodeGen.graph = graph;
 			var ctx = new ShaderGraphGenContext(graph);
 			var gen = ctx.generate(nodeGen);
 
@@ -717,7 +726,7 @@ class Graph {
 	var current_node_id = 0;
 	var nodes : Map<Int, ShaderNode> = [];
 
-	var variables : Array<ShaderGraphVariable> = [];
+	public var variables : Array<ShaderGraphVariable> = [];
 
 	public var parent : ShaderGraph = null;
 
@@ -731,8 +740,6 @@ class Graph {
 
 	public function load(json : Dynamic) {
 		nodes = [];
-		generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges"));
-
 		for (variable in json.variables ?? []) {
 			variables.push({
 				name: variable.name,
@@ -740,12 +747,14 @@ class Graph {
 				defValue: variable.defValue,
 			});
 		}
+
+		generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges"));
 	}
 
 	public function generate(nodes : Array<Dynamic>, edges : Array<Edge>) {
 		current_node_id = 0;
 		for (n in nodes) {
-			var node = ShaderNode.createFromDynamic(n, parent);
+			var node = ShaderNode.createFromDynamic(n, this);
 			this.nodes.set(node.id, node);
 			current_node_id = hxd.Math.imax(current_node_id, node.id+1);
 		}

+ 4 - 5
hrt/shgraph/ShaderNode.hx

@@ -53,6 +53,8 @@ implements hide.view.GraphInterface.IGraphNode
 	public var x : Float;
 	public var y : Float;
 	public var showPreview : Bool = true;
+	public var graph : ShaderGraph.Graph;
+
 	@prop public var nameOverride : String;
 
 
@@ -156,16 +158,13 @@ implements hide.view.GraphInterface.IGraphNode
 		};
 	}
 
-	public static function createFromDynamic(data: Dynamic, graph: ShaderGraph) : ShaderNode {
+	public static function createFromDynamic(data: Dynamic, graph: ShaderGraph.Graph) : ShaderNode {
 		var type = std.Type.resolveClass(data.type);
 		var inst = std.Type.createInstance(type, []);
-		var shaderParam = Std.downcast(inst, ShaderParam);
-		if (shaderParam != null) {
-			shaderParam.shaderGraph = graph;
-		}
 		inst.x = data.x;
 		inst.y = data.y;
 		inst.id = data.id;
+		inst.graph = graph;
 		inst.connections = [];
 		inst.loadProperties(data.properties);
 		return inst;

+ 1 - 3
hrt/shgraph/ShaderParam.hx

@@ -9,8 +9,6 @@ class ShaderParam extends ShaderNode {
 	@prop() public var parameterId : Int;
 	@prop() public var perInstance : Bool;
 
-	public var shaderGraph : ShaderGraph;
-
 	public function new() {
 
 	}
@@ -46,7 +44,7 @@ class ShaderParam extends ShaderNode {
 	}
 
 	function getVariable() : TVar {
-		return shaderGraph.getParameter(parameterId).variable;
+		return graph.parent.getParameter(parameterId).variable;
 	}
 
 	override public function loadProperties(props : Dynamic) {

+ 3 - 3
hrt/shgraph/nodes/SubGraph.hx

@@ -25,7 +25,7 @@ class SubGraph extends ShaderNode {
 		var graph = shader.getGraph(ctx.domain);
 
 		var genCtx = new ShaderGraphGenContext(graph, false);
-		genCtx.generate(new NodeGenContext.NodeGenContextSubGraph(ctx));
+		genCtx.generate(new NodeGenContext.NodeGenContextSubGraph(graph, ctx));
 	}
 
 	override public function getInputs() : Array<ShaderNode.InputInfo> {
@@ -37,7 +37,7 @@ class SubGraph extends ShaderNode {
 		var graph = shader.getGraph(hrt.shgraph.ShaderGraph.Domain.Fragment);
 
 		var genCtx = new ShaderGraphGenContext(graph, false);
-		var nodeGenCtx = new NodeGenContext.NodeGenContextSubGraph(null);
+		var nodeGenCtx = new NodeGenContext.NodeGenContextSubGraph(graph, null);
 		genCtx.generate(nodeGenCtx);
 		var inputs: Array<ShaderNode.InputInfo> = [];
 
@@ -59,7 +59,7 @@ class SubGraph extends ShaderNode {
 		var graph = shader.getGraph(hrt.shgraph.ShaderGraph.Domain.Fragment);
 
 		var genCtx = new ShaderGraphGenContext(graph, false);
-		var nodeGenCtx = new NodeGenContext.NodeGenContextSubGraph(null);
+		var nodeGenCtx = new NodeGenContext.NodeGenContextSubGraph(graph, null);
 		genCtx.generate(nodeGenCtx);
 		var outputs: Array<ShaderNode.InputInfo> = [];
 

+ 2 - 2
hrt/shgraph/nodes/VarRead.hx

@@ -13,13 +13,13 @@ class VarRead extends ShaderVar {
 	var outputs: Array<ShaderNode.OutputInfo>;
 	override public function getOutputs() : Array<ShaderNode.OutputInfo> {
 		if (outputs == null) {
-			outputs  = [{name: "output", type: SgFloat(4)}];
+			outputs  = [{name: "output", type: graph.variables[varId].type}];
 		}
 		return outputs;
 	}
 
 	override function generate(ctx:NodeGenContext) {
-		var out = AstTools.makeVar(ctx.getLocalTVar('_sg_var_$varId', TVec(4, VFloat)));
+		var out = AstTools.makeVar(ctx.getLocalTVar(varId));
 		ctx.setOutput(0, out);
 		#if editor
 		ctx.addPreview(out);

+ 3 - 3
hrt/shgraph/nodes/VarWrite.hx

@@ -12,15 +12,15 @@ class VarWrite extends ShaderVar {
 	var inputs: Array<ShaderNode.InputInfo>;
 	override public function getInputs() : Array<ShaderNode.InputInfo> {
 		if (inputs == null) {
-			inputs = [{name: "input", type: SgFloat(4)}];
+			inputs = [{name: "input", type: graph.variables[varId].type}];
 		}
 		return inputs;
 	}
 
 	override function generate(ctx:NodeGenContext) {
 		var input = ctx.getInput(0);
-		ctx.addExpr(AstTools.makeVarDecl(ctx.getLocalTVar('_sg_var_$varId', TVec(4, VFloat)), input));
-		ctx.addPreview(input);
+		var tVar = ctx.getLocalTVar(varId, input);
+		ctx.addPreview(AstTools.makeVar(tVar));
 	}
 
 	#if editor