Browse Source

Update: support vararg in coroutine

AnnulusGames 1 year ago
parent
commit
96bb6762ea

+ 35 - 1
src/Lua/Internal/FastStackCore.cs

@@ -1,3 +1,4 @@
+using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 
 namespace Lua.Internal;
@@ -10,7 +11,7 @@ public struct FastStackCore<T>
     T?[] array;
     int tail;
 
-    public int Size => tail;
+    public int Count => tail;
 
     public readonly ReadOnlySpan<T> AsSpan()
     {
@@ -18,6 +19,12 @@ public struct FastStackCore<T>
         return array.AsSpan(0, tail);
     }
 
+    public readonly Span<T?> GetBuffer()
+    {
+        if (array == null) return [];
+        return array.AsSpan();
+    }
+
     public readonly T? this[int index]
     {
         get
@@ -41,6 +48,7 @@ public struct FastStackCore<T>
         tail++;
     }
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public bool TryPop(out T value)
     {
         if (tail == 0)
@@ -62,6 +70,7 @@ public struct FastStackCore<T>
         return result;
     }
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public bool TryPeek(out T value)
     {
         if (tail == 0)
@@ -80,6 +89,31 @@ public struct FastStackCore<T>
         return result;
     }
 
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public void EnsureCapacity(int capacity)
+    {
+        if (array == null)
+        {
+            array = new T[InitialCapacity];
+        }
+
+        var newSize = array.Length;
+        while (newSize < capacity)
+        {
+            newSize *= 2;
+        }
+
+        Array.Resize(ref array, newSize);
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public void NotifyTop(int top)
+    {
+        if (tail < top) tail = top;
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public void Clear()
     {
         array.AsSpan(0, tail).Clear();

+ 79 - 25
src/Lua/LuaCoroutine.cs

@@ -1,5 +1,7 @@
+using System.Buffers;
 using System.Threading.Tasks.Sources;
 using Lua.Internal;
+using Lua.Runtime;
 
 namespace Lua;
 
@@ -7,21 +9,31 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 {
     struct YieldContext
     {
+        public required LuaValue[] Results;
     }
 
     struct ResumeContext
     {
-        public LuaValue[] Results;
+        public required LuaValue[] Results;
     }
 
     byte status;
     bool isFirstCall = true;
-    LuaState threadState;
     ValueTask<int> functionTask;
+    LuaValue[] buffer;
 
     ManualResetValueTaskSourceCore<ResumeContext> resume;
     ManualResetValueTaskSourceCore<YieldContext> yield;
 
+    public LuaCoroutine(LuaFunction function, bool isProtectedMode)
+    {
+        IsProtectedMode = isProtectedMode;
+        Function = function;
+
+        buffer = ArrayPool<LuaValue>.Shared.Rent(1024);
+        buffer.AsSpan().Clear();
+    }
+
     public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
 
     public override void UnsafeSetStatus(LuaThreadStatus status)
@@ -32,13 +44,6 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
     public bool IsProtectedMode { get; }
     public LuaFunction Function { get; }
 
-    internal LuaCoroutine(LuaState state, LuaFunction function, bool isProtectedMode)
-    {
-        IsProtectedMode = isProtectedMode;
-        threadState = state;
-        Function = function;
-    }
-
     public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         var baseThread = context.Thread;
@@ -59,18 +64,19 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                         baseThread.Stack.AsSpan().CopyTo(Stack.GetBuffer());
                         Stack.NotifyTop(baseThread.Stack.Count);
 
-                        functionTask = Function.InvokeAsync(new()
-                        {
-                            State = threadState,
-                            Thread = this,
-                            ArgumentCount = context.ArgumentCount - 1,
-                            ChunkName = Function.Name,
-                            RootChunkName = context.RootChunkName,
-                        }, buffer[1..], cancellationToken).Preserve();
+                        // copy callstack value
+                        CallStack.EnsureCapacity(baseThread.CallStack.Count);
+                        baseThread.CallStack.AsSpan().CopyTo(CallStack.GetBuffer());
+                        CallStack.NotifyTop(baseThread.CallStack.Count);
                     }
                     else
                     {
-                        yield.SetResult(new());
+                        yield.SetResult(new()
+                        {
+                            Results = context.ArgumentCount == 1
+                                ? []
+                                : context.Arguments[1..].ToArray()
+                        });
                     }
                     break;
                 case LuaThreadStatus.Normal:
@@ -114,11 +120,47 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
                 if (isFirstCall)
                 {
-                    for (int i = 0; i < context.ArgumentCount - 1; i++)
+                    int frameBase;
+                    var variableArgumentCount = Function.GetVariableArgumentCount(context.ArgumentCount - 1);
+                    
+                    if (variableArgumentCount > 0)
+                    {
+                        var fixedArgumentCount = context.ArgumentCount - 1 - variableArgumentCount;
+
+                        for (int i = 0; i < variableArgumentCount; i++)
+                        {
+                            Stack.Push(context.GetArgument(i + fixedArgumentCount + 1));
+                        }
+
+                        Stack.Push(Function);
+                        frameBase = Stack.Count;
+
+                        for (int i = 0; i < fixedArgumentCount; i++)
+                        {
+                            Stack.Push(context.GetArgument(i + 1));
+                        }
+                    }
+                    else
                     {
-                        threadState.Push(context.Arguments[i + 1]);
+                        Stack.Push(Function);
+                        frameBase = Stack.Count;
+
+                        for (int i = 0; i < context.ArgumentCount - 1; i++)
+                        {
+                            Stack.Push(context.GetArgument(i + 1));
+                        }
                     }
 
+                    functionTask = Function.InvokeAsync(new()
+                    {
+                        State = context.State,
+                        Thread = this,
+                        ArgumentCount = context.ArgumentCount - 1,
+                        FrameBase = frameBase,
+                        ChunkName = Function.Name,
+                        RootChunkName = context.RootChunkName,
+                    }, this.buffer, cancellationToken).Preserve();
+
                     Volatile.Write(ref isFirstCall, false);
                 }
 
@@ -138,6 +180,8 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                 }
                 else
                 {
+                    ArrayPool<LuaValue>.Shared.Return(this.buffer);
+                    
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     buffer.Span[0] = true;
                     return 1 + functionTask!.Result;
@@ -147,6 +191,8 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
                 if (IsProtectedMode)
                 {
+                    ArrayPool<LuaValue>.Shared.Return(this.buffer);
+
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     buffer.Span[0] = false;
                     buffer.Span[1] = ex.Message;
@@ -170,7 +216,7 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
         }
     }
 
-    public override async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    public override async ValueTask<int> Yield(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
         {
@@ -197,16 +243,24 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
     RETRY:
         try
         {
-            await new ValueTask<YieldContext>(this, yield.Version);
+            var result = await new ValueTask<YieldContext>(this, yield.Version);
+            for (int i = 0; i < result.Results.Length; i++)
+            {
+                buffer.Span[i] = result.Results[i];
+            }
+
+            return result.Results.Length;
         }
         catch (Exception ex) when (ex is not OperationCanceledException)
         {
             yield.Reset();
             goto RETRY;
         }
-
-        registration.Dispose();
-        yield.Reset();
+        finally
+        {
+            registration.Dispose();
+            yield.Reset();
+        }
     }
 
     YieldContext IValueTaskSource<YieldContext>.GetResult(short token)

+ 11 - 4
src/Lua/LuaFunction.cs

@@ -9,11 +9,18 @@ public abstract partial class LuaFunction
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var state = context.State;
-        var thread = state.CurrentThread;
+        
+        if (context.FrameBase == null)
+        {
+            context = context with
+            {
+                FrameBase = context.Thread.Stack.Count - context.ArgumentCount
+            };
+        }
 
         var frame = new CallStackFrame
         {
-            Base = context.StackPosition == null ? thread.Stack.Count - context.ArgumentCount : context.StackPosition.Value,
+            Base = context.FrameBase.Value,
             CallPosition = context.SourcePosition,
             ChunkName = context.ChunkName ?? LuaState.DefaultChunkName,
             RootChunkName = context.RootChunkName ?? LuaState.DefaultChunkName,
@@ -21,14 +28,14 @@ public abstract partial class LuaFunction
             Function = this,
         };
 
-        thread.PushCallStackFrame(frame);
+        context.Thread.PushCallStackFrame(frame);
         try
         {
             return await InvokeAsyncCore(context, buffer, cancellationToken);
         }
         finally
         {
-            thread.PopCallStackFrame();
+            context.Thread.PopCallStackFrame();
         }
     }
 

+ 2 - 2
src/Lua/LuaFunctionExecutionContext.cs

@@ -8,7 +8,7 @@ public readonly record struct LuaFunctionExecutionContext
     public required LuaState State { get; init; }
     public required LuaThread Thread { get; init; }
     public required int ArgumentCount { get; init; }
-    public int? StackPosition { get; init; }
+    public int? FrameBase { get; init; }
     public SourcePosition? SourcePosition { get; init; }
     public string? RootChunkName { get; init; }
     public string? ChunkName { get; init; }
@@ -17,7 +17,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
         get
         {
-            return Thread.GetStackValues().Slice(Thread.GetCurrentFrame().Base, ArgumentCount);
+            return Thread.GetStackValues().Slice(FrameBase!.Value, ArgumentCount);
         }
     }
 

+ 1 - 1
src/Lua/LuaMainThread.cs

@@ -19,7 +19,7 @@ public sealed class LuaMainThread : LuaThread
         return new(2);
     }
 
