Browse Source

Merge pull request #4 from AnnulusGames/feat-coroutine

Add: coroutine
Annulus Games 1 year ago
parent
commit
0968c2cc81

+ 11 - 45
sandbox/Benchmark/AddBenchmark.cs

@@ -1,4 +1,3 @@
-using System.Buffers;
 using BenchmarkDotNet.Attributes;
 using BenchmarkDotNet.Attributes;
 using Lua;
 using Lua;
 using MoonSharp.Interpreter;
 using MoonSharp.Interpreter;
@@ -6,83 +5,50 @@ using MoonSharp.Interpreter;
 [Config(typeof(BenchmarkConfig))]
 [Config(typeof(BenchmarkConfig))]
 public class AddBenchmark
 public class AddBenchmark
 {
 {
-    NLua.Lua nLuaState = default!;
-    Script moonSharpState = default!;
-    LuaState luaCSharpState = default!;
-
-    string filePath = default!;
-    string sourceText = default!;
+    BenchmarkCore core = new();
+    LuaValue[] buffer = new LuaValue[1];
 
 
     [GlobalSetup]
     [GlobalSetup]
     public void GlobalSetup()
     public void GlobalSetup()
     {
     {
-        // moonsharp
-        moonSharpState = new Script();
-        Script.WarmUp();
-
-        // NLua
-        nLuaState = new();
-
-        // Lua-CSharp
-        luaCSharpState = LuaState.Create();
-
-        filePath = FileHelper.GetAbsolutePath("add.lua");
-        sourceText = File.ReadAllText(filePath);
+        core.Setup("add.lua");
     }
     }
 
 
     [Benchmark(Description = "MoonSharp (RunString)")]
     [Benchmark(Description = "MoonSharp (RunString)")]
     public DynValue Benchmark_MoonSharp_String()
     public DynValue Benchmark_MoonSharp_String()
     {
     {
-        var result = moonSharpState.DoString(sourceText);
-        return result;
+        return core.MoonSharpState.DoString(core.SourceText);
     }
     }
 
 
     [Benchmark(Description = "MoonSharp (RunFile)")]
     [Benchmark(Description = "MoonSharp (RunFile)")]
     public DynValue Benchmark_MoonSharp_File()
     public DynValue Benchmark_MoonSharp_File()
     {
     {
-        var result = moonSharpState.DoFile(filePath);
-        return result;
+        return core.MoonSharpState.DoFile(core.FilePath);
     }
     }
 
 
     [Benchmark(Description = "NLua (DoString)")]
     [Benchmark(Description = "NLua (DoString)")]
     public object[] Benchmark_NLua_String()
     public object[] Benchmark_NLua_String()
     {
     {
-        return nLuaState.DoString(sourceText);
+        return core.NLuaState.DoString(core.SourceText);
     }
     }
 
 
     [Benchmark(Description = "NLua (DoFile)")]
     [Benchmark(Description = "NLua (DoFile)")]
     public object[] Benchmark_NLua_File()
     public object[] Benchmark_NLua_File()
     {
     {
-        return nLuaState.DoFile(filePath);
+        return core.NLuaState.DoFile(core.FilePath);
     }
     }
 
 
     [Benchmark(Description = "Lua-CSharp (DoString)")]
     [Benchmark(Description = "Lua-CSharp (DoString)")]
     public async Task<LuaValue> Benchmark_LuaCSharp_String()
     public async Task<LuaValue> Benchmark_LuaCSharp_String()
     {
     {
-        var buffer = ArrayPool<LuaValue>.Shared.Rent(1);
-        try
-        {
-            await luaCSharpState.DoStringAsync(sourceText, buffer);
-            return buffer[0];
-        }
-        finally
-        {
-            ArrayPool<LuaValue>.Shared.Return(buffer);
-        }
+        await core.LuaCSharpState.DoStringAsync(core.SourceText, buffer);
+        return buffer[0];
     }
     }
 
 
     [Benchmark(Description = "Lua-CSharp (DoFileAsync)")]
     [Benchmark(Description = "Lua-CSharp (DoFileAsync)")]
     public async Task<LuaValue> Benchmark_LuaCSharp_File()
     public async Task<LuaValue> Benchmark_LuaCSharp_File()
     {
     {
-        var buffer = ArrayPool<LuaValue>.Shared.Rent(1);
-        try
-        {
-            await luaCSharpState.DoFileAsync(filePath, buffer);
-            return buffer[0];
-        }
-        finally
-        {
-            ArrayPool<LuaValue>.Shared.Return(buffer);
-        }
+        await core.LuaCSharpState.DoFileAsync(core.FilePath, buffer);
+        return buffer[0];
     }
     }
 }
 }

+ 33 - 0
sandbox/Benchmark/BenchmarkCore.cs

@@ -0,0 +1,33 @@
+using Lua;
+using MoonSharp.Interpreter;
+
+public class BenchmarkCore
+{
+    public NLua.Lua NLuaState => nLuaState;
+    public Script MoonSharpState => moonSharpState;
+    public LuaState LuaCSharpState => luaCSharpState;
+    public string FilePath => filePath;
+    public string SourceText => sourceText;
+
+    NLua.Lua nLuaState = default!;
+    Script moonSharpState = default!;
+    LuaState luaCSharpState = default!;
+    string filePath = default!;
+    string sourceText = default!;
+
+    public void Setup(string fileName)
+    {
+        // moonsharp
+        moonSharpState = new Script();
+        Script.WarmUp();
+
+        // NLua
+        nLuaState = new();
+
+        // Lua-CSharp
+        luaCSharpState = LuaState.Create();
+
+        filePath = FileHelper.GetAbsolutePath(fileName);
+        sourceText = File.ReadAllText(filePath);
+    }
+}

+ 54 - 0
sandbox/Benchmark/CoroutineBenchmark.cs

@@ -0,0 +1,54 @@
+using BenchmarkDotNet.Attributes;
+using Lua;
+using MoonSharp.Interpreter;
+
+[Config(typeof(BenchmarkConfig))]
+public class CoroutineBenchmark
+{
+    BenchmarkCore core = new();
+    LuaValue[] buffer = new LuaValue[1];
+
+    [GlobalSetup]
+    public void GlobalSetup()
+    {
+        core.Setup("coroutine.lua");
+    }
+
+    [Benchmark(Description = "MoonSharp (RunString)")]
+    public DynValue Benchmark_MoonSharp_String()
+    {
+        return core.MoonSharpState.DoString(core.SourceText);
+    }
+
+    [Benchmark(Description = "MoonSharp (RunFile)")]
+    public DynValue Benchmark_MoonSharp_File()
+    {
+        return core.MoonSharpState.DoFile(core.FilePath);
+    }
+
+    [Benchmark(Description = "NLua (DoString)")]
+    public object[] Benchmark_NLua_String()
+    {
+        return core.NLuaState.DoString(core.SourceText);
+    }
+
+    [Benchmark(Description = "NLua (DoFile)")]
+    public object[] Benchmark_NLua_File()
+    {
+        return core.NLuaState.DoFile(core.FilePath);
+    }
+
+    [Benchmark(Description = "Lua-CSharp (DoString)")]
+    public async Task<LuaValue> Benchmark_LuaCSharp_String()
+    {
+        await core.LuaCSharpState.DoStringAsync(core.SourceText, buffer);
+        return buffer[0];
+    }
+
+    [Benchmark(Description = "Lua-CSharp (DoFileAsync)")]
+    public async Task<LuaValue> Benchmark_LuaCSharp_File()
+    {
+        await core.LuaCSharpState.DoFileAsync(core.FilePath, buffer);
+        return buffer[0];
+    }
+}

