Browse Source

Fix: LuaCoroutine

AnnulusGames 1 year ago
parent
commit
1e64eb6f42

+ 8 - 2
src/Lua/LuaCoroutine.cs

@@ -41,7 +41,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
 
     public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
     {
-        var baseThread = context.State.CurrentThread;
+        var baseThread = context.Thread;
         baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
         baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
 
 
         context.State.ThreadStack.Push(this);
         context.State.ThreadStack.Push(this);
@@ -51,12 +51,18 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
             {
                 case LuaThreadStatus.Suspended:
                 case LuaThreadStatus.Suspended:
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
-                    
+
                     if (isFirstCall)
                     if (isFirstCall)
                     {
                     {
+                        // copy stack value
+                        Stack.EnsureCapacity(baseThread.Stack.Count);
+                        baseThread.Stack.AsSpan().CopyTo(Stack.GetBuffer());
+                        Stack.NotifyTop(baseThread.Stack.Count);
+
                         functionTask = Function.InvokeAsync(new()
                         functionTask = Function.InvokeAsync(new()
                         {
                         {
                             State = threadState,
                             State = threadState,
+                            Thread = this,
                             ArgumentCount = context.ArgumentCount - 1,
                             ArgumentCount = context.ArgumentCount - 1,
                             ChunkName = Function.Name,
                             ChunkName = Function.Name,
                             RootChunkName = context.RootChunkName,
                             RootChunkName = context.RootChunkName,

+ 5 - 6
src/Lua/LuaFunctionExecutionContext.cs

@@ -6,6 +6,7 @@ namespace Lua;
 public readonly record struct LuaFunctionExecutionContext
 public readonly record struct LuaFunctionExecutionContext
 {
 {
     public required LuaState State { get; init; }
     public required LuaState State { get; init; }
+    public required LuaThread Thread { get; init; }
     public required int ArgumentCount { get; init; }
     public required int ArgumentCount { get; init; }
     public int? StackPosition { get; init; }
     public int? StackPosition { get; init; }
     public SourcePosition? SourcePosition { get; init; }
     public SourcePosition? SourcePosition { get; init; }
@@ -16,8 +17,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
     {
         get
         get
         {
         {
-            var thread = State.CurrentThread;
-            return thread.GetStackValues().Slice(thread.GetCurrentFrame().Base, ArgumentCount);
+            return Thread.GetStackValues().Slice(Thread.GetCurrentFrame().Base, ArgumentCount);
         }
         }
     }
     }
 
 
@@ -42,14 +42,13 @@ public readonly record struct LuaFunctionExecutionContext
         var arg = Arguments[index];
         var arg = Arguments[index];
         if (!arg.TryRead<T>(out var argValue))
         if (!arg.TryRead<T>(out var argValue))
         {
         {
-            var thread = State.CurrentThread;
             if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
             if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
             {
             {
-                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, thread.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
             }
             }
             else
             else
             {
             {
-                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, thread.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
             }
             }
         }
         }
 
 
@@ -60,7 +59,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
     {
         if (ArgumentCount <= index)
         if (ArgumentCount <= index)
         {
         {
-            LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, State.CurrentThread.GetCurrentFrame().Function.Name);
+            LuaRuntimeException.BadArgument(State.GetTraceback(), index + 1, Thread.GetCurrentFrame().Function.Name);
         }
         }
     }
     }
 }
 }

+ 1 - 0
src/Lua/LuaState.cs

@@ -64,6 +64,7 @@ public sealed class LuaState
             return await closure.InvokeAsync(new()
             return await closure.InvokeAsync(new()
             {
             {
                 State = this,
                 State = this,
+                Thread = CurrentThread,
                 ArgumentCount = 0,
                 ArgumentCount = 0,
                 StackPosition = 0,
                 StackPosition = 0,
                 SourcePosition = null,
                 SourcePosition = null,

+ 1 - 1
src/Lua/LuaValue.cs

@@ -410,7 +410,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             return func.InvokeAsync(context with
             return func.InvokeAsync(context with
             {
             {
                 ArgumentCount = 1,
                 ArgumentCount = 1,
-                StackPosition = context.State.CurrentThread.Stack.Count,
+                StackPosition = context.Thread.Stack.Count,
             }, buffer, cancellationToken);
             }, buffer, cancellationToken);
         }
         }
         else
         else

+ 1 - 1
src/Lua/Runtime/Closure.cs

@@ -27,7 +27,7 @@ public sealed class Closure : LuaFunction
 
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
-        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.State.CurrentThread.GetCurrentFrame(), buffer, cancellationToken);
+        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.Thread.GetCurrentFrame(), buffer, cancellationToken);
     }
     }
 
 
     static UpValue GetUpValueFromDescription(LuaState state, UpValue envUpValue, Chunk proto, UpValueInfo description, int depth)
     static UpValue GetUpValueFromDescription(LuaState state, UpValue envUpValue, Chunk proto, UpValueInfo description, int depth)

+ 21 - 2
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -141,6 +141,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -180,6 +181,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -219,6 +221,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -258,6 +261,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -302,6 +306,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -341,6 +346,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -378,6 +384,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 1,
                                     ArgumentCount = 1,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -422,6 +429,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 1,
                                     ArgumentCount = 1,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -479,6 +487,7 @@ public static partial class LuaVirtualMachine
                                 await func.InvokeAsync(new()
                                 await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
                                 }, methodBuffer.AsMemory(), cancellationToken);
                                 }, methodBuffer.AsMemory(), cancellationToken);
