package hrt.shgraph; import hxsl.SharedShader; using hxsl.Ast; using haxe.EnumTools.EnumValueTools; using Lambda; import hrt.shgraph.AstTools.*; import hrt.shgraph.SgHxslVar.ShaderDefInput; enum abstract AngleUnit(String) { var Radian; var Degree; } final angleUnits = [Radian, Degree]; #if editor function getAngleUnitDropdown(self: Dynamic, width: Float) : hide.Element { var element = new hide.Element('
'); element.append('Unit'); element.append(new hide.Element('')); if (self.unit == null) { self.unit = angleUnits[0]; } var input = element.children("#unit"); var indexOption = 0; for (i => curAngle in angleUnits) { input.append(new hide.Element('')); if (self.unit == curAngle) { input.val(i); } indexOption++; } input.on("change", function(e) { var value = input.val(); self.unit = angleUnits[value]; }); return element; } #end enum SgType { SgFloat(dimension: Int); SgSampler; SgInt; SgBool; /** All the generics in the same shader node with the same id unify to the same type. Constraint : newType : the type we are trying to constraint. If null the function should return previousType previousType : the type previously constraint to this generic. if both newType and previousType are null, the function should return the most generic type for the constraint return : null if the newType can't be constrained, or a type that can fit both new and previous types **/ SgGeneric(id: Int, constraint: (newType: Type, previousType: Type) -> Null); } function typeToSgType(t: Type) : SgType { return switch(t) { case TFloat: SgFloat(1); case TVec(n, VFloat): SgFloat(n); case TSampler(T2D, false): SgSampler; case TInt: SgInt; case TBool: SgBool; default: throw "Unsuported type"; } } function sgTypeToType(t: SgType) : Type { return switch(t) { case SgBool: return TBool; case SgFloat(1): return TFloat; case SgFloat(n): return TVec(n, VFloat); case SgSampler: return TSampler(T2D, false); case SgInt: return TInt; case SgGeneric(id, consDtraint): throw "Can't resolve generic without context"; } } function ConstraintFloat(newType: Type, previousType: Type) : Null { function getN(type:Type) { return switch(type) { case TFloat: 1; case TVec(n, VFloat): n; case null, _: null; }; } var newN = getN(newType) ?? return (previousType ?? TFloat); var oldN = getN(previousType) ?? newN; var maxN = hxd.Math.imax(newN, oldN); switch (maxN) { case 1: return TFloat; case 2,3,4: return TVec(maxN, VFloat); default: throw "invalid float size " + maxN; } } typedef ShaderNodeDefInVar = {v: TVar, internal: Bool, ?defVal: ShaderDefInput, isDynamic: Bool}; typedef ShaderNodeDefOutVar = {v: TVar, internal: Bool, isDynamic: Bool}; typedef ShaderNodeDef = { expr: TExpr, inVars: Array, // If internal = true, don't show input in ui outVars: Array, externVars: Array, // other external variables like globals and stuff inits: Array<{variable: TVar, value: Dynamic}>, // Default values for some variables ?__inits__: Array<{name: String, e:TExpr}>, ?functions: Array, }; typedef Edge = { ?outputNodeId : Int, ?nameOutput : String, // Fallback if name has changed ?outputId : Int, ?inputNodeId : Int, ?nameInput : String, // Fallback if name has changed ?inputId : Int, }; typedef Connection = { from : ShaderNode, outputId : Int, }; typedef Parameter = { name : String, type : Type, defaultValue : Dynamic, ?id : Int, ?variable : TVar, ?internal: Bool, index : Int }; enum Domain { Vertex; Fragment; } @:structInit @:publicFields class ExternVarDef { var v: TVar; var defValue: Dynamic; var __init__: TExpr; var paramIndex: Null = null; } @:access(hrt.shgraph.Graph) class ShaderGraphGenContext { var graph : Graph; var includePreviews : Bool; public function new(graph: Graph, includePreviews: Bool = false) { this.graph = graph; this.includePreviews = includePreviews; } var nodes : Array<{ var outputs: Array>; var inputs : Array; var node : ShaderNode; }>; var inputNodes : Array = []; public function initNodes() { nodes = []; for (id => node in graph.nodes) { nodes[id] = {node: node, inputs : [], outputs : []}; } } public function generate(?genContext: NodeGenContext) : TExpr { initNodes(); var sortedNodes = sortGraph(); genContext = genContext ?? new NodeGenContext(graph.domain); var expressions : Array = []; genContext.expressions = expressions; for (nodeId in sortedNodes) { var node = nodes[nodeId]; genContext.initForNode(node.node, node.inputs); node.node.generate(genContext); for (outputId => expr in genContext.outputs) { if (expr == null) throw "null expr for output " + outputId; var targets = node.outputs[outputId]; if (targets == null) continue; for (target in targets) { nodes[target.to].inputs[target.input] = expr; } } genContext.finishNode(); } // Assign preview color to pixel color as last operation var previewColor = genContext.globalVars.get(Variables.Globals[Variables.Global.PreviewColor].name); if (previewColor != null) { var previewSelect = genContext.getOrAllocateGlobal(PreviewSelect); var pixelColor = genContext.getOrAllocateGlobal(PixelColor); var assign = makeAssign(makeVar(pixelColor), makeVar(previewColor.v)); var ifExpr = makeIf(makeBinop(makeVar(previewSelect), OpNotEq, makeInt(0)), assign); expressions.push(ifExpr); } for (id => p in graph.parent.parametersAvailable) { var global = genContext.globalVars.get(p.name); if (global == null) continue; global.defValue = p.defaultValue; global.paramIndex = p.index; } return AstTools.makeExpr(TBlock(expressions), TVoid); } // returns null if the graph couldn't be sorted (i.e. contains cycles) function sortGraph() : Array { // Topological sort all the nodes from input to ouputs var nodeToExplore : Array = []; var nodeTopology : Array<{to: Array, incoming: Int}> = []; nodeTopology.resize(nodes.length); for (id => node in nodes) { if (node == null) continue; nodeTopology[id] = {to: [], incoming: 0}; } var totalEdges = 0; for (id => node in nodes) { if (node == null) continue; var inst = node.node; var empty = true; var inputs = inst.getInputs(); // Todo : store ID of input in connections instead of relying on the "name" at runtime for (inputId => connection in inst.connections) { if (connection == null) continue; empty = false; var nodeOutputs = connection.from.getOutputs(); var outputs = nodes[connection.from.id].outputs; if (outputs == null) { outputs = []; nodes[connection.from.id].outputs = []; } var outputId = connection.outputId; var output = outputs[outputId]; if (output == null) { output = []; outputs[outputId] = output; } output.push({to: id, input: inputId}); nodeTopology[connection.from.id].to.push(id); nodeTopology[id].incoming ++; totalEdges++; } for (inputId => input in inputs) { } if (empty) { nodeToExplore.push(id); } } var sortedNodes : Array = []; // Perform the sort while (nodeToExplore.length > 0) { var currentNodeId = nodeToExplore.pop(); sortedNodes.push(currentNodeId); var currentTopology = nodeTopology[currentNodeId]; for (to in currentTopology.to) { var remaining = --nodeTopology[to].incoming; totalEdges --; if (remaining == 0) { nodeToExplore.push(to); } } } if (totalEdges > 0) { return null; } return sortedNodes; } } class ShaderGraph extends hrt.prefab.Prefab { var graphs : Array = []; var cachedDef : hrt.prefab.Cache.ShaderDef = null; static var _ = hrt.prefab.Prefab.register("shgraph", hrt.shgraph.ShaderGraph, "shgraph"); override public function load(json : Dynamic) : Void { super.load(json); graphs = []; parametersAvailable = []; parametersKeys = []; loadParameters(json.parameters ?? []); for (domain in haxe.EnumTools.getConstructors(Domain)) { var graph = new Graph(this, haxe.EnumTools.createByName(Domain, domain)); var graphJson = Reflect.getProperty(json, domain); if (graphJson != null) { graph.load(graphJson); } graphs.push(graph); } } override public function copy(other: hrt.prefab.Prefab) : Void { throw "Shadergraph is not meant to be put in a prefab tree. Use a dynamic shader that references this shadergraph instead"; } override function save() { var json = super.save(); json.parameters = [ for (p in parametersAvailable) { id : p.id, name : p.name, type : [p.type.getName(), p.type.getParameters().toString()], defaultValue : p.defaultValue, index : p.index, internal : p.internal } ]; for (graph in graphs) { var serName = EnumValueTools.getName(graph.domain); Reflect.setField(json, serName, graph.saveToDynamic()); } return json; } public function saveToText() : String { return haxe.Json.stringify(save(), "\t"); } static public function resolveDynamicType(inputTypes: Array, inVars: Array) : Type { var dynamicType : Type = TFloat; for (i => t in inputTypes) { var targetInput = inVars[i]; if (targetInput == null) throw "More input types than inputs"; if (!targetInput.isDynamic) continue; // Skip variables not marked as dynamic switch (t) { case null: case TFloat: if (dynamicType == null) dynamicType = TFloat; case TVec(size, t1): // Vec2 always convert to it because it's the smallest vec type switch(dynamicType) { case TFloat, null: dynamicType = t; case TVec(size2, t2): if (t1 != t2) throw "Incompatible vectors types"; dynamicType = TVec(size < size2 ? size : size2, t1); default: } default: throw "Type " + t + " is incompatible with Dynamic"; } } return dynamicType; } public function compile(?previewDomain: Domain) : hrt.prefab.Cache.ShaderDef { #if !editor if ( cachedDef != null ) return cachedDef; #end var inits : Array<{variable: TVar, value: Dynamic}>= []; var shaderData : ShaderData = { name: this.shared.currentPath ?? "", vars: [], funs: [], }; var nodeGen = new NodeGenContext(Vertex); nodeGen.previewDomain = previewDomain; for (i => graph in graphs) { if (previewDomain != null && previewDomain != graph.domain) continue; nodeGen.domain = graph.domain; var ctx = new ShaderGraphGenContext(graph); var gen = ctx.generate(nodeGen); var fnKind : FunctionKind = switch(previewDomain != null ? Fragment : graph.domain) { case Fragment: Fragment; case Vertex: Vertex; }; var functionName : String = EnumValueTools.getName(fnKind).toLowerCase(); var funcVar : TVar = { name : functionName, id : hxsl.Tools.allocVarId(), kind : Function, type : TFun([{ ret : TVoid, args : [] }]) }; var fn : TFunction = { ret: TVoid, kind: fnKind, ref: funcVar, expr: gen, args: [], }; shaderData.funs.push(fn); shaderData.vars.push(funcVar); } var externs = [for (v in nodeGen.globalVars) v]; var __init__exprs : Array= []; externs.sort((a,b) -> Reflect.compare(a.paramIndex ?? -1, b.paramIndex ?? -1)); for (v in externs) { // Patch unknow global variables to be locals instead with a dummy value // so the preview shader doesn't crash if (previewDomain != null && v.paramIndex == null) { var fullName = AstTools.getFullName(v.v); if (Variables.getGlobalNameMap().get(fullName) == null) { AstTools.removeFromParent(v.v); v.v.name = StringTools.replace(fullName, ".", "_") + "_SG"; v.v.kind = Local; var expr = switch (v.v.type) { case TInt: AstTools.makeInt(0); case TFloat: AstTools.makeFloat(0.0); case TVec(size, VFloat): AstTools.makeVec([for (i in 0...size) 0.0]); case TMat3: AstTools.makeGlobalCall(Mat3, [ AstTools.makeVec([1.0,0.0,0.0]), AstTools.makeVec([0.0,1.0,0.0]), AstTools.makeVec([0.0,0.0,1.0]), ], TMat3); case TMat4: AstTools.makeGlobalCall(Mat4, [ AstTools.makeVec([1.0,0.0,0.0,0.0]), AstTools.makeVec([0.0,1.0,0.0,0.0]), AstTools.makeVec([0.0,0.0,1.0,0.0]), AstTools.makeVec([0.0,0.0,0.0,1.0]), ], TMat4); case TChannel(_), TSampler(T2D, false): v.v.name = "blackChannel"; v.v.kind = Global; null; default: throw 'Can not default initialize global vaiable $fullName in preview shader (type ${v.v.type})'; } if (expr != null) v.__init__ = AstTools.makeAssign(AstTools.makeVar(v.v), expr); } } if (v.v.parent == null) { shaderData.vars.push(v.v); } if (v.defValue != null) { switch(v.v.kind) { case Param: inits.push({variable:v.v, value:v.defValue}); default: throw "unsupported default value for variable kind"; } } if (v.__init__ != null) { __init__exprs.push(v.__init__); } } if (__init__exprs.length != 0) { var funcVar : TVar = { name : "__init__", id : hxsl.Tools.allocVarId(), kind : Function, type : TFun([{ ret : TVoid, args : [] }]) }; var fn : TFunction = { ret : TVoid, kind : Init, ref : funcVar, expr : makeExpr(TBlock(__init__exprs), TVoid), args : [] }; shaderData.funs.push(fn); shaderData.vars.push(funcVar); } var shared = new SharedShader(""); @:privateAccess shared.data = shaderData; @:privateAccess shared.initialize(); cachedDef = {shader : shared, inits: inits} return cachedDef; } public function makeShaderInstance() : hxsl.DynamicShader { var def = compile(null); var s = new hxsl.DynamicShader(def.shader); for (init in def.inits) setParamValue(s, init.variable, init.value); return s; } static function setParamValue(shader : hxsl.DynamicShader, variable : hxsl.Ast.TVar, value : Dynamic) { try { switch (variable.type) { case TSampler(_): var t = hrt.impl.TextureType.Utils.getTextureFromValue(value, Repeat); shader.setParamValue(variable, t); case TVec(size, _): shader.setParamValue(variable, h3d.Vector.fromArray(value)); default: shader.setParamValue(variable, value); } } catch (e : Dynamic) { // The parameter is not used } } var allParameters = []; var current_param_id = 0; public var parametersAvailable : Map = []; public var parametersKeys : Array = []; function generateParameter(name : String, type : Type) : TVar { return { parent: null, id: 0, kind:Param, name: name, type: type }; } public function getParameter(id : Int) { return parametersAvailable.get(id); } public function addParameter(type : Type) { var name = "Param_" + current_param_id; parametersAvailable.set(current_param_id, {id: current_param_id, name : name, type : type, defaultValue : null, variable : generateParameter(name, type), index : parametersKeys.length}); parametersKeys.push(current_param_id); current_param_id++; return current_param_id-1; } function loadParameters(parameters: Array) { for (p in parameters) { var typeString : Array = Reflect.field(p, "type"); if (Std.isOfType(typeString, Array)) { typeString[1] = typeString[1] ?? ""; var enumParamsString = typeString[1].split(","); switch(typeString[0]) { case "TSampler2D": // Legacy parameters conversion p.type = Type.TSampler(T2D, false); case "TSampler": var params : Array = [std.Type.createEnum(TexDimension, enumParamsString[0] ?? "T2D"), enumParamsString[1] == "true"]; p.type = std.Type.createEnum(Type, typeString[0], params); case "TVec": var params : Array = [Std.parseInt(enumParamsString[0]), std.Type.createEnum(VecType, enumParamsString[1])]; p.type = std.Type.createEnum(Type, typeString[0], params); case "TFloat": p.type = TFloat; default: throw "Couldn't unserialize type " + typeString[0]; } } p.variable = generateParameter(p.name, p.type); this.parametersAvailable.set(p.id, p); parametersKeys.push(p.id); current_param_id = p.id + 1; } checkParameterOrder(); } public function checkParameterOrder() { parametersKeys.sort((x,y) -> Reflect.compare(parametersAvailable.get(x).index, parametersAvailable.get(y).index)); } public function setParameterTitle(id : Int, newName : String) { var p = parametersAvailable.get(id); if (p != null) { if (newName != null) { for (p in parametersAvailable) { if (p.name == newName) { return false; } } p.name = newName; p.variable = generateParameter(newName, p.type); return true; } } return false; } public function setParameterDefaultValue(id : Int, newDefaultValue : Dynamic) : Bool { var p = parametersAvailable.get(id); if (p != null) { p.defaultValue = newDefaultValue; return true; } return false; } public function removeParameter(id : Int) { parametersAvailable.remove(id); parametersKeys.remove(id); checkParameterIndex(); } public function checkParameterIndex() { for (k in parametersKeys) { var oldParam = parametersAvailable.get(k); oldParam.index = parametersKeys.indexOf(k); parametersAvailable.set(k, oldParam); } } public function getGraph(domain: Domain) { return graphs[domain.getIndex()]; } } class Graph { var cachedGen : ShaderNodeDef = null; var allParamDefaultValue = []; var current_node_id = 0; var nodes : Map = []; public var parent : ShaderGraph = null; public var domain : Domain = Fragment; public function new(parent: ShaderGraph, domain: Domain) { this.parent = parent; this.domain = domain; } public function load(json : Dynamic) { nodes = []; generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges")); } public function generate(nodes : Array, edges : Array) { current_node_id = 0; for (n in nodes) { var node = ShaderNode.createFromDynamic(n, parent); this.nodes.set(node.id, node); current_node_id = hxd.Math.imax(current_node_id, node.id+1); } // Migration patch for (e in edges) { if (e.inputNodeId == null) e.inputNodeId = (e:Dynamic).idInput; if (e.outputNodeId == null) e.outputNodeId = (e:Dynamic).idOutput; } for (e in edges) { addEdge(e, false); } } public function canAddEdge(edge : Edge) { var node = this.nodes.get(edge.inputNodeId); var output = this.nodes.get(edge.outputNodeId); var inputs = node.getInputs(); var outputs = output.getOutputs(); var inputType = inputs[edge.inputId].type; var outputType = outputs[edge.outputId].type; if (!areTypesCompatible(inputType, outputType)) { return false; } function hasCycle(node: ShaderNode, ?visited: Map) : Bool { var visited = visited?.copy() ?? []; if (visited.get(node) != null) { return true; } visited.set(node, true); for (id => conn in node.connections) { if (conn != null) { if (hasCycle(conn.from, visited)) return true; } } return false; } var prev = node.connections[edge.inputId]; node.connections[edge.inputId] = {from: output, outputId: edge.outputId}; var res = hasCycle(node); node.connections[edge.inputId] = prev; if (res) return false; return true; } public function addEdge(edge : Edge, checkCycles: Bool = true) { var node = this.nodes.get(edge.inputNodeId); var output = this.nodes.get(edge.outputNodeId); var inputs = node.getInputs(); var outputs = output.getOutputs(); var outputId = edge.outputId; var inputId = edge.inputId; { // Check if there is an output with that id and if it has the same name // else try to find the id of another output with the same name var output = outputs[outputId]; if (output == null || output.name != edge.nameOutput) { for (id => o in outputs) { if (o.name == edge.nameOutput) { outputId = id; break; } } }; } { var input = inputs[inputId]; if (input == null || input.name != edge.nameInput) { for (id => i in inputs) { if (i.name == edge.nameInput) { inputId = id; break; } } } } node.connections[inputId] = {from: output, outputId: outputId}; #if editor if (checkCycles && hasCycle()){ removeEdge(edge.inputNodeId, inputId, false); return false; } var inputType = inputs[inputId].type; var outputType = outputs[outputId].type; if (!areTypesCompatible(inputType, outputType)) { removeEdge(edge.inputNodeId, inputId); } try { } catch (e : Dynamic) { removeEdge(edge.inputNodeId, inputId); throw e; } #end return true; } public function addNode(shNode : ShaderNode) { this.nodes.set(shNode.id, shNode); } public function areTypesCompatible(input: SgType, output: SgType) : Bool { return switch (input) { case SgFloat(_): switch (output) { case SgFloat(_), SgGeneric(_,_): true; default: false; }; case SgGeneric(_, fn): switch (output) { case SgFloat(_), SgGeneric(_,_): true; default: false; }; default: haxe.EnumTools.EnumValueTools.equals(input, output); } } public function removeEdge(idNode, inputId, update = true) { var node = this.nodes.get(idNode); if (node.connections[inputId] == null) return; node.connections[inputId] = null; } public function setPosition(idNode : Int, x : Float, y : Float) { var node = this.nodes.get(idNode); node.x = x; node.y = y; } public function getNodes() { return this.nodes; } public function getNode(id : Int) { return this.nodes.get(id); } public function getParameter(id : Int) { return parent.getParameter(id); } public function hasCycle() : Bool { var ctx = new ShaderGraphGenContext(this, false); @:privateAccess ctx.initNodes(); var res = @:privateAccess ctx.sortGraph(); return res == null; } public function removeNode(idNode : Int) { this.nodes.remove(idNode); } public function saveToDynamic() : Dynamic { var edgesJson : Array = []; for (n in nodes) { for (inputId => connection in n.connections) { if (connection == null) continue; var outputId = connection.outputId; edgesJson.push({ outputNodeId: connection.from.id, nameOutput: connection.from.getOutputs()[outputId].name, inputNodeId: n.id, nameInput: n.getInputs()[inputId].name, inputId: inputId, outputId: outputId }); } } var json = { nodes: [ for (n in nodes) n.serializeToDynamic(), ], edges: edgesJson }; return json; } }