Browse Source

First "complete" implementation of to-be-closed variables

Still missing:
- handling of memory errors when creating upvalue (must run closing
method all the same)
- interaction with coroutines
Roberto Ierusalimschy 6 năm trước cách đây
mục cha
commit
bd96330d03
13 tập tin đã thay đổi với 145 bổ sung26 xóa
  1. 13 7
      ldo.c
  2. 1 0
      ldo.h
  3. 44 3
      lfunc.c
  4. 1 1
      lfunc.h
  5. 5 0
      lgc.c
  6. 1 1
      lparser.c
  7. 2 2
      lstate.c
  8. 1 1
      ltests.c
  9. 1 1
      ltm.c
  10. 1 0
      ltm.h
  11. 3 4
      lvm.c
  12. 13 1
      testes/api.lua
  13. 59 5
      testes/locals.lua

+ 13 - 7
ldo.c

@@ -88,7 +88,7 @@ struct lua_longjmp {
 };
 
 
-static void seterrorobj (lua_State *L, int errcode, StkId oldtop) {
+void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop) {
   switch (errcode) {
     case LUA_ERRMEM: {  /* memory error? */
       setsvalue2s(L, oldtop, G(L)->memerrmsg); /* reuse preregistered msg. */
@@ -121,7 +121,7 @@ l_noret luaD_throw (lua_State *L, int errcode) {
     }
     else {  /* no handler at all; abort */
       if (g->panic) {  /* panic function? */
-        seterrorobj(L, errcode, L->top);  /* assume EXTRA_STACK */
+        luaD_seterrorobj(L, errcode, L->top);  /* assume EXTRA_STACK */
         if (L->ci->top < L->top)
           L->ci->top = L->top;  /* pushing msg. can break this invariant */
         lua_unlock(L);
@@ -584,8 +584,8 @@ static int recover (lua_State *L, int status) {
   if (ci == NULL) return 0;  /* no recovery point */
   /* "finish" luaD_pcall */
   oldtop = restorestack(L, ci->u2.funcidx);
-  luaF_close(L, oldtop);
-  seterrorobj(L, status, oldtop);
+  luaF_close(L, oldtop, status);
+  luaD_seterrorobj(L, status, oldtop);
   L->ci = ci;
   L->allowhook = getoah(ci->callstatus);  /* restore original 'allowhook' */
   L->nny = 0;  /* should be zero to be yieldable */
@@ -678,7 +678,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs,
     }
     if (unlikely(errorstatus(status))) {  /* unrecoverable error? */
       L->status = cast_byte(status);  /* mark thread as 'dead' */
-      seterrorobj(L, status, L->top);  /* push error message */
+      luaD_seterrorobj(L, status, L->top);  /* push error message */
       L->ci->top = L->top;
     }
     else lua_assert(status == L->status);  /* normal end or yield */
@@ -726,6 +726,11 @@ LUA_API int lua_yieldk (lua_State *L, int nresults, lua_KContext ctx,
 }
 
 
+/*
+** Call the C function 'func' in protected mode, restoring basic
+** thread information ('allowhook', 'nny', etc.) and in particular
+** its stack level in case of errors.
+*/
 int luaD_pcall (lua_State *L, Pfunc func, void *u,
                 ptrdiff_t old_top, ptrdiff_t ef) {
   int status;
@@ -737,11 +742,12 @@ int luaD_pcall (lua_State *L, Pfunc func, void *u,
   status = luaD_rawrunprotected(L, func, u);
   if (unlikely(status != LUA_OK)) {  /* an error occurred? */
     StkId oldtop = restorestack(L, old_top);
-    luaF_close(L, oldtop);  /* close possible pending closures */
-    seterrorobj(L, status, oldtop);
     L->ci = old_ci;
     L->allowhook = old_allowhooks;
     L->nny = old_nny;
+    status = luaF_close(L, oldtop, status);
+    oldtop = restorestack(L, old_top);  /* previous call may change stack */
+    luaD_seterrorobj(L, status, oldtop);
     luaD_shrinkstack(L);
   }
   L->errfunc = old_errfunc;

+ 1 - 0
ldo.h

@@ -50,6 +50,7 @@
 /* type of protected functions, to be ran by 'runprotected' */
 typedef void (*Pfunc) (lua_State *L, void *ud);
 
+LUAI_FUNC void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop);
 LUAI_FUNC int luaD_protectedparser (lua_State *L, ZIO *z, const char *name,
                                                   const char *mode);
 LUAI_FUNC void luaD_hook (lua_State *L, int event, int line,

+ 44 - 3
lfunc.c

@@ -14,6 +14,7 @@
 
 #include "lua.h"
 
+#include "ldo.h"
 #include "lfunc.h"
 #include "lgc.h"
 #include "lmem.h"
@@ -83,6 +84,40 @@ UpVal *luaF_findupval (lua_State *L, StkId level) {
 }
 
 
+static void callclose (lua_State *L, void *ud) {
+  luaD_callnoyield(L, cast(StkId, ud), 0);
+}
+
+
+static int closeupval (lua_State *L, UpVal *uv, StkId level, int status) {
+  StkId func = level + 1;  /* save slot for old error message */
+  if (status != LUA_OK)  /* was there an error? */
+    luaD_seterrorobj(L, status, level);  /* save error message */
+  else
+    setnilvalue(s2v(level));
+  if (ttisfunction(uv->v)) {  /* object to-be-closed is a function? */
+    setobj2s(L, func, uv->v);  /* will call it */
+    setobjs2s(L, func + 1, level);  /* error msg. as argument */
+  }
+  else {  /* try '__close' metamethod */
+    const TValue *tm = luaT_gettmbyobj(L, uv->v, TM_CLOSE);
+    if (ttisnil(tm))
+      return status;  /* no metamethod */
+    setobj2s(L, func, tm);  /* will call metamethod */
+    setobj2s(L, func + 1, uv->v);  /* with 'self' as argument */
+  }
+  L->top = func + 2;  /* add function and argument */
+  if (status == LUA_OK)  /* not in "error mode"? */
+    callclose(L, func);  /* call closing method */
+  else {  /* already inside error handler; cannot raise another error */
+    int newstatus = luaD_pcall(L, callclose, func, savestack(L, level), 0);
+    if (newstatus != LUA_OK)  /* error when closing? */
+      status = newstatus;  /* this will be the new error */
+  }
+  return status;
+}
+
+
 void luaF_unlinkupval (UpVal *uv) {
   lua_assert(upisopen(uv));
   *uv->u.open.previous = uv->u.open.next;
@@ -91,10 +126,10 @@ void luaF_unlinkupval (UpVal *uv) {
 }
 
 
-void luaF_close (lua_State *L, StkId level) {
+int luaF_close (lua_State *L, StkId level, int status) {
   UpVal *uv;
-  while (L->openupval != NULL &&
-        (uv = L->openupval, uplevel(uv) >= level)) {
+  while ((uv = L->openupval) != NULL && uplevel(uv) >= level) {
+    StkId upl = uplevel(uv);
     TValue *slot = &uv->u.value;  /* new position for value */
     luaF_unlinkupval(uv);
     setobj(L, slot, uv->v);  /* move value to upvalue slot */
@@ -102,7 +137,13 @@ void luaF_close (lua_State *L, StkId level) {
     if (!iswhite(uv))
       gray2black(uv);  /* closed upvalues cannot be gray */
     luaC_barrier(L, uv, slot);
+    if (status >= 0 && uv->tt == LUA_TUPVALTBC) {  /* must be closed? */
+      ptrdiff_t levelrel = savestack(L, level);
+      status = closeupval(L, uv, upl, status);  /* may reallocate the stack */
+      level = restorestack(L, levelrel);
+    }
   }
+  return status;
 }
 
 

+ 1 - 1
lfunc.h

@@ -47,7 +47,7 @@ LUAI_FUNC CClosure *luaF_newCclosure (lua_State *L, int nelems);
 LUAI_FUNC LClosure *luaF_newLclosure (lua_State *L, int nelems);
 LUAI_FUNC void luaF_initupvals (lua_State *L, LClosure *cl);
 LUAI_FUNC UpVal *luaF_findupval (lua_State *L, StkId level);
-LUAI_FUNC void luaF_close (lua_State *L, StkId level);
+LUAI_FUNC int luaF_close (lua_State *L, StkId level, int status);
 LUAI_FUNC void luaF_unlinkupval (UpVal *uv);
 LUAI_FUNC void luaF_freeproto (lua_State *L, Proto *f);
 LUAI_FUNC const char *luaF_getlocalname (const Proto *func, int local_number,

+ 5 - 0
lgc.c

@@ -609,6 +609,7 @@ static int traverseLclosure (global_State *g, LClosure *cl) {
 ** That ensures that the entire stack have valid (non-dead) objects.
 */
 static int traversethread (global_State *g, lua_State *th) {
+  UpVal *uv;
   StkId o = th->stack;
   if (o == NULL)
     return 1;  /* stack not completely built yet */
@@ -616,6 +617,10 @@ static int traversethread (global_State *g, lua_State *th) {
              th->openupval == NULL || isintwups(th));
   for (; o < th->top; o++)  /* mark live elements in the stack */
     markvalue(g, s2v(o));
+  for (uv = th->openupval; uv != NULL; uv = uv->u.open.next) {
+    if (uv->tt == LUA_TUPVALTBC)  /* to be closed? */
+      markobject(g, uv);  /* cannot be collected */
+  }
   if (g->gcstate == GCSatomic) {  /* final traversal? */
     StkId lim = th->stack + th->stacksize;  /* real end of stack */
     for (; o < lim; o++)  /* clear not-marked stack slice */

+ 1 - 1
lparser.c

@@ -1536,9 +1536,9 @@ static void scopedlocalstat (LexState *ls) {
   FuncState *fs = ls->fs;
   new_localvar(ls, str_checkname(ls));
   checknext(ls, '=');
+  exp1(ls, 0);
   luaK_codeABC(fs, OP_TBC, fs->nactvar, 0, 0);
   markupval(fs, fs->nactvar);
-  exp1(ls, 0);
   adjustlocalvars(ls, 1);
 }
 

+ 2 - 2
lstate.c

@@ -258,7 +258,7 @@ static void preinit_thread (lua_State *L, global_State *g) {
 
 static void close_state (lua_State *L) {
   global_State *g = G(L);
-  luaF_close(L, L->stack);  /* close all upvalues for this thread */
+  luaF_close(L, L->stack, -1);  /* close all upvalues for this thread */
   luaC_freeallobjects(L);  /* collect all objects */
   if (ttisnil(&g->nilvalue))  /* closing a fully built state? */
     luai_userstateclose(L);
@@ -301,7 +301,7 @@ LUA_API lua_State *lua_newthread (lua_State *L) {
 
 void luaE_freethread (lua_State *L, lua_State *L1) {
   LX *l = fromstate(L1);
-  luaF_close(L1, L1->stack);  /* close all upvalues for this thread */
+  luaF_close(L1, L1->stack, -1);  /* close all upvalues for this thread */
   lua_assert(L1->openupval == NULL);
   luai_userstatefree(L, L1);
   freestack(L1);

+ 1 - 1
ltests.c

@@ -1208,7 +1208,7 @@ static int getindex_aux (lua_State *L, lua_State *L1, const char **pc) {
 
 static void pushcode (lua_State *L, int code) {
   static const char *const codes[] = {"OK", "YIELD", "ERRRUN",
-                   "ERRSYNTAX", "ERRMEM", "ERRGCMM", "ERRERR"};
+                   "ERRSYNTAX", MEMERRMSG, "ERRGCMM", "ERRERR"};
   lua_pushstring(L, codes[code]);
 }
 

+ 1 - 1
ltm.c

@@ -43,7 +43,7 @@ void luaT_init (lua_State *L) {
     "__div", "__idiv",
     "__band", "__bor", "__bxor", "__shl", "__shr",
     "__unm", "__bnot", "__lt", "__le",
-    "__concat", "__call"
+    "__concat", "__call", "__close"
   };
   int i;
   for (i=0; i<TM_N; i++) {

+ 1 - 0
ltm.h

@@ -40,6 +40,7 @@ typedef enum {
   TM_LE,
   TM_CONCAT,
   TM_CALL,
+  TM_CLOSE,
   TM_N		/* number of elements in the enum */
 } TMS;
 

+ 3 - 4
lvm.c

@@ -1452,13 +1452,12 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
         vmbreak;
       }
       vmcase(OP_CLOSE) {
-        luaF_close(L, ra);
+        luaF_close(L, ra, LUA_OK);
         vmbreak;
       }
       vmcase(OP_TBC) {
         UpVal *up = luaF_findupval(L, ra);  /* create new upvalue */
         up->tt = LUA_TUPVALTBC;  /* mark it to be closed */
-        setnilvalue(s2v(ra));  /* intialize it with nil */
         vmbreak;
       }
       vmcase(OP_JMP) {
@@ -1591,7 +1590,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           int nparams1 = GETARG_C(i);
           if (nparams1)  /* vararg function? */
             delta = ci->u.l.nextraargs + nparams1;
-          luaF_close(L, base);  /* close upvalues from current call */
+          luaF_close(L, base, LUA_OK);  /* close upvalues from current call */
         }
         if (!ttisfunction(s2v(ra))) {  /* not a function? */
           luaD_tryfuncTM(L, ra);  /* try '__call' metamethod */
@@ -1625,7 +1624,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           int nparams1 = GETARG_C(i);
           if (nparams1)  /* vararg function? */
             ci->func -= ci->u.l.nextraargs + nparams1;
-          luaF_close(L, base);  /* there may be open upvalues */
+          luaF_close(L, base, LUA_OK);  /* there may be open upvalues */
         }
         halfProtect(luaD_poscall(L, ci, n));
         return;

+ 13 - 1
testes/api.lua

@@ -1,4 +1,4 @@
--- $Id: testes/api.lua $
+-- $Id: testes/api.lua 2018-07-25 15:31:04 -0300 $
 -- See Copyright Notice in file all.lua
 
 if T==nil then
@@ -1027,6 +1027,18 @@ testamem("coroutine creation", function()
 end)
 
 
+-- testing to-be-closed variables
+testamem("to-be-closed variables", function()
+  local flag
+  do
+    local scoped x = function () flag = true end
+    flag = false
+    local x = {}
+  end
+  return flag
+end)
+
+
 -- testing threads
 
 -- get main thread from registry (at index LUA_RIDX_MAINTHREAD == 1)

+ 59 - 5
testes/locals.lua

@@ -173,15 +173,69 @@ end
 assert(x==20)
 
 
--- tests for to-be-closed variables
+print"testing to-be-closed variables"
+
+do
+  local a = {}
+  do
+    local scoped x = setmetatable({"x"}, {__close = function (self)
+                                                   a[#a + 1] = self[1] end})
+    local scoped y = function () a[#a + 1] = "y" end
+    a[#a + 1] = "in"
+  end
+  a[#a + 1] = "out"
+  assert(a[1] == "in" and a[2] == "y" and a[3] == "x" and a[4] == "out")
+end
+
+
+do   -- errors in __close
+  local log = {}
+  local function foo (err)
+    local scoped x = function (msg) log[#log + 1] = msg; error(1) end
+    local scoped x1 = function (msg) log[#log + 1] = msg; end
+    local scoped gc = function () collectgarbage() end
+    local scoped y = function (msg) log[#log + 1] = msg; error(2) end
+    local scoped z = function (msg) log[#log + 1] = msg or 10; error(3) end
+    if err then error(4) end
+  end
+  local stat, msg = pcall(foo, false)
+  assert(msg == 1)
+  assert(log[1] == 10 and log[2] == 3 and log[3] == 2 and log[4] == 2
+         and #log == 4)
+
+  log = {}
+  local stat, msg = pcall(foo, true)
+  assert(msg == 1)
+  assert(log[1] == 4 and log[2] == 3 and log[3] == 2 and log[4] == 2
+         and #log == 4)
+end
+
 do
-  local scoped x = 3
-  local a
-  local scoped y = 5
-  assert(x == 3 and y == 5)
+  -- memory error inside closing function
+  local function foo ()
+    local scoped y = function () io.write(2); T.alloccount() end
+    local scoped x = setmetatable({}, {__close = function ()
+      T.alloccount(0); local x = {}   -- force a memory error
+    end})
+    io.write("1\n")
+    error("a")   -- common error inside the function's body
+  end
+
+  local _, msg = pcall(foo)
+T.alloccount()
+  assert(msg == "not enough memory")
+
 end
 
 
+-- a suspended coroutine should not close its variables when collected
+local co = coroutine.wrap(function()
+  local scoped x = function () os.exit(1) end    -- should not run
+   coroutine.yield()
+end)
+co()
+co = nil
+
 print('OK')
 
 return 5,f