浏览代码

Change in the handling of 'L->top' when calling metamethods

Instead of updating 'L->top' in every place that may call a
metamethod, the metamethod functions themselves (luaT_trybinTM and
luaT_callorderTM) correct the top. (When calling metamethods from
the C API, however, the callers must preserve 'L->top'.)
Roberto Ierusalimschy 6 年之前
父节点
当前提交
b80077b8f3
共有 10 个文件被更改,包括 75 次插入29 次删除
  1. 2 0
      lapi.c
  2. 2 0
      lobject.c
  3. 1 1
      lopcodes.c
  4. 9 3
      ltm.c
  5. 1 0
      ltm.h
  6. 24 20
      lvm.c
  7. 17 0
      testes/api.lua
  8. 1 1
      testes/coroutine.lua
  9. 13 2
      testes/events.lua
  10. 5 2
      testes/strings.lua

+ 2 - 0
lapi.c

@@ -329,12 +329,14 @@ LUA_API int lua_compare (lua_State *L, int index1, int index2, int op) {
   o1 = index2value(L, index1);
   o1 = index2value(L, index1);
   o2 = index2value(L, index2);
   o2 = index2value(L, index2);
   if (isvalid(L, o1) && isvalid(L, o2)) {
   if (isvalid(L, o1) && isvalid(L, o2)) {
+    ptrdiff_t top = savestack(L, L->top);
     switch (op) {
     switch (op) {
       case LUA_OPEQ: i = luaV_equalobj(L, o1, o2); break;
       case LUA_OPEQ: i = luaV_equalobj(L, o1, o2); break;
       case LUA_OPLT: i = luaV_lessthan(L, o1, o2); break;
       case LUA_OPLT: i = luaV_lessthan(L, o1, o2); break;
       case LUA_OPLE: i = luaV_lessequal(L, o1, o2); break;
       case LUA_OPLE: i = luaV_lessequal(L, o1, o2); break;
       default: api_check(L, 0, "invalid option");
       default: api_check(L, 0, "invalid option");
     }
     }
+    L->top = restorestack(L, top);
   }
   }
   lua_unlock(L);
   lua_unlock(L);
   return i;
   return i;

+ 2 - 0
lobject.c

@@ -127,7 +127,9 @@ void luaO_arith (lua_State *L, int op, const TValue *p1, const TValue *p2,
                  StkId res) {
                  StkId res) {
   if (!luaO_rawarith(L, op, p1, p2, s2v(res))) {
   if (!luaO_rawarith(L, op, p1, p2, s2v(res))) {
     /* could not perform raw operation; try metamethod */
     /* could not perform raw operation; try metamethod */
+    ptrdiff_t top = savestack(L, L->top);
     luaT_trybinTM(L, p1, p2, res, cast(TMS, (op - LUA_OPADD) + TM_ADD));
     luaT_trybinTM(L, p1, p2, res, cast(TMS, (op - LUA_OPADD) + TM_ADD));
+    L->top = restorestack(L, top);
   }
   }
 }
 }
 
 

+ 1 - 1
lopcodes.c

@@ -101,7 +101,7 @@ LUAI_DDEF const lu_byte luaP_opmodes[NUM_OPCODES] = {
  ,opmode(0, 1, 0, 0, iABC)		/* OP_SETLIST */
  ,opmode(0, 1, 0, 0, iABC)		/* OP_SETLIST */
  ,opmode(0, 0, 0, 1, iABx)		/* OP_CLOSURE */
  ,opmode(0, 0, 0, 1, iABx)		/* OP_CLOSURE */
  ,opmode(1, 0, 0, 1, iABC)		/* OP_VARARG */
  ,opmode(1, 0, 0, 1, iABC)		/* OP_VARARG */
- ,opmode(0, 0, 0, 1, iABC)		/* OP_VARARGPREP */
+ ,opmode(0, 1, 0, 1, iABC)		/* OP_VARARGPREP */
  ,opmode(0, 0, 0, 0, iAx)		/* OP_EXTRAARG */
  ,opmode(0, 0, 0, 0, iAx)		/* OP_EXTRAARG */
 };
 };
 
 

