Browse Source

Handles '__close' errors in coroutines in "coroutine style"

Errors in '__close' metamethods in coroutines are handled by the same
logic that handles other errors, through 'recover'.
Roberto Ierusalimschy 4 years ago
parent
commit
ce101dcaf7
2 changed files with 85 additions and 22 deletions
  1. 46 20
      ldo.c
  2. 39 2
      testes/coroutine.lua

+ 46 - 20
ldo.c

@@ -103,7 +103,7 @@ void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop) {
       break;
     }
     default: {
-      lua_assert(errcode >= LUA_ERRRUN);  /* real error */
+      lua_assert(errorstatus(errcode));  /* real error */
       setobjs2s(L, oldtop, L->top - 1);  /* error message on current top */
       break;
     }
@@ -593,15 +593,11 @@ static void finishCcall (lua_State *L, int status) {
 /*
 ** Executes "full continuation" (everything in the stack) of a
 ** previously interrupted coroutine until the stack is empty (or another
-** interruption long-jumps out of the loop). If the coroutine is
-** recovering from an error, 'ud' points to the error status, which must
-** be passed to the first continuation function (otherwise the default
-** status is LUA_YIELD).
+** interruption long-jumps out of the loop).
 */
 static void unroll (lua_State *L, void *ud) {
   CallInfo *ci;
-  if (ud != NULL)  /* error status? */
-    finishCcall(L, *(int *)ud);  /* finish 'lua_pcallk' callee */
+  UNUSED(ud);
   while ((ci = L->ci) != &L->base_ci) {  /* something in the stack */
     if (!isLua(ci))  /* C function? */
       finishCcall(L, LUA_YIELD);  /* complete its execution */
@@ -628,21 +624,36 @@ static CallInfo *findpcall (lua_State *L) {
 
 
 /*
-** Recovers from an error in a coroutine. Finds a recover point (if
-** there is one) and completes the execution of the interrupted
-** 'luaD_pcall'. If there is no recover point, returns zero.
+** Auxiliary structure to call 'recover' in protected mode.
 */
-static int recover (lua_State *L, int status) {
-  CallInfo *ci = findpcall(L);
-  if (ci == NULL) return 0;  /* no recovery point */
+struct RecoverS {
+  int status;
+  CallInfo *ci;
+};
+
+
+/*
+** Recovers from an error in a coroutine: completes the execution of the
+** interrupted 'luaD_pcall', completes the interrupted C function which
+** called 'lua_pcallk', and continues running the coroutine. If there is
+** an error in 'luaF_close', this function will be called again and the
+** coroutine will continue from where it left.
+*/
+static void recover (lua_State *L, void *ud) {
+  struct RecoverS *r = cast(struct RecoverS *, ud);
+  int status = r->status;
+  CallInfo *ci = r->ci;  /* recover point */
+  StkId func = restorestack(L, ci->u2.funcidx);
   /* "finish" luaD_pcall */
   L->ci = ci;
   L->allowhook = getoah(ci->callstatus);  /* restore original 'allowhook' */
-  status = luaD_closeprotected(L, ci->u2.funcidx, status);
-  luaD_seterrorobj(L, status, restorestack(L, ci->u2.funcidx));
+  luaF_close(L, func, status);  /* may change the stack */
+  func = restorestack(L, ci->u2.funcidx);
+  luaD_seterrorobj(L, status, func);
   luaD_shrinkstack(L);   /* restore stack size in case of overflow */
   L->errfunc = ci->u.c.old_errfunc;
-  return 1;  /* continue running the coroutine */
+  finishCcall(L, status);  /* finish 'lua_pcallk' callee */
+  unroll(L, NULL);  /* continue running the coroutine */
 }
 
 
@@ -692,6 +703,24 @@ static void resume (lua_State *L, void *ud) {
   }
 }
 
+
+/*
+** Calls 'recover' in protected mode, repeating while there are
+** recoverable errors, that is, errors inside a protected call. (Any
+** error interrupts 'recover', and this loop protects it again so it
+** can continue.) Stops with a normal end (status == LUA_OK), an yield
+** (status == LUA_YIELD), or an unprotected error ('findpcall' doesn't
+** find a recover point).
+*/
+static int p_recover (lua_State *L, int status) {
+  struct RecoverS r;
+  r.status = status;
+  while (errorstatus(status) && (r.ci = findpcall(L)) != NULL)
+    r.status = luaD_rawrunprotected(L, recover, &r);
+  return r.status;
+}
+
+
 LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs,
                                       int *nresults) {
   int status;
@@ -709,10 +738,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs,
   api_checknelems(L, (L->status == LUA_OK) ? nargs + 1 : nargs);
   status = luaD_rawrunprotected(L, resume, &nargs);
    /* continue running after recoverable errors */
-  while (errorstatus(status) && recover(L, status)) {
-    /* unroll continuation */
-    status = luaD_rawrunprotected(L, unroll, &status);
-  }
+  status = p_recover(L, status);
   if (likely(!errorstatus(status)))
     lua_assert(status == L->status);  /* normal end or yield */
   else {  /* unrecoverable error */

+ 39 - 2
testes/coroutine.lua

@@ -123,7 +123,7 @@ assert(#a == 22 and a[#a] == 79)
 x, a = nil
 
 
--- coroutine closing
+print("to-be-closed variables in coroutines")
 
 local function func2close (f)
   return setmetatable({}, {__close = f})
@@ -189,7 +189,6 @@ do
   local st, msg = coroutine.close(co)
   assert(st == false and coroutine.status(co) == "dead" and msg == 200)
   assert(x == 200)
-
 end
 
 do
@@ -207,6 +206,44 @@ do
   local st1, st2, err = coroutine.resume(co)
   assert(st1 and not st2 and err == 43)
   assert(X == 43 and Y.name == "pcall")
+
+  -- recovering from errors in __close metamethods
+  local track = {}
+
+  local function h (o)
+    local hv <close> = o
+    return 1
+  end
+
+  local function foo ()
+    local x <close> = func2close(function(_,msg)
+      track[#track + 1] = msg or false
+      error(20)
+    end)
+    local y <close> = func2close(function(_,msg)
+      track[#track + 1] = msg or false
+      return 1000
+    end)
+    local z <close> = func2close(function(_,msg)
+      track[#track + 1] = msg or false
+      error(10)
+    end)
+    coroutine.yield(1)
+    h(func2close(function(_,msg)
+        track[#track + 1] = msg or false
+        error(2)
+      end))
+  end
+
+  local co = coroutine.create(pcall)
+
+  local st, res = coroutine.resume(co, foo)    -- call 'foo' protected
+  assert(st and res == 1)   -- yield 1
+  local st, res1, res2 = coroutine.resume(co)   -- continue
+  assert(coroutine.status(co) == "dead")
+  assert(st and not res1 and res2 == 20)   -- last error (20)
+  assert(track[1] == false and track[2] == 2 and track[3] == 10 and
+         track[4] == 10)
 end