@@ -523,6 +532,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     await func.InvokeAsync(new()
                                     {
                                     {
                                         State = state,
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
                                     }, methodBuffer, cancellationToken);
@@ -572,6 +582,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     await func.InvokeAsync(new()
                                     {
                                     {
                                         State = state,
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
                                     }, methodBuffer, cancellationToken);
@@ -625,6 +636,7 @@ public static partial class LuaVirtualMachine
                                     await func.InvokeAsync(new()
                                     await func.InvokeAsync(new()
                                     {
                                     {
                                         State = state,
                                         State = state,
+                                        Thread = thread,
                                         ArgumentCount = 2,
                                         ArgumentCount = 2,
                                         SourcePosition = chunk.SourcePositions[pc],
                                         SourcePosition = chunk.SourcePositions[pc],
                                     }, methodBuffer, cancellationToken);
                                     }, methodBuffer, cancellationToken);
@@ -691,6 +703,7 @@ public static partial class LuaVirtualMachine
                                 var resultCount = await func.InvokeAsync(new()
                                 var resultCount = await func.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = argumentCount,
                                     ArgumentCount = argumentCount,
                                     StackPosition = newBase,
                                     StackPosition = newBase,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
@@ -741,6 +754,7 @@ public static partial class LuaVirtualMachine
                             return await func.InvokeAsync(new()
                             return await func.InvokeAsync(new()
                             {
                             {
                                 State = state,
                                 State = state,
+                                Thread = thread,
                                 ArgumentCount = argumentCount,
                                 ArgumentCount = argumentCount,
                                 StackPosition = newBase,
                                 StackPosition = newBase,
                                 SourcePosition = chunk.SourcePositions[pc],
                                 SourcePosition = chunk.SourcePositions[pc],
@@ -814,6 +828,7 @@ public static partial class LuaVirtualMachine
                                 await iterator.InvokeAsync(new()
                                 await iterator.InvokeAsync(new()
                                 {
                                 {
                                     State = state,
                                     State = state,
+                                    Thread = thread,
                                     ArgumentCount = 2,
                                     ArgumentCount = 2,
                                     StackPosition = nextBase,
                                     StackPosition = nextBase,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     SourcePosition = chunk.SourcePositions[pc],
@@ -905,7 +920,8 @@ public static partial class LuaVirtualMachine
 #endif
 #endif
     static async ValueTask<LuaValue> GetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, CancellationToken cancellationToken)
     static async ValueTask<LuaValue> GetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, CancellationToken cancellationToken)
     {
     {
-        var stack = state.CurrentThread.Stack;
+        var thread = state.CurrentThread;
+        var stack = thread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
         var isTable = table.TryRead<LuaTable>(out var t);
 
 
         if (isTable && t.TryGetValue(key, out var result))
         if (isTable && t.TryGetValue(key, out var result))
@@ -929,6 +945,7 @@ public static partial class LuaVirtualMachine
                 await indexTable.InvokeAsync(new()
                 await indexTable.InvokeAsync(new()
                 {
                 {
                     State = state,
                     State = state,
+                    Thread = thread,
                     ArgumentCount = 2,
                     ArgumentCount = 2,
                     SourcePosition = chunk.SourcePositions[pc],
                     SourcePosition = chunk.SourcePositions[pc],
                 }, methodBuffer, cancellationToken);
                 }, methodBuffer, cancellationToken);
@@ -956,7 +973,8 @@ public static partial class LuaVirtualMachine
 #endif
 #endif
     static async ValueTask SetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken)
     static async ValueTask SetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken)
     {
     {
-        var stack = state.CurrentThread.Stack;
+        var thread = state.CurrentThread;
+        var stack = thread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
         var isTable = table.TryRead<LuaTable>(out var t);
 
 
         if (key.Type is LuaValueType.Number && key.TryRead<double>(out var d) && double.IsNaN(d))
         if (key.Type is LuaValueType.Number && key.TryRead<double>(out var d) && double.IsNaN(d))
@@ -986,6 +1004,7 @@ public static partial class LuaVirtualMachine
                 await indexTable.InvokeAsync(new()
                 await indexTable.InvokeAsync(new()
                 {
                 {
                     State = state,
                     State = state,
+                    Thread = thread,
                     ArgumentCount = 3,
                     ArgumentCount = 3,
                     SourcePosition = chunk.SourcePositions[pc],
                     SourcePosition = chunk.SourcePositions[pc],
                 }, methodBuffer, cancellationToken);
                 }, methodBuffer, cancellationToken);

+ 2 - 2
src/Lua/Standard/Coroutines/CoroutineRunningFunction.cs

@@ -9,8 +9,8 @@ public sealed class CoroutineRunningFunction : LuaFunction
 
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
-        buffer.Span[0] = context.State.CurrentThread;
-        buffer.Span[1] = context.State.CurrentThread == context.State.MainThread;
+        buffer.Span[0] = context.Thread;
+        buffer.Span[1] = context.Thread == context.State.MainThread;
         return new(2);
         return new(2);
     }
     }
 }
 }

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineYieldFunction.cs

@@ -8,7 +8,7 @@ public sealed class CoroutineYieldFunction : LuaFunction
 
 
     protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
-        await context.State.CurrentThread.Yield(context, cancellationToken);
+        await context.Thread.Yield(context, cancellationToken);
         return 0;
         return 0;
     }
     }
 }
 }