Explorar o código

constant folding and API arithmetic with integers

Roberto Ierusalimschy %!s(int64=12) %!d(string=hai) anos
pai
achega
67532d5a10
Modificáronse 4 ficheiros con 83 adicións e 31 borrados
  1. 4 11
      lapi.c
  2. 31 16
      lcode.c
  3. 45 2
      lobject.c
  4. 3 2
      lobject.h

+ 4 - 11
lapi.c

@@ -1,5 +1,5 @@
 /*
 /*
-** $Id: lapi.c,v 2.177 2013/04/26 19:51:17 roberto Exp roberto $
+** $Id: lapi.c,v 2.178 2013/04/29 17:12:50 roberto Exp roberto $
 ** Lua API
 ** Lua API
 ** See Copyright Notice in lua.h
 ** See Copyright Notice in lua.h
 */
 */
@@ -297,9 +297,6 @@ LUA_API int lua_rawequal (lua_State *L, int index1, int index2) {
 
 
 
 
 LUA_API void lua_arith (lua_State *L, int op) {
 LUA_API void lua_arith (lua_State *L, int op) {
-  StkId o1;  /* 1st operand */
-  StkId o2;  /* 2nd operand */
-  lua_Number n1, n2;
   lua_lock(L);
   lua_lock(L);
   if (op != LUA_OPUNM) /* all other operations expect two operands */
   if (op != LUA_OPUNM) /* all other operations expect two operands */
     api_checknelems(L, 2);
     api_checknelems(L, 2);
@@ -308,13 +305,9 @@ LUA_API void lua_arith (lua_State *L, int op) {
     setobjs2s(L, L->top, L->top - 1);
     setobjs2s(L, L->top, L->top - 1);
     L->top++;
     L->top++;
   }
   }
-  o1 = L->top - 2;
-  o2 = L->top - 1;
-  if (tonumber(o1, &n1) && tonumber(o2, &n2)) {
-    setnvalue(o1, luaO_numarith(op, n1, n2));
-  }
-  else luaT_trybinTM(L, o1, o2, o1, cast(TMS, op - LUA_OPADD + TM_ADD));
-  L->top--;
+  /* first operand at top - 2, second at top - 1; result go to top - 2 */
+  luaO_arith(L, op, L->top - 2, L->top - 1, L->top - 2);
+  L->top--;  /* remove second operand */
   lua_unlock(L);
   lua_unlock(L);
 }
 }
 
 

+ 31 - 16
lcode.c

@@ -1,5 +1,5 @@
 /*
 /*
-** $Id: lcode.c,v 2.66 2013/04/26 13:07:53 roberto Exp roberto $
+** $Id: lcode.c,v 2.67 2013/04/29 16:57:48 roberto Exp roberto $
 ** Code generator for Lua
 ** Code generator for Lua
 ** See Copyright Notice in lua.h
 ** See Copyright Notice in lua.h
 */
 */
@@ -29,8 +29,18 @@
 #define hasjumps(e)	((e)->t != (e)->f)
 #define hasjumps(e)	((e)->t != (e)->f)
 
 
 
 
-static int isnumeral(expdesc *e) {
-  return (e->k == VKFLT && e->t == NO_JUMP && e->f == NO_JUMP);
+static int tonumeral(expdesc *e, TValue *v) {
+  if (e->t != NO_JUMP || e->f != NO_JUMP)
+    return 0;  /* not a numeral */
+  switch (e->k) {
+    case VKINT:
+      if (v) setivalue(v, e->u.ival);
+      return 1;
+    case VKFLT:
+      if (v) setnvalue(v, e->u.nval);
+      return 1;
+    default: return 0;
+  }
 }
 }
 
 
 
 
