浏览代码

A coroutine can close itself

A call to close itself will close all its to-be-closed variables and
return to the resume that (re)started the coroutine.
Roberto Ierusalimschy 1 月之前
父节点
当前提交
fd897027f1
共有 6 个文件被更改,包括 103 次插入21 次删除
  1. 11 2
      lcorolib.c
  2. 10 0
      ldo.c
  3. 1 0
      ldo.h
  4. 2 0
      lstate.c
  5. 26 10
      manual/manual.of
  6. 53 9
      testes/coroutine.lua

+ 11 - 2
lcorolib.c

@@ -154,8 +154,13 @@ static int luaB_costatus (lua_State *L) {
 }
 }
 
 
 
 
+static lua_State *getoptco (lua_State *L) {
+  return (lua_isnone(L, 1) ? L : getco(L));
+}
+
+
 static int luaB_yieldable (lua_State *L) {
 static int luaB_yieldable (lua_State *L) {
-  lua_State *co = lua_isnone(L, 1) ? L : getco(L);
+  lua_State *co = getoptco(L);
   lua_pushboolean(L, lua_isyieldable(co));
   lua_pushboolean(L, lua_isyieldable(co));
   return 1;
   return 1;
 }
 }
@@ -169,7 +174,7 @@ static int luaB_corunning (lua_State *L) {
 
 
 
 
 static int luaB_close (lua_State *L) {
 static int luaB_close (lua_State *L) {
-  lua_State *co = getco(L);
+  lua_State *co = getoptco(L);
   int status = auxstatus(L, co);
   int status = auxstatus(L, co);
   switch (status) {
   switch (status) {
     case COS_DEAD: case COS_YIELD: {
     case COS_DEAD: case COS_YIELD: {
@@ -184,6 +189,10 @@ static int luaB_close (lua_State *L) {
         return 2;
         return 2;
       }
       }
     }
     }
+    case COS_RUN:  /* running coroutine? */
+      lua_closethread(co, L);  /* close itself */
+      lua_assert(0);  /* previous call does not return */
+      return 0;
     default:  /* normal or running coroutine */
     default:  /* normal or running coroutine */
       return luaL_error(L, "cannot close a %s coroutine", statname[status]);
       return luaL_error(L, "cannot close a %s coroutine", statname[status]);
   }
   }

+ 10 - 0
ldo.c

@@ -139,6 +139,16 @@ l_noret luaD_throw (lua_State *L, TStatus errcode) {
 }
 }
 
 
 
 
+l_noret luaD_throwbaselevel (lua_State *L, TStatus errcode) {
+  if (L->errorJmp) {
+    /* unroll error entries up to the first level */
+    while (L->errorJmp->previous != NULL)
+      L->errorJmp = L->errorJmp->previous;
+  }
+  luaD_throw(L, errcode);
+}
+
+
 TStatus luaD_rawrunprotected (lua_State *L, Pfunc f, void *ud) {
 TStatus luaD_rawrunprotected (lua_State *L, Pfunc f, void *ud) {
   l_uint32 oldnCcalls = L->nCcalls;
   l_uint32 oldnCcalls = L->nCcalls;
   struct lua_longjmp lj;
   struct lua_longjmp lj;

+ 1 - 0
ldo.h

@@ -91,6 +91,7 @@ LUAI_FUNC void luaD_shrinkstack (lua_State *L);
 LUAI_FUNC void luaD_inctop (lua_State *L);
 LUAI_FUNC void luaD_inctop (lua_State *L);
 
 
 LUAI_FUNC l_noret luaD_throw (lua_State *L, TStatus errcode);
 LUAI_FUNC l_noret luaD_throw (lua_State *L, TStatus errcode);
+LUAI_FUNC l_noret luaD_throwbaselevel (lua_State *L, TStatus errcode);
 LUAI_FUNC TStatus luaD_rawrunprotected (lua_State *L, Pfunc f, void *ud);
 LUAI_FUNC TStatus luaD_rawrunprotected (lua_State *L, Pfunc f, void *ud);
 
 
 #endif
 #endif

+ 2 - 0
lstate.c

@@ -326,6 +326,8 @@ LUA_API int lua_closethread (lua_State *L, lua_State *from) {
   lua_lock(L);
   lua_lock(L);
   L->nCcalls = (from) ? getCcalls(from) : 0;
   L->nCcalls = (from) ? getCcalls(from) : 0;
   status = luaE_resetthread(L, L->status);
   status = luaE_resetthread(L, L->status);
+  if (L == from)  /* closing itself? */
+    luaD_throwbaselevel(L, status);
   lua_unlock(L);
   lua_unlock(L);
   return APIstatus(status);
   return APIstatus(status);
 }
 }

+ 26 - 10
manual/manual.of

@@ -3267,17 +3267,25 @@ when called through this function.
 
 
 Resets a thread, cleaning its call stack and closing all pending
 Resets a thread, cleaning its call stack and closing all pending
 to-be-closed variables.
 to-be-closed variables.
-Returns a status code:
+The parameter @id{from} represents the coroutine that is resetting @id{L}.
+If there is no such coroutine,
+this parameter can be @id{NULL}.
+
+Unless @id{L} is equal to @id{from},
+the call returns a status code:
 @Lid{LUA_OK} for no errors in the thread
 @Lid{LUA_OK} for no errors in the thread
 (either the original error that stopped the thread or
 (either the original error that stopped the thread or
 errors in closing methods),
 errors in closing methods),
 or an error status otherwise.
 or an error status otherwise.
 In case of error,
 In case of error,
-leaves the error object on the top of the stack.
+the error object is put on the top of the stack.
 
 
-The parameter @id{from} represents the coroutine that is resetting @id{L}.
-If there is no such coroutine,
-this parameter can be @id{NULL}.
+If @id{L} is equal to @id{from},
+it corresponds to a thread closing itself.
+In that case,
+the call does not return;
+instead, the resume or the protected call
+that (re)started the thread returns.
 
 
 }
 }
 
 
