Browse Source

'__close' methods can yield in the return of a C function

When, inside a coroutine, a C function with to-be-closed slots return,
the corresponding metamethods can yield. ('__close' metamethods called
through 'lua_closeslot' still cannot yield, as there is no continuation
to go when resuming.)
Roberto Ierusalimschy 4 years ago
parent
commit
bc970005ce
5 changed files with 131 additions and 34 deletions
  1. 2 0
      lapi.h
  2. 43 29
      ldo.c
  3. 8 4
      lstate.h
  4. 3 0
      manual/manual.of
  5. 75 1
      testes/locals.lua

+ 2 - 0
lapi.h

@@ -42,6 +42,8 @@
 
 
 #define hastocloseCfunc(n)	((n) < LUA_MULTRET)
 #define hastocloseCfunc(n)	((n) < LUA_MULTRET)
 
 
+/* Map [-1, inf) (range of 'nresults') into (-inf, -2] */
 #define codeNresults(n)		(-(n) - 3)
 #define codeNresults(n)		(-(n) - 3)
+#define decodeNresults(n)	(-(n) - 3)
 
 
 #endif
 #endif

+ 43 - 29
ldo.c

@@ -408,24 +408,27 @@ static void moveresults (lua_State *L, StkId res, int nres, int wanted) {
     case LUA_MULTRET:
     case LUA_MULTRET:
       wanted = nres;  /* we want all results */
       wanted = nres;  /* we want all results */
       break;
       break;
-    default:  /* multiple results (or to-be-closed variables) */
+    default:  /* two/more results and/or to-be-closed variables */
       if (hastocloseCfunc(wanted)) {  /* to-be-closed variables? */
       if (hastocloseCfunc(wanted)) {  /* to-be-closed variables? */
         ptrdiff_t savedres = savestack(L, res);
         ptrdiff_t savedres = savestack(L, res);
-        luaF_close(L, res, CLOSEKTOP, 0);  /* may change the stack */
-        wanted = codeNresults(wanted);  /* correct value */
-        if (wanted == LUA_MULTRET)
-          wanted = nres;
+        L->ci->callstatus |= CIST_CLSRET;  /* in case of yields */
+        L->ci->u2.nres = nres;
+        luaF_close(L, res, CLOSEKTOP, 1);
+        L->ci->callstatus &= ~CIST_CLSRET;
         if (L->hookmask)  /* if needed, call hook after '__close's */
         if (L->hookmask)  /* if needed, call hook after '__close's */
           rethook(L, L->ci, nres);
           rethook(L, L->ci, nres);
         res = restorestack(L, savedres);  /* close and hook can move stack */
         res = restorestack(L, savedres);  /* close and hook can move stack */
+        wanted = decodeNresults(wanted);
+        if (wanted == LUA_MULTRET)
+          wanted = nres;  /* we want all results */
       }
       }
       break;
       break;
   }
   }
+  /* generic case */
   firstresult = L->top - nres;  /* index of first result */
   firstresult = L->top - nres;  /* index of first result */
-  /* move all results to correct place */
-  if (nres > wanted)
-    nres = wanted;  /* don't need more than that */
-  for (i = 0; i < nres; i++)
+  if (nres > wanted)  /* extra results? */
+    nres = wanted;  /* don't need them */
+  for (i = 0; i < nres; i++)  /* move all results to correct place */
     setobjs2s(L, res + i, firstresult + i);
     setobjs2s(L, res + i, firstresult + i);
   for (; i < wanted; i++)  /* complete wanted number of results */
   for (; i < wanted; i++)  /* complete wanted number of results */
     setnilvalue(s2v(res + i));
     setnilvalue(s2v(res + i));
@@ -445,6 +448,9 @@ void luaD_poscall (lua_State *L, CallInfo *ci, int nres) {
     rethook(L, ci, nres);
     rethook(L, ci, nres);
   /* move results to proper place */
   /* move results to proper place */
   moveresults(L, ci->func, nres, wanted);
   moveresults(L, ci->func, nres, wanted);
+  /* function cannot be in any of these cases when returning */
+  lua_assert(!(ci->callstatus &
+        (CIST_HOOKED | CIST_YPCALL | CIST_FIN | CIST_TRAN | CIST_CLSRET)));
   L->ci = ci->previous;  /* back to caller (after closing variables) */
   L->ci = ci->previous;  /* back to caller (after closing variables) */
 }
 }
 
 