+ 17 - 0
sandbox/Benchmark/coroutine.lua

@@ -0,0 +1,17 @@
+local function c()
+    for i = 1, 100 do
+        coroutine.yield(i)
+    end
+
+    return 1000
+end
+
+local co = coroutine.create(c)
+
+local x = 0
+while coroutine.status(co) ~= "dead" do
+    local _, i = coroutine.resume(co)
+    x = x + i
+end
+
+return x

+ 44 - 16
sandbox/ConsoleApp1/Program.cs

@@ -5,31 +5,59 @@ using Lua;
 using Lua.Standard;
 using Lua.Standard;
 
 
 var state = LuaState.Create();
 var state = LuaState.Create();
-state.OpenBaseLibrary();
+state.OpenBasicLibrary();
 
 
 try
 try
 {
 {
     var source =
     var source =
-@"
-metatable = {
-    __add = function(a, b)
-        local t = { }
-
-        for i = 1, #a do
-            t[i] = a[i] + b[i]
+"""
+-- メインコルーチンの定義
+local co_main = coroutine.create(function ()
+    print("Main coroutine starts")
+
+    -- コルーチンAの定義
+    local co_a = coroutine.create(function()
+        for i = 1, 3 do
+            print("Coroutine A, iteration "..i)
+            coroutine.yield()
         end
         end
-
-        return t
+        print("Coroutine A ends")
+    end)
+
+    --コルーチンBの定義
+    local co_b = coroutine.create(function()
+        print("Coroutine B starts")
+        coroutine.yield()-- 一時停止
+        print("Coroutine B resumes")
+    end)
+
+    -- コルーチンCの定義(コルーチンBを呼び出す)
+    local co_c = coroutine.create(function()
+        print("Coroutine C starts")
+        coroutine.resume(co_b)-- コルーチンBを実行
+        print("Coroutine C calls B and resumes")
+        coroutine.yield()-- 一時停止
+        print("Coroutine C resumes")
+    end)
+
+    -- コルーチンAとCの交互実行
+    for _ = 1, 2 do
+            coroutine.resume(co_a)
+        coroutine.resume(co_c)
     end
     end
-}
 
 
-local a = { 1, 2, 3 }
-local b = { 4, 5, 6 }
+    -- コルーチンAを再開し完了させる
+    coroutine.resume(co_a)
+
+    -- コルーチンCを再開し完了させる
+    coroutine.resume(co_c)
 
 
-setmetatable(a, metatable)
+    print("Main coroutine ends")
+end)
 
 
-return a + b
-";
+--メインコルーチンを開始
+coroutine.resume(co_main)
+""";
 
 
     var syntaxTree = LuaSyntaxTree.Parse(source, "main.lua");
     var syntaxTree = LuaSyntaxTree.Parse(source, "main.lua");
 
 

+ 0 - 87
src/Lua/Internal/AutoResizeArrayCore.cs

@@ -1,87 +0,0 @@
-using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
-
-namespace Lua.Internal;
-
-internal struct AutoResizeArrayCore<T>
-{
-    T[]? array;
-    int size;
-
-    public int Size => size;
-    public int Capacity => array == null ? 0 : array.Length;
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void Add(T item)
-    {
-        this[size] = item;
-    }
-
-    public ref T this[int index]
-    {
-        get
-        {
-            EnsureCapacity(index);
-            size = Math.Max(size, index + 1);
-
-#if NET6_0_OR_GREATER
-            ref var reference = ref MemoryMarshal.GetArrayDataReference(array!);
-#else
-            ref var reference = ref MemoryMarshal.GetReference(array.AsSpan());
-#endif
-
-            return ref Unsafe.Add(ref reference, index);
-        }
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public Span<T> AsSpan()
-    {
-        return array.AsSpan(0, size);
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public T[] GetInternalArray()
-    {
-        return array ?? [];
-    }
-
-    [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void Clear()
-    {
-        array.AsSpan().Clear();
-    }
-
-    public void EnsureCapacity(int newCapacity, bool overrideSize = false)
-    {
-        var capacity = 64;
-        while (capacity <= newCapacity)
-        {
-            capacity = MathEx.NewArrayCapacity(capacity);
-        }
-
-        if (array == null)
-        {
-            array = new T[capacity];
-        }
-        else
-        {
-            Array.Resize(ref array, capacity);
-        }
-
-        if (overrideSize)
-        {
-            size = newCapacity;
-        }
-    }
-
-    public void Shrink(int newSize)
-    {
-        if (array != null && array.Length > newSize)
-        {
-            array.AsSpan(newSize).Clear();
-        }
-
-        size = newSize;
-    }
-}

+ 13 - 0
src/Lua/Internal/CancellationExtensions.cs

@@ -0,0 +1,13 @@
+#if NETSTANDARD2_0 || NETSTANDARD2_1
+
+namespace System.Threading;
+
+internal static class CancellationTokenExtensions
+{
+    public static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action<object?> callback, object? state)
+    {
+        return cancellationToken.Register(callback, state, useSynchronizationContext: false);
+    }
+}
+
+#endif

+ 2 - 4
src/Lua/Internal/FastListCore.cs

@@ -85,12 +85,10 @@ public struct FastListCore<T>
         AsSpan().CopyTo(destination.AsSpan());
         AsSpan().CopyTo(destination.AsSpan());
     }
     }
 
 
-    public readonly T this[int index]
+    public ref T this[int index]
     {
     {
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        get => array![index];
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        set => array![index] = value;
+        get => ref array![index];
     }
     }
 
 
     public readonly int Length
     public readonly int Length

+ 263 - 0
src/Lua/Internal/ValueTaskEx.cs

@@ -0,0 +1,263 @@
+using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
+using System.Threading.Tasks.Sources;
+
+/*
+
+ValueTaskEx based on ValueTaskSupprement
+https://github.com/Cysharp/ValueTaskSupplement
+
+MIT License
+
+Copyright (c) 2019 Cysharp, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+*/
+
+namespace Lua.Internal;
+
+internal static class ContinuationSentinel
+{
+    public static readonly Action<object?> AvailableContinuation = _ => { };
+    public static readonly Action<object?> CompletedContinuation = _ => { };
+}
+
+internal static class ValueTaskEx
+{
+    public static ValueTask<(int winArgumentIndex, T0 result0, T1 result1)> WhenAny<T0, T1>(ValueTask<T0> task0, ValueTask<T1> task1)
+    {
+        return new ValueTask<(int winArgumentIndex, T0 result0, T1 result1)>(new WhenAnyPromise<T0, T1>(task0, task1), 0);
+    }
+
+    class WhenAnyPromise<T0, T1> : IValueTaskSource<(int winArgumentIndex, T0 result0, T1 result1)>
+    {
+        static readonly ContextCallback execContextCallback = ExecutionContextCallback!;
+        static readonly SendOrPostCallback syncContextCallback = SynchronizationContextCallback!;
+
+        T0 t0 = default!;
+        T1 t1 = default!;
+        ValueTaskAwaiter<T0> awaiter0;
+        ValueTaskAwaiter<T1> awaiter1;
+
+        int completedCount = 0;
+        int winArgumentIndex = -1;
+        ExceptionDispatchInfo? exception;
+        Action<object?> continuation = ContinuationSentinel.AvailableContinuation;
+        Action<object?>? invokeContinuation;
+        object? state;
+        SynchronizationContext? syncContext;
+        ExecutionContext? execContext;
+
+        public WhenAnyPromise(ValueTask<T0> task0, ValueTask<T1> task1)
+        {
+            {
+                var awaiter = task0.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t0 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(0);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter0 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT0);
+                }
+            }
+            {
+                var awaiter = task1.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t1 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(1);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter1 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT1);
+                }
+            }
+        }
+
+        void ContinuationT0()
+        {
+            try
+            {
+                t0 = awaiter0.GetResult();
+            }
+            catch (Exception ex)
+            {
+                exception = ExceptionDispatchInfo.Capture(ex);
+                TryInvokeContinuation();
+                return;
+            }
+            TryInvokeContinuationWithIncrement(0);
+        }
+
+        void ContinuationT1()
+        {
+            try
+            {
+                t1 = awaiter1.GetResult();
+            }
+            catch (Exception ex)
+            {
+                exception = ExceptionDispatchInfo.Capture(ex);
+                TryInvokeContinuation();
+                return;
+            }
+            TryInvokeContinuationWithIncrement(1);
+        }
+
+
+        void TryInvokeContinuationWithIncrement(int index)
+        {
+            if (Interlocked.Increment(ref completedCount) == 1)
+            {
+                Volatile.Write(ref winArgumentIndex, index);
+                TryInvokeContinuation();
+            }
+        }
+
+        void TryInvokeContinuation()
+        {
+            var c = Interlocked.Exchange(ref continuation, ContinuationSentinel.CompletedContinuation);
+            if (c != ContinuationSentinel.AvailableContinuation && c != ContinuationSentinel.CompletedContinuation)
+            {
+                var spinWait = new SpinWait();
+                while (state == null) // worst case, state is not set yet so wait.
+                {
+                    spinWait.SpinOnce();
+                }
+
+                if (execContext != null)
+                {
+                    invokeContinuation = c;
+                    ExecutionContext.Run(execContext, execContextCallback, this);
+                }
+                else if (syncContext != null)
+                {
+                    invokeContinuation = c;
+                    syncContext.Post(syncContextCallback, this);
+                }
+                else
+                {
+                    c(state);
+                }
+            }
+        }
+
+        public (int winArgumentIndex, T0 result0, T1 result1) GetResult(short token)
+        {
+            if (exception != null)
+            {
+                exception.Throw();
+            }
+            var i = winArgumentIndex;
+            return (winArgumentIndex, t0, t1);
+        }
+
+        public ValueTaskSourceStatus GetStatus(short token)
+        {
+            return Volatile.Read(ref winArgumentIndex) != -1 ? ValueTaskSourceStatus.Succeeded
+                : exception != null ? exception.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : ValueTaskSourceStatus.Faulted
+                : ValueTaskSourceStatus.Pending;
+        }
+
+        public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+        {
+            var c = Interlocked.CompareExchange(ref this.continuation, continuation, ContinuationSentinel.AvailableContinuation);
+            if (c == ContinuationSentinel.CompletedContinuation)
+            {
+                continuation(state);
+                return;
+            }
+
+            if (c != ContinuationSentinel.AvailableContinuation)
+            {
+                throw new InvalidOperationException("does not allow multiple await.");
+            }
+
+            if (state == null)
+            {
+                throw new InvalidOperationException("invalid state.");
+            }
+
+            if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) == ValueTaskSourceOnCompletedFlags.FlowExecutionContext)
+            {
+                execContext = ExecutionContext.Capture();
+            }
+            if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) == ValueTaskSourceOnCompletedFlags.UseSchedulingContext)
+            {
+                syncContext = SynchronizationContext.Current;
+            }
+            this.state = state;
+
+            if (GetStatus(token) != ValueTaskSourceStatus.Pending)
+            {
+                TryInvokeContinuation();
+            }
+        }
+
+        static void ExecutionContextCallback(object state)
+        {
+            var self = (WhenAnyPromise<T0, T1>)state;
+            if (self.syncContext != null)
+            {
+                self.syncContext.Post(syncContextCallback, self);
+            }
+            else
+            {
+                var invokeContinuation = self.invokeContinuation!;
+                var invokeState = self.state;
+                self.invokeContinuation = null;
+                self.state = null;
+                invokeContinuation(invokeState);
+            }
+        }
+
+        static void SynchronizationContextCallback(object state)
+        {
+            var self = (WhenAnyPromise<T0, T1>)state;
+            var invokeContinuation = self.invokeContinuation!;
+            var invokeState = self.state;
+            self.invokeContinuation = null;
+            self.state = null;
+            invokeContinuation(invokeState);
+        }
+    }
+}