@@ -6939,18 +6947,26 @@ which come inside the table @defid{coroutine}.
 See @See{coroutine} for a general description of coroutines.
 See @See{coroutine} for a general description of coroutines.
 
 
 
 
-@LibEntry{coroutine.close (co)|
+@LibEntry{coroutine.close ([co])|
 
 
 Closes coroutine @id{co},
 Closes coroutine @id{co},
 that is,
 that is,
 closes all its pending to-be-closed variables
 closes all its pending to-be-closed variables
 and puts the coroutine in a dead state.
 and puts the coroutine in a dead state.
-The given coroutine must be dead or suspended.
-In case of error
+The default for @id{co} is the running coroutine.
+
+The given coroutine must be dead, suspended,
+or be the running coroutine.
+For the running coroutine,
+this function does not return.
+Instead, the resume that (re)started the coroutine returns.
+
+For other coroutines,
+in case of error
 (either the original error that stopped the coroutine or
 (either the original error that stopped the coroutine or
 errors in closing methods),
 errors in closing methods),
-returns @false plus the error object;
-otherwise returns @true.
+this function returns @false plus the error object;
+otherwise ir returns @true.
 
 
 }
 }
 
 

+ 53 - 9
testes/coroutine.lua

@@ -156,11 +156,6 @@ do
   st, msg = coroutine.close(co)
   st, msg = coroutine.close(co)
   assert(st and msg == nil)
   assert(st and msg == nil)
 
 
-
-  -- cannot close the running coroutine
-  local st, msg = pcall(coroutine.close, coroutine.running())
-  assert(not st and string.find(msg, "running"))
-
   local main = coroutine.running()
   local main = coroutine.running()
 
 
   -- cannot close a "normal" coroutine
   -- cannot close a "normal" coroutine
@@ -169,20 +164,19 @@ do
     assert(not st and string.find(msg, "normal"))
     assert(not st and string.find(msg, "normal"))
   end))()
   end))()
 
 
-  -- cannot close a coroutine while closing it
-  do
+  do   -- close a coroutine while closing it
     local co
     local co
     co = coroutine.create(
     co = coroutine.create(
       function()
       function()
         local x <close> = func2close(function()
         local x <close> = func2close(function()
-            coroutine.close(co)   -- try to close it again
+            coroutine.close(co)   -- close it again
          end)
          end)
         coroutine.yield(20)
         coroutine.yield(20)
       end)
       end)
     local st, msg = coroutine.resume(co)
     local st, msg = coroutine.resume(co)
     assert(st and msg == 20)
     assert(st and msg == 20)
     st, msg = coroutine.close(co)
     st, msg = coroutine.close(co)
-    assert(not st and string.find(msg, "running coroutine"))
+    assert(st and msg == nil)
   end
   end
 
 
   -- to-be-closed variables in coroutines
   -- to-be-closed variables in coroutines
@@ -289,6 +283,56 @@ do
 end
 end
 
 
 
 
+do print("coroutines closing itself")
+  global <const> coroutine, string, os
+  global <const> assert, error, pcall
+
+  local X = nil
+
+  local function new ()
+    return coroutine.create(function (what)
+
+      local <close>var = func2close(function (t, err)
+        if what == "yield" then
+          coroutine.yield()
+        elseif what == "error" then
+          error(200)
+        else
+          X = "Ok"
+          return X
+        end
+      end)
+
+      -- do an unprotected call so that coroutine becomes non-yieldable
+      string.gsub("a", "a", function ()
+        assert(not coroutine.isyieldable())
+        -- do protected calls while non-yieldable, to add recovery
+        -- entries (setjmp) to the stack
+        assert(pcall(pcall, function ()
+          -- 'close' works even while non-yieldable
+          coroutine.close()   -- close itself
+          os.exit(false)   -- not reacheable
+        end))
+      end)
+    end)
+  end
+
+  local co = new()
+  local st, msg = coroutine.resume(co, "ret")
+  assert(st and msg == nil)
+  assert(X == "Ok")
+
+  local co = new()
+  local st, msg = coroutine.resume(co, "error")
+  assert(not st and msg == 200)
+
+  local co = new()
+  local st, msg = coroutine.resume(co, "yield")
+  assert(not st and string.find(msg, "attempt to yield"))
+
+end
+
+
 -- yielding across C boundaries
 -- yielding across C boundaries
 
 
 local co = coroutine.wrap(function()
 local co = coroutine.wrap(function()