+ 9 - 3
ltm.c

@@ -147,11 +147,9 @@ static int callbinTM (lua_State *L, const TValue *p1, const TValue *p2,
 
 
 void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2,
 void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2,
                     StkId res, TMS event) {
                     StkId res, TMS event) {
+  L->top = L->ci->top;
   if (!callbinTM(L, p1, p2, res, event)) {
   if (!callbinTM(L, p1, p2, res, event)) {
     switch (event) {
     switch (event) {
-      case TM_CONCAT:
-        luaG_concaterror(L, p1, p2);
-      /* call never returns, but to avoid warnings: *//* FALLTHROUGH */
       case TM_BAND: case TM_BOR: case TM_BXOR:
       case TM_BAND: case TM_BOR: case TM_BXOR:
       case TM_SHL: case TM_SHR: case TM_BNOT: {
       case TM_SHL: case TM_SHR: case TM_BNOT: {
         if (ttisnumber(p1) && ttisnumber(p2))
         if (ttisnumber(p1) && ttisnumber(p2))
@@ -167,6 +165,13 @@ void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2,
 }
 }
 
 
 
 
+void luaT_tryconcatTM (lua_State *L) {
+  StkId top = L->top;
+  if (!callbinTM(L, s2v(top - 2), s2v(top - 1), top - 2, TM_CONCAT))
+    luaG_concaterror(L, s2v(top - 2), s2v(top - 1));
+}
+
+
 void luaT_trybinassocTM (lua_State *L, const TValue *p1, const TValue *p2,
 void luaT_trybinassocTM (lua_State *L, const TValue *p1, const TValue *p2,
                                        StkId res, int flip, TMS event) {
                                        StkId res, int flip, TMS event) {
   if (flip)
   if (flip)
@@ -186,6 +191,7 @@ void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2,
 
 
 int luaT_callorderTM (lua_State *L, const TValue *p1, const TValue *p2,
 int luaT_callorderTM (lua_State *L, const TValue *p1, const TValue *p2,
                       TMS event) {
                       TMS event) {
+  L->top = L->ci->top;
   if (callbinTM(L, p1, p2, L->top, event))  /* try original event */
   if (callbinTM(L, p1, p2, L->top, event))  /* try original event */
     return !l_isfalse(s2v(L->top));
     return !l_isfalse(s2v(L->top));
 #if defined(LUA_COMPAT_LT_LE)
 #if defined(LUA_COMPAT_LT_LE)

+ 1 - 0
ltm.h

@@ -75,6 +75,7 @@ LUAI_FUNC void luaT_callTMres (lua_State *L, const TValue *f,
                             const TValue *p1, const TValue *p2, StkId p3);
                             const TValue *p1, const TValue *p2, StkId p3);
 LUAI_FUNC void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2,
 LUAI_FUNC void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2,
                               StkId res, TMS event);
                               StkId res, TMS event);
+LUAI_FUNC void luaT_tryconcatTM (lua_State *L);
 LUAI_FUNC void luaT_trybinassocTM (lua_State *L, const TValue *p1,
 LUAI_FUNC void luaT_trybinassocTM (lua_State *L, const TValue *p1,
        const TValue *p2, StkId res, int inv, TMS event);
        const TValue *p2, StkId res, int inv, TMS event);
 LUAI_FUNC void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2,
 LUAI_FUNC void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2,

+ 24 - 20
lvm.c

@@ -515,8 +515,11 @@ int luaV_equalobj (lua_State *L, const TValue *t1, const TValue *t2) {
   }
   }
   if (tm == NULL)  /* no TM? */
   if (tm == NULL)  /* no TM? */
     return 0;  /* objects are different */
     return 0;  /* objects are different */
-  luaT_callTMres(L, tm, t1, t2, L->top);  /* call TM */
-  return !l_isfalse(s2v(L->top));
+  else {
+    L->top = L->ci->top;
+    luaT_callTMres(L, tm, t1, t2, L->top);  /* call TM */
+    return !l_isfalse(s2v(L->top));
+  }
 }
 }
 
 
 
 
@@ -548,7 +551,7 @@ void luaV_concat (lua_State *L, int total) {
     int n = 2;  /* number of elements handled in this pass (at least 2) */
     int n = 2;  /* number of elements handled in this pass (at least 2) */
     if (!(ttisstring(s2v(top - 2)) || cvt2str(s2v(top - 2))) ||
     if (!(ttisstring(s2v(top - 2)) || cvt2str(s2v(top - 2))) ||
         !tostring(L, s2v(top - 1)))
         !tostring(L, s2v(top - 1)))
-      luaT_trybinTM(L, s2v(top - 2), s2v(top - 1), top - 2, TM_CONCAT);
+      luaT_tryconcatTM(L);
     else if (isemptystr(s2v(top - 1)))  /* second operand is empty? */
     else if (isemptystr(s2v(top - 1)))  /* second operand is empty? */
       cast_void(tostring(L, s2v(top - 2)));  /* result is first operand */
       cast_void(tostring(L, s2v(top - 2)));  /* result is first operand */
     else if (isemptystr(s2v(top - 2))) {  /* first operand is empty string? */
     else if (isemptystr(s2v(top - 2))) {  /* first operand is empty string? */
@@ -747,7 +750,7 @@ void luaV_finishOp (lua_State *L) {
       break;
       break;
     }
     }
     case OP_CONCAT: {
     case OP_CONCAT: {
-      StkId top = L->top - 1;  /* top when 'luaT_trybinTM' was called */
+      StkId top = L->top - 1;  /* top when 'luaT_tryconcatTM' was called */
       int a = GETARG_A(inst);      /* first element to concatenate */
       int a = GETARG_A(inst);      /* first element to concatenate */
       int total = cast_int(top - 1 - (base + a));  /* yet to concatenate */
       int total = cast_int(top - 1 - (base + a));  /* yet to concatenate */
       setobjs2s(L, top - 2, top);  /* put TM result in proper position */
       setobjs2s(L, top - 2, top);  /* put TM result in proper position */
@@ -801,7 +804,7 @@ void luaV_finishOp (lua_State *L) {
     setfltvalue(s2v(ra), fop(L, nb, fimm));  \
     setfltvalue(s2v(ra), fop(L, nb, fimm));  \
   }  \
   }  \
   else  \
   else  \
-    Protect(luaT_trybiniTM(L, v1, imm, flip, ra, tm)); }
+    ProtectNT(luaT_trybiniTM(L, v1, imm, flip, ra, tm)); }
 
 
 
 
 /*
 /*
@@ -836,7 +839,7 @@ void luaV_finishOp (lua_State *L) {
     setfltvalue(s2v(ra), fop(L, n1, n2));  \
     setfltvalue(s2v(ra), fop(L, n1, n2));  \
   }  \
   }  \
   else  \
   else  \
-    Protect(luaT_trybinTM(L, v1, v2, ra, tm)); }
+    ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); }
 
 
 
 
 /*
 /*
@@ -877,7 +880,7 @@ void luaV_finishOp (lua_State *L) {
       setfltvalue(s2v(ra), fop(L, n1, n2));  \
       setfltvalue(s2v(ra), fop(L, n1, n2));  \
     }  \
     }  \
     else  \
     else  \
-      Protect(luaT_trybinassocTM(L, v1, v2, ra, flip, tm)); } }
+      ProtectNT(luaT_trybinassocTM(L, v1, v2, ra, flip, tm)); } }
 
 
 
 
 /*
 /*
@@ -891,7 +894,7 @@ void luaV_finishOp (lua_State *L) {
     setfltvalue(s2v(ra), fop(L, n1, n2));  \
     setfltvalue(s2v(ra), fop(L, n1, n2));  \
   }  \
   }  \
   else  \
   else  \
-    Protect(luaT_trybinTM(L, v1, v2, ra, tm)); }
+    ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); }
 
 
 
 
 /*
 /*
@@ -906,7 +909,7 @@ void luaV_finishOp (lua_State *L) {
     setivalue(s2v(ra), op(L, i1, i2));  \
     setivalue(s2v(ra), op(L, i1, i2));  \
   }  \
   }  \
   else  \
   else  \
-    Protect(luaT_trybiniTM(L, v1, i2, TESTARG_k(i), ra, tm)); }
+    ProtectNT(luaT_trybiniTM(L, v1, i2, TESTARG_k(i), ra, tm)); }
 
 
 
 
 /*
 /*
@@ -920,7 +923,7 @@ void luaV_finishOp (lua_State *L) {
     setivalue(s2v(ra), op(L, i1, i2));  \
     setivalue(s2v(ra), op(L, i1, i2));  \
   }  \
   }  \
   else  \
   else  \
-    Protect(luaT_trybinTM(L, v1, v2, ra, tm)); }
+    ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); }
 
 
 
 
 /*
 /*
@@ -937,7 +940,7 @@ void luaV_finishOp (lua_State *L) {
         else if (ttisnumber(s2v(ra)) && ttisnumber(rb))  \
         else if (ttisnumber(s2v(ra)) && ttisnumber(rb))  \
           cond = opf(s2v(ra), rb);  \
           cond = opf(s2v(ra), rb);  \
         else  \
         else  \
-          Protect(cond = other(L, s2v(ra), rb));  \
+          ProtectNT(cond = other(L, s2v(ra), rb));  \
         docondjump(); }
         docondjump(); }
 
 
 
 
@@ -956,7 +959,7 @@ void luaV_finishOp (lua_State *L) {
         }  \
         }  \
         else {  \
         else {  \
           int isf = GETARG_C(i);  \
           int isf = GETARG_C(i);  \
-          Protect(cond = luaT_callorderiTM(L, s2v(ra), im, inv, isf, tm));  \
+          ProtectNT(cond = luaT_callorderiTM(L, s2v(ra), im, inv, isf, tm));  \
         }  \
         }  \
         docondjump(); }
         docondjump(); }
 
 
@@ -1094,7 +1097,8 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
     vmfetch();
     vmfetch();
     lua_assert(base == ci->func + 1);
     lua_assert(base == ci->func + 1);
     lua_assert(base <= L->top && L->top < L->stack + L->stacksize);
     lua_assert(base <= L->top && L->top < L->stack + L->stacksize);
-    lua_assert(ci->top < L->stack + L->stacksize);
+    /* invalidate top for instructions not expecting it */
+    lua_assert(isIT(i) || (L->top = base));
     vmdispatch (GET_OPCODE(i)) {
     vmdispatch (GET_OPCODE(i)) {
       vmcase(OP_MOVE) {
       vmcase(OP_MOVE) {
         setobjs2s(L, ra, RB(i));
         setobjs2s(L, ra, RB(i));
@@ -1359,7 +1363,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           if (TESTARG_k(i)) {
           if (TESTARG_k(i)) {
             ic = -ic;  ev = TM_SHL;
             ic = -ic;  ev = TM_SHL;
           }
           }
-          Protect(luaT_trybiniTM(L, rb, ic, 0, ra, ev));
+          ProtectNT(luaT_trybiniTM(L, rb, ic, 0, ra, ev));
         }
         }
         vmbreak;
         vmbreak;
       }
       }
@@ -1371,7 +1375,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), luaV_shiftl(ic, ib));
           setivalue(s2v(ra), luaV_shiftl(ic, ib));
         }
         }
         else
         else
