Browse Source

Share resume/yield core implementations

Akeit0 7 months ago
parent
commit
e60b59b777
1 changed files with 69 additions and 134 deletions
  1. 69 134
      src/Lua/LuaCoroutine.cs

+ 69 - 134
src/Lua/LuaCoroutine.cs

@@ -75,131 +75,26 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
 
     public ValueTask<int> ResumeAsync(LuaStack stack, CancellationToken cancellationToken = default)
     public ValueTask<int> ResumeAsync(LuaStack stack, CancellationToken cancellationToken = default)
     {
     {
-        return ResumeAsync(stack, stack.Count, 0, cancellationToken);
+        return ResumeAsyncCore(stack, stack.Count, 0, null, cancellationToken);
     }
     }
+
+    public override ValueTask<int> ResumeAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    {
+        return ResumeAsyncCore(context.Thread.Stack, context.ArgumentCount, context.ReturnFrameBase, context.Thread, cancellationToken);
+    }
+
 #if NET6_0_OR_GREATER
 #if NET6_0_OR_GREATER
     [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
     [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
 #endif
 #endif
-    public async ValueTask<int> ResumeAsync(LuaStack stack, int argCount, int returnBase, CancellationToken cancellationToken = default)
+    async ValueTask<int> ResumeAsyncCore(LuaStack stack, int argCount, int returnBase, LuaThread? baseThread, CancellationToken cancellationToken = default)
     {
     {
-        if (isFirstCall)
-        {
-            ThrowIfRunning();
-            IsRunning = true;
-        }
-
-        switch ((LuaThreadStatus)Volatile.Read(ref status))
-        {
-            case LuaThreadStatus.Suspended:
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-
-                if (!isFirstCall)
-                {
-                    yield.SetResult(new(stack, argCount));
-                }
-
-                break;
-            case LuaThreadStatus.Normal:
-            case LuaThreadStatus.Running:
-                if (IsProtectedMode)
-                {
-                    stack.PopUntil(returnBase);
-                    stack.Push(false);
-                    stack.Push("cannot resume non-suspended coroutine");
-                    return 2;
-                }
-                else
-                {
-                    throw new LuaException("cannot resume non-suspended coroutine");
-                }
-            case LuaThreadStatus.Dead:
-                if (IsProtectedMode)
-                {
-                    stack.PopUntil(returnBase);
-                    stack.Push(false);
-                    stack.Push("cannot resume non-suspended coroutine");
-                    return 2;
-                }
-                else
-                {
-                    throw new LuaException("cannot resume dead coroutine");
-                }
-        }
-
-        var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
-
-        CancellationTokenRegistration registration = default;
-        if (cancellationToken.CanBeCanceled)
-        {
-            registration = cancellationToken.UnsafeRegister(static x =>
-            {
-                var coroutine = (LuaCoroutine)x!;
-                coroutine.yield.SetException(new OperationCanceledException());
-            }, this);
-        }
-
-        try
+        if (baseThread != null)
         {
         {
-            if (isFirstCall)
-            {
-                Stack.PushRange(stack.AsSpan()[^argCount..]);
-                functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken).Preserve();
-
-                Volatile.Write(ref isFirstCall, false);
-            }
+            baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
 
 
-            var (index, result0, _, promise) = await ValueTaskEx.WhenAnyPooled(resumeTask, functionTask!);
-            promise.Dispose();
-            if (index == 0)
-            {
-                var results = result0.Results;
-                stack.PopUntil(returnBase);
-                stack.Push(true);
-                stack.PushRange(results);
-                return results.Length + 1;
-            }
-            else
-            {
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
-                stack.PopUntil(returnBase);
-                stack.Push(true);
-                stack.PushRange(Stack.AsSpan());
-                ReleaseCore();
-                return stack.Count - returnBase;
-            }
-        }
-        catch (Exception ex) when (ex is not OperationCanceledException)
-        {
-            if (IsProtectedMode)
-            {
-                traceback = (ex as LuaRuntimeException)?.LuaTraceback;
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
-                ReleaseCore();
-                stack.PopUntil(returnBase);
-                stack.Push(false);
-                stack.Push(ex is LuaRuntimeException luaEx ? luaEx.ErrorObject : ex.Message);
-                return 2;
-            }
-            else
-            {
-                throw;
-            }
+            baseThread.State.ThreadStack.Push(this);
         }
         }
-        finally
-        {
-            registration.Dispose();
-            resume.Reset();
-        }
-    }
-#if NET6_0_OR_GREATER
-    [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
-#endif
-    public override async ValueTask<int> ResumeAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
-    {
-        var baseThread = context.Thread;
-        baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
 
 
-        context.State.ThreadStack.Push(this);
         try
         try
         {
         {
             switch ((LuaThreadStatus)Volatile.Read(ref status))
             switch ((LuaThreadStatus)Volatile.Read(ref status))
@@ -209,7 +104,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
 
                     if (!isFirstCall)
                     if (!isFirstCall)
                     {
                     {
-                        yield.SetResult(new(context.Thread.Stack, context.ArgumentCount));
+                        yield.SetResult(new(stack, argCount));
                     }
                     }
 
 
                     break;
                     break;
@@ -217,20 +112,28 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                 case LuaThreadStatus.Running:
                 case LuaThreadStatus.Running:
                     if (IsProtectedMode)
                     if (IsProtectedMode)
                     {
                     {
-                        return context.Return(false, "cannot resume non-suspended coroutine");
+                        stack.PopUntil(returnBase);
+                        stack.Push(false);
+                        stack.Push("cannot resume non-suspended coroutine");
+                        return 2;
                     }
                     }
                     else
                     else
                     {
                     {
-                        throw new LuaRuntimeException(context.Thread.GetTraceback(), "cannot resume non-suspended coroutine");
+                        if (baseThread != null) throw new LuaRuntimeException(baseThread.GetTraceback(), "cannot resume non-suspended coroutine");
+                        else throw new LuaException("cannot resume non-suspended coroutine");
                     }
                     }
                 case LuaThreadStatus.Dead:
                 case LuaThreadStatus.Dead:
                     if (IsProtectedMode)
                     if (IsProtectedMode)
                     {
                     {
-                        return context.Return(false, "cannot resume dead coroutine");
+                        stack.PopUntil(returnBase);
+                        stack.Push(false);
+                        stack.Push("cannot resume dead coroutine");
+                        return 2;
                     }
                     }
                     else
                     else
                     {
                     {
-                        throw new LuaRuntimeException(context.Thread.GetTraceback(), "cannot resume dead coroutine");
+                        if (baseThread != null) throw new LuaRuntimeException(baseThread.GetTraceback(), "cannot resume dead coroutine");
+                        else throw new LuaException("cannot resume dead coroutine");
                     }
                     }
             }
             }
 
 
@@ -250,7 +153,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
             {
                 if (isFirstCall)
                 if (isFirstCall)
                 {
                 {
-                    Stack.PushRange(context.Arguments);
+                    Stack.PushRange(stack.AsSpan()[^argCount..]);
                     functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken).Preserve();
                     functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken).Preserve();
 
 
                     Volatile.Write(ref isFirstCall, false);
                     Volatile.Write(ref isFirstCall, false);
@@ -261,14 +164,19 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                 if (index == 0)
                 if (index == 0)
                 {
                 {
                     var results = result0.Results;
                     var results = result0.Results;
-                    return context.Return(true, results);
+                    stack.PopUntil(returnBase);
+                    stack.Push(true);
+                    stack.PushRange(results);
+                    return results.Length + 1;
                 }
                 }
                 else
                 else
                 {
                 {
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
-                    var count = context.Return(true, Stack.AsSpan());
+                    stack.PopUntil(returnBase);
+                    stack.Push(true);
+                    stack.PushRange(Stack.AsSpan());
                     ReleaseCore();
                     ReleaseCore();
-                    return count;
+                    return stack.Count - returnBase;
                 }
                 }
             }
             }
             catch (Exception ex) when (ex is not OperationCanceledException)
             catch (Exception ex) when (ex is not OperationCanceledException)
@@ -278,7 +186,10 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                     traceback = (ex as LuaRuntimeException)?.LuaTraceback;
                     traceback = (ex as LuaRuntimeException)?.LuaTraceback;
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     ReleaseCore();
                     ReleaseCore();
-                    return context.Return(false, ex is LuaRuntimeException luaEx ? luaEx.ErrorObject : ex.Message);
+                    stack.PopUntil(returnBase);
+                    stack.Push(false);
+                    stack.Push(ex is LuaRuntimeException luaEx ? luaEx.ErrorObject : ex.Message);
+                    return 2;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -293,26 +204,48 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
         }
         }
         finally
         finally
         {
         {
-            context.State.ThreadStack.Pop();
-            baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
+            if (baseThread != null)
+            {
+                baseThread.State.ThreadStack.Pop();
+                baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
+            }
         }
         }
     }
     }
