Browse Source

Add: LuaThread and Coroutine library

AnnulusGames 1 year ago
parent
commit
e805dd7d2e

+ 5 - 1
src/Lua/LuaFunction.cs

@@ -4,6 +4,11 @@ namespace Lua;
 
 public abstract partial class LuaFunction
 {
+    internal LuaThread? thread;
+    public LuaThread? Thread => thread;
+
+    public virtual string Name => GetType().Name;
+
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var state = context.State;
@@ -29,6 +34,5 @@ public abstract partial class LuaFunction
         }
     }
 
-    public virtual string Name => GetType().Name;
     protected abstract ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken);
 }

+ 35 - 1
src/Lua/LuaState.cs

@@ -1,3 +1,4 @@
+using System.Diagnostics.CodeAnalysis;
 using Lua.Internal;
 using Lua.Runtime;
 
@@ -25,12 +26,23 @@ public sealed class LuaState
         return new();
     }
 
+    internal LuaState CreateCoroutineState()
+    {
+        return new LuaState(this);
+    }
+
     LuaState()
     {
         environment = new();
         EnvUpValue = UpValue.Closed(environment);
     }
 
+    LuaState(LuaState parent)
+    {
+        environment = parent.Environment;
+        EnvUpValue = parent.EnvUpValue;
+    }
+
     public async ValueTask<int> RunAsync(Chunk chunk, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         ThrowIfRunning();
@@ -38,7 +50,8 @@ public sealed class LuaState
         Volatile.Write(ref isRunning, true);
         try
         {
-            return await new Closure(this, chunk).InvokeAsync(new()
+            var closure = new Closure(this, chunk);
+            return await closure.InvokeAsync(new()
             {
                 State = this,
                 ArgumentCount = 0,
@@ -59,6 +72,8 @@ public sealed class LuaState
         return stack.AsSpan();
     }
 
+    public int StackCount => stack.Count;
+
     public void Push(LuaValue value)
     {
         ThrowIfRunning();
@@ -76,6 +91,25 @@ public sealed class LuaState
         stack.PopUntil(frame.Base);
     }
 
+    internal ReadOnlySpan<CallStackFrame> GetCallStackSpan()
+    {
+        return callStack.AsSpan();
+    }
+
+    public bool TryGetCurrentThread([NotNullWhen(true)] out LuaThread? result)
+    {
+        var span = GetCallStackSpan();
+        
+        for (int i = 0; i < span.Length; i++)
+        {
+            result = span[i].Function.Thread;
+            if (result != null) return true;
+        }
+
+        result = default;
+        return false;
+    }
+
     public CallStackFrame GetCurrentFrame()
     {
         return callStack.Peek();

+ 140 - 0
src/Lua/LuaThread.cs

@@ -0,0 +1,140 @@
+namespace Lua;
+
+public sealed class LuaThread
+{
+    LuaThreadStatus status;
+    LuaState threadState;
+    Task<int>? functionTask;
+
+    TaskCompletionSource<LuaValue[]> resume = new();
+    TaskCompletionSource<object?> yield = new();
+
+    public LuaThreadStatus Status => status;
+    public LuaFunction Function { get; }
+
+    internal LuaThread(LuaState state, LuaFunction function)
+    {
+        threadState = state.CreateCoroutineState();
+        Function = function;
+        function.thread = this;
+    }
+
+    public async Task<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
+    {
+        if (status is LuaThreadStatus.Running or LuaThreadStatus.Dead)
+        {
+            throw new Exception(); // TODO:
+        }
+
+        if (status is LuaThreadStatus.Normal)
+        {
+            status = LuaThreadStatus.Running;
+
+            // first argument is LuaThread object
+            for (int i = 0; i < context.ArgumentCount - 1; i++)
+            {
+                threadState.Push(context.Arguments[i + 1]);
+            }
+
+            functionTask = Function.InvokeAsync(new()
+            {
+                State = threadState,
+                ArgumentCount = context.ArgumentCount - 1,
+                ChunkName = Function.Name,
+                RootChunkName = context.RootChunkName,
+            }, buffer[1..], cancellationToken).AsTask();
+        }
+        else
+        {
+            status = LuaThreadStatus.Running;
+
+            if (cancellationToken.IsCancellationRequested)
+            {
+                yield.TrySetCanceled();
+            }
+            else
+            {
+                yield.TrySetResult(null);
+            }
+        }
+
+        var resumeTask = resume.Task;
+        var completedTask = await Task.WhenAny(resumeTask, functionTask!);
+
+        if (!completedTask.IsCompletedSuccessfully)
+        {
+            status = LuaThreadStatus.Dead;
+            buffer.Span[0] = false;
+            buffer.Span[1] = completedTask.Exception.InnerException.Message;
+            return 2;
+        }
+
+        if (completedTask == resumeTask)
+        {
+            resume = new();
+            var results = resumeTask.Result;
+
+            buffer.Span[0] = true;
+            for (int i = 0; i < results.Length; i++)
+            {
+                buffer.Span[i + 1] = results[i];
+            }
+
+            return results.Length + 1;
+        }
+        else
+        {
+            status = LuaThreadStatus.Dead;
+            buffer.Span[0] = true;
+            return 1 + functionTask!.Result;
+        }
+    }
+
+    public async Task Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    {
+        if (status is not LuaThreadStatus.Running)
+        {
+            throw new Exception(); // TODO:
+        }
+
+        if (cancellationToken.IsCancellationRequested)
+        {
+            resume.TrySetCanceled();
+        }
+        else
+        {
+            resume.TrySetResult(context.Arguments.ToArray());
+        }
+
+        status = LuaThreadStatus.Suspended;
+
+RETRY:
+        try
+        {
+            await yield.Task;
+        }
+        catch (Exception ex) when (ex is not OperationCanceledException)
+        {
+            yield = new();
+            goto RETRY;
+        }
+
+        yield = new();
+    }
+
+    public Task<int> Close(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
+    {
+        if (status is LuaThreadStatus.Normal or LuaThreadStatus.Running)
+        {
+            throw new Exception(); // TODO:
+        }
+
+        threadState.CloseUpValues(0);
+        yield.TrySetCanceled();
+
+        status = LuaThreadStatus.Dead;
+
+        buffer.Span[0] = true;
+        return Task.FromResult(1);
+    }
+}

+ 9 - 0
src/Lua/LuaThreadStatus.cs

@@ -0,0 +1,9 @@
+namespace Lua;
+
+public enum LuaThreadStatus
+{
+    Normal,
+    Suspended,
+    Running,
+    Dead,
+}

+ 36 - 0
src/Lua/LuaValue.cs

@@ -11,6 +11,7 @@ public enum LuaValueType : byte
     String,
     Number,
     Function,
+    Thread,
     UserData,
     Table,
 }
@@ -114,6 +115,22 @@ public readonly struct LuaValue : IEquatable<LuaValue>
                 {
                     break;
                 }
+            case LuaValueType.Thread:
+                if (t == typeof(LuaThread))
+                {
+                    var v = (LuaThread)referenceValue!;
+                    result = Unsafe.As<LuaThread, T>(ref v);
+                    return true;
+                }
+                else if (t == typeof(object))
+                {
+                    result = (T)referenceValue!;
+                    return true;
+                }
+                else
+                {
+                    break;
+                }
             case LuaValueType.UserData:
                 if (referenceValue is T userData)
                 {
@@ -187,6 +204,12 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         referenceValue = value;
     }
 
+    public LuaValue(LuaThread value)
+    {
+        type = LuaValueType.Thread;
+        referenceValue = value;
+    }
+
     public LuaValue(object? value)
     {
         type = LuaValueType.UserData;
@@ -218,6 +241,11 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         return new(value);
     }
 
+    public static implicit operator LuaValue(LuaThread value)
+    {
+        return new(value);
+    }
+
     public override int GetHashCode()
     {
         var valueHash = type switch
@@ -227,6 +255,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().GetHashCode(),
             LuaValueType.Number => Read<double>().GetHashCode(),
             LuaValueType.Function => Read<LuaFunction>().GetHashCode(),
+            LuaValueType.Thread => Read<LuaThread>().GetHashCode(),
             LuaValueType.Table => Read<LuaTable>().GetHashCode(),
             LuaValueType.UserData => referenceValue == null ? 0 : referenceValue.GetHashCode(),
             _ => 0,
@@ -246,6 +275,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().Equals(other.Read<string>()),
             LuaValueType.Number => Read<double>().Equals(other.Read<double>()),
             LuaValueType.Function => Read<LuaFunction>().Equals(other.Read<LuaFunction>()),
+            LuaValueType.Thread => Read<LuaThread>().Equals(other.Read<LuaThread>()),
             LuaValueType.Table => Read<LuaTable>().Equals(other.Read<LuaTable>()),
             LuaValueType.UserData => referenceValue == other.referenceValue,
             _ => false,
@@ -276,6 +306,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.String => Read<string>().ToString(),
             LuaValueType.Number => Read<double>().ToString(),
             LuaValueType.Function => Read<LuaFunction>().ToString(),
+            LuaValueType.Thread => Read<LuaThread>().ToString(),
             LuaValueType.Table => Read<LuaTable>().ToString(),
             LuaValueType.UserData => referenceValue?.ToString(),
             _ => "",
@@ -309,6 +340,11 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             result = LuaValueType.Table;
             return true;
         }
+        else if (type == typeof(LuaThread))
+        {
+            result = LuaValueType.Thread;
+            return true;
+        }
 
         result = default;
         return false;

+ 1 - 1
src/Lua/Runtime/LuaValueRuntimeExtensions.cs

@@ -29,7 +29,7 @@ internal static class LuaRuntimeExtensions
         buffer.AsSpan().Clear();
         try
         {
-            return await function.InvokeAsync(context, cancellationToken);
+            return await function.InvokeAsync(context, buffer, cancellationToken);
         }
         finally
         {

+ 16 - 0
src/Lua/Standard/Coroutines/CoroutineCreateFunction.cs

@@ -0,0 +1,16 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineCreateFunction : LuaFunction
+{
+    public const string FunctionName = "create";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var arg0 = context.ReadArgument<LuaFunction>(0);
+        buffer.Span[0] = new LuaThread(context.State, arg0);
+        return new(1);
+    }
+}

+ 16 - 0
src/Lua/Standard/Coroutines/CoroutineResumeFunction.cs

@@ -0,0 +1,16 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineResumeFunction : LuaFunction
+{
+    public const string FunctionName = "resume";
+
+    public override string Name => FunctionName;
+
+    protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var thread = context.ReadArgument<LuaThread>(0);
+
+        return await thread.Resume(context, buffer, cancellationToken);
+    }
+}

+ 15 - 0
src/Lua/Standard/Coroutines/CoroutineRunningFunction.cs

@@ -0,0 +1,15 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineRunningFunction : LuaFunction
+{
+    public const string FunctionName = "running";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        buffer.Span[0] = context.State.TryGetCurrentThread(out _);
+        return new(1);
+    }
+}

+ 23 - 0
src/Lua/Standard/Coroutines/CoroutineStatusFunction.cs

@@ -0,0 +1,23 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineStatusFunction : LuaFunction
+{
+    public const string FunctionName = "status";
+
+    public override string Name => FunctionName;
+
+    protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        var thread = context.ReadArgument<LuaThread>(0);
+        buffer.Span[0] = thread.Status switch
+        {
+            LuaThreadStatus.Normal => "normal",
+            LuaThreadStatus.Suspended => "suspended",
+            LuaThreadStatus.Running => "running",
+            LuaThreadStatus.Dead => "dead",
+            _ => throw new NotImplementedException(),
+        };
+        return new(1);
+    }
+}

+ 21 - 0
src/Lua/Standard/Coroutines/CoroutineYieldFunction.cs

@@ -0,0 +1,21 @@
+
+namespace Lua.Standard.Coroutines;
+
+public sealed class CoroutineYieldFunction : LuaFunction
+{
+    public const string FunctionName = "yield";
+
+    public override string Name => FunctionName;
+
+    protected override async ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        if (!context.State.TryGetCurrentThread(out var thread))
+        {
+            throw new LuaRuntimeException(context.State.GetTracebacks(), "attempt to yield from outside a coroutine");
+        }
+
+        await thread.Yield(context, cancellationToken);
+
+        return 0;
+    }
+}

+ 13 - 0
src/Lua/Standard/OpenLibExtensions.cs

@@ -1,4 +1,5 @@
 using Lua.Standard.Base;
+using Lua.Standard.Coroutines;
 using Lua.Standard.Mathematics;
 
 namespace Lua.Standard;
@@ -70,4 +71,16 @@ public static class OpenLibExtensions
 
         state.Environment["math"] = table;
     }
+    
+    public static void OpenCoroutineLibrary(this LuaState state)
+    {
+        var table = new LuaTable(0, 6);
+        table[CoroutineCreateFunction.FunctionName] = new CoroutineCreateFunction();
+        table[CoroutineResumeFunction.FunctionName] = new CoroutineResumeFunction();
+        table[CoroutineYieldFunction.FunctionName] = new CoroutineYieldFunction();
+        table[CoroutineStatusFunction.FunctionName] = new CoroutineStatusFunction();
+        table[CoroutineRunningFunction.FunctionName] = new CoroutineRunningFunction();
+
+        state.Environment["coroutine"] = table;
+    }
 }