Browse Source

Extend all FOLD rules to work on 64 bit integers.

Mike Pall 14 years ago
parent
commit
42f9b38663
2 changed files with 131 additions and 32 deletions
  1. 1 0
      src/lj_iropt.h
  2. 130 32
      src/lj_opt_fold.c

+ 1 - 0
src/lj_iropt.h

@@ -105,6 +105,7 @@ enum {
 };
 
 #define INTFOLD(k)	((J->fold.ins.i = (k)), (TRef)KINTFOLD)
+#define INT64FOLD(k)	(lj_ir_kint64(J, (k)))
 #define CONDFOLD(cond)	((TRef)FAILFOLD + (TRef)(cond))
 #define LEFTFOLD	(J->fold.ins.op1)
 #define RIGHTFOLD	(J->fold.ins.op2)

+ 130 - 32
src/lj_opt_fold.c

@@ -291,6 +291,98 @@ LJFOLDF(kfold_intcomp0)
   return NEXTFOLD;
 }
 
+/* -- Constant folding for 64 bit integers -------------------------------- */
+
+static uint64_t kfold_int64arith(uint64_t k1, uint64_t k2, IROp op)
+{
+  switch (op) {
+  case IR_ADD: k1 += k2; break;
+  case IR_SUB: k1 -= k2; break;
+  case IR_MUL: k1 *= k2; break;
+  case IR_BAND: k1 &= k2; break;
+  case IR_BOR: k1 |= k2; break;
+  case IR_BXOR: k1 ^= k2; break;
+  default: lua_assert(0); break;
+  }
+  return k1;
+}
+
+LJFOLD(ADD KINT64 KINT64)
+LJFOLD(SUB KINT64 KINT64)
+LJFOLD(MUL KINT64 KINT64)
+LJFOLD(BAND KINT64 KINT64)
+LJFOLD(BOR KINT64 KINT64)
+LJFOLD(BXOR KINT64 KINT64)
+LJFOLDF(kfold_int64arith)
+{
+  return INT64FOLD(kfold_int64arith(ir_k64(fleft)->u64,
+				    ir_k64(fright)->u64, (IROp)fins->o));
+}
+
+LJFOLD(BSHL KINT64 KINT)
+LJFOLD(BSHR KINT64 KINT)
+LJFOLD(BSAR KINT64 KINT)
+LJFOLD(BROL KINT64 KINT)
+LJFOLD(BROR KINT64 KINT)
+LJFOLDF(kfold_int64shift)
+{
+  uint64_t k = ir_k64(fleft)->u64;
+  int32_t sh = (fright->i & 63);
+  switch ((IROp)fins->o) {
+  case IR_BSHL: k <<= sh; break;
+  case IR_BSHR: k >>= sh; break;
+  case IR_BSAR: k = (uint64_t)((int64_t)k >> sh); break;
+  case IR_BROL: k = lj_rol(k, sh); break;
+  case IR_BROR: k = lj_ror(k, sh); break;
+  default: lua_assert(0); break;
+  }
+  return INT64FOLD(k);
+}
+
+LJFOLD(BNOT KINT64)
+LJFOLDF(kfold_bnot64)
+{
+  return INT64FOLD(~ir_k64(fleft)->u64);
+}
+
+LJFOLD(BSWAP KINT64)
+LJFOLDF(kfold_bswap64)
+{
+  return INT64FOLD(lj_bswap64(ir_k64(fleft)->u64));
+}
+
+LJFOLD(LT KINT64 KINT)
+LJFOLD(GE KINT64 KINT)
+LJFOLD(LE KINT64 KINT)
+LJFOLD(GT KINT64 KINT)
+LJFOLD(ULT KINT64 KINT)
+LJFOLD(UGE KINT64 KINT)
+LJFOLD(ULE KINT64 KINT)
+LJFOLD(UGT KINT64 KINT)
+LJFOLDF(kfold_int64comp)
+{
+  uint64_t a = ir_k64(fleft)->u64, b = ir_k64(fright)->u64;
+  switch ((IROp)fins->o) {
+  case IR_LT: return CONDFOLD(a < b);
+  case IR_GE: return CONDFOLD(a >= b);
+  case IR_LE: return CONDFOLD(a <= b);
+  case IR_GT: return CONDFOLD(a > b);
+  case IR_ULT: return CONDFOLD((uint64_t)a < (uint64_t)b);
+  case IR_UGE: return CONDFOLD((uint64_t)a >= (uint64_t)b);
+  case IR_ULE: return CONDFOLD((uint64_t)a <= (uint64_t)b);
+  case IR_UGT: return CONDFOLD((uint64_t)a > (uint64_t)b);
+  default: lua_assert(0); return FAILFOLD;
+  }
+}
+
+LJFOLD(UGE any KINT64)
+LJFOLDF(kfold_int64comp0)
+{
+  if (ir_k64(fright)->u64 == 0)
+    return DROPFOLD;
+  return NEXTFOLD;
+}
+
 /* -- Constant folding for strings ---------------------------------------- */
 
 LJFOLD(SNEW KPTR KINT)