@@ -730,21 +740,28 @@ void luaK_indexed (FuncState *fs, expdesc *t, expdesc *k) {
 
 
 
 
 static int constfolding (OpCode op, expdesc *e1, expdesc *e2) {
 static int constfolding (OpCode op, expdesc *e1, expdesc *e2) {
-  lua_Number r;
-  if (!isnumeral(e1) || !isnumeral(e2)) return 0;
-  if ((op == OP_DIV || op == OP_IDIV || op == OP_MOD) && e2->u.nval == 0)
-    return 0;  /* do not attempt to divide by 0 */
-  r = luaO_numarith(op - OP_ADD + LUA_OPADD, e1->u.nval, e2->u.nval);
-  e1->u.nval = r;
+  TValue v1, v2, res;
+  lua_Integer i2;
+  if (!tonumeral(e1, &v1) || !tonumeral(e2, &v2))
+    return 0;
+  if ((op == OP_IDIV || op == OP_MOD) && tointeger(&v2, &i2) && i2 == 0)
+    return 0;  /* avoid division by 0 at compile time */
+  luaO_arith(NULL, op - OP_ADD + LUA_OPADD, &v1, &v2, &res);
+  if (ttisinteger(&res)) {
+    e1->k = VKINT;
+    e1->u.ival = ivalue(&res);
+  }
+  else {
+    e1->k = VKFLT;
+    e1->u.nval = fltvalue(&res);
+  }
   return 1;
   return 1;
 }
 }
 
 
 
 
 static void codearith (FuncState *fs, OpCode op,
 static void codearith (FuncState *fs, OpCode op,
                        expdesc *e1, expdesc *e2, int line) {
                        expdesc *e1, expdesc *e2, int line) {
-  if (constfolding(op, e1, e2))
-    return;
-  else {
+  if (!constfolding(op, e1, e2)) {  /* could not fold operation? */
     int o2 = (op != OP_UNM && op != OP_LEN) ? luaK_exp2RK(fs, e2) : 0;
     int o2 = (op != OP_UNM && op != OP_LEN) ? luaK_exp2RK(fs, e2) : 0;
     int o1 = luaK_exp2RK(fs, e1);
     int o1 = luaK_exp2RK(fs, e1);
     if (o1 > o2) {
     if (o1 > o2) {
@@ -783,9 +800,7 @@ void luaK_prefix (FuncState *fs, UnOpr op, expdesc *e, int line) {
   e2.t = e2.f = NO_JUMP; e2.k = VKFLT; e2.u.nval = 0;
   e2.t = e2.f = NO_JUMP; e2.k = VKFLT; e2.u.nval = 0;
   switch (op) {
   switch (op) {
     case OPR_MINUS: {
     case OPR_MINUS: {
-      if (isnumeral(e))  /* minus constant? */
-        e->u.nval = luai_numunm(NULL, e->u.nval);  /* fold it */
-      else {
+      if (!constfolding(OP_UNM, e, e)) {  /* cannot fold it? */
         luaK_exp2anyreg(fs, e);
         luaK_exp2anyreg(fs, e);
         codearith(fs, OP_UNM, e, &e2, line);
         codearith(fs, OP_UNM, e, &e2, line);
       }
       }
@@ -819,7 +834,7 @@ void luaK_infix (FuncState *fs, BinOpr op, expdesc *v) {
     case OPR_ADD: case OPR_SUB:
     case OPR_ADD: case OPR_SUB:
     case OPR_MUL: case OPR_DIV: case OPR_IDIV:
     case OPR_MUL: case OPR_DIV: case OPR_IDIV:
     case OPR_MOD: case OPR_POW: {
     case OPR_MOD: case OPR_POW: {
-      if (!isnumeral(v)) luaK_exp2RK(fs, v);
+      if (!tonumeral(v, NULL)) luaK_exp2RK(fs, v);
       break;
       break;
     }
     }
     default: {
     default: {

+ 45 - 2
lobject.c

@@ -1,5 +1,5 @@
 /*
 /*
-** $Id: lobject.c,v 2.60 2013/04/25 13:53:13 roberto Exp roberto $
+** $Id: lobject.c,v 2.61 2013/04/29 16:57:28 roberto Exp roberto $
 ** Some generic functions over Lua objects
 ** Some generic functions over Lua objects
 ** See Copyright Notice in lua.h
 ** See Copyright Notice in lua.h
 */
 */
@@ -70,7 +70,21 @@ int luaO_ceillog2 (unsigned int x) {
 }
 }
 
 
 
 
-lua_Number luaO_numarith (int op, lua_Number v1, lua_Number v2) {
+static lua_Integer intarith (lua_State *L, int op, lua_Integer v1,
+                                                   lua_Integer v2) {
+  switch (op) {
+    case LUA_OPADD: return intop(+, v1, v2);
+    case LUA_OPSUB:return intop(-, v1, v2);
+    case LUA_OPMUL:return intop(*, v1, v2);
+    case LUA_OPMOD: return luaV_mod(L, v1, v2);
+    case LUA_OPPOW: return luaV_pow(v1, v2);
+    case LUA_OPUNM: return -v1;
+    default: lua_assert(0); return 0;
+  }
+}
+
+
+static lua_Number numarith (int op, lua_Number v1, lua_Number v2) {
   switch (op) {
   switch (op) {
     case LUA_OPADD: return luai_numadd(NULL, v1, v2);
     case LUA_OPADD: return luai_numadd(NULL, v1, v2);
     case LUA_OPSUB: return luai_numsub(NULL, v1, v2);
     case LUA_OPSUB: return luai_numsub(NULL, v1, v2);
@@ -84,6 +98,35 @@ lua_Number luaO_numarith (int op, lua_Number v1, lua_Number v2) {
 }
 }
 
 
 
 
+void luaO_arith (lua_State *L, int op, const TValue *p1, const TValue *p2,
+                 TValue *res) {
+  if (op == LUA_OPIDIV) {  /* operates only on integers */
+    lua_Integer i1; lua_Integer i2;
+    if (tointeger(p1, &i1) && tointeger(p2, &i2)) {
+      setivalue(res, luaV_div(L, i1, i2));
+      return;
+    }
+    /* else go to the end */
+  }
+  else {  /* other operations */
+    lua_Number n1; lua_Number n2;
+    if (ttisinteger(p1) && ttisinteger(p2) && op != LUA_OPDIV &&
+        (op != LUA_OPPOW || ivalue(p2) >= 0)) {
+      setivalue(res, intarith(L, op, ivalue(p1), ivalue(p2)));
+      return;
+    }
+    else if (tonumber(p1, &n1) && tonumber(p2, &n2)) {
+      setnvalue(res, numarith(op, n1, n2));
+      return;
+    }
+    /* else go to the end */
+  }
+  /* could not perform raw operation; try metmethod */
+  lua_assert(L != NULL);  /* cannot fail when folding (compile time) */
+  luaT_trybinTM(L, p1, p2, res, cast(TMS, op - LUA_OPADD + TM_ADD));
+}
+
+
 int luaO_hexavalue (int c) {
 int luaO_hexavalue (int c) {
   if (lisdigit(c)) return c - '0';
   if (lisdigit(c)) return c - '0';
   else return ltolower(c) - 'a' + 10;
   else return ltolower(c) - 'a' + 10;

+ 3 - 2
lobject.h

@@ -1,5 +1,5 @@
 /*
 /*
-** $Id: lobject.h,v 2.74 2013/04/16 18:46:28 roberto Exp roberto $
+** $Id: lobject.h,v 2.75 2013/04/29 16:57:28 roberto Exp roberto $
 ** Type definitions for Lua objects
 ** Type definitions for Lua objects
 ** See Copyright Notice in lua.h
 ** See Copyright Notice in lua.h
 */
 */
@@ -496,7 +496,8 @@ LUAI_DDEC const TValue luaO_nilobject_;
 LUAI_FUNC int luaO_int2fb (unsigned int x);
 LUAI_FUNC int luaO_int2fb (unsigned int x);
 LUAI_FUNC int luaO_fb2int (int x);
 LUAI_FUNC int luaO_fb2int (int x);
 LUAI_FUNC int luaO_ceillog2 (unsigned int x);
 LUAI_FUNC int luaO_ceillog2 (unsigned int x);
-LUAI_FUNC lua_Number luaO_numarith (int op, lua_Number v1, lua_Number v2);
+LUAI_FUNC void luaO_arith (lua_State *L, int op, const TValue *p1,
+                           const TValue *p2, TValue *res);
 LUAI_FUNC int luaO_str2d (const char *s, size_t len, lua_Number *result);
 LUAI_FUNC int luaO_str2d (const char *s, size_t len, lua_Number *result);
 LUAI_FUNC int luaO_str2int (const char *s, lua_Integer *result);
 LUAI_FUNC int luaO_str2int (const char *s, lua_Integer *result);
 LUAI_FUNC int luaO_hexavalue (int c);
 LUAI_FUNC int luaO_hexavalue (int c);