NodeGenContext.hx 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. package hrt.shgraph;
  2. using hxsl.Ast;
  3. using Lambda;
  4. using hrt.shgraph.Utils;
  5. using hrt.tools.MapUtils;
  6. import hrt.shgraph.AstTools.*;
  7. import hrt.shgraph.ShaderGraph;
  8. import hrt.shgraph.ShaderNode;
  9. class NodeGenContextSubGraph extends NodeGenContext {
  10. public function new(parentCtx : NodeGenContext) {
  11. super(parentCtx?.domain ?? Fragment);
  12. this.parentCtx = parentCtx;
  13. }
  14. override function getGlobalInput(id: Variables.Global) : TExpr {
  15. var global = Variables.Globals[id];
  16. var info = globalInVars.getOrPut(Variables.getFullPath(global), {type: global.type, id: inputCount++});
  17. return parentCtx?.nodeInputExprs[info.id] ?? parentCtx?.getGlobalInput(id) ?? super.getGlobalInput(id);
  18. }
  19. override function getGlobalTVar(tvar: TVar) : TExpr {
  20. if (parentCtx != null) {
  21. return parentCtx.getGlobalTVar(tvar);
  22. } else {
  23. return super.getGlobalTVar(tvar);
  24. }
  25. }
  26. override function setGlobalOutput(id: Variables.Global, expr: TExpr) : Void {
  27. var global = Variables.Globals[id];
  28. if (outputCount == 0 && parentCtx != null) {
  29. parentCtx.addPreview(expr);
  30. }
  31. var info = globalOutVars.getOrPut(Variables.getFullPath(global), {type: global.type, id: outputCount ++});
  32. if (parentCtx != null) {
  33. parentCtx.setOutput(info.id, expr);
  34. } else {
  35. super.setGlobalOutput(id, expr);
  36. }
  37. }
  38. override function getGlobalParam(name: String, type: Type) : TExpr {
  39. var info = globalInVars.getOrPut(name, {type: type, id: inputCount ++});
  40. return parentCtx?.nodeInputExprs[info.id] ?? parentCtx?.getGlobalParam(name, type) ?? super.getGlobalParam(name, type);
  41. }
  42. override function setGlobalCustomOutput(name: String, expr: TExpr) : Void {
  43. if (outputCount == 0 && parentCtx != null) {
  44. parentCtx.addPreview(expr);
  45. }
  46. var info = globalOutVars.getOrPut(name, {type : expr.t, id: outputCount ++});
  47. if (parentCtx != null) {
  48. parentCtx.setOutput(info.id, expr);
  49. } else {
  50. super.setGlobalCustomOutput(name, expr);
  51. }
  52. }
  53. override function addExpr(expr: TExpr) : Void {
  54. if (parentCtx != null) {
  55. parentCtx.addExpr(expr);
  56. }
  57. }
  58. var parentCtx : NodeGenContext;
  59. var globalInVars: Map<String, {type: Type, id: Int}> = [];
  60. var globalOutVars: Map<String, {type: Type, id: Int}> = [];
  61. var inputCount = 0;
  62. var outputCount = 0;
  63. }
  64. @:allow(hrt.shgraph.ShaderGraph)
  65. class NodeGenContext {
  66. // Pour les rares nodes qui ont besoin de differencier entre vertex et fragment
  67. public var domain : ShaderGraph.Domain;
  68. public var previewDomain: ShaderGraph.Domain = null;
  69. public function new(domain: ShaderGraph.Domain) {
  70. this.domain = domain;
  71. }
  72. // For general input/output of the shader graph. Allocate a new global var if name is not found,
  73. // else return the previously allocated variable and assert that v.type == type and devValue == v.defValue
  74. public function getGlobalInput(id: Variables.Global) : TExpr {
  75. var global = Variables.Globals[id];
  76. switch (global.varkind) {
  77. case KVar(_,_,_):
  78. var v = getOrAllocateGlobal(id);
  79. return makeVar(v);
  80. case KSwizzle(id, swiz):
  81. var v = getOrAllocateGlobal(id);
  82. return makeSwizzle(makeVar(v), swiz);
  83. }
  84. }
  85. public function getGlobalTVar(tvar: TVar) : TExpr {
  86. return makeVar(getOrAllocateFromTVar(tvar));
  87. }
  88. public function setGlobalOutput(id: Variables.Global, expr: TExpr) : Void {
  89. var global = Variables.Globals[id];
  90. switch (global.varkind) {
  91. case KVar(_,_,_):
  92. var v = getOrAllocateGlobal(id);
  93. expressions.push(makeAssign(makeVar(v), expr));
  94. case KSwizzle(otherId, swiz):
  95. var v = getOrAllocateGlobal(otherId);
  96. expressions.push(makeAssign(makeSwizzle(makeVar(v), swiz), expr));
  97. }
  98. }
  99. public function getGlobalParam(name: String, type: Type) : TExpr {
  100. return makeVar(globalVars.getOrPut(name, {v: {id: hxsl.Tools.allocVarId(), name: name, type: type, kind: Param}, defValue:null, __init__: null}).v);
  101. }
  102. public function setGlobalCustomOutput(name: String, expr: TExpr) : Void {
  103. var v = makeVar(globalVars.getOrPut(name, {v: {id: hxsl.Tools.allocVarId(), name: name, type: expr.t, kind: Param}, defValue:null, __init__: null}).v);
  104. expressions.push(makeAssign(v, expr));
  105. }
  106. function getOrAllocateFromTVar(tvar: TVar) : TVar {
  107. var fullName = AstTools.getFullName(tvar);
  108. // special case handling for normal because it gets replaced in the preview shader
  109. if (fullName == "input.normal")
  110. return getOrAllocateGlobal(Normal);
  111. var def = globalVars.get(fullName);
  112. if (def != null) {
  113. return def.v;
  114. }
  115. var type = tvar.type;
  116. switch (type) {
  117. case TStruct(_):
  118. type = TStruct([]);
  119. default:
  120. }
  121. var v : TVar = {id: hxsl.Tools.allocVarId(), name: tvar.name, type: type, kind: tvar.kind, qualifiers: tvar.qualifiers};
  122. def = {v:v, defValue: null, __init__: null};
  123. if (tvar.parent != null) {
  124. v.parent = getOrAllocateFromTVar(tvar.parent);
  125. switch(v.parent.type) {
  126. case TStruct(arr):
  127. arr.push(v);
  128. default: throw "parent must be a TStruct";
  129. }
  130. }
  131. globalVars.set(fullName, def);
  132. return def.v;
  133. }
  134. function getOrAllocateGlobal(id: Variables.Global) : TVar {
  135. // Remap id for certains variables
  136. switch (id) {
  137. case Normal if (previewDomain == domain):
  138. id = FakeNormal;
  139. default:
  140. }
  141. var global = Variables.Globals[id];
  142. switch (global.varkind)
  143. {
  144. case KVar(kind, parent, defValue):
  145. var fullName = Variables.getFullPath(global);
  146. var def : ShaderGraph.ExternVarDef = globalVars.get(fullName);
  147. if (def == null) {
  148. var v : TVar = {id: hxsl.Tools.allocVarId(), name: global.name, type: global.type, kind: kind};
  149. var __init__ = null;
  150. if (global.__init__ != null) {
  151. __init__ = AstTools.makeAssign(AstTools.makeVar(v), global.__init__);
  152. }
  153. def = {v: v, defValue: defValue, __init__: __init__};
  154. if (parent != null) {
  155. var p = Variables.Globals[parent];
  156. switch (p.varkind) {
  157. case KVar(kind, _, _):
  158. v.parent = globalVars.getOrPut(Variables.getFullPath(p), {v : {id : hxsl.Tools.allocVarId(), name: p.name, type: TStruct([]), kind: kind}, defValue: null, __init__: null}).v;
  159. default:
  160. throw "Parent var must be a KVar";
  161. }
  162. switch(v.parent.type) {
  163. case TStruct(arr):
  164. arr.push(v);
  165. default: throw "parent must be a TStruct";
  166. }
  167. }
  168. // Post process certain variables
  169. switch (id) {
  170. case CalculatedUV:
  171. var uv = getOrAllocateGlobal(UV);
  172. var expr = makeAssign(makeVar(v), makeVar(uv));
  173. def.__init__ = expr;
  174. default:
  175. }
  176. globalVars.set(fullName, def);
  177. }
  178. return def.v;
  179. default: throw "id must be a global Var";
  180. }
  181. }
  182. // Generate a preview block that displays the content of expr
  183. // in the preview box of the node. Expr must be a type that
  184. // can be casted a Vec4
  185. public function addPreview(expr: TExpr) {
  186. if (previewDomain != domain) return;
  187. var selector = getGlobalInput(PreviewSelect);
  188. var outputColor = getOrAllocateGlobal(PreviewColor);
  189. var previewExpr = makeAssign(makeVar(outputColor), convertToType(TVec(4, VFloat), expr));
  190. var ifExpr = makeIf(makeEq(selector, makeInt(currentPreviewId)), previewExpr);
  191. preview = ifExpr;
  192. }
  193. public static function convertToType(targetType: hxsl.Ast.Type, sourceExpr: TExpr) : TExpr {
  194. if (sourceExpr.t.equals(targetType))
  195. return sourceExpr;
  196. if (sourceExpr.t.match(TBool)) {
  197. sourceExpr = makeIf(sourceExpr, makeFloat(1.0), makeFloat(0.0), null, TFloat);
  198. }
  199. var sourceSize = switch (sourceExpr.t) {
  200. case TFloat: 1;
  201. case TVec(size, VFloat): size;
  202. default:
  203. throw "Unsupported source type " + sourceExpr.t;
  204. }
  205. var targetSize = switch (targetType) {
  206. case TFloat: 1;
  207. case TVec(size, VFloat): size;
  208. default:
  209. throw "Unsupported target type " + targetType;
  210. }
  211. var delta = targetSize - sourceSize;
  212. if (delta == 0)
  213. return sourceExpr;
  214. if (delta > 0) {
  215. var args = [];
  216. if (sourceSize == 1) {
  217. for (i in 0...targetSize) {
  218. args.push(sourceExpr);
  219. }
  220. }
  221. else {
  222. args.push(sourceExpr);
  223. for (i in 0...delta) {
  224. // Set alpha to 1.0 by default on upcasts casts
  225. var value = ((sourceSize + i) == 3) ? 1.0 : 0.0;
  226. args.push({e : TConst(CFloat(value)), p: sourceExpr.p, t: TFloat});
  227. }
  228. }
  229. var global : TGlobal = switch (targetSize) {
  230. case 2: Vec2;
  231. case 3: Vec3;
  232. case 4: Vec4;
  233. default: throw "unreachable";
  234. }
  235. return {e: TCall({e: TGlobal(global), p: sourceExpr.p, t:targetType}, args), p: sourceExpr.p, t: targetType};
  236. }
  237. if (delta < 0) {
  238. var swizz : Array<hxsl.Ast.Component> = [X,Y,Z,W];
  239. swizz.resize(targetSize);
  240. return {e: TSwiz(sourceExpr, swizz), p: sourceExpr.p, t: targetType};
  241. }
  242. throw "unreachable";
  243. }
  244. public function addExpr(e: TExpr) {
  245. expressions.push(e);
  246. }
  247. public function setOutput(id: Int, e: TExpr) {
  248. var expectedType = getType(nodeOutputInfo[id].type);
  249. if (!expectedType.equals(e.t))
  250. throw "Output " + id + " has different type than declared";
  251. outputs[id]=e;
  252. }
  253. public function getType(type: SgType) : Type {
  254. switch (type) {
  255. case SgGeneric(id, consDtraint):
  256. return getGenericType(id);
  257. default:
  258. return inline sgTypeToType(type);
  259. }
  260. }
  261. public inline function getGenericType(id: Int) {
  262. return genericTypes[id];
  263. }
  264. public function getInput(id: Int, ?defValue: SgHxslVar.ShaderDefInput) : Null<TExpr> {
  265. var input = nodeInputExprs[id];
  266. var inputType = getType(nodeInputInfo[id].type);
  267. if (input != null) {
  268. return convertToType(inputType, input);
  269. }
  270. if (defValue != null) {
  271. switch(defValue) {
  272. case Const(f):
  273. return convertToType(inputType, makeFloat(f));
  274. default:
  275. throw "def value not handled yet";
  276. }
  277. }
  278. return null;
  279. }
  280. /**
  281. API used by ShaderGraphGenContext
  282. **/
  283. function initForNode(node: ShaderNode, nodeInputExprs: Array<TExpr>) {
  284. nodeInputInfo = node.getInputs();
  285. nodeOutputInfo = node.getOutputs();
  286. this.node = node;
  287. this.nodeInputExprs = nodeInputExprs;
  288. outputs.resize(0);
  289. genericTypes.resize(0);
  290. preview = null;
  291. for (inputId => input in nodeInputInfo) {
  292. switch(input.type) {
  293. case SgGeneric(id, constraint):
  294. genericTypes[id] = constraint(nodeInputExprs[inputId]?.t, genericTypes[id]);
  295. default:
  296. }
  297. }
  298. currentPreviewId = node.id + 1;
  299. }
  300. function finishNode() {
  301. if (nodeOutputInfo.length != outputs.length) {
  302. throw "Missing outputs for node";
  303. }
  304. if (preview != null) {
  305. addExpr(preview);
  306. }
  307. }
  308. var node : ShaderNode = null;
  309. var currentPreviewId: Int = -1;
  310. var expressions: Array<TExpr> = [];
  311. var outputs: Array<TExpr> = [];
  312. var preview : TExpr = null;
  313. var nodeOutputInfo: Array<OutputInfo>;
  314. var genericTypes: Array<Type> = [];
  315. var nodeInputExprs : Array<TExpr>;
  316. var nodeInputInfo : Array<InputInfo>;
  317. var globalVars: Map<String, ShaderGraph.ExternVarDef> = [];
  318. }