@@ -385,16 +477,16 @@ LJFOLDF(kfold_toi64_kint)
 {
   lua_assert(fins->op2 == IRTOINT_ZEXT64 || fins->op2 == IRTOINT_SEXT64);
   if (fins->op2 == IRTOINT_ZEXT64)
-    return lj_ir_kint64(J, (int64_t)(uint32_t)fleft->i);
+    return INT64FOLD((uint64_t)(uint32_t)fleft->i);
   else
-    return lj_ir_kint64(J, (int64_t)(int32_t)fleft->i);
+    return INT64FOLD((uint64_t)(int32_t)fleft->i);
 }
 
 LJFOLD(TOI64 KNUM any)
 LJFOLDF(kfold_toi64_knum)
 {
   lua_assert(fins->op2 == IRTOINT_TRUNCI64);
-  return lj_ir_kint64(J, (int64_t)knumleft);
+  return INT64FOLD((uint64_t)(int64_t)knumleft);
 }
 
 LJFOLD(TOSTR KNUM)
@@ -432,6 +524,8 @@ LJFOLD(EQ KNULL any)
 LJFOLD(NE KNULL any)
 LJFOLD(EQ KINT KINT)  /* Constants are unique, so same refs <==> same value. */
 LJFOLD(NE KINT KINT)
+LJFOLD(EQ KINT64 KINT64)
+LJFOLD(NE KINT64 KINT64)
 LJFOLD(EQ KGC KGC)
 LJFOLD(NE KGC KGC)
 LJFOLDF(kfold_kref)
@@ -790,7 +884,7 @@ LJFOLDF(simplify_intmul_k64)
 
 {
   if (ir_kint64(fright)->u64 == 0)  /* i * 0 ==> 0 */
-    return lj_ir_kint64(J, 0);
+    return INT64FOLD(0);
   else if (ir_kint64(fright)->u64 < 0x80000000u)
     return simplify_intmul_k(J, (int32_t)ir_kint64(fright)->u64);
   return NEXTFOLD;
@@ -893,31 +987,40 @@ LJFOLDF(simplify_intsubaddadd_cancel)
 }
 
 LJFOLD(BAND any KINT)
+LJFOLD(BAND any KINT64)
 LJFOLDF(simplify_band_k)
 {
-  if (fright->i == 0)  /* i & 0 ==> 0 */
+  int64_t k = fright->o == IR_KINT ? (int64_t)fright->i :
+				     (int64_t)ir_k64(fright)->u64;
+  if (k == 0)  /* i & 0 ==> 0 */
     return RIGHTFOLD;
-  if (fright->i == -1)  /* i & -1 ==> i */
+  if (k == -1)  /* i & -1 ==> i */
     return LEFTFOLD;
   return NEXTFOLD;
 }
 
 LJFOLD(BOR any KINT)
+LJFOLD(BOR any KINT64)
 LJFOLDF(simplify_bor_k)
 {
-  if (fright->i == 0)  /* i | 0 ==> i */
+  int64_t k = fright->o == IR_KINT ? (int64_t)fright->i :
+				     (int64_t)ir_k64(fright)->u64;
+  if (k == 0)  /* i | 0 ==> i */
     return LEFTFOLD;
-  if (fright->i == -1)  /* i | -1 ==> -1 */
+  if (k == -1)  /* i | -1 ==> -1 */
     return RIGHTFOLD;
   return NEXTFOLD;
 }
 
 LJFOLD(BXOR any KINT)
