Eval.hx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. package hxsl;
  2. using hxsl.Ast;
  3. /**
  4. Evaluator : will substitute some variables (usually constants) by their runtime value and will
  5. evaluate and reduce the expression, unroll loops, etc.
  6. **/
  7. class Eval {
  8. public var varMap : Map<TVar,TVar>;
  9. public var inlineCalls : Bool;
  10. public var unrollLoops : Bool;
  11. public var eliminateConditionals : Bool;
  12. var constants : Map<Int,TExprDef>;
  13. var funMap : Map<TVar,TFunction>;
  14. var curFun : TFunction;
  15. var mapped : Array<TVar> = [];
  16. public function new() {
  17. varMap = new Map();
  18. funMap = new Map();
  19. constants = new Map();
  20. }
  21. public function setConstant( v : TVar, c : Const ) {
  22. constants.set(v.id, TConst(c));
  23. }
  24. function mapVar( v : TVar, isLocal : Bool ) {
  25. var v2 = varMap.get(v);
  26. if( v2 != null )
  27. return v2;
  28. if( v.parent != null ) {
  29. mapVar(v.parent, isLocal); // always map parent first
  30. v2 = varMap.get(v);
  31. if( v2 != null )
  32. return v == v2 ? v2 : mapVar(v2, isLocal);
  33. }
  34. v2 = {
  35. id : v.type.match(TChannel(_)) ? v.id : Tools.allocVarId(),
  36. name : v.name,
  37. type : v.type,
  38. kind : v.kind,
  39. };
  40. if( v.parent != null ) v2.parent = mapVar(v.parent, isLocal);
  41. if( v.qualifiers != null ) v2.qualifiers = v.qualifiers.copy();
  42. varMap.set(v, v2);
  43. varMap.set(v2, v2); // make it safe to have multiple eval
  44. if (isLocal)
  45. mapped.push(v);
  46. switch( v2.type ) {
  47. case TStruct(vl):
  48. v2.type = TStruct([for( v in vl ) mapVar(v, isLocal)]);
  49. case TArray(t, SVar(vs)), TBuffer(t, SVar(vs), _):
  50. var c = constants.get(vs.id);
  51. if( c != null )
  52. switch( c ) {
  53. case TConst(CInt(v)):
  54. v2.type = switch( v2.type ) {
  55. case TArray(_): TArray(t, SConst(v));
  56. case TBuffer(_,_,kind): TBuffer(t, SConst(v), kind);
  57. default: throw "assert";
  58. };
  59. default:
  60. Error.t("Integer value expected for array size constant " + vs.name, null);
  61. }
  62. else {
  63. var vs2 = mapVar(vs, isLocal);
  64. v2.type = switch( v2.type ) {
  65. case TArray(_): TArray(t, SVar(vs2));
  66. case TBuffer(_,_,kind): TBuffer(t, SVar(vs2), kind);
  67. default: throw "assert";
  68. }
  69. }
  70. default:
  71. }
  72. return v2;
  73. }
  74. function checkTextureRec(t:Type) {
  75. if( t.isTexture() )
  76. return true;
  77. switch( t ) {
  78. case TStruct(vl):
  79. for( v in vl )
  80. if( checkTextureRec(v.type) )
  81. return true;
  82. return false;
  83. case TArray(t, _):
  84. return checkTextureRec(t);
  85. case TBuffer(_):
  86. return true;
  87. default:
  88. }
  89. return false;
  90. }
  91. function needsInline(f:TFunction) {
  92. for( a in f.args )
  93. if( checkTextureRec(a.type) )
  94. return true;
  95. return false;
  96. }
  97. public function eval( s : ShaderData ) : ShaderData {
  98. var funs = [];
  99. for( f in s.funs ) {
  100. var f2 : TFunction = {
  101. kind : f.kind,
  102. ref : mapVar(f.ref, false),
  103. args : [for( a in f.args ) mapVar(a, false)],
  104. ret : f.ret,
  105. expr : f.expr,
  106. };
  107. if( (f.kind == Helper && inlineCalls) || needsInline(f2) )
  108. funMap.set(f2.ref, f);
  109. else
  110. funs.push(f2);
  111. }
  112. for( i in 0...funs.length ) {
  113. curFun = funs[i];
  114. curFun.expr = evalExpr(curFun.expr,false);
  115. }
  116. return {
  117. name : s.name,
  118. vars : [for( v in s.vars ) mapVar(v, false)],
  119. funs : funs,
  120. };
  121. }
  122. var markReturn : Bool;
  123. function hasReturn( e : TExpr ) {
  124. markReturn = false;
  125. hasReturnLoop(e);
  126. return markReturn;
  127. }
  128. function hasReturnLoop( e : TExpr ) {
  129. switch( e.e ) {
  130. case TReturn(_):
  131. markReturn = true;
  132. default:
  133. if( !markReturn ) e.iter(hasReturnLoop);
  134. }
  135. }
  136. function handleReturn( e : TExpr, isFinal : Bool = false ) : TExpr {
  137. switch( e.e ) {
  138. case TReturn(v):
  139. if( !isFinal )
  140. Error.t("Cannot inline not final return", e.p);
  141. if( v == null )
  142. return { e : TBlock([]), t : TVoid, p : e.p };
  143. return handleReturn(v, true);
  144. case TBlock(el):
  145. var i = 0, last = el.length;
  146. var out = [];
  147. while( i < last ) {
  148. var e = el[i++];
  149. if( i == last )
  150. out.push(handleReturn(e, isFinal));
  151. else switch( e.e ) {
  152. case TIf(econd, eif, null) if( isFinal && hasReturn(eif) ):
  153. out.push(handleReturn( { e : TIf(econd, eif, { e : TBlock(el.slice(i)), t : e.t, p : e.p } ), t : e.t, p : e.p } ));
  154. break;
  155. case TReturn(e):
  156. out.push(handleReturn(e, isFinal));
  157. break;
  158. default:
  159. out.push(handleReturn(e));
  160. }
  161. }
  162. var t = if( isFinal ) (out.length == 0 ? TVoid : out[out.length - 1].t) else e.t;
  163. return { e : TBlock(out), t : t, p : e.p };
  164. case TParenthesis(v):
  165. var v = handleReturn(v, isFinal);
  166. return { e : TParenthesis(v), t : v.t, p : e.p };
  167. case TIf(cond, eif, eelse) if( eelse != null && isFinal ):
  168. var cond = handleReturn(cond);
  169. var eif = handleReturn(eif, isFinal);
  170. return { e : TIf(cond, eif, handleReturn(eelse, isFinal)), t : eif.t, p : e.p };
  171. default:
  172. return e.map(handleReturnDef);
  173. }
  174. }
  175. function handleReturnDef(e) {
  176. return handleReturn(e);
  177. }
  178. function evalCall( g : TGlobal, args : Array<TExpr>, oldArgs : Array<TExpr>, pos : Position ) {
  179. return switch( [g,args] ) {
  180. case [ToFloat, [ { e : TConst(CInt(i)) } ]]: TConst(CFloat(i));
  181. case [Trace, args]:
  182. for( a in args )
  183. haxe.Log.trace(Printer.toString(a), { fileName : #if macro haxe.macro.Context.getPosInfos(a.p).file #else a.p.file #end, lineNumber : 0, className : null, methodName : null });
  184. TBlock([]);
  185. case [Length, [{ e : TVar(v) }]]:
  186. switch( v.type ) {
  187. case TArray(_, SConst(v)):
  188. TConst(CInt(v));
  189. default:
  190. null;
  191. }
  192. case [ChannelRead|ChannelReadLod, _]:
  193. var i = switch( args[0].e ) { case TConst(CInt(i)): i; default: Error.t("Cannot eval complex channel " + Printer.toString(args[0],true)+" "+constantsToString(), pos); throw "assert"; };
  194. var channel = oldArgs[0];
  195. channel = { e : switch( channel.e ) {
  196. case TVar(v): TVar(mapVar(v, false));
  197. default: throw "assert";
  198. }, t : channel.t, p : channel.p };
  199. var count = switch( channel.t ) { case TChannel(i): i; default: throw "assert"; };
  200. var channelMode = hxsl.Channel.createByIndex(i & 7);
  201. var targs = [channel];
  202. for( i in 1...args.length )
  203. targs.push(args[i]);
  204. targs.push({ e : TConst(CInt(i >> 3)), t : TInt, p : pos });
  205. var tget = {
  206. e : TCall({ e : TGlobal(g), t : TVoid, p : pos }, targs),
  207. t : TVoid,
  208. p : pos,
  209. };
  210. switch( channelMode ) {
  211. case R, G, B, A:
  212. return TSwiz(tget, switch( [count,channelMode] ) {
  213. case [1,R]: [X];
  214. case [1,G]: [Y];
  215. case [1,B]: [Z];
  216. case [1,A]: [W];
  217. case [2,R]: [X,Y];
  218. case [2,G]: [Y,Z];
  219. case [2,B]: [Z,W];
  220. case [3,R]: [X,Y,Z];
  221. case [3,G]: [Y,Z,W];
  222. default: throw "Invalid channel value "+channelMode+" for "+count+" channels";
  223. });
  224. case Unknown:
  225. var zero = { e : TConst(CFloat(0.)), t : TFloat, p : pos };
  226. if( count == 1 )
  227. return zero.e;
  228. return TCall({ e : TGlobal([Vec2, Vec3, Vec4][count - 2]), t : TVoid, p : pos }, [zero]);
  229. case PackedFloat:
  230. return TCall({ e : TGlobal(Unpack), t:TVoid, p:pos}, [tget]);
  231. case PackedNormal:
  232. return TCall({ e : TGlobal(UnpackNormal), t:TVoid, p:pos}, [tget]);
  233. }
  234. default: null;
  235. }
  236. }
  237. function constantsToString() {
  238. return [for( c in constants.keys() ) c + " => " + Printer.toString({ e : constants.get(c), t : TVoid, p : null }, true)].toString();
  239. }
  240. function ifBlock( e : TExpr ) {
  241. if( e == null || !e.e.match(TIf(_)) )
  242. return e;
  243. return { e : TBlock([e]), t : e.t, p : e.p };
  244. }
  245. function evalExpr( e : TExpr, isVal = true ) : TExpr {
  246. var t = e.t;
  247. var d : TExprDef = switch( e.e ) {
  248. case TGlobal(_), TConst(_): e.e;
  249. case TVar(v):
  250. var c = constants.get(v.id);
  251. if( c != null )
  252. c;
  253. else {
  254. var v2 = mapVar(v, false);
  255. t = v2.type;
  256. TVar(v2);
  257. }
  258. case TVarDecl(v, init):
  259. TVarDecl(mapVar(v, true), init == null ? null : evalExpr(init));
  260. case TArray(e1, e2):
  261. var e1 = evalExpr(e1);
  262. var e2 = evalExpr(e2);
  263. switch( [e1.e, e2.e] ) {
  264. case [TArrayDecl(el),TConst(CInt(i))] if( i >= 0 && i < el.length ):
  265. el[i].e;
  266. default:
  267. switch( e1.t ) {
  268. case TArray(at, _), TBuffer(at,_,_): t = at;
  269. default:
  270. }
  271. TArray(e1, e2);
  272. }
  273. case TSwiz(e, r):
  274. TSwiz(evalExpr(e), r.copy());
  275. case TReturn(e):
  276. TReturn(e == null ? null : evalExpr(e));
  277. case TCall(c, eargs):
  278. var c = evalExpr(c);
  279. var args = [for( a in eargs ) evalExpr(a)];
  280. switch( c.e ) {
  281. case TGlobal(g):
  282. var v = evalCall(g, args, eargs, e.p);
  283. if( v != null ) v else TCall(c, args);
  284. case TVar(v) if( funMap.exists(v) ):
  285. // inline the function call
  286. var f = funMap.get(v);
  287. var outExprs = [], undo = [];
  288. for( i in 0...f.args.length ) {
  289. var v = f.args[i];
  290. var e = args[i];
  291. switch( e.e ) {
  292. case TConst(_), TVar({ kind : (Input|Param|Global) }):
  293. var old = constants.get(v.id);
  294. undo.push(function() old == null ? constants.remove(v.id) : constants.set(v.id, old));
  295. constants.set(v.id, e.e);
  296. default:
  297. var old = varMap.get(v);
  298. if( old == null )
  299. undo.push(function() varMap.remove(v));
  300. else {
  301. varMap.remove(v);
  302. undo.push(function() varMap.set(v, old));
  303. }
  304. var v2 = mapVar(v, false);
  305. outExprs.push( { e : TVarDecl(v2, e), t : TVoid, p : e.p } );
  306. }
  307. }
  308. var e = handleReturn(evalExpr(f.expr,false), true);
  309. for( u in undo ) u();
  310. switch( e.e ) {
  311. case TBlock(el):
  312. for( e in el )
  313. outExprs.push(e);
  314. default:
  315. outExprs.push(e);
  316. }
  317. TBlock(outExprs);
  318. case TVar(_):
  319. TCall(c, args);
  320. default:
  321. Error.t("Cannot eval non-static call expresssion '" + new Printer().exprString(c)+"'", c.p);
  322. }
  323. case TBlock(el):
  324. var index = mapped.length;
  325. var out = [];
  326. var last = el.length - 1;
  327. for( i in 0...el.length ) {
  328. var isVal = isVal && i == last;
  329. var e = evalExpr(el[i], isVal);
  330. switch( e.e ) {
  331. case TConst(_), TVar(_) if( !isVal ):
  332. default:
  333. out.push(e);
  334. }
  335. }
  336. // unmap previous vars
  337. while( mapped.length > index ) {
  338. var v = mapped.pop();
  339. var v2 = varMap.get(v);
  340. if (v2 != null ) {
  341. varMap.remove(v);
  342. varMap.remove(v2);
  343. }
  344. }
  345. if( out.length == 1 && curFun.kind != Init )
  346. out[0].e
  347. else
  348. TBlock(out);
  349. case TBinop(op, e1, e2):
  350. var e1 = evalExpr(e1);
  351. var e2 = evalExpr(e2);
  352. inline function fop(callb:Float->Float->Float) {
  353. return switch( [e1.e, e2.e] ) {
  354. case [TConst(CInt(a)), TConst(CInt(b))]:
  355. TConst(CInt(Std.int(callb(a, b))));
  356. case [TConst(CFloat(a)), TConst(CFloat(b))]:
  357. TConst(CFloat(callb(a, b)));
  358. default:
  359. TBinop(op, e1, e2);
  360. }
  361. }
  362. inline function iop(callb:Int->Int->Int) {
  363. return switch( [e1.e, e2.e] ) {
  364. case [TConst(CInt(a)), TConst(CInt(b))]:
  365. TConst(CInt(callb(a, b)));
  366. default:
  367. TBinop(op, e1, e2);
  368. }
  369. }
  370. inline function bop(callb:Bool->Bool->Bool,def) {
  371. return switch( [e1.e, e2.e] ) {
  372. case [TConst(CBool(a)), TConst(CBool(b))]:
  373. TConst(CBool(callb(a, b)));
  374. case [TConst(CBool(a)), _]:
  375. if( a == def )
  376. TConst(CBool(a));
  377. else
  378. e2.e;
  379. case [_, TConst(CBool(a))]:
  380. if( a == def )
  381. TConst(CBool(a)); // ignore e1 side effects ?
  382. else
  383. e1.e;
  384. default:
  385. TBinop(op, e1, e2);
  386. }
  387. }
  388. inline function compare(callb:Int->Bool) {
  389. return switch( [e1.e, e2.e] ) {
  390. case [TConst(CNull), TConst(CNull)]:
  391. TConst(CBool(callb(0)));
  392. case [TConst(_), TConst(CNull)]:
  393. TConst(CBool(callb(1)));
  394. case [TConst(CNull), TConst(_)]:
  395. TConst(CBool(callb(-1)));
  396. case [TConst(CBool(a)), TConst(CBool(b))]:
  397. TConst(CBool(callb(a == b ? 0 : 1)));
  398. case [TConst(CInt(a)), TConst(CInt(b))]:
  399. TConst(CBool(callb(a - b)));
  400. case [TConst(CFloat(a)), TConst(CFloat(b))]:
  401. TConst(CBool(callb(a > b ? 1 : (a == b) ? 0 : -1)));
  402. case [TConst(CString(a)), TConst(CString(b))]:
  403. TConst(CBool(callb(a > b ? 1 : (a == b) ? 0 : -1)));
  404. default:
  405. TBinop(op, e1, e2);
  406. }
  407. }
  408. switch( op ) {
  409. case OpAdd: fop(function(a, b) return a + b);
  410. case OpSub: fop(function(a, b) return a - b);
  411. case OpMult: fop(function(a, b) return a * b);
  412. case OpDiv: fop(function(a, b) return a / b);
  413. case OpMod: fop(function(a, b) return a % b);
  414. case OpXor: iop(function(a, b) return a ^ b);
  415. case OpOr: iop(function(a, b) return a | b);
  416. case OpAnd: iop(function(a, b) return a & b);
  417. case OpShr: iop(function(a, b) return a >> b);
  418. case OpUShr: iop(function(a, b) return a >>> b);
  419. case OpShl: iop(function(a, b) return a << b);
  420. case OpBoolAnd: bop(function(a, b) return a && b, false);
  421. case OpBoolOr: bop(function(a, b) return a || b, true);
  422. case OpEq: compare(function(x) return x == 0);
  423. case OpNotEq: compare(function(x) return x != 0);
  424. case OpGt: compare(function(x) return x > 0);
  425. case OpGte: compare(function(x) return x >= 0);
  426. case OpLt: compare(function(x) return x < 0);
  427. case OpLte: compare(function(x) return x <= 0);
  428. case OpInterval, OpAssign, OpAssignOp(_): TBinop(op, e1, e2);
  429. default: throw "assert";
  430. }
  431. case TUnop(op, e):
  432. var e = evalExpr(e);
  433. switch( e.e ) {
  434. case TConst(c):
  435. switch( [op, c] ) {
  436. case [OpNot, CBool(b)]: TConst(CBool(!b));
  437. case [OpNeg, CInt(i)]: TConst(CInt( -i));
  438. case [OpNeg, CFloat(f)]: TConst(CFloat( -f));
  439. default:
  440. TUnop(op, e);
  441. }
  442. default:
  443. TUnop(op, e);
  444. }
  445. case TParenthesis(e):
  446. var e = evalExpr(e, isVal);
  447. switch( e.e ) {
  448. case TConst(_): e.e;
  449. default: TParenthesis(e);
  450. }
  451. case TIf(econd, eif, eelse):
  452. var econd = evalExpr(econd);
  453. switch( econd.e ) {
  454. case TConst(CBool(b)): b ? evalExpr(eif, isVal).e : eelse == null ? TConst(CNull) : evalExpr(eelse, isVal).e;
  455. default:
  456. if( isVal && eelse != null && eliminateConditionals )
  457. TCall( { e : TGlobal(Mix), t : e.t, p : e.p }, [evalExpr(eelse,true), evalExpr(eif,true), { e : TCall( { e : TGlobal(ToFloat), t : TFun([]), p : econd.p }, [econd]), t : TFloat, p : e.p } ]);
  458. else {
  459. eif = evalExpr(eif, isVal);
  460. if( eelse != null ) {
  461. eelse = evalExpr(eelse,isVal);
  462. if( eelse.e.match(TConst(CNull)) ) eelse = null;
  463. }
  464. eif = ifBlock(eif);
  465. eelse = ifBlock(eelse);
  466. TIf(econd, eif, eelse);
  467. }
  468. }
  469. case TBreak:
  470. TBreak;
  471. case TContinue:
  472. TContinue;
  473. case TDiscard:
  474. TDiscard;
  475. case TFor(v, it, loop):
  476. var v2 = mapVar(v, true);
  477. var it = evalExpr(it);
  478. var e = switch( it.e ) {
  479. case TBinop(OpInterval, { e : TConst(CInt(start)) }, { e : TConst(CInt(len)) } ) if( unrollLoops ):
  480. var out = [];
  481. for( i in start...len ) {
  482. constants.set(v.id, TConst(CInt(i)));
  483. out.push(evalExpr(loop,false));
  484. }
  485. constants.remove(v.id);
  486. TBlock(out);
  487. default:
  488. TFor(v2, it, ifBlock(evalExpr(loop,false)));
  489. }
  490. varMap.remove(v);
  491. e;
  492. case TWhile(cond, loop, normalWhile):
  493. var cond = evalExpr(cond);
  494. var loop = evalExpr(loop, false);
  495. TWhile(cond, ifBlock(loop), normalWhile);
  496. case TSwitch(e, cases, def):
  497. var e = evalExpr(e);
  498. var cases = [for( c in cases ) { values : [for( v in c.values ) evalExpr(v)], expr : evalExpr(c.expr, isVal) }];
  499. var def = def == null ? null : evalExpr(def, isVal);
  500. var hasCase = false;
  501. switch( e.e ) {
  502. case TConst(c):
  503. switch( c ) {
  504. case CInt(val):
  505. for( c in cases ) {
  506. for( v in c.values )
  507. switch( v.e ) {
  508. case TConst(cst):
  509. switch( cst ) {
  510. case CInt(k) if( k == val ): return c.expr;
  511. case CFloat(k) if( k == val ): return c.expr;
  512. default:
  513. }
  514. default:
  515. hasCase = true;
  516. }
  517. }
  518. default:
  519. throw "Unsupported switch constant "+c;
  520. }
  521. default:
  522. hasCase = true;
  523. }
  524. if( hasCase )
  525. TSwitch(e, cases, def);
  526. else if( def == null )
  527. TBlock([]);
  528. else
  529. def.e;
  530. case TArrayDecl(el):
  531. TArrayDecl([for( e in el ) evalExpr(e)]);
  532. case TMeta(name, args, e):
  533. var e2;
  534. switch( name ) {
  535. case "unroll":
  536. var old = unrollLoops;
  537. unrollLoops = true;
  538. e2 = evalExpr(e, isVal);
  539. unrollLoops = false;
  540. default:
  541. e2 = evalExpr(e, isVal);
  542. }
  543. TMeta(name, args, e2);
  544. case TField(e, name):
  545. TField(evalExpr(e), name);
  546. case TSyntax(target, code, args):
  547. TSyntax(target, code, [for ( arg in args ) ({ e : evalExpr(arg.e), access : arg.access })]);
  548. };
  549. return { e : d, t : t, p : e.p }
  550. }
  551. }