-          Protect(luaT_trybiniTM(L, rb, ic, 1, ra, TM_SHL));
+          ProtectNT(luaT_trybiniTM(L, rb, ic, 1, ra, TM_SHL));
         vmbreak;
         vmbreak;
       }
       }
       vmcase(OP_ADD) {
       vmcase(OP_ADD) {
@@ -1422,7 +1426,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), luaV_shiftl(ib, -ic));
           setivalue(s2v(ra), luaV_shiftl(ib, -ic));
         }
         }
         else
         else
-          Protect(luaT_trybinTM(L, rb, rc, ra, TM_SHR));
+          ProtectNT(luaT_trybinTM(L, rb, rc, ra, TM_SHR));
         vmbreak;
         vmbreak;
       }
       }
       vmcase(OP_SHL) {
       vmcase(OP_SHL) {
@@ -1433,7 +1437,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), luaV_shiftl(ib, ic));
           setivalue(s2v(ra), luaV_shiftl(ib, ic));
         }
         }
         else
         else
-          Protect(luaT_trybinTM(L, rb, rc, ra, TM_SHL));
+          ProtectNT(luaT_trybinTM(L, rb, rc, ra, TM_SHL));
         vmbreak;
         vmbreak;
       }
       }
       vmcase(OP_UNM) {
       vmcase(OP_UNM) {
@@ -1447,7 +1451,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setfltvalue(s2v(ra), luai_numunm(L, nb));
           setfltvalue(s2v(ra), luai_numunm(L, nb));
         }
         }
         else
         else
