Browse Source

Add: UnsafeSetStatus

AnnulusGames 1 year ago
parent
commit
b73b52b3bd
4 changed files with 42 additions and 18 deletions
  1. 35 17
      src/Lua/LuaCoroutine.cs
  2. 5 0
      src/Lua/LuaMainThread.cs
  3. 1 0
      src/Lua/LuaThread.cs
  4. 1 1
      src/Lua/LuaThreadStatus.cs

+ 35 - 17
src/Lua/LuaCoroutine.cs

@@ -15,6 +15,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
     }
 
     byte status;
+    bool isFirstCall = true;
     LuaState threadState;
     ValueTask<int> functionTask;
 
@@ -22,6 +23,12 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
     ManualResetValueTaskSourceCore<YieldContext> yield;
 
     public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
+
+    public override void UnsafeSetStatus(LuaThreadStatus status)
+    {
+        this.status = (byte)status;
+    }
+
     public bool IsProtectedMode { get; }
     public LuaFunction Function { get; }
 
@@ -34,33 +41,33 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
     public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
+        var baseThread = context.State.CurrentThread;
+        baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
+
         context.State.ThreadStack.Push(this);
         try
         {
             switch ((LuaThreadStatus)Volatile.Read(ref status))
             {
-                case LuaThreadStatus.Normal:
+                case LuaThreadStatus.Suspended:
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-
-                    // first argument is LuaThread object
-                    for (int i = 0; i < context.ArgumentCount - 1; i++)
+                    
+                    if (isFirstCall)
                     {
-                        threadState.Push(context.Arguments[i + 1]);
+                        functionTask = Function.InvokeAsync(new()
+                        {
+                            State = threadState,
+                            ArgumentCount = context.ArgumentCount - 1,
+                            ChunkName = Function.Name,
+                            RootChunkName = context.RootChunkName,
+                        }, buffer[1..], cancellationToken).Preserve();
                     }
-
-                    functionTask = Function.InvokeAsync(new()
+                    else
                     {
-                        State = threadState,
-                        ArgumentCount = context.ArgumentCount - 1,
-                        ChunkName = Function.Name,
-                        RootChunkName = context.RootChunkName,
-                    }, buffer[1..], cancellationToken).Preserve();
-
-                    break;
-                case LuaThreadStatus.Suspended:
-                    Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-                    yield.SetResult(new());
+                        yield.SetResult(new());
+                    }
                     break;
+                case LuaThreadStatus.Normal:
                 case LuaThreadStatus.Running:
                     if (IsProtectedMode)
                     {
@@ -99,6 +106,16 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
             try
             {
+                if (isFirstCall)
+                {
+                    for (int i = 0; i < context.ArgumentCount - 1; i++)
+                    {
+                        threadState.Push(context.Arguments[i + 1]);
+                    }
+
+                    Volatile.Write(ref isFirstCall, false);
+                }
+
                 (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
 
                 if (index == 0)
@@ -143,6 +160,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
         finally
         {
             context.State.ThreadStack.Pop();
+            baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
         }
     }
 

+ 5 - 0
src/Lua/LuaMainThread.cs

@@ -7,6 +7,11 @@ public sealed class LuaMainThread : LuaThread
         return LuaThreadStatus.Running;
     }
 
+    public override void UnsafeSetStatus(LuaThreadStatus status)
+    {
+        // Do nothing
+    }
+
     public override ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         buffer.Span[0] = false;

+ 1 - 0
src/Lua/LuaThread.cs

@@ -6,6 +6,7 @@ namespace Lua;
 public abstract class LuaThread
 {
     public abstract LuaThreadStatus GetStatus();
+    public abstract void UnsafeSetStatus(LuaThreadStatus status);
     public abstract ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default);
     public abstract ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default);
 

+ 1 - 1
src/Lua/LuaThreadStatus.cs

@@ -2,8 +2,8 @@ namespace Lua;
 
 public enum LuaThreadStatus : byte
 {
-    Normal,
     Suspended,
+    Normal,
     Running,
     Dead,
 }