ShaderGraph.hx 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. package hrt.shgraph;
  2. import hxsl.SharedShader;
  3. using hxsl.Ast;
  4. typedef Node = {
  5. x : Float,
  6. y : Float,
  7. id : Int,
  8. type : String,
  9. ?properties : Dynamic,
  10. ?instance : ShaderNode,
  11. ?outputs: Array<Node>,
  12. ?indegree : Int
  13. };
  14. private typedef Edge = {
  15. idOutput : Int,
  16. nameOutput : String,
  17. idInput : Int,
  18. nameInput : String
  19. };
  20. typedef Parameter = {
  21. name : String,
  22. type : Type,
  23. defaultValue : Dynamic,
  24. ?id : Int,
  25. ?variable : TVar
  26. };
  27. class ShaderGraph {
  28. var allVariables : Array<TVar> = [];
  29. var allParameters = [];
  30. var allParamDefaultValue = [];
  31. var current_node_id = 0;
  32. var current_param_id = 0;
  33. var filepath : String;
  34. var nodes : Map<Int, Node> = [];
  35. public var parametersAvailable : Map<Int, Parameter> = [];
  36. // subgraph variable
  37. var variableNamesAlreadyUpdated = false;
  38. public function new(filepath : String) {
  39. if (filepath == null) return;
  40. this.filepath = filepath;
  41. var json : Dynamic;
  42. try {
  43. var content : String = null;
  44. #if editor
  45. content = sys.io.File.getContent(hide.Ide.inst.resourceDir + "/" + this.filepath);
  46. #else
  47. content = hxd.res.Loader.currentInstance.load(this.filepath).toText();
  48. //content = hxd.Res.load(this.filepath).toText();
  49. #end
  50. if (content.length == 0) return;
  51. json = haxe.Json.parse(content);
  52. } catch( e : Dynamic ) {
  53. throw "Invalid shader graph parsing ("+e+")";
  54. }
  55. load(json);
  56. }
  57. public function load(json : Dynamic) {
  58. nodes = [];
  59. parametersAvailable = [];
  60. generate(Reflect.getProperty(json, "nodes"), Reflect.getProperty(json, "edges"), Reflect.getProperty(json, "parameters"));
  61. }
  62. public function generate(nodes : Array<Node>, edges : Array<Edge>, parameters : Array<Parameter>) {
  63. for (p in parameters) {
  64. var typeString : Array<Dynamic> = Reflect.field(p, "type");
  65. if (Std.is(typeString, Array)) {
  66. if (typeString[1] == null || typeString[1].length == 0)
  67. p.type = std.Type.createEnum(Type, typeString[0]);
  68. else {
  69. var paramsEnum = typeString[1].split(",");
  70. p.type = std.Type.createEnum(Type, typeString[0], [Std.parseInt(paramsEnum[0]), std.Type.createEnum(VecType, paramsEnum[1])]);
  71. }
  72. }
  73. p.variable = generateParameter(p.name, p.type);
  74. this.parametersAvailable.set(p.id, p);
  75. current_param_id = p.id + 1;
  76. }
  77. for (n in nodes) {
  78. n.outputs = [];
  79. var cl = std.Type.resolveClass(n.type);
  80. if( cl == null ) throw "Missing shader node "+n.type;
  81. n.instance = std.Type.createInstance(cl, []);
  82. n.instance.setId(n.id);
  83. n.instance.loadProperties(n.properties);
  84. this.nodes.set(n.id, n);
  85. var shaderParam = Std.downcast(n.instance, ShaderParam);
  86. if (shaderParam != null) {
  87. var paramShader = getParameter(shaderParam.parameterId);
  88. shaderParam.variable = paramShader.variable;
  89. shaderParam.computeOutputs();
  90. }
  91. }
  92. if (nodes[nodes.length-1] != null)
  93. this.current_node_id = nodes[nodes.length-1].id+1;
  94. for (e in edges) {
  95. addEdge(e);
  96. }
  97. }
  98. public function addEdge(edge : Edge) {
  99. var node = this.nodes.get(edge.idInput);
  100. var output = this.nodes.get(edge.idOutput);
  101. node.instance.setInput(edge.nameInput, new NodeVar(output.instance, edge.nameOutput));
  102. output.outputs.push(node);
  103. var subShaderIn = Std.downcast(node.instance, hrt.shgraph.nodes.SubGraph);
  104. var subShaderOut = Std.downcast(output.instance, hrt.shgraph.nodes.SubGraph);
  105. if( @:privateAccess ((subShaderIn != null) && !subShaderIn.inputInfoKeys.contains(edge.nameInput))
  106. || @:privateAccess ((subShaderOut != null) && !subShaderOut.outputInfoKeys.contains(edge.nameOutput))
  107. ) {
  108. removeEdge(edge.idInput, edge.nameInput, false);
  109. }
  110. #if editor
  111. if (hasCycle()){
  112. removeEdge(edge.idInput, edge.nameInput, false);
  113. return false;
  114. }
  115. try {
  116. updateOutputs(output);
  117. } catch (e : Dynamic) {
  118. removeEdge(edge.idInput, edge.nameInput);
  119. throw e;
  120. }
  121. #end
  122. return true;
  123. }
  124. public function nodeUpdated(idNode : Int) {
  125. var node = this.nodes.get(idNode);
  126. if (node != null) {
  127. updateOutputs(node);
  128. }
  129. }
  130. function updateOutputs(node : Node) {
  131. node.instance.computeOutputs();
  132. for (o in node.outputs) {
  133. updateOutputs(o);
  134. }
  135. }
  136. public function removeEdge(idNode, nameInput, update = true) {
  137. var node = this.nodes.get(idNode);
  138. this.nodes.get(node.instance.getInput(nameInput).node.id).outputs.remove(node);
  139. node.instance.setInput(nameInput, null);
  140. if (update) {
  141. updateOutputs(node);
  142. }
  143. }
  144. public function setPosition(idNode : Int, x : Float, y : Float) {
  145. var node = this.nodes.get(idNode);
  146. node.x = x;
  147. node.y = y;
  148. }
  149. public function getNodes() {
  150. return this.nodes;
  151. }
  152. public function getNode(id : Int) {
  153. return this.nodes.get(id);
  154. }
  155. function generateParameter(name : String, type : Type) : TVar {
  156. return {
  157. parent: null,
  158. id: 0,
  159. kind:Param,
  160. name: name,
  161. type: type
  162. };
  163. }
  164. public function getParameter(id : Int) {
  165. return parametersAvailable.get(id);
  166. }
  167. function buildNodeVar(nodeVar : NodeVar) : Array<TExpr>{
  168. var node = nodeVar.node;
  169. var isSubGraph = Std.is(node, hrt.shgraph.nodes.SubGraph);
  170. if (node == null)
  171. return [];
  172. var res = [];
  173. var keys = node.getInputInfoKeys();
  174. for (key in keys) {
  175. var input = node.getInput(key);
  176. if (input != null) {
  177. res = res.concat(buildNodeVar(input));
  178. } else if (node.getInputInfo(key).hasProperty) {
  179. } else if (!node.getInputInfo(key).isRequired) {
  180. } else {
  181. throw ShaderException.t("This box has inputs not connected", node.id);
  182. }
  183. }
  184. var build = nodeVar.getExpr();
  185. var shaderInput = Std.downcast(node, ShaderInput);
  186. if (shaderInput != null) {
  187. var variable = shaderInput.variable;
  188. if ((variable.kind == Param || variable.kind == Global || variable.kind == Input || variable.kind == Local) && !alreadyAddVariable(variable)) {
  189. allVariables.push(variable);
  190. }
  191. }
  192. var shaderParam = Std.downcast(node, ShaderParam);
  193. if (shaderParam != null && !alreadyAddVariable(shaderParam.variable)) {
  194. if (shaderParam.variable == null) {
  195. shaderParam.variable = generateParameter(shaderParam.variable.name, shaderParam.variable.type);
  196. }
  197. allVariables.push(shaderParam.variable);
  198. allParameters.push(shaderParam.variable);
  199. if (parametersAvailable.exists(shaderParam.parameterId))
  200. allParamDefaultValue.push(getParameter(shaderParam.parameterId).defaultValue);
  201. }
  202. if (isSubGraph) {
  203. var subGraph = Std.downcast(node, hrt.shgraph.nodes.SubGraph);
  204. var params = subGraph.subShaderGraph.parametersAvailable;
  205. for (subVar in subGraph.varsSubGraph) {
  206. if (subVar.kind == Param) {
  207. if (!alreadyAddVariable(subVar)) {
  208. allVariables.push(subVar);
  209. allParameters.push(subVar);
  210. var defaultValueFound = false;
  211. for (param in params) {
  212. if (param.variable.name == subVar.name) {
  213. allParamDefaultValue.push(param.defaultValue);
  214. defaultValueFound = true;
  215. break;
  216. }
  217. }
  218. if (!defaultValueFound) {
  219. throw ShaderException.t("Default value of '" + subVar.name + "' parameter not found", node.id);
  220. }
  221. }
  222. } else {
  223. if (!alreadyAddVariable(subVar)) {
  224. allVariables.push(subVar);
  225. }
  226. }
  227. }
  228. var buildWithoutTBlock = [];
  229. for (i in 0...build.length) {
  230. switch (build[i].e) {
  231. case TBlock(block):
  232. for (b in block) {
  233. buildWithoutTBlock.push(b);
  234. }
  235. default:
  236. buildWithoutTBlock.push(build[i]);
  237. }
  238. }
  239. build = buildWithoutTBlock;
  240. }
  241. res = res.concat(build);
  242. return res;
  243. }
  244. function alreadyAddVariable(variable : TVar) {
  245. for (v in allVariables) {
  246. if (v.name == variable.name && v.type == variable.type) {
  247. return true;
  248. }
  249. }
  250. return false;
  251. }
  252. var variableNameAvailableOnlyInVertex = [];
  253. public function generateShader(specificOutput : ShaderNode = null, subShaderId : Int = null) : ShaderData {
  254. allVariables = [];
  255. allParameters = [];
  256. allParamDefaultValue = [];
  257. var contentVertex = [];
  258. var contentFragment = [];
  259. for (n in nodes) {
  260. if (!variableNamesAlreadyUpdated && subShaderId != null && !Std.is(n.instance, ShaderInput)) {
  261. for (outputKey in n.instance.getOutputInfoKeys()) {
  262. var output = n.instance.getOutput(outputKey);
  263. if (output != null)
  264. output.name = "sub_" + subShaderId + "_" + output.name;
  265. }
  266. }
  267. n.instance.outputCompiled = [];
  268. #if !editor
  269. if (!n.instance.hasInputs()) {
  270. updateOutputs(n);
  271. }
  272. #end
  273. }
  274. variableNamesAlreadyUpdated = true;
  275. var outputs : Array<String> = [];
  276. for (g in ShaderGlobalInput.globalInputs) {
  277. allVariables.push(g);
  278. }
  279. for (n in nodes) {
  280. var outNode;
  281. var outVar;
  282. if (specificOutput != null) {
  283. if (n.instance != specificOutput) continue;
  284. outNode = specificOutput;
  285. outVar = Std.downcast(specificOutput, hrt.shgraph.nodes.Preview).variable;
  286. } else {
  287. var shaderOutput = Std.downcast(n.instance, ShaderOutput);
  288. if (shaderOutput != null) {
  289. outVar = shaderOutput.variable;
  290. outNode = n.instance;
  291. } else {
  292. continue;
  293. }
  294. }
  295. if (outNode != null) {
  296. if (outputs.indexOf(outVar.name) != -1) {
  297. throw ShaderException.t("This output already exists", n.id);
  298. }
  299. outputs.push(outVar.name);
  300. if ( !alreadyAddVariable(outVar) ) {
  301. allVariables.push(outVar);
  302. }
  303. var nodeVar = new NodeVar(outNode, "input");
  304. var isVertex = (variableNameAvailableOnlyInVertex.indexOf(outVar.name) != -1);
  305. if (isVertex) {
  306. contentVertex = contentVertex.concat(buildNodeVar(nodeVar));
  307. } else {
  308. contentFragment = contentFragment.concat(buildNodeVar(nodeVar));
  309. }
  310. if (specificOutput != null) break;
  311. }
  312. }
  313. var shvars = [];
  314. var inputVar : TVar = null, inputVars = [], inputMap = new Map();
  315. for( v in allVariables ) {
  316. if( v.id == 0 )
  317. v.id = hxsl.Tools.allocVarId();
  318. if( v.kind != Input ) {
  319. shvars.push(v);
  320. continue;
  321. }
  322. if( inputVar == null ) {
  323. inputVar = {
  324. id : hxsl.Tools.allocVarId(),
  325. name : "input",
  326. kind : Input,
  327. type : TStruct(inputVars),
  328. };
  329. shvars.push(inputVar);
  330. }
  331. var prevId = v.id;
  332. v = Reflect.copy(v);
  333. v.id = hxsl.Tools.allocVarId();
  334. v.parent = inputVar;
  335. inputVars.push(v);
  336. inputMap.set(prevId, v);
  337. }
  338. if( inputVars.length > 0 ) {
  339. function remap(e:TExpr) {
  340. return switch( e.e ) {
  341. case TVar(v):
  342. var i = inputMap.get(v.id);
  343. if( i == null ) e else { e : TVar(i), p : e.p, t : e.t };
  344. default:
  345. hxsl.Tools.map(e, remap);
  346. }
  347. }
  348. contentVertex = [for( e in contentVertex ) remap(e)];
  349. contentFragment = [for( e in contentFragment ) remap(e)];
  350. }
  351. var shaderData = {
  352. funs : [],
  353. name: "SHADER_GRAPH",
  354. vars: shvars
  355. };
  356. if (contentVertex.length > 0) {
  357. shaderData.funs.push({
  358. ret : TVoid, kind : Vertex,
  359. ref : {
  360. name : "vertex",
  361. id : 0,
  362. kind : Function,
  363. type : TFun([{ ret : TVoid, args : [] }])
  364. },
  365. expr : {
  366. p : null,
  367. t : TVoid,
  368. e : TBlock(contentVertex)
  369. },
  370. args : []
  371. });
  372. }
  373. if (contentFragment.length > 0) {
  374. shaderData.funs.push({
  375. ret : TVoid, kind : Fragment,
  376. ref : {
  377. name : "fragment",
  378. id : 0,
  379. kind : Function,
  380. type : TFun([{ ret : TVoid, args : [] }])
  381. },
  382. expr : {
  383. p : null,
  384. t : TVoid,
  385. e : TBlock(contentFragment)
  386. },
  387. args : []
  388. });
  389. }
  390. return shaderData;
  391. }
  392. public function compile(?specificOutput : ShaderNode, ?subShaderId : Int) : hrt.prefab.ContextShared.ShaderDef {
  393. var shaderData = generateShader(specificOutput, subShaderId);
  394. var s = new SharedShader("");
  395. s.data = shaderData;
  396. @:privateAccess s.initialize();
  397. var inits : Array<{ variable : hxsl.Ast.TVar, value : Dynamic }> = [];
  398. for (i in 0...allParameters.length) {
  399. inits.push({ variable : allParameters[i], value : allParamDefaultValue[i] });
  400. }
  401. var shaderDef = { shader : s, inits : inits };
  402. return shaderDef;
  403. }
  404. public function makeInstance(ctx: hrt.prefab.ContextShared) : hxsl.DynamicShader {
  405. var def = compile();
  406. var s = new hxsl.DynamicShader(def.shader);
  407. for (init in def.inits)
  408. setParamValue(ctx, s, init.variable, init.value);
  409. return s;
  410. }
  411. static function setParamValue(ctx: hrt.prefab.ContextShared, shader : hxsl.DynamicShader, variable : hxsl.Ast.TVar, value : Dynamic) {
  412. try {
  413. switch (variable.type) {
  414. case TSampler2D:
  415. var t = ctx.loadTexture(value);
  416. t.wrap = Repeat;
  417. shader.setParamValue(variable, t);
  418. case TVec(size, _):
  419. shader.setParamValue(variable, h3d.Vector.fromArray(value));
  420. default:
  421. shader.setParamValue(variable, value);
  422. }
  423. } catch (e : Dynamic) {
  424. // The parameter is not used
  425. }
  426. }
  427. #if editor
  428. public function addNode(x : Float, y : Float, nameClass : Class<ShaderNode>) {
  429. var node : Node = { x : x, y : y, id : current_node_id, type: std.Type.getClassName(nameClass) };
  430. node.instance = std.Type.createInstance(nameClass, []);
  431. node.instance.setId(current_node_id);
  432. node.instance.computeOutputs();
  433. node.outputs = [];
  434. this.nodes.set(node.id, node);
  435. current_node_id++;
  436. return node.instance;
  437. }
  438. public function hasCycle() : Bool {
  439. var queue : Array<Node> = [];
  440. var counter = 0;
  441. var nbNodes = 0;
  442. for (n in nodes) {
  443. n.indegree = n.outputs.length;
  444. if (n.indegree == 0) {
  445. queue.push(n);
  446. }
  447. nbNodes++;
  448. }
  449. var currentIndex = 0;
  450. while (currentIndex < queue.length) {
  451. var node = queue[currentIndex];
  452. currentIndex++;
  453. for (input in node.instance.getInputs()) {
  454. var nodeInput = nodes.get(input.node.id);
  455. nodeInput.indegree -= 1;
  456. if (nodeInput.indegree == 0) {
  457. queue.push(nodeInput);
  458. }
  459. }
  460. counter++;
  461. }
  462. return counter != nbNodes;
  463. }
  464. public function addParameter(type : Type) {
  465. var name = "Param_" + current_param_id;
  466. parametersAvailable.set(current_param_id, {id: current_param_id, name : name, type : type, defaultValue : null, variable : generateParameter(name, type)});
  467. current_param_id++;
  468. return current_param_id-1;
  469. }
  470. public function setParameterTitle(id : Int, newName : String) {
  471. var p = parametersAvailable.get(id);
  472. if (p != null) {
  473. if (newName != null) {
  474. for (p in parametersAvailable) {
  475. if (p.name == newName) {
  476. return false;
  477. }
  478. }
  479. p.name = newName;
  480. p.variable = generateParameter(newName, p.type);
  481. return true;
  482. }
  483. }
  484. return false;
  485. }
  486. public function setParameterDefaultValue(id : Int, newDefaultValue : Dynamic) : Bool {
  487. var p = parametersAvailable.get(id);
  488. if (p != null) {
  489. if (newDefaultValue != null) {
  490. p.defaultValue = newDefaultValue;
  491. return true;
  492. }
  493. }
  494. return false;
  495. }
  496. public function removeParameter(id : Int) {
  497. parametersAvailable.remove(id);
  498. }
  499. public function removeNode(idNode : Int) {
  500. this.nodes.remove(idNode);
  501. }
  502. public function save() {
  503. var edgesJson : Array<Edge> = [];
  504. for (n in nodes) {
  505. for (k in n.instance.getInputsKey()) {
  506. var output = n.instance.getInput(k);
  507. edgesJson.push({ idOutput: output.node.id, nameOutput: output.keyOutput, idInput: n.id, nameInput: k });
  508. }
  509. }
  510. var json = haxe.Json.stringify({
  511. nodes: [
  512. for (n in nodes) { x : Std.int(n.x), y : Std.int(n.y), id: n.id, type: n.type, properties : n.instance.savePropertiesNode() }
  513. ],
  514. edges: edgesJson,
  515. parameters: [
  516. for (p in parametersAvailable) { id : p.id, name : p.name, type : [p.type.getName(), p.type.getParameters().toString()], defaultValue : p.defaultValue }
  517. ]
  518. }, "\t");
  519. return json;
  520. }
  521. #end
  522. }