+LJFOLD(BXOR any KINT64)
 LJFOLDF(simplify_bxor_k)
 {
-  if (fright->i == 0)  /* i xor 0 ==> i */
+  int64_t k = fright->o == IR_KINT ? (int64_t)fright->i :
+				     (int64_t)ir_k64(fright)->u64;
+  if (k == 0)  /* i xor 0 ==> i */
     return LEFTFOLD;
-  if (fright->i == -1) {  /* i xor -1 ==> ~i */
+  if (k == -1) {  /* i xor -1 ==> ~i */
     fins->o = IR_BNOT;
     fins->op2 = 0;
     return RETRYFOLD;
@@ -976,9 +1079,13 @@ LJFOLDF(simplify_shift_andk)
 
 LJFOLD(BSHL KINT any)
 LJFOLD(BSHR KINT any)
+LJFOLD(BSHL KINT64 any)
+LJFOLD(BSHR KINT64 any)
 LJFOLDF(simplify_shift1_ki)
 {
-  if (fleft->i == 0)  /* 0 o i ==> 0 */
+  int64_t k = fleft->o == IR_KINT ? (int64_t)fleft->i :
+				    (int64_t)ir_k64(fleft)->u64;
+  if (k == 0)  /* 0 o i ==> 0 */
     return LEFTFOLD;
   return NEXTFOLD;
 }
@@ -986,29 +1093,15 @@ LJFOLDF(simplify_shift1_ki)
 LJFOLD(BSAR KINT any)
 LJFOLD(BROL KINT any)
 LJFOLD(BROR KINT any)
-LJFOLDF(simplify_shift2_ki)
-{
-  if (fleft->i == 0 || fleft->i == -1)  /* 0 o i ==> 0; -1 o i ==> -1 */
-    return LEFTFOLD;
-  return NEXTFOLD;
-}
-
-LJFOLD(BSHL KINT64 any)
-LJFOLD(BSHR KINT64 any)
-LJFOLDF(simplify_shift1_ki64)
-{
-  if (ir_kint64(fleft)->u64 == 0)  /* 0 o i ==> 0 */
-    return LEFTFOLD;
-  return NEXTFOLD;
-}
-
 LJFOLD(BSAR KINT64 any)
 LJFOLD(BROL KINT64 any)
 LJFOLD(BROR KINT64 any)
-LJFOLDF(simplify_shift2_ki64)
+LJFOLDF(simplify_shift2_ki)
 {
-  if (ir_kint64(fleft)->u64 == 0 || (int64_t)ir_kint64(fleft)->u64 == -1)
-    return LEFTFOLD;  /* 0 o i ==> 0; -1 o i ==> -1 */
+  int64_t k = fleft->o == IR_KINT ? (int64_t)fleft->i :
+				    (int64_t)ir_k64(fleft)->u64;
+  if (k == 0 || k == -1)  /* 0 o i ==> 0; -1 o i ==> -1 */
+    return LEFTFOLD;
   return NEXTFOLD;
 }
 
@@ -1035,11 +1128,16 @@ LJFOLDF(reassoc_intarith_k)
 }
 
 LJFOLD(ADD ADD KINT64)
+LJFOLD(MUL MUL KINT64)
+LJFOLD(BAND BAND KINT64)
+LJFOLD(BOR BOR KINT64)
+LJFOLD(BXOR BXOR KINT64)
 LJFOLDF(reassoc_intarith_k64)
 {
   IRIns *irk = IR(fleft->op2);
   if (irk->o == IR_KINT64) {
-    uint64_t k = ir_kint64(irk)->u64 + ir_kint64(fright)->u64;
+    uint64_t k = kfold_int64arith(ir_k64(irk)->u64,
+				  ir_k64(fright)->u64, (IROp)fins->o);
     PHIBARRIER(fleft);
     fins->op1 = fleft->op1;
     fins->op2 = (IRRef1)lj_ir_kint64(J, k);
@@ -1085,7 +1183,7 @@ LJFOLDF(reassoc_shift)
     int32_t k = (irk->i & mask) + (fright->i & mask);
     if (k > mask) {  /* Combined shift too wide? */
       if (fins->o == IR_BSHL || fins->o == IR_BSHR)
-	return mask == 31 ? INTFOLD(0) : lj_ir_kint64(J, 0);
+	return mask == 31 ? INTFOLD(0) : INT64FOLD(0);
       else if (fins->o == IR_BSAR)
 	k = mask;
       else