ShaderGraph.hx 16 KB


  1. package hrt.shgraph;
  2. import hxsl.SharedShader;
  3. using hxsl.Ast;
  4. typedef Node = {
  5. x : Float,
  6. y : Float,
  7. id : Int,
  8. type : String,
  9. ?properties : Dynamic,
  10. ?instance : ShaderNode,
  11. ?outputs: Array<Node>,
  12. ?indegree : Int
  13. };
  14. private typedef Edge = {
  15. idOutput : Int,
  16. nameOutput : String,
  17. idInput : Int,
  18. nameInput : String
  19. };
  20. typedef Parameter = {
  21. name : String,
  22. type : Type,
  23. defaultValue : Dynamic,
  24. ?id : Int,
  25. ?variable : TVar,
  26. index : Int
  27. };
  28. class ShaderGraph {
  29. var allVariables : Array<TVar> = [];
  30. var allParameters = [];
  31. var allParamDefaultValue = [];
  32. var current_node_id = 0;
  33. var current_param_id = 0;
  34. var filepath : String;
  35. var nodes : Map<Int, Node> = [];
  36. public var parametersAvailable : Map<Int, Parameter> = [];
  37. public var parametersKeys : Array<Int> = [];
  38. // subgraph variable
  39. var variableNamesAlreadyUpdated = false;
  40. public function new(filepath : String) {
  41. if (filepath == null) return;
  42. this.filepath = filepath;
  43. var json : Dynamic;
  44. try {
  45. var content : String = null;
  46. #if editor
  47. content = sys.io.File.getContent(hide.Ide.inst.resourceDir + "/" + this.filepath);
  48. #else
  49. content = hxd.res.Loader.currentInstance.load(this.filepath).toText();
  50. //content = hxd.Res.load(this.filepath).toText();
  51. #end
  52. if (content.length == 0) return;
  53. json = haxe.Json.parse(content);
  54. } catch( e : Dynamic ) {
  55. throw "Invalid shader graph parsing ("+e+")";
  56. }
  57. load(json);
  58. }
  59. public function load(json : Dynamic) {
  60. nodes = [];
  61. parametersAvailable = [];
  62. generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges"), Reflect.getProperty(json, "parameters"));
  63. }
  64. public function checkParameterOrder() {
  65. parametersKeys.sort((x,y) -> Reflect.compare(parametersAvailable.get(x).index, parametersAvailable.get(y).index));
  66. }
  67. public function generate(nodes : Array<Node>, edges : Array<Edge>, parameters : Array<Parameter>) {
  68. for (p in parameters) {
  69. var typeString : Array<Dynamic> = Reflect.field(p, "type");
  70. if (Std.is(typeString, Array)) {
  71. if (typeString[1] == null || typeString[1].length == 0)
  72. p.type = std.Type.createEnum(Type, typeString[0]);
  73. else {
  74. var paramsEnum = typeString[1].split(",");
  75. p.type = std.Type.createEnum(Type, typeString[0], [Std.parseInt(paramsEnum[0]), std.Type.createEnum(VecType, paramsEnum[1])]);
  76. }
  77. }
  78. p.variable = generateParameter(p.name, p.type);
  79. this.parametersAvailable.set(p.id, p);
  80. parametersKeys.push(p.id);
  81. current_param_id = p.id + 1;
  82. }
  83. checkParameterOrder();
  84. for (n in nodes) {
  85. n.outputs = [];
  86. var cl = std.Type.resolveClass(n.type);
  87. if( cl == null ) throw "Missing shader node "+n.type;
  88. n.instance = std.Type.createInstance(cl, []);
  89. n.instance.setId(n.id);
  90. n.instance.loadProperties(n.properties);
  91. this.nodes.set(n.id, n);
  92. var shaderParam = Std.downcast(n.instance, ShaderParam);
  93. if (shaderParam != null) {
  94. var paramShader = getParameter(shaderParam.parameterId);
  95. shaderParam.variable = paramShader.variable;
  96. shaderParam.computeOutputs();
  97. }
  98. }
  99. if (nodes[nodes.length-1] != null)
  100. this.current_node_id = nodes[nodes.length-1].id+1;
  101. for (e in edges) {
  102. addEdge(e);
  103. }
  104. }
  105. public function addEdge(edge : Edge) {
  106. var node = this.nodes.get(edge.idInput);
  107. var output = this.nodes.get(edge.idOutput);
  108. node.instance.setInput(edge.nameInput, new NodeVar(output.instance, edge.nameOutput));
  109. output.outputs.push(node);
  110. var subShaderIn = Std.downcast(node.instance, hrt.shgraph.nodes.SubGraph);
  111. var subShaderOut = Std.downcast(output.instance, hrt.shgraph.nodes.SubGraph);
  112. if( @:privateAccess ((subShaderIn != null) && !subShaderIn.inputInfoKeys.contains(edge.nameInput))
  113. || @:privateAccess ((subShaderOut != null) && !subShaderOut.outputInfoKeys.contains(edge.nameOutput))
  114. ) {
  115. removeEdge(edge.idInput, edge.nameInput, false);
  116. }
  117. #if editor
  118. if (hasCycle()){
  119. removeEdge(edge.idInput, edge.nameInput, false);
  120. return false;
  121. }
  122. try {
  123. updateOutputs(output);
  124. } catch (e : Dynamic) {
  125. removeEdge(edge.idInput, edge.nameInput);
  126. throw e;
  127. }
  128. #end
  129. return true;
  130. }
  131. public function nodeUpdated(idNode : Int) {
  132. var node = this.nodes.get(idNode);
  133. if (node != null) {
  134. updateOutputs(node);
  135. }
  136. }
  137. function updateOutputs(node : Node) {
  138. node.instance.computeOutputs();
  139. for (o in node.outputs) {
  140. updateOutputs(o);
  141. }
  142. }
  143. public function removeEdge(idNode, nameInput, update = true) {
  144. var node = this.nodes.get(idNode);
  145. this.nodes.get(node.instance.getInput(nameInput).node.id).outputs.remove(node);
  146. node.instance.setInput(nameInput, null);
  147. if (update) {
  148. updateOutputs(node);
  149. }
  150. }
  151. public function setPosition(idNode : Int, x : Float, y : Float) {
  152. var node = this.nodes.get(idNode);
  153. node.x = x;
  154. node.y = y;
  155. }
  156. public function getNodes() {
  157. return this.nodes;
  158. }
  159. public function getNode(id : Int) {
  160. return this.nodes.get(id);
  161. }
  162. function generateParameter(name : String, type : Type) : TVar {
  163. return {
  164. parent: null,
  165. id: 0,
  166. kind:Param,
  167. name: name,
  168. type: type
  169. };
  170. }
  171. public function getParameter(id : Int) {
  172. return parametersAvailable.get(id);
  173. }
  174. function buildNodeVar(nodeVar : NodeVar) : Array<TExpr>{
  175. var node = nodeVar.node;
  176. var isSubGraph = Std.is(node, hrt.shgraph.nodes.SubGraph);
  177. if (node == null)
  178. return [];
  179. var res = [];
  180. var keys = node.getInputInfoKeys();
  181. var alreadyBuiltNodes = [];
  182. for (key in keys) {
  183. var input = node.getInput(key);
  184. if (input != null && !alreadyBuiltNodes.contains(input.node.id)) {
  185. res = res.concat(buildNodeVar(input));
  186. alreadyBuiltNodes.push(input.node.id);
  187. } else if (node.getInputInfo(key).hasProperty) {
  188. } else if (!node.getInputInfo(key).isRequired) {
  189. } else {
  190. throw ShaderException.t("This box has inputs not connected", node.id);
  191. }
  192. }
  193. var shaderInput = Std.downcast(node, ShaderInput);
  194. if (shaderInput != null) {
  195. var variable = shaderInput.variable;
  196. if ((variable.kind == Param || variable.kind == Global || variable.kind == Input || variable.kind == Local) && !alreadyAddVariable(variable)) {
  197. allVariables.push(variable);
  198. }
  199. }
  200. var shaderParam = Std.downcast(node, ShaderParam);
  201. if (shaderParam != null && !alreadyAddVariable(shaderParam.variable)) {
  202. if (shaderParam.variable == null) {
  203. shaderParam.variable = generateParameter(shaderParam.variable.name, shaderParam.variable.type);
  204. }
  205. allVariables.push(shaderParam.variable);
  206. allParameters.push(shaderParam.variable);
  207. if (parametersAvailable.exists(shaderParam.parameterId))
  208. allParamDefaultValue.push(getParameter(shaderParam.parameterId).defaultValue);
  209. }
  210. var build = [];
  211. if (!isSubGraph)
  212. build = nodeVar.getExpr();
  213. else {
  214. var subGraph = Std.downcast(node, hrt.shgraph.nodes.SubGraph);
  215. var nodeBuild = node.build("");
  216. for (k in subGraph.getOutputInfoKeys()) {
  217. var tvar = subGraph.getOutput(k);
  218. if (tvar != null && tvar.kind == Local && ShaderInput.availableInputs.indexOf(tvar) < 0)
  219. build.push({ e : TVarDecl(tvar), t : tvar.type, p : null });
  220. }
  221. if (nodeBuild != null)
  222. build.push(nodeBuild);
  223. var params = subGraph.subShaderGraph.parametersAvailable;
  224. for (subVar in subGraph.varsSubGraph) {
  225. if (subVar.kind == Param) {
  226. if (!alreadyAddVariable(subVar)) {
  227. allVariables.push(subVar);
  228. allParameters.push(subVar);
  229. var defaultValueFound = false;
  230. for (param in params) {
  231. if (param.variable.name == subVar.name) {
  232. allParamDefaultValue.push(param.defaultValue);
  233. defaultValueFound = true;
  234. break;
  235. }
  236. }
  237. if (!defaultValueFound) {
  238. throw ShaderException.t("Default value of '" + subVar.name + "' parameter not found", node.id);
  239. }
  240. }
  241. } else {
  242. if (!alreadyAddVariable(subVar)) {
  243. allVariables.push(subVar);
  244. }
  245. }
  246. }
  247. var buildWithoutTBlock = [];
  248. for (i in 0...build.length) {
  249. switch (build[i].e) {
  250. case TBlock(block):
  251. for (b in block) {
  252. buildWithoutTBlock.push(b);
  253. }
  254. default:
  255. buildWithoutTBlock.push(build[i]);
  256. }
  257. }
  258. build = buildWithoutTBlock;
  259. }
  260. res = res.concat(build);
  261. return res;
  262. }
  263. function alreadyAddVariable(variable : TVar) {
  264. for (v in allVariables) {
  265. if (v.name == variable.name && v.type == variable.type) {
  266. return true;
  267. }
  268. }
  269. return false;
  270. }
  271. var variableNameAvailableOnlyInVertex = [];
  272. public function generateShader(specificOutput : ShaderNode = null, subShaderId : Int = null) : ShaderData {
  273. allVariables = [];
  274. allParameters = [];
  275. allParamDefaultValue = [];
  276. var contentVertex = [];
  277. var contentFragment = [];
  278. for (n in nodes) {
  279. if (!variableNamesAlreadyUpdated && subShaderId != null && !Std.is(n.instance, ShaderInput)) {
  280. for (outputKey in n.instance.getOutputInfoKeys()) {
  281. var output = n.instance.getOutput(outputKey);
  282. if (output != null)
  283. output.name = "sub_" + subShaderId + "_" + output.name;
  284. }
  285. }
  286. n.instance.outputCompiled = [];
  287. #if !editor
  288. if (!n.instance.hasInputs()) {
  289. updateOutputs(n);
  290. }
  291. #end
  292. }
  293. variableNamesAlreadyUpdated = true;
  294. var outputs : Array<String> = [];
  295. for (g in ShaderGlobalInput.globalInputs) {
  296. allVariables.push(g);
  297. }
  298. for (n in nodes) {
  299. var outNode;
  300. var outVar;
  301. if (specificOutput != null) {
  302. if (n.instance != specificOutput) continue;
  303. outNode = specificOutput;
  304. outVar = Std.downcast(specificOutput, hrt.shgraph.nodes.Preview).variable;
  305. } else {
  306. var shaderOutput = Std.downcast(n.instance, ShaderOutput);
  307. if (shaderOutput != null) {
  308. outVar = shaderOutput.variable;
  309. outNode = n.instance;
  310. } else {
  311. continue;
  312. }
  313. }
  314. if (outNode != null) {
  315. if (outputs.indexOf(outVar.name) != -1) {
  316. throw ShaderException.t("This output already exists", n.id);
  317. }
  318. outputs.push(outVar.name);
  319. if ( !alreadyAddVariable(outVar) ) {
  320. allVariables.push(outVar);
  321. }
  322. var nodeVar = new NodeVar(outNode, "input");
  323. var isVertex = (variableNameAvailableOnlyInVertex.indexOf(outVar.name) != -1);
  324. if (isVertex) {
  325. contentVertex = contentVertex.concat(buildNodeVar(nodeVar));
  326. } else {
  327. contentFragment = contentFragment.concat(buildNodeVar(nodeVar));
  328. }
  329. if (specificOutput != null) break;
  330. }
  331. }
  332. var shvars = [];
  333. var inputVar : TVar = null, inputVars = [], inputMap = new Map();
  334. for( v in allVariables ) {
  335. if( v.id == 0 )
  336. v.id = hxsl.Tools.allocVarId();
  337. if( v.kind != Input ) {
  338. shvars.push(v);
  339. continue;
  340. }
  341. if( inputVar == null ) {
  342. inputVar = {
  343. id : hxsl.Tools.allocVarId(),
  344. name : "input",
  345. kind : Input,
  346. type : TStruct(inputVars),
  347. };
  348. shvars.push(inputVar);
  349. }
  350. var prevId = v.id;
  351. v = Reflect.copy(v);
  352. v.id = hxsl.Tools.allocVarId();
  353. v.parent = inputVar;
  354. inputVars.push(v);
  355. inputMap.set(prevId, v);
  356. }
  357. if( inputVars.length > 0 ) {
  358. function remap(e:TExpr) {
  359. return switch( e.e ) {
  360. case TVar(v):
  361. var i = inputMap.get(v.id);
  362. if( i == null ) e else { e : TVar(i), p : e.p, t : e.t };
  363. default:
  364. hxsl.Tools.map(e, remap);
  365. }
  366. }
  367. contentVertex = [for( e in contentVertex ) remap(e)];
  368. contentFragment = [for( e in contentFragment ) remap(e)];
  369. }
  370. var shaderData = {
  371. funs : [],
  372. name: "SHADER_GRAPH",
  373. vars: shvars
  374. };
  375. if (contentVertex.length > 0) {
  376. shaderData.funs.push({
  377. ret : TVoid, kind : Vertex,
  378. ref : {
  379. name : "vertex",
  380. id : 0,
  381. kind : Function,
  382. type : TFun([{ ret : TVoid, args : [] }])
  383. },
  384. expr : {
  385. p : null,
  386. t : TVoid,
  387. e : TBlock(contentVertex)
  388. },
  389. args : []
  390. });
  391. }
  392. if (contentFragment.length > 0) {
  393. shaderData.funs.push({
  394. ret : TVoid, kind : Fragment,
  395. ref : {
  396. name : "fragment",
  397. id : 0,
  398. kind : Function,
  399. type : TFun([{ ret : TVoid, args : [] }])
  400. },
  401. expr : {
  402. p : null,
  403. t : TVoid,
  404. e : TBlock(contentFragment)
  405. },
  406. args : []
  407. });
  408. }
  409. return shaderData;
  410. }
  411. public function compile(?specificOutput : ShaderNode, ?subShaderId : Int) : hrt.prefab.ContextShared.ShaderDef {
  412. var shaderData = generateShader(specificOutput, subShaderId);
  413. var s = new SharedShader("");
  414. s.data = shaderData;
  415. @:privateAccess s.initialize();
  416. var inits : Array<{ variable : hxsl.Ast.TVar, value : Dynamic }> = [];
  417. for (i in 0...allParameters.length) {
  418. inits.push({ variable : allParameters[i], value : allParamDefaultValue[i] });
  419. }
  420. var shaderDef = { shader : s, inits : inits };
  421. return shaderDef;
  422. }
  423. public function makeInstance(ctx: hrt.prefab.ContextShared) : hxsl.DynamicShader {
  424. var def = compile();
  425. var s = new hxsl.DynamicShader(def.shader);
  426. for (init in def.inits)
  427. setParamValue(ctx, s, init.variable, init.value);
  428. return s;
  429. }
  430. static function setParamValue(ctx: hrt.prefab.ContextShared, shader : hxsl.DynamicShader, variable : hxsl.Ast.TVar, value : Dynamic) {
  431. try {
  432. switch (variable.type) {
  433. case TSampler2D:
  434. var t = ctx.loadTexture(value);
  435. t.wrap = Repeat;
  436. shader.setParamValue(variable, t);
  437. case TVec(size, _):
  438. shader.setParamValue(variable, h3d.Vector.fromArray(value));
  439. default:
  440. shader.setParamValue(variable, value);
  441. }
  442. } catch (e : Dynamic) {
  443. // The parameter is not used
  444. }
  445. }
  446. #if editor
  447. public function addNode(x : Float, y : Float, nameClass : Class<ShaderNode>) {
  448. var node : Node = { x : x, y : y, id : current_node_id, type: std.Type.getClassName(nameClass) };
  449. node.instance = std.Type.createInstance(nameClass, []);
  450. node.instance.setId(current_node_id);
  451. node.instance.computeOutputs();
  452. node.outputs = [];
  453. this.nodes.set(node.id, node);
  454. current_node_id++;
  455. return node.instance;
  456. }
  457. public function hasCycle() : Bool {
  458. var queue : Array<Node> = [];
  459. var counter = 0;
  460. var nbNodes = 0;
  461. for (n in nodes) {
  462. n.indegree = n.outputs.length;
  463. if (n.indegree == 0) {
  464. queue.push(n);
  465. }
  466. nbNodes++;
  467. }
  468. var currentIndex = 0;
  469. while (currentIndex < queue.length) {
  470. var node = queue[currentIndex];
  471. currentIndex++;
  472. for (input in node.instance.getInputs()) {
  473. var nodeInput = nodes.get(input.node.id);
  474. nodeInput.indegree -= 1;
  475. if (nodeInput.indegree == 0) {
  476. queue.push(nodeInput);
  477. }
  478. }
  479. counter++;
  480. }
  481. return counter != nbNodes;
  482. }
  483. public function addParameter(type : Type) {
  484. var name = "Param_" + current_param_id;
  485. parametersAvailable.set(current_param_id, {id: current_param_id, name : name, type : type, defaultValue : null, variable : generateParameter(name, type), index : parametersKeys.length});
  486. parametersKeys.push(current_param_id);
  487. current_param_id++;
  488. return current_param_id-1;
  489. }
  490. public function setParameterTitle(id : Int, newName : String) {
  491. var p = parametersAvailable.get(id);
  492. if (p != null) {
  493. if (newName != null) {
  494. for (p in parametersAvailable) {
  495. if (p.name == newName) {
  496. return false;
  497. }
  498. }
  499. p.name = newName;
  500. p.variable = generateParameter(newName, p.type);
  501. return true;
  502. }
  503. }
  504. return false;
  505. }
  506. public function setParameterDefaultValue(id : Int, newDefaultValue : Dynamic) : Bool {
  507. var p = parametersAvailable.get(id);
  508. if (p != null) {
  509. if (newDefaultValue != null) {
  510. p.defaultValue = newDefaultValue;
  511. return true;
  512. }
  513. }
  514. return false;
  515. }
  516. public function removeParameter(id : Int) {
  517. parametersAvailable.remove(id);
  518. parametersKeys.remove(id);
  519. checkParameterIndex();
  520. }
  521. public function checkParameterIndex() {
  522. for (k in parametersKeys) {
  523. var oldParam = parametersAvailable.get(k);
  524. oldParam.index = parametersKeys.indexOf(k);
  525. parametersAvailable.set(k, oldParam);
  526. }
  527. }
  528. public function removeNode(idNode : Int) {
  529. this.nodes.remove(idNode);
  530. }
  531. public function save() {
  532. var edgesJson : Array<Edge> = [];
  533. for (n in nodes) {
  534. for (k in n.instance.getInputsKey()) {
  535. var output = n.instance.getInput(k);
  536. edgesJson.push({ idOutput: output.node.id, nameOutput: output.keyOutput, idInput: n.id, nameInput: k });
  537. }
  538. }
  539. var json = haxe.Json.stringify({
  540. nodes: [
  541. for (n in nodes) { x : Std.int(n.x), y : Std.int(n.y), id: n.id, type: n.type, properties : n.instance.savePropertiesNode() }
  542. ],
  543. edges: edgesJson,
  544. parameters: [
  545. for (p in parametersAvailable) { id : p.id, name : p.name, type : [p.type.getName(), p.type.getParameters().toString()], defaultValue : p.defaultValue, index : p.index }
  546. ]
  547. }, "\t");
  548. return json;
  549. }
  550. #end
  551. }