-          Protect(luaT_trybinTM(L, rb, rb, ra, TM_UNM));
+          ProtectNT(luaT_trybinTM(L, rb, rb, ra, TM_UNM));
         vmbreak;
         vmbreak;
       }
       }
       vmcase(OP_BNOT) {
       vmcase(OP_BNOT) {
@@ -1457,7 +1461,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           setivalue(s2v(ra), intop(^, ~l_castS2U(0), ib));
           setivalue(s2v(ra), intop(^, ~l_castS2U(0), ib));
         }
         }
         else
         else
-          Protect(luaT_trybinTM(L, rb, rb, ra, TM_BNOT));
+          ProtectNT(luaT_trybinTM(L, rb, rb, ra, TM_BNOT));
         vmbreak;
         vmbreak;
       }
       }
       vmcase(OP_NOT) {
       vmcase(OP_NOT) {
@@ -1493,7 +1497,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
       vmcase(OP_EQ) {
       vmcase(OP_EQ) {
         int cond;
         int cond;
         TValue *rb = vRB(i);
         TValue *rb = vRB(i);
-        Protect(cond = luaV_equalobj(L, s2v(ra), rb));
+        ProtectNT(cond = luaV_equalobj(L, s2v(ra), rb));
         docondjump();
         docondjump();
         vmbreak;
         vmbreak;
       }
       }

+ 17 - 0
testes/api.lua

@@ -241,6 +241,23 @@ assert(a == 20 and b == false)
 a,b = T.testC("compare LE 5 -6, return 2", a1, 2, 2, a1, 2, 20)
 a,b = T.testC("compare LE 5 -6, return 2", a1, 2, 2, a1, 2, 20)
 assert(a == 20 and b == true)
 assert(a == 20 and b == true)
 
 
+
+do  -- testing lessthan and lessequal with metamethods
+  local mt = {__lt = function (a,b) return a[1] < b[1] end,
+              __le = function (a,b) return a[1] <= b[1] end,
+              __eq = function (a,b) return a[1] == b[1] end}
+  local function O (x)
+    return setmetatable({x}, mt)
+  end
+
+  local a, b = T.testC("compare LT 2 3; pushint 10; return 2", O(1), O(2))
+  assert(a == true and b == 10)
+  local a, b = T.testC("compare LE 2 3; pushint 10; return 2", O(3), O(2))
+  assert(a == false and b == 10)
+  local a, b = T.testC("compare EQ 2 3; pushint 10; return 2", O(3), O(3))
+  assert(a == true and b == 10)
+end
+
 -- testing length
 -- testing length
 local t = setmetatable({x = 20}, {__len = function (t) return t.x end})
 local t = setmetatable({x = 20}, {__len = function (t) return t.x end})
 a,b,c = T.testC([[
 a,b,c = T.testC([[

+ 1 - 1
testes/coroutine.lua

@@ -809,7 +809,7 @@ assert(run(function ()
 -- tests for coroutine API
 -- tests for coroutine API
 if T==nil then
 if T==nil then
   (Message or print)('\n >>> testC not active: skipping coroutine API tests <<<\n')
   (Message or print)('\n >>> testC not active: skipping coroutine API tests <<<\n')
-  return
+  print "OK"; return
 end
 end
 
 
 print('testing coroutine API')
 print('testing coroutine API')

+ 13 - 2
testes/events.lua

@@ -217,9 +217,16 @@ t.__le = function (a,b,c)
  return a<=b, "dummy"
  return a<=b, "dummy"
 end
 end
 
 
+t.__eq = function (a,b,c)
+  assert(c == nil)
+  if type(a) == 'table' then a = a.x end
+  if type(b) == 'table' then b = b.x end
+ return a == b, "dummy"
+end
+
 function Op(x) return setmetatable({x=x}, t) end
 function Op(x) return setmetatable({x=x}, t) end
 
 
-local function test ()
+local function test (a, b, c)
   assert(not(Op(1)<Op(1)) and (Op(1)<Op(2)) and not(Op(2)<Op(1)))
   assert(not(Op(1)<Op(1)) and (Op(1)<Op(2)) and not(Op(2)<Op(1)))
   assert(not(1 < Op(1)) and (Op(1) < 2) and not(2 < Op(1)))
   assert(not(1 < Op(1)) and (Op(1) < 2) and not(2 < Op(1)))
   assert(not(Op('a')<Op('a')) and (Op('a')<Op('b')) and not(Op('b')<Op('a')))
   assert(not(Op('a')<Op('a')) and (Op('a')<Op('b')) and not(Op('b')<Op('a')))
@@ -232,9 +239,13 @@ local function test ()
   assert((1 >= Op(1)) and not(1 >= Op(2)) and (Op(2) >= 1))
   assert((1 >= Op(1)) and not(1 >= Op(2)) and (Op(2) >= 1))
   assert((Op('a')>=Op('a')) and not(Op('a')>=Op('b')) and (Op('b')>=Op('a')))
   assert((Op('a')>=Op('a')) and not(Op('a')>=Op('b')) and (Op('b')>=Op('a')))
   assert(('a' >= Op('a')) and not(Op('a') >= 'b') and (Op('b') >= Op('a')))
   assert(('a' >= Op('a')) and not(Op('a') >= 'b') and (Op('b') >= Op('a')))
+  assert(Op(1) == Op(1) and Op(1) ~= Op(2))
+  assert(Op('a') == Op('a') and Op('a') ~= Op('b'))
+  assert(a == a and a ~= b)
+  assert(Op(3) == c)
 end
 end
 
 
-test()
+test(Op(1), Op(2), Op(3))
 
 
 
 
 -- test `partial order'
 -- test `partial order'

+ 5 - 2
testes/strings.lua

@@ -167,8 +167,11 @@ do  -- tests for '%p' format
     local t1 = {}; local t2 = {}
     local t1 = {}; local t2 = {}
     assert(string.format("%p", t1) ~= string.format("%p", t2))
     assert(string.format("%p", t1) ~= string.format("%p", t2))
   end
   end
-  assert(string.format("%p", string.rep("a", 10)) ==
-         string.format("%p", string.rep("a", 10)))     -- short strings
+  do     -- short strings
+    local s1 = string.rep("a", 10)
+    local s2 = string.rep("a", 10)
+  assert(string.format("%p", s1) == string.format("%p", s2))
+  end
   do     -- long strings
   do     -- long strings
     local s1 = string.rep("a", 300); local s2 = string.rep("a", 300)
     local s1 = string.rep("a", 300); local s2 = string.rep("a", 300)
     assert(string.format("%p", s1) ~= string.format("%p", s2))
     assert(string.format("%p", s1) ~= string.format("%p", s2))