+ 217 - 0
src/Lua/LuaCoroutine.cs

@@ -0,0 +1,217 @@
+using System.Threading.Tasks.Sources;
+using Lua.Internal;
+
+namespace Lua;
+
+public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.YieldContext>, IValueTaskSource<LuaCoroutine.ResumeContext>
+{
+    struct YieldContext
+    {
+    }
+
+    struct ResumeContext
+    {
+        public LuaValue[] Results;
+    }
+
+    byte status;
+    LuaState threadState;
+    ValueTask<int> functionTask;
+
+    ManualResetValueTaskSourceCore<ResumeContext> resume;
+    ManualResetValueTaskSourceCore<YieldContext> yield;
+
+    public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
+    public bool IsProtectedMode { get; }
+    public LuaFunction Function { get; }
+
+    internal LuaCoroutine(LuaState state, LuaFunction function, bool isProtectedMode)
+    {
+        IsProtectedMode = isProtectedMode;
+        threadState = state;
+        Function = function;
+    }
+
+    public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
+    {
+        context.State.ThreadStack.Push(this);
+        try
+        {
+            switch ((LuaThreadStatus)Volatile.Read(ref status))
+            {
+                case LuaThreadStatus.Normal:
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+
+                    // first argument is LuaThread object
+                    for (int i = 0; i < context.ArgumentCount - 1; i++)
+                    {
+                        threadState.Push(context.Arguments[i + 1]);
+                    }
+
+                    functionTask = Function.InvokeAsync(new()
+                    {
+                        State = threadState,
+                        ArgumentCount = context.ArgumentCount - 1,
+                        ChunkName = Function.Name,
+                        RootChunkName = context.RootChunkName,
+                    }, buffer[1..], cancellationToken).Preserve();
+
+                    break;
+                case LuaThreadStatus.Suspended:
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
+                    yield.SetResult(new());
+                    break;
+                case LuaThreadStatus.Running:
+                    if (IsProtectedMode)
+                    {
+                        buffer.Span[0] = false;
+                        buffer.Span[1] = "cannot resume non-suspended coroutine";
+                        return 2;
+                    }
+                    else
+                    {
+                        throw new InvalidOperationException("cannot resume non-suspended coroutine");
+                    }
+                case LuaThreadStatus.Dead:
+                    if (IsProtectedMode)
+                    {
+                        buffer.Span[0] = false;
+                        buffer.Span[1] = "cannot resume dead coroutine";
+                        return 2;
+                    }
+                    else
+                    {
+                        throw new InvalidOperationException("cannot resume dead coroutine");
+                    }
+            }
+
+            var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
+
+            CancellationTokenRegistration registration = default;
+            if (cancellationToken.CanBeCanceled)
+            {
+                registration = cancellationToken.UnsafeRegister(static x =>
+                {
+                    var coroutine = (LuaCoroutine)x!;
+                    coroutine.yield.SetException(new OperationCanceledException());
+                }, this);
+            }
+
+            try
+            {
+                (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
+
+                if (index == 0)
+                {
+                    var results = result0.Results;
+
+                    buffer.Span[0] = true;
+                    for (int i = 0; i < results.Length; i++)
+                    {
+                        buffer.Span[i + 1] = results[i];
+                    }
+
+                    return results.Length + 1;
+                }
+                else
+                {
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
+                    buffer.Span[0] = true;
+                    return 1 + functionTask!.Result;
+                }
+            }
+            catch (Exception ex) when (ex is not OperationCanceledException)
+            {
+                if (IsProtectedMode)
+                {
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
+                    buffer.Span[0] = false;
+                    buffer.Span[1] = ex.Message;
+                    return 2;
+                }
+                else
+                {
+                    throw;
+                }
+            }
+            finally
+            {
+                registration.Dispose();
+                resume.Reset();
+            }
+        }
+        finally
+        {
+            context.State.ThreadStack.Pop();
+        }
+    }
+
+    public override async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    {
+        if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
+        {
+            throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
+        }
+
+        resume.SetResult(new()
+        {
+            Results = context.Arguments.ToArray(),
+        });
+
+        Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
+
+        CancellationTokenRegistration registration = default;
+        if (cancellationToken.CanBeCanceled)
+        {
+            registration = cancellationToken.UnsafeRegister(static x =>
+            {
+                var coroutine = (LuaCoroutine)x!;
+                coroutine.yield.SetException(new OperationCanceledException());
+            }, this);
+        }
+
+    RETRY:
+        try
+        {
+            await new ValueTask<YieldContext>(this, yield.Version);
+        }
+        catch (Exception ex) when (ex is not OperationCanceledException)
+        {
+            yield.Reset();
+            goto RETRY;
+        }
+
+        registration.Dispose();
+        yield.Reset();
+    }
+
+    YieldContext IValueTaskSource<YieldContext>.GetResult(short token)
+    {
+        return yield.GetResult(token);
+    }
+
+    ValueTaskSourceStatus IValueTaskSource<YieldContext>.GetStatus(short token)
+    {
+        return yield.GetStatus(token);
+    }
+
+    void IValueTaskSource<YieldContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+    {
+        yield.OnCompleted(continuation, state, token, flags);
+    }
+
+    ResumeContext IValueTaskSource<ResumeContext>.GetResult(short token)
+    {
+        return resume.GetResult(token);
+    }
+
+    ValueTaskSourceStatus IValueTaskSource<ResumeContext>.GetStatus(short token)
+    {
+        return resume.GetStatus(token);
+    }
+
+    void IValueTaskSource<ResumeContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+    {
+        resume.OnCompleted(continuation, state, token, flags);
+    }
+}

+ 10 - 4
src/Lua/LuaFunction.cs

@@ -4,13 +4,16 @@ namespace Lua;
 
 
 public abstract partial class LuaFunction
 public abstract partial class LuaFunction
 {
 {
+    public virtual string Name => GetType().Name;
+
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
         var state = context.State;
         var state = context.State;
+        var thread = state.CurrentThread;
 
 
         var frame = new CallStackFrame
         var frame = new CallStackFrame
         {
         {
-            Base = context.StackPosition == null ? state.Stack.Count - context.ArgumentCount : context.StackPosition.Value,
+            Base = context.StackPosition == null ? thread.Stack.Count - context.ArgumentCount : context.StackPosition.Value,
             CallPosition = context.SourcePosition,
             CallPosition = context.SourcePosition,
             ChunkName = context.ChunkName ?? LuaState.DefaultChunkName,
             ChunkName = context.ChunkName ?? LuaState.DefaultChunkName,
             RootChunkName = context.RootChunkName ?? LuaState.DefaultChunkName,
             RootChunkName = context.RootChunkName ?? LuaState.DefaultChunkName,
@@ -18,17 +21,20 @@ public abstract partial class LuaFunction
             Function = this,
             Function = this,
         };
         };
 
 
-        state.PushCallStackFrame(frame);
+        thread.PushCallStackFrame(frame);
         try
         try
         {
         {
             return await InvokeAsyncCore(context, buffer, cancellationToken);
             return await InvokeAsyncCore(context, buffer, cancellationToken);
         }
         }
+        catch (Exception ex) when (ex is not (LuaException or OperationCanceledException))
+        {
+            throw new LuaRuntimeException(thread.GetTracebacks(), ex.Message);
+        }
         finally
         finally
         {
         {
-            state.PopCallStackFrame();
+            thread.PopCallStackFrame();
         }
         }
     }
     }
 
 
-    public virtual string Name => GetType().Name;
     protected abstract ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken);
     protected abstract ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken);
 }
 }