-    public override ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    public override ValueTask<int> Yield(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         throw new LuaRuntimeException(context.State.GetTraceback(), "attempt to yield from outside a coroutine");
     }

+ 1 - 6
src/Lua/LuaState.cs

@@ -66,7 +66,7 @@ public sealed class LuaState
                 State = this,
                 Thread = CurrentThread,
                 ArgumentCount = 0,
-                StackPosition = 0,
+                FrameBase = 0,
                 SourcePosition = null,
                 RootChunkName = chunk.Name ?? DefaultChunkName,
                 ChunkName = chunk.Name ?? DefaultChunkName,
@@ -83,11 +83,6 @@ public sealed class LuaState
         CurrentThread.Stack.Push(value);
     }
 
-    public LuaThread CreateThread(LuaFunction function, bool isProtectedMode = true)
-    {
-        return new LuaCoroutine(this, function, isProtectedMode);
-    }
-
     public Traceback GetTraceback()
     {
         // TODO: optimize

+ 4 - 2
src/Lua/LuaThread.cs

@@ -1,3 +1,4 @@
+using System.Diagnostics;
 using Lua.Internal;
 using Lua.Runtime;
 
@@ -8,12 +9,13 @@ public abstract class LuaThread
     public abstract LuaThreadStatus GetStatus();
     public abstract void UnsafeSetStatus(LuaThreadStatus status);
     public abstract ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default);
