Browse Source

Change: LuaThread status to thread-safe

AnnulusGames 1 year ago
parent
commit
77b38ad49f
2 changed files with 41 additions and 44 deletions
  1. 40 43
      src/Lua/LuaThread.cs
  2. 1 1
      src/Lua/LuaThreadStatus.cs

+ 40 - 43
src/Lua/LuaThread.cs

@@ -14,14 +14,14 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
         public LuaValue[] Results;
         public LuaValue[] Results;
     }
     }
 
 
-    LuaThreadStatus status;
+    byte status;
     LuaState threadState;
     LuaState threadState;
     ValueTask<int> functionTask;
     ValueTask<int> functionTask;
 
 
     ManualResetValueTaskSourceCore<ResumeContext> resume;
     ManualResetValueTaskSourceCore<ResumeContext> resume;
     ManualResetValueTaskSourceCore<YieldContext> yield;
     ManualResetValueTaskSourceCore<YieldContext> yield;
 
 
-    public LuaThreadStatus Status => status;
+    public LuaThreadStatus Status => (LuaThreadStatus)status;
     public bool IsProtectedMode { get; }
     public bool IsProtectedMode { get; }
     public LuaFunction Function { get; }
     public LuaFunction Function { get; }
 
 
@@ -35,46 +35,43 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
 
 
     public async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     public async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
     {
-        if (status is LuaThreadStatus.Dead)
+        switch ((LuaThreadStatus)Volatile.Read(ref status))
         {
         {
-            if (IsProtectedMode)
-            {
-                buffer.Span[0] = false;
-                buffer.Span[1] = "cannot resume dead coroutine";
-                return 2;
-            }
-            else
-            {
-                throw new InvalidOperationException("cannot resume dead coroutine");
-            }
-        }
-        else if (status is LuaThreadStatus.Running)
-        {
-            throw new InvalidOperationException("cannot resume running coroutine");
-        }
-
-        if (status is LuaThreadStatus.Normal)
-        {
-            status = LuaThreadStatus.Running;
+            case LuaThreadStatus.Normal:
+                Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
 
 
-            // first argument is LuaThread object
-            for (int i = 0; i < context.ArgumentCount - 1; i++)
-            {
-                threadState.Push(context.Arguments[i + 1]);
-            }
+                // first argument is LuaThread object
+                for (int i = 0; i < context.ArgumentCount - 1; i++)
+                {
+                    threadState.Push(context.Arguments[i + 1]);
+                }
 
 
-            functionTask = Function.InvokeAsync(new()
-            {
-                State = threadState,
-                ArgumentCount = context.ArgumentCount - 1,
-                ChunkName = Function.Name,
-                RootChunkName = context.RootChunkName,
-            }, buffer[1..], cancellationToken).Preserve();
-        }
-        else
-        {
-            status = LuaThreadStatus.Running;
-            yield.SetResult(new());
+                functionTask = Function.InvokeAsync(new()
+                {
+                    State = threadState,
+                    ArgumentCount = context.ArgumentCount - 1,
+                    ChunkName = Function.Name,
+                    RootChunkName = context.RootChunkName,
+                }, buffer[1..], cancellationToken).Preserve();
+
+                break;
+            case LuaThreadStatus.Suspended:
+                Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+                yield.SetResult(new());
+                break;
+            case LuaThreadStatus.Running:
+                throw new InvalidOperationException("cannot resume running coroutine");
+            case LuaThreadStatus.Dead:
+                if (IsProtectedMode)
+                {
+                    buffer.Span[0] = false;
+                    buffer.Span[1] = "cannot resume dead coroutine";
+                    return 2;
+                }
+                else
+                {
+                    throw new InvalidOperationException("cannot resume dead coroutine");
+                }
         }
         }
 
 
         var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
         var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
@@ -107,7 +104,7 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
             }
             }
             else
             else
             {
             {
-                status = LuaThreadStatus.Dead;
+                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                 buffer.Span[0] = true;
                 buffer.Span[0] = true;
                 return 1 + functionTask!.Result;
                 return 1 + functionTask!.Result;
             }
             }
@@ -116,7 +113,7 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
         {
         {
             if (IsProtectedMode)
             if (IsProtectedMode)
             {
             {
-                status = LuaThreadStatus.Dead;
+                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                 buffer.Span[0] = false;
                 buffer.Span[0] = false;
                 buffer.Span[1] = ex.Message;
                 buffer.Span[1] = ex.Message;
                 return 2;
                 return 2;
@@ -135,7 +132,7 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
 
 
     public async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
     public async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
     {
     {
-        if (status is not LuaThreadStatus.Running)
+        if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
         {
         {
             throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
             throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
         }
         }
@@ -145,7 +142,7 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
             Results = context.Arguments.ToArray(),
             Results = context.Arguments.ToArray(),
         });
         });
 
 
-        status = LuaThreadStatus.Suspended;
+        Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
 
 
         CancellationTokenRegistration registration = default;
         CancellationTokenRegistration registration = default;
         if (cancellationToken.CanBeCanceled)
         if (cancellationToken.CanBeCanceled)

+ 1 - 1
src/Lua/LuaThreadStatus.cs

@@ -1,6 +1,6 @@
 namespace Lua;
 namespace Lua;
 
 
-public enum LuaThreadStatus
+public enum LuaThreadStatus : byte
 {
 {
     Normal,
     Normal,
     Suspended,
     Suspended,