@@ -615,28 +621,36 @@ static int finishpcallk (lua_State *L,  CallInfo *ci) {
 
 
 /*
 /*
 ** Completes the execution of a C function interrupted by an yield.
 ** Completes the execution of a C function interrupted by an yield.
-** The interruption must have happened while the function was
-** executing 'lua_callk' or 'lua_pcallk'. In the second case, the
-** call to 'finishpcallk' finishes the interrupted execution of
-** 'lua_pcallk'. After that, it calls the continuation of the
-** interrupted function and finally it completes the job of the
-** 'luaD_call' that called the function.
-** In the call to 'adjustresults', we do not know the number of
-** results of the function called by 'lua_callk'/'lua_pcallk',
-** so we are conservative and use LUA_MULTRET (always adjust).
+** The interruption must have happened while the function was either
+** closing its tbc variables in 'moveresults' or executing
+** 'lua_callk'/'lua_pcallk'. In the first case, it just redoes
+** 'luaD_poscall'. In the second case, the call to 'finishpcallk'
+** finishes the interrupted execution of 'lua_pcallk'.  After that, it
+** calls the continuation of the interrupted function and finally it
+** completes the job of the 'luaD_call' that called the function.  In
+** the call to 'adjustresults', we do not know the number of results
+** of the function called by 'lua_callk'/'lua_pcallk', so we are
+** conservative and use LUA_MULTRET (always adjust).
 */
 */
 static void finishCcall (lua_State *L, CallInfo *ci) {
 static void finishCcall (lua_State *L, CallInfo *ci) {
-  int n;
-  int status = LUA_YIELD;  /* default if there were no errors */
-  /* must have a continuation and must be able to call it */
-  lua_assert(ci->u.c.k != NULL && yieldable(L));
-  if (ci->callstatus & CIST_YPCALL)   /* was inside a 'lua_pcallk'? */
-    status = finishpcallk(L, ci);  /* finish it */
-  adjustresults(L, LUA_MULTRET);  /* finish 'lua_callk' */
-  lua_unlock(L);
-  n = (*ci->u.c.k)(L, status, ci->u.c.ctx);  /* call continuation */
-  lua_lock(L);
-  api_checknelems(L, n);
+  int n;  /* actual number of results from C function */
+  if (ci->callstatus & CIST_CLSRET) {  /* was returning? */
+    lua_assert(hastocloseCfunc(ci->nresults));
+    n = ci->u2.nres;  /* just redo 'luaD_poscall' */
+    /* don't need to reset CIST_CLSRET, as it will be set again anyway */
+  }
+  else {
+    int status = LUA_YIELD;  /* default if there were no errors */
+    /* must have a continuation and must be able to call it */
+    lua_assert(ci->u.c.k != NULL && yieldable(L));
+    if (ci->callstatus & CIST_YPCALL)   /* was inside a 'lua_pcallk'? */
+      status = finishpcallk(L, ci);  /* finish it */
+    adjustresults(L, LUA_MULTRET);  /* finish 'lua_callk' */
+    lua_unlock(L);
+    n = (*ci->u.c.k)(L, status, ci->u.c.ctx);  /* call continuation */
+    lua_lock(L);
+    api_checknelems(L, n);
+  }
   luaD_poscall(L, ci, n);  /* finish 'luaD_call' */
   luaD_poscall(L, ci, n);  /* finish 'luaD_call' */
 }
 }
 
 

+ 8 - 4
lstate.h

@@ -164,6 +164,8 @@ typedef struct stringtable {
 ** protected call;
 ** protected call;
 ** - field 'nyield' is used only while a function is "doing" an
 ** - field 'nyield' is used only while a function is "doing" an
 ** yield (from the yield until the next resume);
 ** yield (from the yield until the next resume);
+** - field 'nres' is used only while closing tbc variables when
+** returning from a C function;
 ** - field 'transferinfo' is used only during call/returnhooks,
 ** - field 'transferinfo' is used only during call/returnhooks,
 ** before the function starts or after it ends.
 ** before the function starts or after it ends.
 */
 */
@@ -186,6 +188,7 @@ typedef struct CallInfo {
   union {
   union {
     int funcidx;  /* called-function index */
     int funcidx;  /* called-function index */
     int nyield;  /* number of values yielded */
     int nyield;  /* number of values yielded */
+    int nres;  /* number of values returned */
     struct {  /* info about transferred values (for call/return hooks) */
     struct {  /* info about transferred values (for call/return hooks) */
       unsigned short ftransfer;  /* offset of first value transferred */
       unsigned short ftransfer;  /* offset of first value transferred */
       unsigned short ntransfer;  /* number of values transferred */
       unsigned short ntransfer;  /* number of values transferred */
@@ -203,15 +206,16 @@ typedef struct CallInfo {
 #define CIST_C		(1<<1)	/* call is running a C function */
 #define CIST_C		(1<<1)	/* call is running a C function */
 #define CIST_FRESH	(1<<2)	/* call is on a fresh "luaV_execute" frame */
 #define CIST_FRESH	(1<<2)	/* call is on a fresh "luaV_execute" frame */
 #define CIST_HOOKED	(1<<3)	/* call is running a debug hook */
 #define CIST_HOOKED	(1<<3)	/* call is running a debug hook */
-#define CIST_YPCALL	(1<<4)	/* call is a yieldable protected call */
+#define CIST_YPCALL	(1<<4)	/* doing a yieldable protected call */
 #define CIST_TAIL	(1<<5)	/* call was tail called */
 #define CIST_TAIL	(1<<5)	/* call was tail called */
 #define CIST_HOOKYIELD	(1<<6)	/* last hook called yielded */
 #define CIST_HOOKYIELD	(1<<6)	/* last hook called yielded */
 #define CIST_FIN	(1<<7)	/* call is running a finalizer */
 #define CIST_FIN	(1<<7)	/* call is running a finalizer */
 #define CIST_TRAN	(1<<8)	/* 'ci' has transfer information */
 #define CIST_TRAN	(1<<8)	/* 'ci' has transfer information */
-/* Bits 9-11 are used for CIST_RECST (see below) */
-#define CIST_RECST	9
+#define CIST_CLSRET	(1<<9)  /* function is closing tbc variables */
+/* Bits 10-12 are used for CIST_RECST (see below) */
+#define CIST_RECST	10
 #if defined(LUA_COMPAT_LT_LE)
 #if defined(LUA_COMPAT_LT_LE)
-#define CIST_LEQ	(1<<12)  /* using __lt for __le */
+#define CIST_LEQ	(1<<13)  /* using __lt for __le */
 #endif
 #endif
 
 
 
 

+ 3 - 0
manual/manual.of

@@ -3102,6 +3102,9 @@ Close the to-be-closed slot at the given index and set its value to @nil.
 The index must be the last index previously marked to be closed
 The index must be the last index previously marked to be closed
 @see{lua_toclose} that is still active (that is, not closed yet).
 @see{lua_toclose} that is still active (that is, not closed yet).
 
 
+A @Lid{__close} metamethod cannot yield
+when called through this function.
+
 (Exceptionally, this function was introduced in release 5.4.3.
 (Exceptionally, this function was introduced in release 5.4.3.
 It is not present in previous 5.4 releases.)
 It is not present in previous 5.4 releases.)
 
 

+ 75 - 1
testes/locals.lua

@@ -707,7 +707,6 @@ if rawget(_G, "T") then
     -- results are correct
     -- results are correct
     checktable(t, {10, 20})
     checktable(t, {10, 20})
   end
   end
-
 end
 end
 
 
 
 
@@ -930,6 +929,81 @@ assert(co == nil)    -- eventually it will be collected
 collectgarbage()
 collectgarbage()
 
 
 
 
+if rawget(_G, "T") then
+  print("to-be-closed variables x coroutines in C")
+  do
+    local token = 0
+    local count = 0
+    local f = T.makeCfunc[[
+      toclose 1
+      toclose 2
+      return .
+    ]]
+
+    local obj = func2close(function (_, msg)
+      count = count + 1
+      token = coroutine.yield(count, token)
+    end)
+
+    local co = coroutine.wrap(f)
+    local ct, res = co(obj, obj, 10, 20, 30, 3)   -- will return 10, 20, 30
+    -- initial token value, after closing 2nd obj
+    assert(ct == 1 and res == 0)
+    -- run until yield when closing 1st obj
+    ct, res = co(100)
+    assert(ct == 2 and res == 100)
+    res = {co(200)}      -- run until end
+    assert(res[1] == 10 and res[2] == 20 and res[3] == 30 and res[4] == nil)
+    assert(token == 200)
+  end
+
+  do
+    local f = T.makeCfunc[[
+      toclose 1
+      return .
+    ]]
+
+    local obj = func2close(function ()
+      local temp
+      local x <close> = func2close(function ()
+        coroutine.yield(temp)
+        return 1,2,3    -- to be ignored
+      end)
+      temp = coroutine.yield("closing obj")
+      return 1,2,3    -- to be ignored
+    end)
+
+    local co = coroutine.wrap(f)
+    local res = co(obj, 10, 30, 1)   -- will return only 30
+    assert(res == "closing obj")
+    res = co("closing x")
+    assert(res == "closing x")
+    res = {co()}
+    assert(res[1] == 30 and res[2] == nil)
+  end
+
+  do
+    -- still cannot yield inside 'closeslot'
+    local f = T.makeCfunc[[
+      toclose 1
+      closeslot 1
+    ]]
+    local obj = func2close(coroutine.yield)
+    local co = coroutine.create(f)
+    local st, msg = coroutine.resume(co, obj)
+    assert(not st and string.find(msg, "attempt to yield across"))
+
+    -- nor outside a coroutine
+    local f = T.makeCfunc[[
+      toclose 1
+    ]]
+    local st, msg = pcall(f, obj)
+    assert(not st and string.find(msg, "attempt to yield from outside"))
+  end
+end
+
+
+
 -- to-be-closed variables in generic for loops
 -- to-be-closed variables in generic for loops
 do
 do
   local numopen = 0
   local numopen = 0