Browse Source

Fix: LuaCoroutine

AnnulusGames 1 year ago
parent
commit
1e64eb6f42

+ 8 - 2
src/Lua/LuaCoroutine.cs

@@ -41,7 +41,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
     public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
-        var baseThread = context.State.CurrentThread;
+        var baseThread = context.Thread;
         baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
 
         context.State.ThreadStack.Push(this);
@@ -51,12 +51,18 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
                 case LuaThreadStatus.Suspended:
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-                    
+
                     if (isFirstCall)
                     {
+                        // copy stack value
+                        Stack.EnsureCapacity(baseThread.Stack.Count);
+                        baseThread.Stack.AsSpan().CopyTo(Stack.GetBuffer());
+                        Stack.NotifyTop(baseThread.Stack.Count);
+
                         functionTask = Function.InvokeAsync(new()
                         {
                             State = threadState,
+                            Thread = this,
                             ArgumentCount = context.ArgumentCount - 1,
                             ChunkName = Function.Name,
                             RootChunkName = context.RootChunkName,

+ 5 - 6
src/Lua/LuaFunctionExecutionContext.cs

@@ -6,6 +6,7 @@ namespace Lua;
 public readonly record struct LuaFunctionExecutionContext
 {
     public required LuaState State { get; init; }
+    public required LuaThread Thread { get; init; }
     public required int ArgumentCount { get; init; }
     public int? StackPosition { get; init; }
     public SourcePosition? SourcePosition { get; init; }
@@ -16,8 +17,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
         get
         {
-            var thread = State.CurrentThread;
-            return thread.GetStackValues().Slice(thread.GetCurrentFrame().Base, ArgumentCount);
+            return Thread.GetStackValues().Slice(Thread.GetCurrentFrame().Base, ArgumentCount);
         }
     }
 
@@ -42,14 +42,13 @@ public readonly record struct LuaFunctionExecutionContext
         var arg = Arguments[index];
         if (!arg.TryRead<T>(out var argValue))
         {
-            var thread = State.CurrentThread;
             if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
             {
-                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, thread.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
             }
             else
             {
-                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, thread.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
             }
         }
 
@@ -60,7 +59,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
         if (ArgumentCount <= index)
         {
-            LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, State.CurrentThread.GetCurrentFrame().Function.Name);
+            LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name);
         }
     }
 }

+ 1 - 0
src/Lua/LuaState.cs

@@ -64,6 +64,7 @@ public sealed class LuaState
             return await closure.InvokeAsync(new()
             {
                 State = this,
+                Thread = CurrentThread,
                 ArgumentCount = 0,
                 StackPosition = 0,
                 SourcePosition = null,

+ 1 - 1
src/Lua/LuaValue.cs

@@ -410,7 +410,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             return func.InvokeAsync(context with
             {
                 ArgumentCount = 1,
-                StackPosition = context.State.CurrentThread.Stack.Count,
+                StackPosition = context.Thread.Stack.Count,
             }, buffer, cancellationToken);
         }
         else

+ 1 - 1
src/Lua/Runtime/Closure.cs

@@ -27,7 +27,7 @@ public sealed class Closure : LuaFunction
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.State.CurrentThread.GetCurrentFrame(), buffer, cancellationToken);
+        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.Thread.GetCurrentFrame(), buffer, cancellationToken);
     }
 
     static UpValue GetUpValueFromDescription(LuaState state, UpValue envUpValue, Chunk proto, UpValueInfo description, int depth)

+ 21 - 2
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -141,6 +141,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -180,6 +181,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -219,6 +221,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -258,6 +261,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -302,6 +306,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -341,6 +346,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -378,6 +384,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 1,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -422,6 +429,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 1,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -479,6 +487,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -523,6 +532,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     {
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
@@ -572,6 +582,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     {
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
@@ -625,6 +636,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     {
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
@@ -691,6 +703,7 @@ public static partial class LuaVirtualMachine
                                 var resultCount = await func.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = argumentCount,
                                     StackPosition = newBase,
                                     SourcePosition = chunk.SourcePositions[pc],
@@ -741,6 +754,7 @@ public static partial class LuaVirtualMachine
                             return await func.InvokeAsync(new()
                             {
                                 State = state,
+                                Thread = thread,
                                 ArgumentCount = argumentCount,
                                 StackPosition = newBase,
                                 SourcePosition = chunk.SourcePositions[pc],
@@ -814,6 +828,7 @@ public static partial class LuaVirtualMachine
                                 await iterator.InvokeAsync(new()
                                 {
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     StackPosition = nextBase,
                                     SourcePosition = chunk.SourcePositions[pc],
@@ -905,7 +920,8 @@ public static partial class LuaVirtualMachine
 #endif
     static async ValueTask<LuaValue> GetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, CancellationToken cancellationToken)
     {
-        var stack = state.CurrentThread.Stack;
+        var thread = state.CurrentThread;
+        var stack = thread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
 
         if (isTable && t.TryGetValue(key, out var result))
@@ -929,6 +945,7 @@ public static partial class LuaVirtualMachine
                 await indexTable.InvokeAsync(new()
                 {
                     State = state,
+                    Thread = thread,
                     ArgumentCount = 2,
                     SourcePosition = chunk.SourcePositions[pc],
                 }, methodBuffer, cancellationToken);
@@ -956,7 +973,8 @@ public static partial class LuaVirtualMachine
 #endif
     static async ValueTask SetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken)
     {
-        var stack = state.CurrentThread.Stack;
+        var thread = state.CurrentThread;
+        var stack = thread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
 
         if (key.Type is LuaValueType.Number && key.TryRead<double>(out var d) && double.IsNaN(d))
@@ -986,6 +1004,7 @@ public static partial class LuaVirtualMachine
                 await indexTable.InvokeAsync(new()
                 {
                     State = state,
+                    Thread = thread,
                     ArgumentCount = 3,
                     SourcePosition = chunk.SourcePositions[pc],
                 }, methodBuffer, cancellationToken);

+ 2 - 2
src/Lua/Standard/Coroutines/CoroutineRunningFunction.cs

@@ -9,8 +9,8 @@ public sealed class CoroutineRunningFunction : LuaFunction
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        buffer.Span[0] = context.State.CurrentThread;
-        buffer.Span[1] = context.State.CurrentThread == context.State.MainThread;
+        buffer.Span[0] = context.Thread;
+        buffer.Span[1] = context.Thread == context.State.MainThread;
         return new(2);
     }
 }

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineYieldFunction.cs

@@ -8,7 +8,7 @@ public sealed class CoroutineYieldFunction : LuaFunction
 
     protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        await context.State.CurrentThread.Yield(context, cancellationToken);
+        await context.Thread.Yield(context, cancellationToken);
         return 0;
     }
 }