Browse Source

Node shader caching and fix dangling nodes

Clement Espeute 2 years ago
parent
commit
45cb9645ab
1 changed files with 46 additions and 20 deletions
  1. 46 20
      hrt/shgraph/ShaderGraph.hx

+ 46 - 20
hrt/shgraph/ShaderGraph.hx

@@ -235,6 +235,18 @@ class ShaderGraph {
 			};
 	}
 
+	static var nodeCache : Map<String, ShaderData> = [];
+	function getShaderData(cl: Class<TestNewNode>) {
+		var className = Type.getClassName(cl);
+		var data = nodeCache.get(className);
+		if (data == null) {
+			var unser = new hxsl.Serializer();
+			data = @:privateAccess unser.unserialize((cl:Dynamic).SRC);
+			nodeCache.set(className, data);
+		}
+		return data;
+	}
+
 	public function generate2(?getNewVarId: () -> Int) : {expr: TExpr, inVars: Array<{variable: TVar, value: Dynamic}>, outVars: Array<TVar>} {
 		if (getNewVarId == null) {
 			var varIdCount = 0;
@@ -288,12 +300,16 @@ class ShaderGraph {
 		var outputNodes : Array<Node> = [];
 		var inits : Array<{ variable : hxsl.Ast.TVar, value : Dynamic }> = [];
 
-		// find all outputs
+		var allConnections : Array<Connection> = [for (node in nodes) for (connection in node.instance.inputs2) connection];
+
+
+		// find all node with no output
+		var nodeHasOutputs : Map<Node, Bool> = [];
 		for (node in nodes) {
-			var shaderOutput = Std.downcast(node.instance, ShaderOutput);
-			if (shaderOutput != null) {
-				outputNodes.push(node);
-			}
+			nodeHasOutputs.set(node, false);
+		}
+		for (connection in allConnections) {
+			nodeHasOutputs.set(connection.from, true);
 		}
 
 		var graphInputVars : Array<TVar> = [];
@@ -301,42 +317,41 @@ class ShaderGraph {
 
 		var nodeToExplore : Array<Node> = [];
 
-		for (outputNode in outputNodes) {
-			var outputs = getOutputs(outputNode);
-			for (outputVar in outputs) {
-				graphOutputsVars.push(outputVar);
-			}
-
-			nodeToExplore.push(outputNode);
+		for (node => hasOutputs in nodeHasOutputs) {
+			if (!hasOutputs)
+				nodeToExplore.push(node);
 		}
 
+
 		var sortedNodes : Array<Node> = [];
 
 		// Topological sort the nodes with Kahn's algorithm
 		// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
 		{
-			var connectionsSortCopy : Array<Connection> = [for (node in nodes) for (connection in node.instance.inputs2) connection];
-
 			while (nodeToExplore.length > 0) {
 				var currentNode = nodeToExplore.pop();
 				sortedNodes.push(currentNode);
 				for (connection in currentNode.instance.inputs2) {
 					var targetNode = connection.from;
-					if (!connectionsSortCopy.remove(connection)) throw "connection not in graph";
-					if (connectionsSortCopy.find((n:Connection) -> n.from == targetNode) == null) {
+					if (!allConnections.remove(connection)) throw "connection not in graph";
+					if (allConnections.find((n:Connection) -> n.from == targetNode) == null) {
 						nodeToExplore.push(targetNode);
 					}
 				}
 			}
 		}
 
-
-
 		//sortedNodes.reverse();
 
 		// Actually build the final shader expression
 		var exprsReverse : Array<TExpr> = [];
 		for (currentNode in sortedNodes) {
+			// Skip nodes with no outputs that arent a final node
+			if (Std.downcast(currentNode.instance, ShaderOutput)==null) {
+				if (!nodeHasOutputs.get(currentNode))
+					continue;
+			}
+
 
 			var outputs = getOutputs(currentNode);
 
@@ -350,6 +365,8 @@ class ShaderGraph {
 			}
 
 
+
+
 			if (Std.downcast(currentNode.instance, ShaderOutput) != null) {
 				var outputNode : ShaderOutput = cast currentNode.instance;
 				var outVar : TVar = {name: outputNode.variable.name, id:getNewVarId(), type: TVec(4, VFloat), kind: Local};
@@ -406,7 +423,7 @@ class ShaderGraph {
 
 					if (Std.downcast(currentNode.instance, hrt.shgraph.nodes.Add) != null) {
 						var unser = new hxsl.Serializer();
-						var shaderData = @:privateAccess unser.unserialize(TestNewNode2.SRC);
+						var shaderData = getShaderData(TestNewNode2);
 						var outputVar : TVar = shaderData.vars.find((v : TVar) -> v.name == outputName);
 						var fn = shaderData.funs[0]; // TODO : spec the function to use for shader node definitions
 						var expr = fn.expr;
@@ -445,10 +462,17 @@ class ShaderGraph {
 		};
 	}
 
+	public static function measure2<T>(f:Void->T, ?pos:haxe.PosInfos):T {
+		//var t0 = haxe.Timer.stamp();
+		var r = f();
+		//haxe.Log.trace((haxe.Timer.stamp() - t0) * 1000 + "ms", pos);
+		return r;
+	}
+
 	public function compile2() : hrt.prefab.ContextShared.ShaderDef {
 		var start = haxe.Timer.stamp();
 
-		var gen = generate2();
+		var gen = measure2(()->generate2());
 
 		var shaderData : ShaderData = {
 			name: "",
@@ -471,12 +495,14 @@ class ShaderGraph {
 			args : []
 		});
 
+
 		var shared = new SharedShader("");
 		@:privateAccess shared.data = shaderData;
 		@:privateAccess shared.initialize();
 
 		var time = haxe.Timer.stamp() - start;
 		trace("Shader compile2 in " + time * 1000 + " ms");
+
 		return {shader : shared, inits: gen.inVars};
 	}