+
+    public ValueTask<int> YieldAsync(LuaStack stack, CancellationToken cancellationToken = default)
+    {
+        return YieldAsyncCore(stack, stack.Count, 0, null, cancellationToken);
+    }
+
+    public override ValueTask<int> YieldAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    {
+        return YieldAsyncCore(context.Thread.Stack, context.ArgumentCount, context.ReturnFrameBase, context.Thread, cancellationToken);
+    }
+
 #if NET6_0_OR_GREATER
 #if NET6_0_OR_GREATER
     [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
     [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
 #endif
 #endif
-    public override async ValueTask<int> YieldAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    async ValueTask<int> YieldAsyncCore(LuaStack stack, int argCount, int returnBase, LuaThread? baseThread, CancellationToken cancellationToken = default)
     {
     {
         if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
         if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
         {
         {
-            throw new LuaRuntimeException(context.Thread.GetTraceback(), "cannot call yield on a coroutine that is not currently running");
+            if (baseThread != null)
+            {
+                throw new LuaRuntimeException(baseThread.GetTraceback(), "cannot yield from a non-running coroutine");
+            }
+
+            throw new LuaException("cannot call yield on a coroutine that is not currently running");
         }
         }
 
 
-        if (context.Thread.GetCallStackFrames()[^2].Function is not LuaClosure)
+        if (baseThread != null)
         {
         {
-            throw new LuaRuntimeException(context.Thread.GetTraceback(), "attempt to yield across a C#-call boundary");
+            if (baseThread.GetCallStackFrames()[^2].Function is not LuaClosure)
+            {
+                throw new LuaRuntimeException(baseThread.GetTraceback(), "attempt to yield across a C#-call boundary");
+            }
         }
         }
 
 
-        resume.SetResult(new(context.Thread.Stack, context.ArgumentCount));
+        resume.SetResult(new(stack, argCount));
 
 
         Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
         Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
 
 
@@ -330,7 +263,9 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
         try
         try
         {
         {
             var result = await new ValueTask<YieldContext>(this, yield.Version);
             var result = await new ValueTask<YieldContext>(this, yield.Version);
-            return (context.Return(result.Results));
+            stack.PopUntil(returnBase);
+            stack.PushRange(result.Results);
+            return (result.Results).Length;
         }
         }
         catch (Exception ex) when (ex is not OperationCanceledException)
         catch (Exception ex) when (ex is not OperationCanceledException)
         {
         {