Browse Source

Fix: behavior when multiple coroutines share a function

AnnulusGames 1 year ago
parent
commit
5ec08eb86e

+ 0 - 43
src/Lua/Internal/FastList.cs

@@ -1,43 +0,0 @@
-using System.Runtime.CompilerServices;
-
-namespace Lua.Internal;
-
-public class FastList<T>
-{
-    FastListCore<T> core;
-
-    public int Length
-    {
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        get
-        {
-            return core.Length;
-        }
-    }
-
-    public ref T this[int index]
-    {
-        get
-        {
-            return ref core[index];
-        }
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void Add(T value)
-    {
-        core.Add(value);
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void RemoveAtSwapback(int index)
-    {
-        core.RemoveAtSwapback(index);
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public Span<T> AsSpan()
-    {
-        return core.AsSpan();
-    }
-}

+ 0 - 8
src/Lua/LuaFunction.cs

@@ -4,14 +4,6 @@ namespace Lua;
 
 public abstract partial class LuaFunction
 {
-    LuaThread? thread;
-    public LuaThread? Thread => thread;
-
-    internal void SetCurrentThread(LuaThread thread)
-    {
-        this.thread = thread;
-    }
-
     public virtual string Name => GetType().Name;
 
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)

+ 31 - 29
src/Lua/LuaState.cs

@@ -8,17 +8,31 @@ public sealed class LuaState
 {
     public const string DefaultChunkName = "chunk";
 
-    LuaStack stack = new();
-    FastStackCore<CallStackFrame> callStack;
-    FastList<UpValue> openUpValues;
+    class GlobalState
+    {
+        public FastStackCore<LuaThread> threadStack;
+        public FastListCore<UpValue> openUpValues;
+        public readonly LuaTable environment;
+        public readonly UpValue envUpValue;
+
+        public GlobalState(LuaState state)
+        {
+            environment = new();
+            envUpValue = UpValue.Closed(state, environment);
+        }
+    }
+
+    readonly GlobalState globalState;
 
-    LuaTable environment;
-    internal UpValue EnvUpValue { get; }
+    readonly LuaStack stack = new();
+    FastStackCore<CallStackFrame> callStack;
     bool isRunning;
 
     internal LuaStack Stack => stack;
+    internal UpValue EnvUpValue => globalState.envUpValue;
+    internal ref FastStackCore<LuaThread> ThreadStack => ref globalState.threadStack;
 
-    public LuaTable Environment => environment;
+    public LuaTable Environment => globalState.environment;
     public bool IsRunning => Volatile.Read(ref isRunning);
 
     public static LuaState Create()
@@ -26,23 +40,14 @@ public sealed class LuaState
         return new();
     }
 
-    internal LuaState CreateCoroutineState()
-    {
-        return new LuaState(this);
-    }
-
     LuaState()
     {
-        environment = new();
-        EnvUpValue = UpValue.Closed(this, environment);
-        openUpValues = new();
+        globalState = new(this);
     }
 
     LuaState(LuaState parent)
     {
-        environment = parent.Environment;
-        EnvUpValue = parent.EnvUpValue;
-        openUpValues = parent.openUpValues;
+        globalState = parent.globalState;
     }
 
     public async ValueTask<int> RunAsync(Chunk chunk, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
@@ -98,18 +103,14 @@ public sealed class LuaState
         return callStack.AsSpan();
     }
 
-    public bool TryGetCurrentThread([NotNullWhen(true)] out LuaThread? result)
+    public LuaThread CreateThread(LuaFunction function, bool isProtectedMode = true)
     {
-        var span = GetCallStackSpan();
-
-        for (int i = 0; i < span.Length; i++)
-        {
-            result = span[i].Function.Thread;
-            if (result != null) return true;
-        }
+        return new LuaThread(new LuaState(this), function, isProtectedMode);
+    }
 
-        result = default;
-        return false;
+    public bool TryGetCurrentThread([NotNullWhen(true)] out LuaThread? result)
+    {
+        return ThreadStack.TryPeek(out result);
     }
 
     public CallStackFrame GetCurrentFrame()
@@ -127,7 +128,7 @@ public sealed class LuaState
 
     internal UpValue GetOrAddUpValue(int registerIndex)
     {
-        foreach (var upValue in openUpValues.AsSpan())
+        foreach (var upValue in globalState.openUpValues.AsSpan())
         {
             if (upValue.RegisterIndex == registerIndex && upValue.State == this)
             {
@@ -136,12 +137,13 @@ public sealed class LuaState
         }
 
         var newUpValue = UpValue.Open(this, registerIndex);
-        openUpValues.Add(newUpValue);
+        globalState.openUpValues.Add(newUpValue);
         return newUpValue;
     }
 
     internal void CloseUpValues(int frameBase)
     {
+        var openUpValues = globalState.openUpValues;
         for (int i = 0; i < openUpValues.Length; i++)
         {
             var upValue = openUpValues[i];

+ 82 - 75
src/Lua/LuaThread.cs

@@ -28,105 +28,112 @@ public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValue
     internal LuaThread(LuaState state, LuaFunction function, bool isProtectedMode)
     {
         IsProtectedMode = isProtectedMode;
-        threadState = state.CreateCoroutineState();
+        threadState = state;
         Function = function;
-        function.SetCurrentThread(this);
     }
 
     public async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
-        switch ((LuaThreadStatus)Volatile.Read(ref status))
+        context.State.ThreadStack.Push(this);
+        try
         {
-            case LuaThreadStatus.Normal:
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+            switch ((LuaThreadStatus)Volatile.Read(ref status))
+            {
+                case LuaThreadStatus.Normal:
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+
+                    // first argument is LuaThread object
+                    for (int i = 0; i < context.ArgumentCount - 1; i++)
+                    {
+                        threadState.Push(context.Arguments[i + 1]);
+                    }
+
+                    functionTask = Function.InvokeAsync(new()
+                    {
+                        State = threadState,
+                        ArgumentCount = context.ArgumentCount - 1,
+                        ChunkName = Function.Name,
+                        RootChunkName = context.RootChunkName,
+                    }, buffer[1..], cancellationToken).Preserve();
+
+                    break;
+                case LuaThreadStatus.Suspended:
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+                    yield.SetResult(new());
+                    break;
+                case LuaThreadStatus.Running:
+                    throw new InvalidOperationException("cannot resume running coroutine");
+                case LuaThreadStatus.Dead:
+                    if (IsProtectedMode)
+                    {
+                        buffer.Span[0] = false;
+                        buffer.Span[1] = "cannot resume dead coroutine";
+                        return 2;
+                    }
+                    else
+                    {
+                        throw new InvalidOperationException("cannot resume dead coroutine");
+                    }
+            }
 
-                // first argument is LuaThread object
-                for (int i = 0; i < context.ArgumentCount - 1; i++)
-                {
-                    threadState.Push(context.Arguments[i + 1]);
-                }
+            var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
 
-                functionTask = Function.InvokeAsync(new()
-                {
-                    State = threadState,
-                    ArgumentCount = context.ArgumentCount - 1,
-                    ChunkName = Function.Name,
-                    RootChunkName = context.RootChunkName,
-                }, buffer[1..], cancellationToken).Preserve();
-
-                break;
-            case LuaThreadStatus.Suspended:
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-                yield.SetResult(new());
-                break;
-            case LuaThreadStatus.Running:
-                throw new InvalidOperationException("cannot resume running coroutine");
-            case LuaThreadStatus.Dead:
-                if (IsProtectedMode)
-                {
-                    buffer.Span[0] = false;
-                    buffer.Span[1] = "cannot resume dead coroutine";
-                    return 2;
-                }
-                else
+            CancellationTokenRegistration registration = default;
+            if (cancellationToken.CanBeCanceled)
+            {
+                registration = cancellationToken.UnsafeRegister(static x =>
                 {
-                    throw new InvalidOperationException("cannot resume dead coroutine");
-                }
-        }
-
-        var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
+                    var thread = (LuaThread)x!;
+                    thread.yield.SetException(new OperationCanceledException());
+                }, this);
+            }
 
-        CancellationTokenRegistration registration = default;
-        if (cancellationToken.CanBeCanceled)
-        {
-            registration = cancellationToken.UnsafeRegister(static x =>
+            try
             {
-                var thread = (LuaThread)x!;
-                thread.yield.SetException(new OperationCanceledException());
-            }, this);
-        }
+                (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
 
-        try
-        {
-            (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
+                if (index == 0)
+                {
+                    var results = result0.Results;
 
-            if (index == 0)
-            {
-                var results = result0.Results;
+                    buffer.Span[0] = true;
+                    for (int i = 0; i < results.Length; i++)
+                    {
+                        buffer.Span[i + 1] = results[i];
+                    }
 
-                buffer.Span[0] = true;
-                for (int i = 0; i < results.Length; i++)
+                    return results.Length + 1;
+                }
+                else
                 {
-                    buffer.Span[i + 1] = results[i];
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
+                    buffer.Span[0] = true;
+                    return 1 + functionTask!.Result;
                 }
-
-                return results.Length + 1;
-            }
-            else
-            {
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
-                buffer.Span[0] = true;
-                return 1 + functionTask!.Result;
             }
-        }
-        catch (Exception ex) when (ex is not OperationCanceledException)
-        {
-            if (IsProtectedMode)
+            catch (Exception ex) when (ex is not OperationCanceledException)
             {
-                Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
-                buffer.Span[0] = false;
-                buffer.Span[1] = ex.Message;
-                return 2;
+                if (IsProtectedMode)
+                {
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
+                    buffer.Span[0] = false;
+                    buffer.Span[1] = ex.Message;
+                    return 2;
+                }
+                else
+                {
+                    throw;
+                }
             }
-            else
+            finally
             {
-                throw;
+                registration.Dispose();
+                resume.Reset();
             }
         }
         finally
         {
-            registration.Dispose();
-            resume.Reset();
+            context.State.ThreadStack.Pop();
         }
     }
 

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineCreateFunction.cs

@@ -10,7 +10,7 @@ public sealed class CoroutineCreateFunction : LuaFunction
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var arg0 = context.ReadArgument<LuaFunction>(0);
-        buffer.Span[0] = new LuaThread(context.State, arg0, true);
+        buffer.Span[0] = context.State.CreateThread(arg0, true);
         return new(1);
     }
 }

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineWrapFunction.cs

@@ -10,7 +10,7 @@ public sealed class CoroutineWrapFunction : LuaFunction
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var arg0 = context.ReadArgument<LuaFunction>(0);
-        var thread = new LuaThread(context.State, arg0, false);
+        var thread = context.State.CreateThread(arg0, false);
         buffer.Span[0] = new Wrapper(thread);
         return new(1);
     }