fold.c 11 KB


  1. #include "all.h"
  2. enum {
  3. Bot = -1, /* lattice bottom */
  4. Top = 0, /* lattice top (matches UNDEF) */
  5. };
  6. typedef struct Edge Edge;
  7. struct Edge {
  8. int dest;
  9. int dead;
  10. Edge *work;
  11. };
  12. static int *val;
  13. static Edge *flowrk, (*edge)[2];
  14. static Use **usewrk;
  15. static uint nuse;
  16. static int
  17. iscon(Con *c, int w, uint64_t k)
  18. {
  19. if (c->type != CBits)
  20. return 0;
  21. if (w)
  22. return (uint64_t)c->bits.i == k;
  23. else
  24. return (uint32_t)c->bits.i == (uint32_t)k;
  25. }
  26. static int
  27. latval(Ref r)
  28. {
  29. switch (rtype(r)) {
  30. case RTmp:
  31. return val[r.val];
  32. case RCon:
  33. return r.val;
  34. default:
  35. die("unreachable");
  36. }
  37. }
  38. static int
  39. latmerge(int v, int m)
  40. {
  41. return m == Top ? v : (v == Top || v == m) ? m : Bot;
  42. }
  43. static void
  44. update(int t, int m, Fn *fn)
  45. {
  46. Tmp *tmp;
  47. uint u;
  48. m = latmerge(val[t], m);
  49. if (m != val[t]) {
  50. tmp = &fn->tmp[t];
  51. for (u=0; u<tmp->nuse; u++) {
  52. vgrow(&usewrk, ++nuse);
  53. usewrk[nuse-1] = &tmp->use[u];
  54. }
  55. val[t] = m;
  56. }
  57. }
  58. static int
  59. deadedge(int s, int d)
  60. {
  61. Edge *e;
  62. e = edge[s];
  63. if (e[0].dest == d && !e[0].dead)
  64. return 0;
  65. if (e[1].dest == d && !e[1].dead)
  66. return 0;
  67. return 1;
  68. }
  69. static void
  70. visitphi(Phi *p, int n, Fn *fn)
  71. {
  72. int v;
  73. uint a;
  74. v = Top;
  75. for (a=0; a<p->narg; a++)
  76. if (!deadedge(p->blk[a]->id, n))
  77. v = latmerge(v, latval(p->arg[a]));
  78. update(p->to.val, v, fn);
  79. }
  80. static int opfold(int, int, Con *, Con *, Fn *);
  81. static void
  82. visitins(Ins *i, Fn *fn)
  83. {
  84. int v, l, r;
  85. if (rtype(i->to) != RTmp)
  86. return;
  87. if (optab[i->op].canfold) {
  88. l = latval(i->arg[0]);
  89. if (!req(i->arg[1], R))
  90. r = latval(i->arg[1]);
  91. else
  92. r = CON_Z.val;
  93. if (l == Bot || r == Bot)
  94. v = Bot;
  95. else if (l == Top || r == Top)
  96. v = Top;
  97. else
  98. v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
  99. } else
  100. v = Bot;
  101. /* fprintf(stderr, "\nvisiting %s (%p)", optab[i->op].name, (void *)i); */
  102. update(i->to.val, v, fn);
  103. }
  104. static void
  105. visitjmp(Blk *b, int n, Fn *fn)
  106. {
  107. int l;
  108. switch (b->jmp.type) {
  109. case Jjnz:
  110. l = latval(b->jmp.arg);
  111. if (l == Bot) {
  112. edge[n][1].work = flowrk;
  113. edge[n][0].work = &edge[n][1];
  114. flowrk = &edge[n][0];
  115. }
  116. else if (iscon(&fn->con[l], 0, 0)) {
  117. assert(edge[n][0].dead);
  118. edge[n][1].work = flowrk;
  119. flowrk = &edge[n][1];
  120. }
  121. else {
  122. assert(edge[n][1].dead);
  123. edge[n][0].work = flowrk;
  124. flowrk = &edge[n][0];
  125. }
  126. break;
  127. case Jjmp:
  128. edge[n][0].work = flowrk;
  129. flowrk = &edge[n][0];
  130. break;
  131. case Jhlt:
  132. break;
  133. default:
  134. if (isret(b->jmp.type))
  135. break;
  136. die("unreachable");
  137. }
  138. }
  139. static void
  140. initedge(Edge *e, Blk *s)
  141. {
  142. if (s)
  143. e->dest = s->id;
  144. else
  145. e->dest = -1;
  146. e->dead = 1;
  147. e->work = 0;
  148. }
  149. static int
  150. renref(Ref *r)
  151. {
  152. int l;
  153. if (rtype(*r) == RTmp)
  154. if ((l=val[r->val]) != Bot) {
  155. *r = CON(l);
  156. return 1;
  157. }
  158. return 0;
  159. }
  160. /* require rpo, use, pred */
  161. void
  162. fold(Fn *fn)
  163. {
  164. Edge *e, start;
  165. Use *u;
  166. Blk *b, **pb;
  167. Phi *p, **pp;
  168. Ins *i;
  169. int t, d;
  170. uint n, a;
  171. val = emalloc(fn->ntmp * sizeof val[0]);
  172. edge = emalloc(fn->nblk * sizeof edge[0]);
  173. usewrk = vnew(0, sizeof usewrk[0], PHeap);
  174. for (t=0; t<fn->ntmp; t++)
  175. val[t] = Top;
  176. for (n=0; n<fn->nblk; n++) {
  177. b = fn->rpo[n];
  178. b->visit = 0;
  179. initedge(&edge[n][0], b->s1);
  180. initedge(&edge[n][1], b->s2);
  181. }
  182. initedge(&start, fn->start);
  183. flowrk = &start;
  184. nuse = 0;
  185. /* 1. find out constants and dead cfg edges */
  186. for (;;) {
  187. e = flowrk;
  188. if (e) {
  189. flowrk = e->work;
  190. e->work = 0;
  191. if (e->dest == -1 || !e->dead)
  192. continue;
  193. e->dead = 0;
  194. n = e->dest;
  195. b = fn->rpo[n];
  196. for (p=b->phi; p; p=p->link)
  197. visitphi(p, n, fn);
  198. if (b->visit == 0) {
  199. for (i=b->ins; i<&b->ins[b->nins]; i++)
  200. visitins(i, fn);
  201. visitjmp(b, n, fn);
  202. }
  203. b->visit++;
  204. assert(b->jmp.type != Jjmp
  205. || !edge[n][0].dead
  206. || flowrk == &edge[n][0]);
  207. }
  208. else if (nuse) {
  209. u = usewrk[--nuse];
  210. n = u->bid;
  211. b = fn->rpo[n];
  212. if (b->visit == 0)
  213. continue;
  214. switch (u->type) {
  215. case UPhi:
  216. visitphi(u->u.phi, u->bid, fn);
  217. break;
  218. case UIns:
  219. visitins(u->u.ins, fn);
  220. break;
  221. case UJmp:
  222. visitjmp(b, n, fn);
  223. break;
  224. default:
  225. die("unreachable");
  226. }
  227. }
  228. else
  229. break;
  230. }
  231. if (debug['F']) {
  232. fprintf(stderr, "\n> SCCP findings:");
  233. for (t=Tmp0; t<fn->ntmp; t++) {
  234. if (val[t] == Bot)
  235. continue;
  236. fprintf(stderr, "\n%10s: ", fn->tmp[t].name);
  237. if (val[t] == Top)
  238. fprintf(stderr, "Top");
  239. else
  240. printref(CON(val[t]), fn, stderr);
  241. }
  242. fprintf(stderr, "\n dead code: ");
  243. }
  244. /* 2. trim dead code, replace constants */
  245. d = 0;
  246. for (pb=&fn->start; (b=*pb);) {
  247. if (b->visit == 0) {
  248. d = 1;
  249. if (debug['F'])
  250. fprintf(stderr, "%s ", b->name);
  251. edgedel(b, &b->s1);
  252. edgedel(b, &b->s2);
  253. *pb = b->link;
  254. continue;
  255. }
  256. for (pp=&b->phi; (p=*pp);)
  257. if (val[p->to.val] != Bot)
  258. *pp = p->link;
  259. else {
  260. for (a=0; a<p->narg; a++)
  261. if (!deadedge(p->blk[a]->id, b->id))
  262. renref(&p->arg[a]);
  263. pp = &p->link;
  264. }
  265. for (i=b->ins; i<&b->ins[b->nins]; i++)
  266. if (renref(&i->to))
  267. *i = (Ins){.op = Onop};
  268. else {
  269. for (n=0; n<2; n++)
  270. renref(&i->arg[n]);
  271. if (isstore(i->op))
  272. if (req(i->arg[0], UNDEF))
  273. *i = (Ins){.op = Onop};
  274. }
  275. renref(&b->jmp.arg);
  276. if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) {
  277. if (iscon(&fn->con[b->jmp.arg.val], 0, 0)) {
  278. edgedel(b, &b->s1);
  279. b->s1 = b->s2;
  280. b->s2 = 0;
  281. } else
  282. edgedel(b, &b->s2);
  283. b->jmp.type = Jjmp;
  284. b->jmp.arg = R;
  285. }
  286. pb = &b->link;
  287. }
  288. if (debug['F']) {
  289. if (!d)
  290. fprintf(stderr, "(none)");
  291. fprintf(stderr, "\n\n> After constant folding:\n");
  292. printfn(fn, stderr);
  293. }
  294. free(val);
  295. free(edge);
  296. vfree(usewrk);
  297. }
  298. /* boring folding code */
  299. static int
  300. foldint(Con *res, int op, int w, Con *cl, Con *cr)
  301. {
  302. union {
  303. int64_t s;
  304. uint64_t u;
  305. float fs;
  306. double fd;
  307. } l, r;
  308. uint64_t x;
  309. Sym sym;
  310. int typ;
  311. memset(&sym, 0, sizeof sym);
  312. typ = CBits;
  313. l.s = cl->bits.i;
  314. r.s = cr->bits.i;
  315. if (op == Oadd) {
  316. if (cl->type == CAddr) {
  317. if (cr->type == CAddr)
  318. return 1;
  319. typ = CAddr;
  320. sym = cl->sym;
  321. }
  322. else if (cr->type == CAddr) {
  323. typ = CAddr;
  324. sym = cr->sym;
  325. }
  326. }
  327. else if (op == Osub) {
  328. if (cl->type == CAddr) {
  329. if (cr->type != CAddr) {
  330. typ = CAddr;
  331. sym = cl->sym;
  332. } else if (!symeq(cl->sym, cr->sym))
  333. return 1;
  334. }
  335. else if (cr->type == CAddr)
  336. return 1;
  337. }
  338. else if (cl->type == CAddr || cr->type == CAddr)
  339. return 1;
  340. if (op == Odiv || op == Orem || op == Oudiv || op == Ourem) {
  341. if (iscon(cr, w, 0))
  342. return 1;
  343. if (op == Odiv || op == Orem) {
  344. x = w ? INT64_MIN : INT32_MIN;
  345. if (iscon(cr, w, -1))
  346. if (iscon(cl, w, x))
  347. return 1;
  348. }
  349. }
  350. switch (op) {
  351. case Oadd: x = l.u + r.u; break;
  352. case Osub: x = l.u - r.u; break;
  353. case Oneg: x = -l.u; break;
  354. case Odiv: x = w ? l.s / r.s : (int32_t)l.s / (int32_t)r.s; break;
  355. case Orem: x = w ? l.s % r.s : (int32_t)l.s % (int32_t)r.s; break;
  356. case Oudiv: x = w ? l.u / r.u : (uint32_t)l.u / (uint32_t)r.u; break;
  357. case Ourem: x = w ? l.u % r.u : (uint32_t)l.u % (uint32_t)r.u; break;
  358. case Omul: x = l.u * r.u; break;
  359. case Oand: x = l.u & r.u; break;
  360. case Oor: x = l.u | r.u; break;
  361. case Oxor: x = l.u ^ r.u; break;
  362. case Osar: x = (w ? l.s : (int32_t)l.s) >> (r.u & (31|w<<5)); break;
  363. case Oshr: x = (w ? l.u : (uint32_t)l.u) >> (r.u & (31|w<<5)); break;
  364. case Oshl: x = l.u << (r.u & (31|w<<5)); break;
  365. case Oextsb: x = (int8_t)l.u; break;
  366. case Oextub: x = (uint8_t)l.u; break;
  367. case Oextsh: x = (int16_t)l.u; break;
  368. case Oextuh: x = (uint16_t)l.u; break;
  369. case Oextsw: x = (int32_t)l.u; break;
  370. case Oextuw: x = (uint32_t)l.u; break;
  371. case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break;
  372. case Ostoui: x = w ? (uint64_t)cl->bits.s : (uint32_t)cl->bits.s; break;
  373. case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break;
  374. case Odtoui: x = w ? (uint64_t)cl->bits.d : (uint32_t)cl->bits.d; break;
  375. case Ocast:
  376. x = l.u;
  377. if (cl->type == CAddr) {
  378. typ = CAddr;
  379. sym = cl->sym;
  380. }
  381. break;
  382. default:
  383. if (Ocmpw <= op && op <= Ocmpl1) {
  384. if (op <= Ocmpw1) {
  385. l.u = (int32_t)l.u;
  386. r.u = (int32_t)r.u;
  387. } else
  388. op -= Ocmpl - Ocmpw;
  389. switch (op - Ocmpw) {
  390. case Ciule: x = l.u <= r.u; break;
  391. case Ciult: x = l.u < r.u; break;
  392. case Cisle: x = l.s <= r.s; break;
  393. case Cislt: x = l.s < r.s; break;
  394. case Cisgt: x = l.s > r.s; break;
  395. case Cisge: x = l.s >= r.s; break;
  396. case Ciugt: x = l.u > r.u; break;
  397. case Ciuge: x = l.u >= r.u; break;
  398. case Cieq: x = l.u == r.u; break;
  399. case Cine: x = l.u != r.u; break;
  400. default: die("unreachable");
  401. }
  402. }
  403. else if (Ocmps <= op && op <= Ocmps1) {
  404. switch (op - Ocmps) {
  405. case Cfle: x = l.fs <= r.fs; break;
  406. case Cflt: x = l.fs < r.fs; break;
  407. case Cfgt: x = l.fs > r.fs; break;
  408. case Cfge: x = l.fs >= r.fs; break;
  409. case Cfne: x = l.fs != r.fs; break;
  410. case Cfeq: x = l.fs == r.fs; break;
  411. case Cfo: x = l.fs < r.fs || l.fs >= r.fs; break;
  412. case Cfuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
  413. default: die("unreachable");
  414. }
  415. }
  416. else if (Ocmpd <= op && op <= Ocmpd1) {
  417. switch (op - Ocmpd) {
  418. case Cfle: x = l.fd <= r.fd; break;
  419. case Cflt: x = l.fd < r.fd; break;
  420. case Cfgt: x = l.fd > r.fd; break;
  421. case Cfge: x = l.fd >= r.fd; break;
  422. case Cfne: x = l.fd != r.fd; break;
  423. case Cfeq: x = l.fd == r.fd; break;
  424. case Cfo: x = l.fd < r.fd || l.fd >= r.fd; break;
  425. case Cfuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
  426. default: die("unreachable");
  427. }
  428. }
  429. else
  430. die("unreachable");
  431. }
  432. *res = (Con){.type=typ, .sym=sym, .bits={.i=x}};
  433. return 0;
  434. }
  435. static void
  436. foldflt(Con *res, int op, int w, Con *cl, Con *cr)
  437. {
  438. float xs, ls, rs;
  439. double xd, ld, rd;
  440. if (cl->type != CBits || cr->type != CBits)
  441. err("invalid address operand for '%s'", optab[op].name);
  442. *res = (Con){.type = CBits};
  443. memset(&res->bits, 0, sizeof(res->bits));
  444. if (w) {
  445. ld = cl->bits.d;
  446. rd = cr->bits.d;
  447. switch (op) {
  448. case Oadd: xd = ld + rd; break;
  449. case Osub: xd = ld - rd; break;
  450. case Oneg: xd = -ld; break;
  451. case Odiv: xd = ld / rd; break;
  452. case Omul: xd = ld * rd; break;
  453. case Oswtof: xd = (int32_t)cl->bits.i; break;
  454. case Ouwtof: xd = (uint32_t)cl->bits.i; break;
  455. case Osltof: xd = (int64_t)cl->bits.i; break;
  456. case Oultof: xd = (uint64_t)cl->bits.i; break;
  457. case Oexts: xd = cl->bits.s; break;
  458. case Ocast: xd = ld; break;
  459. default: die("unreachable");
  460. }
  461. res->bits.d = xd;
  462. res->flt = 2;
  463. } else {
  464. ls = cl->bits.s;
  465. rs = cr->bits.s;
  466. switch (op) {
  467. case Oadd: xs = ls + rs; break;
  468. case Osub: xs = ls - rs; break;
  469. case Oneg: xs = -ls; break;
  470. case Odiv: xs = ls / rs; break;
  471. case Omul: xs = ls * rs; break;
  472. case Oswtof: xs = (int32_t)cl->bits.i; break;
  473. case Ouwtof: xs = (uint32_t)cl->bits.i; break;
  474. case Osltof: xs = (int64_t)cl->bits.i; break;
  475. case Oultof: xs = (uint64_t)cl->bits.i; break;
  476. case Otruncd: xs = cl->bits.d; break;
  477. case Ocast: xs = ls; break;
  478. default: die("unreachable");
  479. }
  480. res->bits.s = xs;
  481. res->flt = 1;
  482. }
  483. }
  484. static int
  485. opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
  486. {
  487. Ref r;
  488. Con c;
  489. if (cls == Kw || cls == Kl) {
  490. if (foldint(&c, op, cls == Kl, cl, cr))
  491. return Bot;
  492. } else
  493. foldflt(&c, op, cls == Kd, cl, cr);
  494. if (!KWIDE(cls))
  495. c.bits.i &= 0xffffffff;
  496. r = newcon(&c, fn);
  497. assert(!(cls == Ks || cls == Kd) || c.flt);
  498. return r.val;
  499. }