+ 12 - 5
src/Lua/LuaFunctionExecutionContext.cs

@@ -12,8 +12,14 @@ public readonly record struct LuaFunctionExecutionContext
     public string? RootChunkName { get; init; }
     public string? RootChunkName { get; init; }
     public string? ChunkName { get; init; }
     public string? ChunkName { get; init; }
 
 
-    public int FrameBase => State.GetCurrentFrame().Base;
-    public ReadOnlySpan<LuaValue> Arguments => State.GetStackValues().Slice(FrameBase, ArgumentCount);
+    public ReadOnlySpan<LuaValue> Arguments
+    {
+        get
+        {
+            var thread = State.CurrentThread;
+            return thread.GetStackValues().Slice(thread.GetCurrentFrame().Base, ArgumentCount);
+        }
+    }
     
     
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public LuaValue ReadArgument(int index)
     public LuaValue ReadArgument(int index)
@@ -30,13 +36,14 @@ public readonly record struct LuaFunctionExecutionContext
         var arg = Arguments[index];
         var arg = Arguments[index];
         if (!arg.TryRead<T>(out var argValue))
         if (!arg.TryRead<T>(out var argValue))
         {
         {
+            var thread = State.CurrentThread;
             if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
             if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
             {
             {
-                LuaRuntimeException.BadArgument(State.GetTracebacks(), 1, State.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTracebacks(), 1, thread.GetCurrentFrame().Function.Name, type.ToString(), arg.Type.ToString());
             }
             }
             else
             else
             {
             {
-                LuaRuntimeException.BadArgument(State.GetTracebacks(), 1, State.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
+                LuaRuntimeException.BadArgument(State.GetTracebacks(), 1, thread.GetCurrentFrame().Function.Name, typeof(T).Name, arg.Type.ToString());
             }
             }
         }
         }
 
 
@@ -47,7 +54,7 @@ public readonly record struct LuaFunctionExecutionContext
     {
     {
         if (ArgumentCount <= index)
         if (ArgumentCount <= index)
         {
         {
-            LuaRuntimeException.BadArgument(State.GetTracebacks(), index + 1, State.GetCurrentFrame().Function.Name);
+            LuaRuntimeException.BadArgument(State.GetTracebacks(), index + 1, State.CurrentThread.GetCurrentFrame().Function.Name);
         }
         }
     }
     }
 }
 }

+ 21 - 0
src/Lua/LuaMainThread.cs

