浏览代码

Topological sorting

Clement Espeute 2 年之前
父节点
当前提交
bdcbc1a000
共有 1 个文件被更改,包括 77 次插入117 次删除
  1. 77 117
      hrt/shgraph/ShaderGraph.hx

+ 77 - 117
hrt/shgraph/ShaderGraph.hx

@@ -3,6 +3,8 @@ package hrt.shgraph;
 import hxsl.SharedShader;
 using hxsl.Ast;
 using hide.tools.Extensions.ArrayExtensions;
+using haxe.EnumTools.EnumValueTools;
+using Lambda;
 
 typedef Node = {
 	x : Float,
@@ -234,34 +236,19 @@ class ShaderGraph {
 	}
 
 	public function generate2(?getNewVarId: () -> Int) : {expr: TExpr, inVars: Array<{variable: TVar, value: Dynamic}>, outVars: Array<TVar>} {
-
-		var pos : Position = {file: "", min: 0, max: 0};
-
-		var fragExprs : Array<TExpr> = [];
-
-		var outputNodes : Array<Node> = [];
-
-		var inits : Array<{ variable : hxsl.Ast.TVar, value : Dynamic }> = [];
-
-		// find all outputs
-		for (id => node in nodes) {
-			var shaderOutput = Std.downcast(node.instance, ShaderOutput);
-			if (shaderOutput != null) {
-				outputNodes.push(node);
-			}
-		}
 		if (getNewVarId == null) {
 			var varIdCount = 0;
-			getNewVarId = function(){return varIdCount++;};
+			getNewVarId = function()
+				{
+					return varIdCount++;
+				};
 		}
 
 		inline function getNewVarName(id: Int) : String {
-			return 'nodeoutput_$id';
+			return '_sg_var_$id';
 		}
 
-		var nodesInputs : Map<Node, Map<String, TVar>> = [];
 		var nodeOutputs : Map<Node, Map<String, TVar>> = [];
-
 		function getOutputs(node: Node) : Map<String, TVar> {
 			if (!nodeOutputs.exists(node)) {
 				var outputs : Map<String, TVar> = [];
@@ -281,14 +268,38 @@ class ShaderGraph {
 			return nodeOutputs.get(node);
 		}
 
-		var graphInputVars = [];
+		// Recursively replace the to tvar with from tvar in the given expression
+		function replaceVar(expr: TExpr, to: TVar, from: TVar) : TExpr {
+			if(!to.type.equals(from.type))
+				throw "type missmatch " + to.type + " != " + from.type;
+			function repRec(f: TExpr) {
+				if (f.e.equals(TVar(to))) {
+					return {e: TVar(from), t: from.type, p:f.p};
+				} else {
+					return f.map(repRec);
+				}
+			}
+			return repRec(expr);
+		}
 
+		// Shader generation starts here
 
-		var graphOutputsVars = [];
+		var pos : Position = {file: "", min: 0, max: 0};
+		var outputNodes : Array<Node> = [];
+		var inits : Array<{ variable : hxsl.Ast.TVar, value : Dynamic }> = [];
 
-		var nodeToExplore : Array<Node> = [];
+		// find all outputs
+		for (node in nodes) {
+			var shaderOutput = Std.downcast(node.instance, ShaderOutput);
+			if (shaderOutput != null) {
+				outputNodes.push(node);
+			}
+		}
 
-		var exprsReverse : Array<TExpr> = [];
+		var graphInputVars : Array<TVar> = [];
+		var graphOutputsVars : Array<TVar> = [];
+
+		var nodeToExplore : Array<Node> = [];
 
 		for (outputNode in outputNodes) {
 			var outputs = getOutputs(outputNode);
@@ -299,62 +310,56 @@ class ShaderGraph {
 			nodeToExplore.push(outputNode);
 		}
 
-		// maps the naked local vars in shader to regroup them
-		var shaderLocalVars : Map<String, TVar> = [];
+		var sortedNodes : Array<Node> = [];
 
-		function replaceVar(expr: TExpr, to: TVar, from: TVar) : TExpr {
-			if(!haxe.EnumTools.EnumValueTools.equals(to.type, from.type))
-				throw "type missmatch " + to.type + " != " + from.type;
-			function repRec(f: TExpr) {
-				if (haxe.EnumTools.EnumValueTools.equals(f.e, TVar(to))) {
-					return {e: TVar(from), t: from.type, p:f.p};
-				} else {
-					return f.map(repRec);
+		// Topological sort the nodes with Kahn's algorithm
+		// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
+		{
+			var connectionsSortCopy : Array<Connection> = [for (node in nodes) for (connection in node.instance.inputs2) connection];
+
+			while (nodeToExplore.length > 0) {
+				var currentNode = nodeToExplore.pop();
+				sortedNodes.push(currentNode);
+				for (connection in currentNode.instance.inputs2) {
+					var targetNode = connection.from;
+					if (!connectionsSortCopy.remove(connection)) throw "connection not in graph";
+					if (connectionsSortCopy.find((n:Connection) -> n.from == targetNode) == null) {
+						nodeToExplore.push(targetNode);
+					}
 				}
 			}
-			return repRec(expr);
 		}
 
-		function syncLocalVar(expr: TExpr, localVar: TVar) : TExpr {
-			throw "locals vars not supported yet";
-			var prev = shaderLocalVars.get(localVar.name);
-			if (prev != null) {
-				return replaceVar(expr, localVar, prev);
-			}
-			else {
-				shaderLocalVars.set(localVar.name, localVar);
-				return expr;
-			}
-		}
 
-		while (nodeToExplore.length > 0) {
-			var currentNode = nodeToExplore[0];
-			nodeToExplore.remove(currentNode);
+
+		//sortedNodes.reverse();
+
+		// Actually build the final shader expression
+		var exprsReverse : Array<TExpr> = [];
+		for (currentNode in sortedNodes) {
+
 			var outputs = getOutputs(currentNode);
 
-			var tvars : Array<TVar> = [];
-			for (name => input in currentNode.instance.inputs2) {
+			var inputVars : Array<TVar> = [];
+			for (input in currentNode.instance.inputs2) {
 				var outputs = getOutputs(input.from);
-				var tvar = outputs[input.fromName];
-				if (tvar == null) throw "null tvar";
+				var outputVar = outputs[input.fromName];
+				if (outputVar == null) throw "null tvar";
 
-				tvars.push(tvar);
-				nodeToExplore.pushUnique(input.from);
+				inputVars.push(outputVar);
 			}
 
-			// Simple addition node
-
-
 
 			if (Std.downcast(currentNode.instance, ShaderOutput) != null) {
 				var outputNode : ShaderOutput = cast currentNode.instance;
 				var outVar : TVar = {name: outputNode.variable.name, id:getNewVarId(), type: TVec(4, VFloat), kind: Local};
-				var finalExpr : TExpr = {e: TBinop(OpAssign, {e: TVar(outVar), p: pos, t: outVar.type}, {e: TVar(tvars[0]), p: pos, t: outVar.type}), p: pos, t: outVar.type};
+				var finalExpr : TExpr = {e: TBinop(OpAssign, {e: TVar(outVar), p: pos, t: outVar.type}, {e: TVar(inputVars[0]), p: pos, t: outVar.type}), p: pos, t: outVar.type};
 
 				exprsReverse.push(finalExpr);
 				graphOutputsVars.push(outVar);
 			} else if (Std.downcast(currentNode.instance, ShaderParam) != null) {
 				var inputNode : ShaderParam = cast currentNode.instance;
+
 				var inVar : TVar = {name: inputNode.variable.name, id:getNewVarId(), type: TVec(4, VFloat), kind: Param};
 
 				for (output in outputs) {
@@ -373,12 +378,12 @@ class ShaderGraph {
 				var finalExprs = [];
 
 				// Patch outputs
-				for (i => output in outputs) {
+				for (output in outputs) {
 					gen.expr = replaceVar(gen.expr, gen.outVars[0], output);
 				}
 
 				// Patch inputs
-				for (i => tvar in tvars) {
+				for (i => tvar in inputVars) {
 					var originalInput = gen.inVars[i].variable;
 					var finalExpr : TExpr = {e: TVarDecl(originalInput, {e: TVar(tvar), p: pos, t: originalInput.type}), p: pos, t: tvar.type};
 					finalExprs.push(finalExpr);
@@ -394,48 +399,33 @@ class ShaderGraph {
 			}
 			else
 			{
-				for (outName => output in outputs) {
+				for (outputName => output in outputs) {
 					if (output == null) throw "null output";
 
 					var fullExpr : TExpr = null;
 
 					if (Std.downcast(currentNode.instance, hrt.shgraph.nodes.Add) != null) {
 						var unser = new hxsl.Serializer();
-						var data = @:privateAccess unser.unserialize(TestNewNode2.SRC);
-						var outputTVar : TVar = data.vars.find((v : TVar) -> v.name == outName);
-						var fn = data.funs[0];
+						var shaderData = @:privateAccess unser.unserialize(TestNewNode2.SRC);
+						var outputVar : TVar = shaderData.vars.find((v : TVar) -> v.name == outputName);
+						var fn = shaderData.funs[0]; // TODO : spec the function to use for shader node definitions
 						var expr = fn.expr;
 
-						expr = replaceVar(expr, outputTVar, output);
+						expr = replaceVar(expr, outputVar, output);
 
-						for (i => tvar in tvars) {
-							var inputTVar : TVar = data.vars.find((v : TVar) -> v.name == 'input$i');
+						for (i => tvar in inputVars) {
+							var inputTVar : TVar = shaderData.vars.find((v : TVar) -> v.name == 'input$i');
 							expr = replaceVar(expr, inputTVar, tvar);
 						}
 
-						/*for (localvar in data.vars) {
-							if (localvar.qualifiers == null || localvar.qualifiers.length == 0) {
-								expr = syncLocalVar(expr,localvar);
-							}
-						}*/
-
 						exprsReverse.push(expr);
 
 						fullExpr = null;
 					}
 					else {
-						for (tvar in tvars) {
-							if (tvar == null) throw "null tvar";
-							var expr : TExpr = {e: TVar(tvar), p: pos, t: output.type};
-							if (fullExpr == null)
-								fullExpr = expr;
-							else
-								fullExpr = {e: TBinop(OpAdd, fullExpr, expr), p: pos, t: output.type};
-						}
+						throw "unsuported node";
 					}
 
-
-
 					{
 						if (output.type == null) throw "no type";
 						var finalExpr : TExpr = {e: TVarDecl(output, fullExpr), p: pos, t: output.type};
@@ -456,7 +446,7 @@ class ShaderGraph {
 	}
 
 	public function compile2() : hrt.prefab.ContextShared.ShaderDef {
-
+		var start = haxe.Timer.stamp();
 
 		var gen = generate2();
 
@@ -485,39 +475,9 @@ class ShaderGraph {
 		@:privateAccess shared.data = shaderData;
 		@:privateAccess shared.initialize();
 
-
+		var time = haxe.Timer.stamp() - start;
+		trace("Shader compile2 in " + time * 1000 + " ms");
 		return {shader : shared, inits: gen.inVars};
-
-		/*var pixelColor : TVar = {id: 0, name: "pixelColor", type: TVec(4, VFloat), kind: Local};
-		shaderData.vars.push(pixelColor);
-
-		var fragExprs : Array<TExpr> = [];
-
-		var pos : Position = {file: "no", min: 0, max: 0};
-
-		fragExprs.push({e : TBinop(
-				OpAssign,
-				{e: TSwiz({e: TVar(pixelColor), t: TFloat, p: pos}, [X]), t: TFloat, p: pos},
-				{e: TConst(CFloat(1.0)), t: TFloat, p: pos}),
-			t: TFloat, p: pos});
-
-		shaderData.funs.push({
-				ret : TVoid, kind : Fragment,
-				ref : {
-					name : "fragment",
-					id : 0,
-					kind : Function,
-					type : TFun([{ ret : TVoid, args : [] }])
-				},
-				expr : {
-					p : null,
-					t : TVoid,
-					e : TBlock(fragExprs)
-				},
-				args : []
-			});
-
-		return shaderData;*/
 	}
 
 	public function getParameter(id : Int) {