-    public abstract ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default);
+    public abstract ValueTask<int> Yield(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default);
 
     LuaStack stack = new();
     FastStackCore<CallStackFrame> callStack;
 
     internal LuaStack Stack => stack;
+    internal ref FastStackCore<CallStackFrame> CallStack => ref callStack;
 
     public CallStackFrame GetCurrentFrame()
     {
@@ -25,7 +27,7 @@ public abstract class LuaThread
         return stack.AsSpan();
     }
 
-    internal ReadOnlySpan<CallStackFrame> GetCallStackFrames()
+    public ReadOnlySpan<CallStackFrame> GetCallStackFrames()
     {
         return callStack.AsSpan();
     }

+ 1 - 1
src/Lua/LuaValue.cs

@@ -410,7 +410,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             return func.InvokeAsync(context with
             {
                 ArgumentCount = 1,
-                StackPosition = context.Thread.Stack.Count,
+                FrameBase = context.Thread.Stack.Count,
             }, buffer, cancellationToken);
         }
         else

+ 9 - 0
src/Lua/Runtime/LuaValueRuntimeExtensions.cs

@@ -5,6 +5,7 @@ namespace Lua.Runtime;
 
 internal static class LuaRuntimeExtensions
 {
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public static bool TryGetMetamethod(this LuaValue value, LuaState state, string methodName, out LuaValue result)
     {
         result = default;
@@ -12,6 +13,14 @@ internal static class LuaRuntimeExtensions
             metatable.TryGetValue(methodName, out result);
     }
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public static int GetVariableArgumentCount(this LuaFunction function, int argumentCount)
+    {
+        return function is Closure luaClosure
+            ? argumentCount - luaClosure.Proto.ParameterCount
+            : 0;
+    }
+
 #if NET6_0_OR_GREATER
     [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
 #endif

+ 10 - 12
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -694,7 +694,7 @@ public static partial class LuaVirtualMachine
                                 }
                             }
 
