ShaderGraph.hx 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. package hrt.shgraph;
  2. import hxsl.SharedShader;
  3. using hxsl.Ast;
  4. using haxe.EnumTools.EnumValueTools;
  5. using Lambda;
  6. import hrt.shgraph.AstTools.*;
  7. import hrt.shgraph.SgHxslVar.ShaderDefInput;
  8. enum abstract AngleUnit(String) {
  9. var Radian;
  10. var Degree;
  11. }
  12. final angleUnits = [Radian, Degree];
  13. #if editor
  14. function getAngleUnitDropdown(self: Dynamic, width: Float) : hide.Element {
  15. var element = new hide.Element('<div style="width: ${width * 0.8}px; height: 40px"></div>');
  16. element.append('<span>Unit</span>');
  17. element.append(new hide.Element('<select id="unit"></select>'));
  18. if (self.unit == null) {
  19. self.unit = angleUnits[0];
  20. }
  21. var input = element.children("#unit");
  22. var indexOption = 0;
  23. for (i => curAngle in angleUnits) {
  24. input.append(new hide.Element('<option value="${i}">${curAngle}</option>'));
  25. if (self.unit == curAngle) {
  26. input.val(i);
  27. }
  28. indexOption++;
  29. }
  30. input.on("change", function(e) {
  31. var value = input.val();
  32. self.unit = angleUnits[value];
  33. });
  34. return element;
  35. }
  36. #end
  37. enum SgType {
  38. SgFloat(dimension: Int);
  39. SgSampler;
  40. SgInt;
  41. SgBool;
  42. /**
  43. All the generics in the same shader node with the same id unify to the
  44. same type.
  45. Constraint :
  46. newType : the type we are trying to constraint. If null the function should return previousType
  47. previousType : the type previously constraint to this generic.
  48. if both newType and previousType are null, the function should return the most generic type for the constraint
  49. return : null if the newType can't be constrained, or a type that can fit both new and previous types
  50. **/
  51. SgGeneric(id: Int, constraint: (newType: Type, previousType: Type) -> Null<Type>);
  52. }
  53. function typeToSgType(t: Type) : SgType {
  54. return switch(t) {
  55. case TFloat:
  56. SgFloat(1);
  57. case TVec(n, VFloat):
  58. SgFloat(n);
  59. case TSampler(T2D, false):
  60. SgSampler;
  61. case TInt:
  62. SgInt;
  63. case TBool:
  64. SgBool;
  65. default:
  66. throw "Unsuported type";
  67. }
  68. }
  69. function sgTypeToType(t: SgType) : Type {
  70. return switch(t) {
  71. case SgBool:
  72. return TBool;
  73. case SgFloat(1):
  74. return TFloat;
  75. case SgFloat(n):
  76. return TVec(n, VFloat);
  77. case SgSampler:
  78. return TSampler(T2D, false);
  79. case SgInt:
  80. return TInt;
  81. case SgGeneric(id, consDtraint):
  82. throw "Can't resolve generic without context";
  83. }
  84. }
  85. function ConstraintFloat(newType: Type, previousType: Type) : Null<Type> {
  86. function getN(type:Type) {
  87. return switch(type) {
  88. case TFloat:
  89. 1;
  90. case TVec(n, VFloat):
  91. n;
  92. case null, _:
  93. null;
  94. };
  95. }
  96. var newN = getN(newType) ?? return (previousType ?? TFloat);
  97. var oldN = getN(previousType) ?? newN;
  98. var maxN = hxd.Math.imax(newN, oldN);
  99. switch (maxN) {
  100. case 1:
  101. return TFloat;
  102. case 2,3,4:
  103. return TVec(maxN, VFloat);
  104. default:
  105. throw "invalid float size " + maxN;
  106. }
  107. }
  108. typedef ShaderNodeDefInVar = {v: TVar, internal: Bool, ?defVal: ShaderDefInput, isDynamic: Bool};
  109. typedef ShaderNodeDefOutVar = {v: TVar, internal: Bool, isDynamic: Bool};
  110. typedef ShaderNodeDef = {
  111. expr: TExpr,
  112. inVars: Array<ShaderNodeDefInVar>, // If internal = true, don't show input in ui
  113. outVars: Array<ShaderNodeDefOutVar>,
  114. externVars: Array<TVar>, // other external variables like globals and stuff
  115. inits: Array<{variable: TVar, value: Dynamic}>, // Default values for some variables
  116. ?__inits__: Array<{name: String, e:TExpr}>,
  117. ?functions: Array<TFunction>,
  118. };
  119. typedef Edge = {
  120. ?outputNodeId : Int,
  121. ?nameOutput : String, // Fallback if name has changed
  122. ?outputId : Int,
  123. ?inputNodeId : Int,
  124. ?nameInput : String, // Fallback if name has changed
  125. ?inputId : Int,
  126. };
  127. typedef Connection = {
  128. from : ShaderNode,
  129. outputId : Int,
  130. };
  131. typedef Parameter = {
  132. name : String,
  133. type : Type,
  134. defaultValue : Dynamic,
  135. ?id : Int,
  136. ?variable : TVar,
  137. ?internal: Bool,
  138. index : Int
  139. };
  140. enum Domain {
  141. Vertex;
  142. Fragment;
  143. }
  144. @:structInit @:publicFields
  145. class
  146. ExternVarDef {
  147. var v: TVar;
  148. var defValue: Dynamic;
  149. var __init__: TExpr;
  150. var paramIndex: Null<Int> = null;
  151. }
  152. @:access(hrt.shgraph.Graph)
  153. class ShaderGraphGenContext {
  154. var graph : Graph;
  155. var includePreviews : Bool;
  156. public function new(graph: Graph, includePreviews: Bool = false) {
  157. this.graph = graph;
  158. this.includePreviews = includePreviews;
  159. }
  160. var nodes : Array<{
  161. var outputs: Array<Array<{to: Int, input: Int}>>;
  162. var inputs : Array<TExpr>;
  163. var node : ShaderNode;
  164. }>;
  165. var inputNodes : Array<Int> = [];
  166. public function initNodes() {
  167. nodes = [];
  168. for (id => node in graph.nodes) {
  169. nodes[id] = {node: node, inputs : [], outputs : []};
  170. }
  171. }
  172. public function generate(?genContext: NodeGenContext) : TExpr {
  173. initNodes();
  174. var sortedNodes = sortGraph();
  175. genContext = genContext ?? new NodeGenContext(graph.domain);
  176. var expressions : Array<TExpr> = [];
  177. genContext.expressions = expressions;
  178. for (nodeId in sortedNodes) {
  179. var node = nodes[nodeId];
  180. genContext.initForNode(node.node, node.inputs);
  181. node.node.generate(genContext);
  182. for (outputId => expr in genContext.outputs) {
  183. if (expr == null) throw "null expr for output " + outputId;
  184. var targets = node.outputs[outputId];
  185. if (targets == null) continue;
  186. for (target in targets) {
  187. nodes[target.to].inputs[target.input] = expr;
  188. }
  189. }
  190. genContext.finishNode();
  191. }
  192. // Assign preview color to pixel color as last operation
  193. var previewColor = genContext.globalVars.get(Variables.Globals[Variables.Global.PreviewColor].name);
  194. if (previewColor != null) {
  195. var previewSelect = genContext.getOrAllocateGlobal(PreviewSelect);
  196. var pixelColor = genContext.getOrAllocateGlobal(PixelColor);
  197. var assign = makeAssign(makeVar(pixelColor), makeVar(previewColor.v));
  198. var ifExpr = makeIf(makeBinop(makeVar(previewSelect), OpNotEq, makeInt(0)), assign);
  199. expressions.push(ifExpr);
  200. }
  201. for (id => p in graph.parent.parametersAvailable) {
  202. var global = genContext.globalVars.get(p.name);
  203. if (global == null)
  204. continue;
  205. global.defValue = p.defaultValue;
  206. global.paramIndex = p.index;
  207. }
  208. return AstTools.makeExpr(TBlock(expressions), TVoid);
  209. }
  210. // returns null if the graph couldn't be sorted (i.e. contains cycles)
  211. function sortGraph() : Array<Int>
  212. {
  213. // Topological sort all the nodes from input to ouputs
  214. var nodeToExplore : Array<Int> = [];
  215. var nodeTopology : Array<{to: Array<Int>, incoming: Int}> = [];
  216. nodeTopology.resize(nodes.length);
  217. for (id => node in nodes) {
  218. if (node == null) continue;
  219. nodeTopology[id] = {to: [], incoming: 0};
  220. }
  221. var totalEdges = 0;
  222. for (id => node in nodes) {
  223. if (node == null) continue;
  224. var inst = node.node;
  225. var empty = true;
  226. var inputs = inst.getInputs();
  227. // Todo : store ID of input in connections instead of relying on the "name" at runtime
  228. for (inputId => connection in inst.connections) {
  229. if (connection == null)
  230. continue;
  231. empty = false;
  232. var nodeOutputs = connection.from.getOutputs();
  233. var outputs = nodes[connection.from.id].outputs;
  234. if (outputs == null) {
  235. outputs = [];
  236. nodes[connection.from.id].outputs = [];
  237. }
  238. var outputId = connection.outputId;
  239. var output = outputs[outputId];
  240. if (output == null) {
  241. output = [];
  242. outputs[outputId] = output;
  243. }
  244. output.push({to: id, input: inputId});
  245. nodeTopology[connection.from.id].to.push(id);
  246. nodeTopology[id].incoming ++;
  247. totalEdges++;
  248. }
  249. for (inputId => input in inputs) {
  250. }
  251. if (empty) {
  252. nodeToExplore.push(id);
  253. }
  254. }
  255. var sortedNodes : Array<Int> = [];
  256. // Perform the sort
  257. while (nodeToExplore.length > 0) {
  258. var currentNodeId = nodeToExplore.pop();
  259. sortedNodes.push(currentNodeId);
  260. var currentTopology = nodeTopology[currentNodeId];
  261. for (to in currentTopology.to) {
  262. var remaining = --nodeTopology[to].incoming;
  263. totalEdges --;
  264. if (remaining == 0) {
  265. nodeToExplore.push(to);
  266. }
  267. }
  268. }
  269. if (totalEdges > 0) {
  270. return null;
  271. }
  272. return sortedNodes;
  273. }
  274. }
  275. class ShaderGraph extends hrt.prefab.Prefab {
  276. var graphs : Array<Graph> = [];
  277. var cachedDef : hrt.prefab.Cache.ShaderDef = null;
  278. static var _ = hrt.prefab.Prefab.register("shgraph", hrt.shgraph.ShaderGraph, "shgraph");
  279. override public function load(json : Dynamic) : Void {
  280. super.load(json);
  281. graphs = [];
  282. parametersAvailable = [];
  283. parametersKeys = [];
  284. loadParameters(json.parameters ?? []);
  285. for (domain in haxe.EnumTools.getConstructors(Domain)) {
  286. var graph = new Graph(this, haxe.EnumTools.createByName(Domain, domain));
  287. var graphJson = Reflect.getProperty(json, domain);
  288. if (graphJson != null) {
  289. graph.load(graphJson);
  290. }
  291. graphs.push(graph);
  292. }
  293. }
  294. override public function copy(other: hrt.prefab.Prefab) : Void {
  295. throw "Shadergraph is not meant to be put in a prefab tree. Use a dynamic shader that references this shadergraph instead";
  296. }
  297. override function save() {
  298. var json = super.save();
  299. json.parameters = [
  300. for (p in parametersAvailable) { id : p.id, name : p.name, type : [p.type.getName(), p.type.getParameters().toString()], defaultValue : p.defaultValue, index : p.index, internal : p.internal }
  301. ];
  302. for (graph in graphs) {
  303. var serName = EnumValueTools.getName(graph.domain);
  304. Reflect.setField(json, serName, graph.saveToDynamic());
  305. }
  306. return json;
  307. }
  308. public function saveToText() : String {
  309. return haxe.Json.stringify(save(), "\t");
  310. }
  311. static public function resolveDynamicType(inputTypes: Array<Type>, inVars: Array<ShaderNodeDefInVar>) : Type {
  312. var dynamicType : Type = TFloat;
  313. for (i => t in inputTypes) {
  314. var targetInput = inVars[i];
  315. if (targetInput == null)
  316. throw "More input types than inputs";
  317. if (!targetInput.isDynamic)
  318. continue; // Skip variables not marked as dynamic
  319. switch (t) {
  320. case null:
  321. case TFloat:
  322. if (dynamicType == null)
  323. dynamicType = TFloat;
  324. case TVec(size, t1): // Vec2 always convert to it because it's the smallest vec type
  325. switch(dynamicType) {
  326. case TFloat, null:
  327. dynamicType = t;
  328. case TVec(size2, t2):
  329. if (t1 != t2)
  330. throw "Incompatible vectors types";
  331. dynamicType = TVec(size < size2 ? size : size2, t1);
  332. default:
  333. }
  334. default:
  335. throw "Type " + t + " is incompatible with Dynamic";
  336. }
  337. }
  338. return dynamicType;
  339. }
  340. public function compile(?previewDomain: Domain) : hrt.prefab.Cache.ShaderDef {
  341. #if !editor
  342. if ( cachedDef != null )
  343. return cachedDef;
  344. #end
  345. var inits : Array<{variable: TVar, value: Dynamic}>= [];
  346. var shaderData : ShaderData = {
  347. name: this.shared.currentPath ?? "",
  348. vars: [],
  349. funs: [],
  350. };
  351. var nodeGen = new NodeGenContext(Vertex);
  352. nodeGen.previewDomain = previewDomain;
  353. for (i => graph in graphs) {
  354. if (previewDomain != null && previewDomain != graph.domain)
  355. continue;
  356. nodeGen.domain = graph.domain;
  357. var ctx = new ShaderGraphGenContext(graph);
  358. var gen = ctx.generate(nodeGen);
  359. var fnKind : FunctionKind = switch(previewDomain != null ? Fragment : graph.domain) {
  360. case Fragment: Fragment;
  361. case Vertex: Vertex;
  362. };
  363. var functionName : String = EnumValueTools.getName(fnKind).toLowerCase();
  364. var funcVar : TVar = {
  365. name : functionName,
  366. id : hxsl.Tools.allocVarId(),
  367. kind : Function,
  368. type : TFun([{ ret : TVoid, args : [] }])
  369. };
  370. var fn : TFunction = {
  371. ret: TVoid, kind: fnKind,
  372. ref: funcVar,
  373. expr: gen,
  374. args: [],
  375. };
  376. shaderData.funs.push(fn);
  377. shaderData.vars.push(funcVar);
  378. }
  379. var externs = [for (v in nodeGen.globalVars) v];
  380. var __init__exprs : Array<TExpr>= [];
  381. externs.sort((a,b) -> Reflect.compare(a.paramIndex ?? -1, b.paramIndex ?? -1));
  382. for (v in externs) {
  383. // Patch unknow global variables to be locals instead with a dummy value
  384. // so the preview shader doesn't crash
  385. if (previewDomain != null && v.paramIndex == null) {
  386. var fullName = AstTools.getFullName(v.v);
  387. if (Variables.getGlobalNameMap().get(fullName) == null) {
  388. AstTools.removeFromParent(v.v);
  389. v.v.name = StringTools.replace(fullName, ".", "_") + "_SG";
  390. v.v.kind = Local;
  391. var expr = switch (v.v.type) {
  392. case TInt:
  393. AstTools.makeInt(0);
  394. case TFloat:
  395. AstTools.makeFloat(0.0);
  396. case TVec(size, VFloat):
  397. AstTools.makeVec([for (i in 0...size) 0.0]);
  398. case TMat3:
  399. AstTools.makeGlobalCall(Mat3, [
  400. AstTools.makeVec([1.0,0.0,0.0]),
  401. AstTools.makeVec([0.0,1.0,0.0]),
  402. AstTools.makeVec([0.0,0.0,1.0]),
  403. ], TMat3);
  404. case TMat4:
  405. AstTools.makeGlobalCall(Mat4, [
  406. AstTools.makeVec([1.0,0.0,0.0,0.0]),
  407. AstTools.makeVec([0.0,1.0,0.0,0.0]),
  408. AstTools.makeVec([0.0,0.0,1.0,0.0]),
  409. AstTools.makeVec([0.0,0.0,0.0,1.0]),
  410. ], TMat4);
  411. case TChannel(_), TSampler(T2D, false):
  412. v.v.name = "blackChannel";
  413. v.v.kind = Global;
  414. null;
  415. default:
  416. throw 'Can not default initialize global vaiable $fullName in preview shader (type ${v.v.type})';
  417. }
  418. if (expr != null)
  419. v.__init__ = AstTools.makeAssign(AstTools.makeVar(v.v), expr);
  420. }
  421. }
  422. if (v.v.parent == null) {
  423. shaderData.vars.push(v.v);
  424. }
  425. if (v.defValue != null) {
  426. switch(v.v.kind) {
  427. case Param:
  428. inits.push({variable:v.v, value:v.defValue});
  429. default:
  430. throw "unsupported default value for variable kind";
  431. }
  432. }
  433. if (v.__init__ != null) {
  434. __init__exprs.push(v.__init__);
  435. }
  436. }
  437. if (__init__exprs.length != 0) {
  438. var funcVar : TVar = {
  439. name : "__init__",
  440. id : hxsl.Tools.allocVarId(),
  441. kind : Function,
  442. type : TFun([{ ret : TVoid, args : [] }])
  443. };
  444. var fn : TFunction = {
  445. ret : TVoid, kind : Init,
  446. ref : funcVar,
  447. expr : makeExpr(TBlock(__init__exprs), TVoid),
  448. args : []
  449. };
  450. shaderData.funs.push(fn);
  451. shaderData.vars.push(funcVar);
  452. }
  453. var shared = new SharedShader("");
  454. @:privateAccess shared.data = shaderData;
  455. @:privateAccess shared.initialize();
  456. cachedDef = {shader : shared, inits: inits}
  457. return cachedDef;
  458. }
  459. public function makeShaderInstance() : hxsl.DynamicShader {
  460. var def = compile(null);
  461. var s = new hxsl.DynamicShader(def.shader);
  462. for (init in def.inits)
  463. setParamValue(s, init.variable, init.value);
  464. return s;
  465. }
  466. static function setParamValue(shader : hxsl.DynamicShader, variable : hxsl.Ast.TVar, value : Dynamic) {
  467. try {
  468. switch (variable.type) {
  469. case TSampler(_):
  470. var t = hrt.impl.TextureType.Utils.getTextureFromValue(value, Repeat);
  471. shader.setParamValue(variable, t);
  472. case TVec(size, _):
  473. shader.setParamValue(variable, h3d.Vector.fromArray(value));
  474. default:
  475. shader.setParamValue(variable, value);
  476. }
  477. } catch (e : Dynamic) {
  478. // The parameter is not used
  479. }
  480. }
  481. var allParameters = [];
  482. var current_param_id = 0;
  483. public var parametersAvailable : Map<Int, Parameter> = [];
  484. public var parametersKeys : Array<Int> = [];
  485. function generateParameter(name : String, type : Type) : TVar {
  486. return {
  487. parent: null,
  488. id: 0,
  489. kind:Param,
  490. name: name,
  491. type: type
  492. };
  493. }
  494. public function getParameter(id : Int) {
  495. return parametersAvailable.get(id);
  496. }
  497. public function addParameter(type : Type) {
  498. var name = "Param_" + current_param_id;
  499. parametersAvailable.set(current_param_id, {id: current_param_id, name : name, type : type, defaultValue : null, variable : generateParameter(name, type), index : parametersKeys.length});
  500. parametersKeys.push(current_param_id);
  501. current_param_id++;
  502. return current_param_id-1;
  503. }
  504. function loadParameters(parameters: Array<Dynamic>) {
  505. for (p in parameters) {
  506. var typeString : Array<Dynamic> = Reflect.field(p, "type");
  507. if (Std.isOfType(typeString, Array)) {
  508. typeString[1] = typeString[1] ?? "";
  509. var enumParamsString = typeString[1].split(",");
  510. switch(typeString[0]) {
  511. case "TSampler2D": // Legacy parameters conversion
  512. p.type = Type.TSampler(T2D, false);
  513. case "TSampler":
  514. var params : Array<Dynamic> = [std.Type.createEnum(TexDimension, enumParamsString[0] ?? "T2D"), enumParamsString[1] == "true"];
  515. p.type = std.Type.createEnum(Type, typeString[0], params);
  516. case "TVec":
  517. var params : Array<Dynamic> = [Std.parseInt(enumParamsString[0]), std.Type.createEnum(VecType, enumParamsString[1])];
  518. p.type = std.Type.createEnum(Type, typeString[0], params);
  519. case "TFloat":
  520. p.type = TFloat;
  521. default:
  522. throw "Couldn't unserialize type " + typeString[0];
  523. }
  524. }
  525. p.variable = generateParameter(p.name, p.type);
  526. this.parametersAvailable.set(p.id, p);
  527. parametersKeys.push(p.id);
  528. current_param_id = p.id + 1;
  529. }
  530. checkParameterOrder();
  531. }
  532. public function checkParameterOrder() {
  533. parametersKeys.sort((x,y) -> Reflect.compare(parametersAvailable.get(x).index, parametersAvailable.get(y).index));
  534. }
  535. public function setParameterTitle(id : Int, newName : String) {
  536. var p = parametersAvailable.get(id);
  537. if (p != null) {
  538. if (newName != null) {
  539. for (p in parametersAvailable) {
  540. if (p.name == newName) {
  541. return false;
  542. }
  543. }
  544. p.name = newName;
  545. p.variable = generateParameter(newName, p.type);
  546. return true;
  547. }
  548. }
  549. return false;
  550. }
  551. public function setParameterDefaultValue(id : Int, newDefaultValue : Dynamic) : Bool {
  552. var p = parametersAvailable.get(id);
  553. if (p != null) {
  554. p.defaultValue = newDefaultValue;
  555. return true;
  556. }
  557. return false;
  558. }
  559. public function removeParameter(id : Int) {
  560. parametersAvailable.remove(id);
  561. parametersKeys.remove(id);
  562. checkParameterIndex();
  563. }
  564. public function checkParameterIndex() {
  565. for (k in parametersKeys) {
  566. var oldParam = parametersAvailable.get(k);
  567. oldParam.index = parametersKeys.indexOf(k);
  568. parametersAvailable.set(k, oldParam);
  569. }
  570. }
  571. public function getGraph(domain: Domain) {
  572. return graphs[domain.getIndex()];
  573. }
  574. }
  575. class Graph {
  576. var cachedGen : ShaderNodeDef = null;
  577. var allParamDefaultValue = [];
  578. var current_node_id = 0;
  579. var nodes : Map<Int, ShaderNode> = [];
  580. public var parent : ShaderGraph = null;
  581. public var domain : Domain = Fragment;
  582. public function new(parent: ShaderGraph, domain: Domain) {
  583. this.parent = parent;
  584. this.domain = domain;
  585. }
  586. public function load(json : Dynamic) {
  587. nodes = [];
  588. generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges"));
  589. }
  590. public function generate(nodes : Array<Dynamic>, edges : Array<Edge>) {
  591. current_node_id = 0;
  592. for (n in nodes) {
  593. var node = ShaderNode.createFromDynamic(n, parent);
  594. this.nodes.set(node.id, node);
  595. current_node_id = hxd.Math.imax(current_node_id, node.id+1);
  596. }
  597. // Migration patch
  598. for (e in edges) {
  599. if (e.inputNodeId == null)
  600. e.inputNodeId = (e:Dynamic).idInput;
  601. if (e.outputNodeId == null)
  602. e.outputNodeId = (e:Dynamic).idOutput;
  603. }
  604. for (e in edges) {
  605. addEdge(e, false);
  606. }
  607. }
  608. public function canAddEdge(edge : Edge) {
  609. var node = this.nodes.get(edge.inputNodeId);
  610. var output = this.nodes.get(edge.outputNodeId);
  611. var inputs = node.getInputs();
  612. var outputs = output.getOutputs();
  613. var inputType = inputs[edge.inputId].type;
  614. var outputType = outputs[edge.outputId].type;
  615. if (!areTypesCompatible(inputType, outputType)) {
  616. return false;
  617. }
  618. function hasCycle(node: ShaderNode, ?visited: Map<ShaderNode, Bool>) : Bool {
  619. var visited = visited?.copy() ?? [];
  620. if (visited.get(node) != null) {
  621. return true;
  622. }
  623. visited.set(node, true);
  624. for (id => conn in node.connections) {
  625. if (conn != null) {
  626. if (hasCycle(conn.from, visited))
  627. return true;
  628. }
  629. }
  630. return false;
  631. }
  632. var prev = node.connections[edge.inputId];
  633. node.connections[edge.inputId] = {from: output, outputId: edge.outputId};
  634. var res = hasCycle(node);
  635. node.connections[edge.inputId] = prev;
  636. if (res)
  637. return false;
  638. return true;
  639. }
  640. public function addEdge(edge : Edge, checkCycles: Bool = true) {
  641. var node = this.nodes.get(edge.inputNodeId);
  642. var output = this.nodes.get(edge.outputNodeId);
  643. var inputs = node.getInputs();
  644. var outputs = output.getOutputs();
  645. var outputId = edge.outputId;
  646. var inputId = edge.inputId;
  647. {
  648. // Check if there is an output with that id and if it has the same name
  649. // else try to find the id of another output with the same name
  650. var output = outputs[outputId];
  651. if (output == null || output.name != edge.nameOutput) {
  652. for (id => o in outputs) {
  653. if (o.name == edge.nameOutput) {
  654. outputId = id;
  655. break;
  656. }
  657. }
  658. };
  659. }
  660. {
  661. var input = inputs[inputId];
  662. if (input == null || input.name != edge.nameInput) {
  663. for (id => i in inputs) {
  664. if (i.name == edge.nameInput) {
  665. inputId = id;
  666. break;
  667. }
  668. }
  669. }
  670. }
  671. node.connections[inputId] = {from: output, outputId: outputId};
  672. #if editor
  673. if (checkCycles && hasCycle()){
  674. removeEdge(edge.inputNodeId, inputId, false);
  675. return false;
  676. }
  677. var inputType = inputs[inputId].type;
  678. var outputType = outputs[outputId].type;
  679. if (!areTypesCompatible(inputType, outputType)) {
  680. removeEdge(edge.inputNodeId, inputId);
  681. }
  682. try {
  683. } catch (e : Dynamic) {
  684. removeEdge(edge.inputNodeId, inputId);
  685. throw e;
  686. }
  687. #end
  688. return true;
  689. }
  690. public function addNode(shNode : ShaderNode) {
  691. this.nodes.set(shNode.id, shNode);
  692. }
  693. public function areTypesCompatible(input: SgType, output: SgType) : Bool {
  694. return switch (input) {
  695. case SgFloat(_):
  696. switch (output) {
  697. case SgFloat(_), SgGeneric(_,_): true;
  698. default: false;
  699. };
  700. case SgGeneric(_, fn):
  701. switch (output) {
  702. case SgFloat(_), SgGeneric(_,_): true;
  703. default: false;
  704. };
  705. default: haxe.EnumTools.EnumValueTools.equals(input, output);
  706. }
  707. }
  708. public function removeEdge(idNode, inputId, update = true) {
  709. var node = this.nodes.get(idNode);
  710. if (node.connections[inputId] == null) return;
  711. node.connections[inputId] = null;
  712. }
  713. public function setPosition(idNode : Int, x : Float, y : Float) {
  714. var node = this.nodes.get(idNode);
  715. node.x = x;
  716. node.y = y;
  717. }
  718. public function getNodes() {
  719. return this.nodes;
  720. }
  721. public function getNode(id : Int) {
  722. return this.nodes.get(id);
  723. }
  724. public function getParameter(id : Int) {
  725. return parent.getParameter(id);
  726. }
  727. public function hasCycle() : Bool {
  728. var ctx = new ShaderGraphGenContext(this, false);
  729. @:privateAccess ctx.initNodes();
  730. var res = @:privateAccess ctx.sortGraph();
  731. return res == null;
  732. }
  733. public function removeNode(idNode : Int) {
  734. this.nodes.remove(idNode);
  735. }
  736. public function saveToDynamic() : Dynamic {
  737. var edgesJson : Array<Edge> = [];
  738. for (n in nodes) {
  739. for (inputId => connection in n.connections) {
  740. if (connection == null) continue;
  741. var outputId = connection.outputId;
  742. edgesJson.push({ outputNodeId: connection.from.id, nameOutput: connection.from.getOutputs()[outputId].name, inputNodeId: n.id, nameInput: n.getInputs()[inputId].name, inputId: inputId, outputId: outputId });
  743. }
  744. }
  745. var json = {
  746. nodes: [
  747. for (n in nodes) n.serializeToDynamic(),
  748. ],
  749. edges: edgesJson
  750. };
  751. return json;
  752. }
  753. }