@@ -0,0 +1,21 @@
+namespace Lua;
+
+public sealed class LuaMainThread : LuaThread
+{
+    public override LuaThreadStatus GetStatus()
+    {
+        return LuaThreadStatus.Running;
+    }
+
+    public override ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
+    {
+        buffer.Span[0] = false;
+        buffer.Span[1] = "cannot resume non-suspended coroutine";
+        return new(2);
+    }
+
+    public override ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    {
+        throw new LuaRuntimeException(context.State.GetTracebacks(), "attempt to yield from outside a coroutine");
+    }
+}

+ 30 - 46
src/Lua/LuaState.cs

@@ -7,18 +7,27 @@ public sealed class LuaState
 {
 {
     public const string DefaultChunkName = "chunk";
     public const string DefaultChunkName = "chunk";
 
 
-    LuaStack stack = new();
-    FastStackCore<CallStackFrame> callStack;
+    readonly LuaMainThread mainThread = new();
     FastListCore<UpValue> openUpValues;
     FastListCore<UpValue> openUpValues;
-
-    LuaTable environment;
-    internal UpValue EnvUpValue { get; }
+    FastStackCore<LuaThread> threadStack;
+    readonly LuaTable environment;
+    readonly UpValue envUpValue;
     bool isRunning;
     bool isRunning;
 
 
-    internal LuaStack Stack => stack;
+    internal UpValue EnvUpValue => envUpValue;
+    internal ref FastStackCore<LuaThread> ThreadStack => ref threadStack;
+    internal ref FastListCore<UpValue> OpenUpValues => ref openUpValues;
 
 
     public LuaTable Environment => environment;
     public LuaTable Environment => environment;
-    public bool IsRunning => Volatile.Read(ref isRunning);
+    public LuaMainThread MainThread => mainThread;
+    public LuaThread CurrentThread
+    {
+        get
+        {
+            if (threadStack.TryPeek(out var thread)) return thread;
+            return mainThread;
+        }
+    }
 
 
     public static LuaState Create()
     public static LuaState Create()
     {
     {
@@ -28,7 +37,7 @@ public sealed class LuaState
     LuaState()
     LuaState()
     {
     {
         environment = new();
         environment = new();
-        EnvUpValue = UpValue.Closed(environment);
+        envUpValue = UpValue.Closed(mainThread, environment);
     }
     }
 
 
     public async ValueTask<int> RunAsync(Chunk chunk, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     public async ValueTask<int> RunAsync(Chunk chunk, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
@@ -38,7 +47,8 @@ public sealed class LuaState
         Volatile.Write(ref isRunning, true);
         Volatile.Write(ref isRunning, true);
         try
         try
         {
         {
-            return await new Closure(this, chunk).InvokeAsync(new()
+            var closure = new Closure(this, chunk);
+            return await closure.InvokeAsync(new()
             {
             {
                 State = this,
                 State = this,
                 ArgumentCount = 0,
                 ArgumentCount = 0,
@@ -54,79 +64,53 @@ public sealed class LuaState
         }
         }
     }
     }
 
 
-    public ReadOnlySpan<LuaValue> GetStackValues()
-    {
-        return stack.AsSpan();
-    }
-
     public void Push(LuaValue value)
     public void Push(LuaValue value)
     {
     {
         ThrowIfRunning();
         ThrowIfRunning();
-        stack.Push(value);
-    }
-
-    internal void PushCallStackFrame(CallStackFrame frame)
-    {
-        callStack.Push(frame);
+        mainThread.Stack.Push(value);
     }
     }
 
 
-    internal void PopCallStackFrame()
+    public LuaThread CreateThread(LuaFunction function, bool isProtectedMode = true)
     {
     {
-        var frame = callStack.Pop();
-        stack.PopUntil(frame.Base);
-    }
-
-    public CallStackFrame GetCurrentFrame()
-    {
-        return callStack.Peek();
+        return new LuaCoroutine(this, function, isProtectedMode);
     }
     }
 
 
     public Tracebacks GetTracebacks()
     public Tracebacks GetTracebacks()
     {
     {
-        return new()
-        {
-            StackFrames = callStack.AsSpan()[1..].ToArray()
-        };
+        return MainThread.GetTracebacks();
     }
     }
 
 
-    internal UpValue GetOrAddUpValue(int registerIndex)
+    internal UpValue GetOrAddUpValue(LuaThread thread, int registerIndex)
     {
     {
         foreach (var upValue in openUpValues.AsSpan())
         foreach (var upValue in openUpValues.AsSpan())
         {
         {
-            if (upValue.RegisterIndex == registerIndex)
+            if (upValue.RegisterIndex == registerIndex && upValue.Thread == thread)
             {
             {
                 return upValue;
                 return upValue;
             }
             }
         }
         }
 
 
-        var newUpValue = UpValue.Open(registerIndex);
+        var newUpValue = UpValue.Open(thread, registerIndex);
         openUpValues.Add(newUpValue);
         openUpValues.Add(newUpValue);
         return newUpValue;
         return newUpValue;
     }
     }
 
 
-    internal void CloseUpValues(int frameBase)
+    internal void CloseUpValues(LuaThread thread, int frameBase)
     {
     {
         for (int i = 0; i < openUpValues.Length; i++)
         for (int i = 0; i < openUpValues.Length; i++)
         {
         {
             var upValue = openUpValues[i];
             var upValue = openUpValues[i];
+            if (upValue.Thread != thread) continue;
+
             if (upValue.RegisterIndex >= frameBase)
             if (upValue.RegisterIndex >= frameBase)
             {
             {
-                upValue.Close(this);
+                upValue.Close();
                 openUpValues.RemoveAtSwapback(i);
                 openUpValues.RemoveAtSwapback(i);
                 i--;
                 i--;
             }
             }
         }
         }
     }
     }
 
 
-    internal void DumpStackValues()
-    {
-        var span = GetStackValues();
-        for (int i = 0; i < span.Length; i++)
-        {
-            Console.WriteLine($"LuaStack [{i}]\t{span[i]}");
-        }
-    }
-
     void ThrowIfRunning()
     void ThrowIfRunning()
     {
     {
         if (Volatile.Read(ref isRunning))
         if (Volatile.Read(ref isRunning))

+ 54 - 0
src/Lua/LuaThread.cs

@@ -0,0 +1,54 @@
+using Lua.Internal;
+using Lua.Runtime;
+
+namespace Lua;
+
+public abstract class LuaThread
+{
+    public abstract LuaThreadStatus GetStatus();
+    public abstract ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default);
+    public abstract ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default);
+
+    LuaStack stack = new();
+    FastStackCore<CallStackFrame> callStack;
+
+    internal LuaStack Stack => stack;
+
+    public CallStackFrame GetCurrentFrame()
+    {
+        return callStack.Peek();
+    }
+
+    public ReadOnlySpan<LuaValue> GetStackValues()
+    {
+        return stack.AsSpan();
+    }
+
+    internal Tracebacks GetTracebacks()
+    {
+        return new()
+        {
+            StackFrames = callStack.AsSpan()[1..].ToArray()
+        };
+    }
+
+    internal void PushCallStackFrame(CallStackFrame frame)
+    {
+        callStack.Push(frame);
+    }
+
+    internal void PopCallStackFrame()
+    {
+        var frame = callStack.Pop();
+        stack.PopUntil(frame.Base);
+    }
+
+    internal void DumpStackValues()
+    {
+        var span = GetStackValues();
+        for (int i = 0; i < span.Length; i++)
+        {
+            Console.WriteLine($"LuaStack [{i}]\t{span[i]}");
+        }
+    }
+}

+ 9 - 0
src/Lua/LuaThreadStatus.cs

@@ -0,0 +1,9 @@
+namespace Lua;
+
+public enum LuaThreadStatus : byte
+{
+    Normal,
+    Suspended,
+    Running,
+    Dead,
+}

+ 36 - 0
src/Lua/LuaValue.cs

@@ -11,6 +11,7 @@ public enum LuaValueType : byte
     String,
     String,
     Number,
     Number,
     Function,
     Function,
+    Thread,
     UserData,
     UserData,
     Table,
     Table,
 }
 }
@@ -114,6 +115,22 @@ public readonly struct LuaValue : IEquatable<LuaValue>
                 {
                 {
                     break;
                     break;
                 }
                 }
+            case LuaValueType.Thread:
+                if (t == typeof(LuaThread))
+                {
+                    var v = (LuaThread)referenceValue!;
+                    result = Unsafe.As<LuaThread, T>(ref v);
+                    return true;
+                }
+                else if (t == typeof(object))
+                {
+                    result = (T)referenceValue!;
+                    return true;
+                }
+                else
+                {
+                    break;
+                }
             case LuaValueType.UserData:
             case LuaValueType.UserData:
                 if (referenceValue is T userData)
                 if (referenceValue is T userData)
                 {
                 {
@@ -187,6 +204,12 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         referenceValue = value;
         referenceValue = value;
     }
     }
 
 
+    public LuaValue(LuaThread value)
+    {
+        type = LuaValueType.Thread;
+        referenceValue = value;
+    }
+
     public LuaValue(object? value)
     public LuaValue(object? value)
     {
     {
         type = LuaValueType.UserData;
         type = LuaValueType.UserData;
@@ -218,6 +241,11 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         return new(value);
         return new(value);
     }
     }
 
 
+    public static implicit operator LuaValue(LuaThread value)
+    {
+        return new(value);
+    }
+
     public override int GetHashCode()
     public override int GetHashCode()
     {
     {
         var valueHash = type switch
         var valueHash = type switch
@@ -227,6 +255,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().GetHashCode(),
             LuaValueType.String => Read<string>().GetHashCode(),
             LuaValueType.Number => Read<double>().GetHashCode(),
             LuaValueType.Number => Read<double>().GetHashCode(),
             LuaValueType.Function => Read<LuaFunction>().GetHashCode(),
             LuaValueType.Function => Read<LuaFunction>().GetHashCode(),
+            LuaValueType.Thread => Read<LuaThread>().GetHashCode(),
             LuaValueType.Table => Read<LuaTable>().GetHashCode(),
             LuaValueType.Table => Read<LuaTable>().GetHashCode(),
             LuaValueType.UserData => referenceValue == null ? 0 : referenceValue.GetHashCode(),
             LuaValueType.UserData => referenceValue == null ? 0 : referenceValue.GetHashCode(),
             _ => 0,
             _ => 0,
@@ -246,6 +275,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().Equals(other.Read<string>()),
             LuaValueType.String => Read<string>().Equals(other.Read<string>()),
             LuaValueType.Number => Read<double>().Equals(other.Read<double>()),
             LuaValueType.Number => Read<double>().Equals(other.Read<double>()),
             LuaValueType.Function => Read<LuaFunction>().Equals(other.Read<LuaFunction>()),
             LuaValueType.Function => Read<LuaFunction>().Equals(other.Read<LuaFunction>()),
+            LuaValueType.Thread => Read<LuaThread>().Equals(other.Read<LuaThread>()),
             LuaValueType.Table => Read<LuaTable>().Equals(other.Read<LuaTable>()),
             LuaValueType.Table => Read<LuaTable>().Equals(other.Read<LuaTable>()),
             LuaValueType.UserData => referenceValue == other.referenceValue,
             LuaValueType.UserData => referenceValue == other.referenceValue,
             _ => false,
             _ => false,
@@ -276,6 +306,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().ToString(),
             LuaValueType.String => Read<string>().ToString(),
             LuaValueType.Number => Read<double>().ToString(),
             LuaValueType.Number => Read<double>().ToString(),
             LuaValueType.Function => Read<LuaFunction>().ToString(),
             LuaValueType.Function => Read<LuaFunction>().ToString(),
+            LuaValueType.Thread => Read<LuaThread>().ToString(),
             LuaValueType.Table => Read<LuaTable>().ToString(),
             LuaValueType.Table => Read<LuaTable>().ToString(),
             LuaValueType.UserData => referenceValue?.ToString(),
             LuaValueType.UserData => referenceValue?.ToString(),
             _ => "",
             _ => "",
@@ -309,6 +340,11 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             result = LuaValueType.Table;
             result = LuaValueType.Table;
             return true;
             return true;
         }
         }
+        else if (type == typeof(LuaThread))
+        {
+            result = LuaValueType.Thread;
+            return true;
+        }
 
 
         result = default;
         result = default;
         return false;
         return false;

+ 3 - 2
src/Lua/Runtime/Closure.cs

@@ -27,14 +27,15 @@ public sealed class Closure : LuaFunction
 
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
-        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.State.GetCurrentFrame(), buffer, cancellationToken);
+        return LuaVirtualMachine.ExecuteClosureAsync(context.State, this, context.State.CurrentThread.GetCurrentFrame(), buffer, cancellationToken);
     }
     }
 
 
     static UpValue GetUpValueFromDescription(LuaState state, Chunk proto, UpValueInfo description)
     static UpValue GetUpValueFromDescription(LuaState state, Chunk proto, UpValueInfo description)
     {
     {
         if (description.IsInRegister)
         if (description.IsInRegister)
         {
         {
-            return state.GetOrAddUpValue(state.GetCurrentFrame().Base + description.Index);
+            var thread = state.CurrentThread;
+            return state.GetOrAddUpValue(thread, thread.GetCurrentFrame().Base + description.Index);
         }
         }
         else if (description.Index == -1) // -1 is global environment
         else if (description.Index == -1) // -1 is global environment
         {
         {

+ 1 - 1
src/Lua/Runtime/LuaValueRuntimeExtensions.cs

@@ -29,7 +29,7 @@ internal static class LuaRuntimeExtensions
         buffer.AsSpan().Clear();
         buffer.AsSpan().Clear();
         try
         try
         {
         {
-            return await function.InvokeAsync(context, cancellationToken);
+            return await function.InvokeAsync(context, buffer, cancellationToken);
         }
         }
         finally
         finally
         {
         {

+ 16 - 15
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -8,7 +8,8 @@ public static partial class LuaVirtualMachine
 {
 {
     internal async static ValueTask<int> ExecuteClosureAsync(LuaState state, Closure closure, CallStackFrame frame, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     internal async static ValueTask<int> ExecuteClosureAsync(LuaState state, Closure closure, CallStackFrame frame, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
     {
-        var stack = state.Stack;
+        var thread = state.CurrentThread;
+        var stack = thread.Stack;
         var chunk = closure.Proto;
         var chunk = closure.Proto;
 
 
         for (var pc = 0; pc < chunk.Instructions.Length; pc++)
         for (var pc = 0; pc < chunk.Instructions.Length; pc++)
@@ -47,7 +48,7 @@ public static partial class LuaVirtualMachine
                     {
                     {
                         stack.EnsureCapacity(RA + 1);
                         stack.EnsureCapacity(RA + 1);
                         var upValue = closure.UpValues[instruction.B];
                         var upValue = closure.UpValues[instruction.B];
-                        stack.UnsafeGet(RA) = upValue.GetValue(state);
+                        stack.UnsafeGet(RA) = upValue.GetValue();
                         stack.NotifyTop(RA + 1);
                         stack.NotifyTop(RA + 1);
                         break;
                         break;
                     }
                     }
@@ -56,7 +57,7 @@ public static partial class LuaVirtualMachine
                         stack.EnsureCapacity(RA + 1);
                         stack.EnsureCapacity(RA + 1);
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var upValue = closure.UpValues[instruction.B];
                         var upValue = closure.UpValues[instruction.B];
-                        var table = upValue.GetValue(state);
+                        var table = upValue.GetValue();
                         var value = await GetTableValue(state, chunk, pc, table, vc, cancellationToken);
                         var value = await GetTableValue(state, chunk, pc, table, vc, cancellationToken);
                         stack.UnsafeGet(RA) = value;
                         stack.UnsafeGet(RA) = value;
                         stack.NotifyTop(RA + 1);
                         stack.NotifyTop(RA + 1);
@@ -78,14 +79,14 @@ public static partial class LuaVirtualMachine
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
 
 
                         var upValue = closure.UpValues[instruction.A];
                         var upValue = closure.UpValues[instruction.A];
-                        var table = upValue.GetValue(state);
+                        var table = upValue.GetValue();
                         await SetTableValue(state, chunk, pc, table, vb, vc, cancellationToken);
                         await SetTableValue(state, chunk, pc, table, vb, vc, cancellationToken);
                         break;
                         break;
                     }
                     }
                 case OpCode.SetUpVal:
                 case OpCode.SetUpVal:
                     {
                     {
                         var upValue = closure.UpValues[instruction.B];
                         var upValue = closure.UpValues[instruction.B];
-                        upValue.SetValue(state, stack.UnsafeGet(RA));
+                        upValue.SetValue(stack.UnsafeGet(RA));
                         break;
                         break;
                     }
                     }
                 case OpCode.SetTable:
                 case OpCode.SetTable:
@@ -498,7 +499,7 @@ public static partial class LuaVirtualMachine
                     pc += instruction.SBx;
                     pc += instruction.SBx;
                     if (instruction.A != 0)
                     if (instruction.A != 0)
                     {
                     {
-                        state.CloseUpValues(instruction.A);
+                        state.CloseUpValues(thread, instruction.A);
                     }
                     }
                     break;
                     break;
                 case OpCode.Eq:
                 case OpCode.Eq:
@@ -711,7 +712,7 @@ public static partial class LuaVirtualMachine
                     break;
                     break;
                 case OpCode.TailCall:
                 case OpCode.TailCall:
                     {
                     {
-                        state.CloseUpValues(frame.Base);
+                        state.CloseUpValues(thread, frame.Base);
 
 
                         var va = stack.UnsafeGet(RA);
                         var va = stack.UnsafeGet(RA);
                         if (!va.TryRead<LuaFunction>(out var func))
                         if (!va.TryRead<LuaFunction>(out var func))
@@ -736,7 +737,7 @@ public static partial class LuaVirtualMachine
                     }
                     }
                 case OpCode.Return:
                 case OpCode.Return:
                     {
                     {
-                        state.CloseUpValues(frame.Base);
+                        state.CloseUpValues(thread, frame.Base);
 
 
                         var retCount = instruction.B - 1;
                         var retCount = instruction.B - 1;
 
 
@@ -880,7 +881,7 @@ public static partial class LuaVirtualMachine
 #endif
 #endif
     static async ValueTask<LuaValue> GetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, CancellationToken cancellationToken)
     static async ValueTask<LuaValue> GetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, CancellationToken cancellationToken)
     {
     {
-        var stack = state.Stack;
+        var stack = state.CurrentThread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
         var isTable = table.TryRead<LuaTable>(out var t);
 
 
         if (isTable && t.TryGetValue(key, out var result))
         if (isTable && t.TryGetValue(key, out var result))
@@ -931,7 +932,7 @@ public static partial class LuaVirtualMachine
 #endif
 #endif
     static async ValueTask SetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken)
     static async ValueTask SetTableValue(LuaState state, Chunk chunk, int pc, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken)
     {
     {
-        var stack = state.Stack;
+        var stack = state.CurrentThread.Stack;
         var isTable = table.TryRead<LuaTable>(out var t);
         var isTable = table.TryRead<LuaTable>(out var t);
 
 
         if (isTable && t.ContainsKey(key))
         if (isTable && t.ContainsKey(key))
@@ -977,7 +978,7 @@ public static partial class LuaVirtualMachine
 
 
     static (int FrameBase, int ArgumentCount) PrepareForFunctionCall(LuaState state, LuaFunction function, Instruction instruction, int RA, bool isTailCall)
     static (int FrameBase, int ArgumentCount) PrepareForFunctionCall(LuaState state, LuaFunction function, Instruction instruction, int RA, bool isTailCall)
     {
     {
-        var stack = state.Stack;
+        var stack = state.CurrentThread.Stack;
 
 
         var argumentCount = instruction.B - 1;
         var argumentCount = instruction.B - 1;
         if (instruction.B == 0)
         if (instruction.B == 0)
@@ -991,7 +992,7 @@ public static partial class LuaVirtualMachine
         // Therefore, a call can be made without allocating new registers.
         // Therefore, a call can be made without allocating new registers.
         if (isTailCall)
         if (isTailCall)
         {
         {
-            var currentBase = state.GetCurrentFrame().Base;
+            var currentBase = state.CurrentThread.GetCurrentFrame().Base;
             var stackBuffer = stack.GetBuffer();
             var stackBuffer = stack.GetBuffer();
             stackBuffer.Slice(newBase, argumentCount).CopyTo(stackBuffer.Slice(currentBase, argumentCount));
             stackBuffer.Slice(newBase, argumentCount).CopyTo(stackBuffer.Slice(currentBase, argumentCount));
             newBase = currentBase;
             newBase = currentBase;
@@ -1031,15 +1032,15 @@ public static partial class LuaVirtualMachine
 
 
     static Tracebacks GetTracebacks(LuaState state, Chunk chunk, int pc)
     static Tracebacks GetTracebacks(LuaState state, Chunk chunk, int pc)
     {
     {
-        var frame = state.GetCurrentFrame();
-        state.PushCallStackFrame(frame with
+        var frame = state.CurrentThread.GetCurrentFrame();
+        state.CurrentThread.PushCallStackFrame(frame with
         {
         {
             CallPosition = chunk.SourcePositions[pc],
             CallPosition = chunk.SourcePositions[pc],
             ChunkName = chunk.Name,
             ChunkName = chunk.Name,
             RootChunkName = chunk.GetRoot().Name,
             RootChunkName = chunk.GetRoot().Name,
         });
         });
         var tracebacks = state.GetTracebacks();
         var tracebacks = state.GetTracebacks();
-        state.PopCallStackFrame();
+        state.CurrentThread.PopCallStackFrame();
 
 
         return tracebacks;
         return tracebacks;
     }
     }

+ 13 - 11
src/Lua/Runtime/UpValue.cs

@@ -7,24 +7,26 @@ public sealed class UpValue
 {
 {
     LuaValue value;
     LuaValue value;
 
 
+    public LuaThread Thread { get; }
     public bool IsClosed { get; private set; }
     public bool IsClosed { get; private set; }
     public int RegisterIndex { get; private set; }
     public int RegisterIndex { get; private set; }
 
 
-    UpValue()
+    UpValue(LuaThread thread)
     {
     {
+        Thread = thread;
     }
     }
 
 
-    public static UpValue Open(int registerIndex)
+    public static UpValue Open(LuaThread thread, int registerIndex)
     {
     {
-        return new()
+        return new(thread)
         {
         {
             RegisterIndex = registerIndex
             RegisterIndex = registerIndex
         };
         };
     }
     }
 
 
-    public static UpValue Closed(LuaValue value)
+    public static UpValue Closed(LuaThread thread, LuaValue value)
     {
     {
-        return new()
+        return new(thread)
         {
         {
             IsClosed = true,
             IsClosed = true,
             value = value
             value = value
@@ -32,7 +34,7 @@ public sealed class UpValue
     }
     }
 
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public LuaValue GetValue(LuaState state)
+    public LuaValue GetValue()
     {
     {
         if (IsClosed)
         if (IsClosed)
         {
         {
@@ -40,12 +42,12 @@ public sealed class UpValue
         }
         }
         else
         else
         {
         {
-            return state.Stack.UnsafeGet(RegisterIndex);
+            return Thread.Stack.UnsafeGet(RegisterIndex);
         }
         }
     }
     }
 
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void SetValue(LuaState state, LuaValue value)
+    public void SetValue(LuaValue value)
     {
     {
         if (IsClosed)
         if (IsClosed)
         {
         {
@@ -53,16 +55,16 @@ public sealed class UpValue
         }
         }
         else
         else
         {
         {
-            state.Stack.UnsafeGet(RegisterIndex) = value;
+            Thread.Stack.UnsafeGet(RegisterIndex) = value;
         }
         }
     }
     }
 
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void Close(LuaState state)
+    public void Close()
     {
     {
         if (!IsClosed)
         if (!IsClosed)
         {
         {
-            value = state.Stack.UnsafeGet(RegisterIndex);
+            value = Thread.Stack.UnsafeGet(RegisterIndex);
         }
         }
 
 
         IsClosed = true;
         IsClosed = true;

+ 16 - 0
src/Lua/Standard/Coroutines/CoroutineCreateFunction.cs

@@ -0,0 +1,16 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineCreateFunction : LuaFunction
+{
+    public const string FunctionName = "create";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var arg0 = context.ReadArgument<LuaFunction>(0);
+        buffer.Span[0] = context.State.CreateThread(arg0, true);
+        return new(1);
+    }
+}

+ 15 - 0
src/Lua/Standard/Coroutines/CoroutineResumeFunction.cs

@@ -0,0 +1,15 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineResumeFunction : LuaFunction
+{
+    public const string FunctionName = "resume";
+
+    public override string Name => FunctionName;
+
+    protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var thread = context.ReadArgument<LuaThread>(0);
+        return await thread.Resume(context, buffer, cancellationToken);
+    }
+}

+ 16 - 0
src/Lua/Standard/Coroutines/CoroutineRunningFunction.cs

@@ -0,0 +1,16 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineRunningFunction : LuaFunction
+{
+    public const string FunctionName = "running";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        buffer.Span[0] = context.State.CurrentThread;
+        buffer.Span[1] = context.State.CurrentThread == context.State.MainThread;
+        return new(2);
+    }
+}

+ 23 - 0
src/Lua/Standard/Coroutines/CoroutineStatusFunction.cs

@@ -0,0 +1,23 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineStatusFunction : LuaFunction
+{
+    public const string FunctionName = "status";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var thread = context.ReadArgument<LuaThread>(0);
+        buffer.Span[0] = thread.GetStatus() switch
+        {
+            LuaThreadStatus.Normal => "normal",
+            LuaThreadStatus.Suspended => "suspended",
+            LuaThreadStatus.Running => "running",
+            LuaThreadStatus.Dead => "dead",
+            _ => throw new NotImplementedException(),
+        };
+        return new(1);
+    }
+}

+ 25 - 0
src/Lua/Standard/Coroutines/CoroutineWrapFunction.cs

@@ -0,0 +1,25 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineWrapFunction : LuaFunction
+{
+    public const string FunctionName = "wrap";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var arg0 = context.ReadArgument<LuaFunction>(0);
+        var thread = context.State.CreateThread(arg0, false);
+        buffer.Span[0] = new Wrapper(thread);
+        return new(1);
+    }
+
+    class Wrapper(LuaThread targetThread) : LuaFunction
+    {
+        protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+        {
+            return await targetThread.Resume(context, buffer, cancellationToken);
+        }
+    }
+}

+ 15 - 0
src/Lua/Standard/Coroutines/CoroutineYieldFunction.cs

@@ -0,0 +1,15 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineYieldFunction : LuaFunction
+{
+    public const string FunctionName = "yield";
+
+    public override string Name => FunctionName;
+
+    protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        await context.State.CurrentThread.Yield(context, cancellationToken);
+        return 0;
+    }
+}

+ 14 - 1
src/Lua/Standard/OpenLibExtensions.cs

@@ -1,4 +1,5 @@
 using Lua.Standard.Base;
 using Lua.Standard.Base;
+using Lua.Standard.Coroutines;
 using Lua.Standard.Mathematics;
 using Lua.Standard.Mathematics;
 
 
 namespace Lua.Standard;
 namespace Lua.Standard;
@@ -46,14 +47,26 @@ public static class OpenLibExtensions
         TanhFunction.Instance,
         TanhFunction.Instance,
     ];
     ];
 
 
-    public static void OpenBaseLibrary(this LuaState state)
+    public static void OpenBasicLibrary(this LuaState state)
     {
     {
+        // basic
         state.Environment["_G"] = state.Environment;
         state.Environment["_G"] = state.Environment;
         state.Environment["_VERSION"] = "Lua 5.2";
         state.Environment["_VERSION"] = "Lua 5.2";
         foreach (var func in baseFunctions)
         foreach (var func in baseFunctions)
         {
         {
             state.Environment[func.Name] = func;
             state.Environment[func.Name] = func;
         }
         }
+
+        // coroutine
+        var coroutine = new LuaTable(0, 6);
+        coroutine[CoroutineCreateFunction.FunctionName] = new CoroutineCreateFunction();
+        coroutine[CoroutineResumeFunction.FunctionName] = new CoroutineResumeFunction();
+        coroutine[CoroutineYieldFunction.FunctionName] = new CoroutineYieldFunction();
+        coroutine[CoroutineStatusFunction.FunctionName] = new CoroutineStatusFunction();
+        coroutine[CoroutineRunningFunction.FunctionName] = new CoroutineRunningFunction();
+        coroutine[CoroutineWrapFunction.FunctionName] = new CoroutineWrapFunction();
+
+        state.Environment["coroutine"] = coroutine;
     }
     }
 
 
     public static void OpenMathLibrary(this LuaState state)
     public static void OpenMathLibrary(this LuaState state)

+ 1 - 1
tests/Lua.Tests/MetatableTests.cs

@@ -10,7 +10,7 @@ public class MetatableTests
     public void SetUp()
     public void SetUp()
     {
     {
         state = LuaState.Create();
         state = LuaState.Create();
-        state.OpenBaseLibrary();
+        state.OpenBasicLibrary();
     }
     }
 
 
     [Test]
     [Test]