-                            (var newBase, var argumentCount) = PrepareForFunctionCall(state, func, instruction, RA, false);
+                            (var newBase, var argumentCount) = PrepareForFunctionCall(thread, func, instruction, RA, false);
 
                             var resultBuffer = ArrayPool<LuaValue>.Shared.Rent(1024);
                             resultBuffer.AsSpan().Clear();
@@ -705,7 +705,7 @@ public static partial class LuaVirtualMachine
                                     State = state,
                                     Thread = thread,
                                     ArgumentCount = argumentCount,
-                                    StackPosition = newBase,
+                                    FrameBase = newBase,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     ChunkName = chunk.Name,
                                     RootChunkName = chunk.GetRoot().Name,
@@ -749,14 +749,14 @@ public static partial class LuaVirtualMachine
                                 }
                             }
 
-                            (var newBase, var argumentCount) = PrepareForFunctionCall(state, func, instruction, RA, true);
+                            (var newBase, var argumentCount) = PrepareForFunctionCall(thread, func, instruction, RA, true);
 
                             return await func.InvokeAsync(new()
                             {
                                 State = state,
                                 Thread = thread,
                                 ArgumentCount = argumentCount,
-                                StackPosition = newBase,
+                                FrameBase = newBase,
                                 SourcePosition = chunk.SourcePositions[pc],
                                 ChunkName = chunk.Name,
                                 RootChunkName = chunk.GetRoot().Name,
@@ -830,7 +830,7 @@ public static partial class LuaVirtualMachine
                                     State = state,
                                     Thread = thread,
                                     ArgumentCount = 2,
-                                    StackPosition = nextBase,
+                                    FrameBase = nextBase,
                                     SourcePosition = chunk.SourcePositions[pc],
                                     ChunkName = chunk.Name,
                                     RootChunkName = chunk.GetRoot().Name,
@@ -887,7 +887,7 @@ public static partial class LuaVirtualMachine
                             for (int i = 0; i < count; i++)
                             {
                                 stack.UnsafeGet(RA + i) = frame.VariableArgumentCount > i
-                                    ? stack.UnsafeGet(frame.Base - (frame.VariableArgumentCount - i))
+                                    ? stack.UnsafeGet(frame.Base - (frame.VariableArgumentCount - i + 1))
                                     : LuaValue.Nil;
                             }
                             stack.NotifyTop(RA + count);
@@ -1024,9 +1024,9 @@ public static partial class LuaVirtualMachine
         }
     }
 
-    static (int FrameBase, int ArgumentCount) PrepareForFunctionCall(LuaState state, LuaFunction function, Instruction instruction, int RA, bool isTailCall)
+    static (int FrameBase, int ArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function, Instruction instruction, int RA, bool isTailCall)
     {
-        var stack = state.CurrentThread.Stack;
+        var stack = thread.Stack;
 
         var argumentCount = instruction.B - 1;
         if (instruction.B == 0)
@@ -1040,15 +1040,13 @@ public static partial class LuaVirtualMachine
         // Therefore, a call can be made without allocating new registers.
         if (isTailCall)
         {
-            var currentBase = state.CurrentThread.GetCurrentFrame().Base;
+            var currentBase = thread.GetCurrentFrame().Base;
             var stackBuffer = stack.GetBuffer();
             stackBuffer.Slice(newBase, argumentCount).CopyTo(stackBuffer.Slice(currentBase, argumentCount));
             newBase = currentBase;
         }
 
-        var variableArgumentCount = function is Closure luaClosure
-            ? argumentCount - luaClosure.Proto.ParameterCount
-            : 0;
+        var variableArgumentCount = function.GetVariableArgumentCount(argumentCount);
 
         // If there are variable arguments, the base of the stack is moved by that number and the values ​​of the variable arguments are placed in front of it.
         // see: https://wubingzheng.github.io/build-lua-in-rust/en/ch08-02.arguments.html

+ 1 - 1
src/Lua/Standard/Basic/PCallFunction.cs

@@ -19,7 +19,7 @@ public sealed class PCallFunction : LuaFunction
             {
                 State = context.State,
                 ArgumentCount = context.ArgumentCount - 1,
-                StackPosition = context.StackPosition + 1,
+                FrameBase = context.FrameBase + 1,
             }, methodBuffer.AsMemory(), cancellationToken);
 
             buffer.Span[0] = true;

+ 2 - 2
src/Lua/Standard/Basic/XPCallFunction.cs

@@ -21,7 +21,7 @@ public sealed class XPCallFunction : LuaFunction
             {
                 State = context.State,
                 ArgumentCount = context.ArgumentCount - 2,
-                StackPosition = context.StackPosition + 2,
+                FrameBase = context.FrameBase + 2,
             }, methodBuffer.AsMemory(), cancellationToken);
 
             buffer.Span[0] = true;
@@ -40,7 +40,7 @@ public sealed class XPCallFunction : LuaFunction
             {
                 State = context.State,
                 ArgumentCount = 1,
-                StackPosition = null,
+                FrameBase = null,
             }, methodBuffer.AsMemory(), cancellationToken);
 
             buffer.Span[0] = false;

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineCreateFunction.cs

@@ -9,7 +9,7 @@ public sealed class CoroutineCreateFunction : LuaFunction
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var arg0 = context.GetArgument<LuaFunction>(0);
-        buffer.Span[0] = context.State.CreateThread(arg0, true);
+        buffer.Span[0] = new LuaCoroutine(arg0, true);
         return new(1);
     }
 }

+ 1 - 1
src/Lua/Standard/Coroutines/CoroutineWrapFunction.cs

@@ -10,7 +10,7 @@ public sealed class CoroutineWrapFunction : LuaFunction
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var arg0 = context.GetArgument<LuaFunction>(0);
-        var thread = context.State.CreateThread(arg0, false);
+        var thread = new LuaCoroutine(arg0, false);
         buffer.Span[0] = new Wrapper(thread);
         return new(1);
     }

+ 2 - 3
src/Lua/Standard/Coroutines/CoroutineYieldFunction.cs

@@ -6,9 +6,8 @@ public sealed class CoroutineYieldFunction : LuaFunction
     public static readonly CoroutineYieldFunction Instance = new();
     public override string Name => "yield";
 
-    protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        await context.Thread.Yield(context, cancellationToken);
-        return 0;
+        return context.Thread.Yield(context, buffer, cancellationToken);
     }
 }

+ 1 - 1
src/Lua/Standard/Table/SortFunction.cs

@@ -63,7 +63,7 @@ public sealed class SortFunction : LuaFunction
             await comparer.InvokeAsync(context with
             {
                 ArgumentCount = 2,
-                StackPosition = null,
+                FrameBase = null,
             }, methodBuffer.AsMemory(), cancellationToken);
 
             if (methodBuffer[0].ToBoolean())

+ 1 - 1
src/Lua/Standard/Text/GSubFunction.cs

@@ -65,7 +65,7 @@ public sealed class GSubFunction : LuaFunction
                 await func.InvokeAsync(context with
                 {
                     ArgumentCount = match.Groups.Count,
-                    StackPosition = null
+                    FrameBase = null
                 }, methodBuffer.AsMemory(), cancellationToken);
 
                 result = methodBuffer[0];