Browse Source

fix: async coroutine bug

Akeit0 7 months ago
parent
commit
5ce0765fd1

+ 239 - 0
src/Lua/Internal/CompilerServices/StateMachineRunner.cs

@@ -0,0 +1,239 @@
+#pragma warning disable CS1591
+/*
+
+IStateMachineRunnerPromise is  based on UniTask
+https://github.com/Cysharp/UniTask
+
+MIT License
+
+Copyright (c) 2019 Cysharp, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+*/
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Threading.Tasks.Sources;
+// ReSharper disable ArrangeTypeMemberModifiers
+
+namespace Lua.Internal.CompilerServices
+{
+
+    internal interface IStateMachineRunnerPromise : IValueTaskSource
+    {
+        Action MoveNext { get; }
+        ValueTask Task { get; }
+        void SetResult();
+        void SetException(Exception exception);
+    }
+
+    internal interface IStateMachineRunnerPromise<T> : IValueTaskSource<T>
+    {
+        Action MoveNext { get; }
+        ValueTask<T> Task { get; }
+        void SetResult(T result);
+        void SetException(Exception exception);
+    }
+
+
+    internal sealed class LightAsyncValueTask<TStateMachine> : IStateMachineRunnerPromise, IPoolNode<LightAsyncValueTask<TStateMachine>>
+        where TStateMachine : IAsyncStateMachine
+    {
+        static LinkedPool<LightAsyncValueTask<TStateMachine>> pool;
+
+        public Action MoveNext { get; }
+
+        TStateMachine? stateMachine;
+        ManualResetValueTaskSourceCore<byte> core;
+
+        LightAsyncValueTask()
+        {
+            MoveNext = Run;
+        }
+
+        public static void SetStateMachine(ref TStateMachine stateMachine, ref IStateMachineRunnerPromise? runnerPromiseFieldRef)
+        {
+            if (!pool.TryPop(out var result))
+            {
+                result = new();
+            }
+
+            runnerPromiseFieldRef = result; // set runner before copied.
+            result.stateMachine = stateMachine; // copy struct StateMachine(in release build).
+        }
+
+        LightAsyncValueTask<TStateMachine>? nextNode;
+        public ref LightAsyncValueTask<TStateMachine>? NextNode => ref nextNode;
+
+        void Return()
+        {
+            core.Reset();
+            stateMachine = default;
+            pool.TryPush(this);
+        }
+
+
+        [DebuggerHidden]
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        void Run()
+        {
+            stateMachine!.MoveNext();
+        }
+
+        public ValueTask Task
+        {
+            [DebuggerHidden]
+            get => new(this, core.Version);
+        }
+
+        [DebuggerHidden]
+        public void SetResult()
+        {
+            core.SetResult(0);
+        }
+
+        [DebuggerHidden]
+        public void SetException(Exception exception)
+        {
+            core.SetException(exception);
+        }
+
+        [DebuggerHidden]
+        public void GetResult(short token)
+        {
+            try
+            {
+                core.GetResult(token);
+            }
+            finally
+            {
+                Return();
+            }
+        }
+
+        [DebuggerHidden]
+        public ValueTaskSourceStatus GetStatus(short token)
+        {
+            return core.GetStatus(token);
+        }
+
+
+        [DebuggerHidden]
+        public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+        {
+            core.OnCompleted(continuation, state, token, flags);
+        }
+    }
+
+    internal sealed class LightAsyncValueTask<TStateMachine, T> : IStateMachineRunnerPromise<T>, IPoolNode<LightAsyncValueTask<TStateMachine, T>>
+        where TStateMachine : IAsyncStateMachine
+    {
+        static LinkedPool<LightAsyncValueTask<TStateMachine, T>> pool;
+
+        public Action MoveNext { get; }
+
+        TStateMachine? stateMachine;
+        ManualResetValueTaskSourceCore<T> core;
+
+        LightAsyncValueTask()
+        {
+            MoveNext = Run;
+        }
+
+        public static void SetStateMachine(ref TStateMachine stateMachine, ref IStateMachineRunnerPromise<T>? runnerPromiseFieldRef)
+        {
+            if (!pool.TryPop(out var result))
+            {
+                result = new();
+            }
+
+            runnerPromiseFieldRef = result; // set runner before copied.
+            result.stateMachine = stateMachine; // copy struct StateMachine(in release build).
+        }
+
+        LightAsyncValueTask<TStateMachine, T>? nextNode;
+        public ref LightAsyncValueTask<TStateMachine, T>? NextNode => ref nextNode;
+
+
+        void Return()
+        {
+            core.Reset();
+            stateMachine = default!;
+            pool.TryPush(this);
+        }
+
+
+        [DebuggerHidden]
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        void Run()
+        {
+            stateMachine!.MoveNext();
+        }
+
+        public ValueTask<T> Task
+        {
+            [DebuggerHidden]
+            get => new(this, core.Version);
+        }
+
+        [DebuggerHidden]
+        public void SetResult(T result)
+        {
+            core.SetResult(result);
+        }
+
+        [DebuggerHidden]
+        public void SetException(Exception exception)
+        {
+            core.SetException(exception);
+        }
+
+        [DebuggerHidden]
+        public T GetResult(short token)
+        {
+            try
+            {
+                return core.GetResult(token);
+            }
+            finally
+            {
+                Return();
+            }
+        }
+
+        [DebuggerHidden]
+        T IValueTaskSource<T>.GetResult(short token)
+        {
+            return GetResult(token);
+        }
+
+        [DebuggerHidden]
+        public ValueTaskSourceStatus GetStatus(short token)
+        {
+            return core.GetStatus(token);
+        }
+
+
+        [DebuggerHidden]
+        public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+        {
+            core.OnCompleted(continuation, state, token, flags);
+        }
+    }
+}

+ 20 - 14
src/Lua/LuaCoroutine.cs

@@ -1,5 +1,6 @@
 using System.Threading.Tasks.Sources;
 using System.Threading.Tasks.Sources;
 using Lua.Internal;
 using Lua.Internal;
+using Lua.Internal.CompilerServices;
 using Lua.Runtime;
 using Lua.Runtime;
 using System.Runtime.CompilerServices;
 using System.Runtime.CompilerServices;
 
 
@@ -38,9 +39,10 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
         public ReadOnlySpan<LuaValue> Results => stack.AsSpan()[^argCount..];
         public ReadOnlySpan<LuaValue> Results => stack.AsSpan()[^argCount..];
     }
     }
 
 
-    readonly struct ResumeContext(LuaStack stack, int argCount)
+    readonly struct ResumeContext(LuaStack? stack, int argCount)
     {
     {
-        public ReadOnlySpan<LuaValue> Results => stack.AsSpan()[^argCount..];
+        public ReadOnlySpan<LuaValue> Results => stack!.AsSpan()[^argCount..];
+        public bool IsDead => stack == null;
     }
     }
 
 
     byte status;
     byte status;
@@ -154,29 +156,34 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                 if (isFirstCall)
                 if (isFirstCall)
                 {
                 {
                     Stack.PushRange(stack.AsSpan()[^argCount..]);
                     Stack.PushRange(stack.AsSpan()[^argCount..]);
-                    functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken).Preserve();
+                    functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken);
 
 
                     Volatile.Write(ref isFirstCall, false);
                     Volatile.Write(ref isFirstCall, false);
+                    if (!functionTask.IsCompleted)
+                    {
+                        functionTask.GetAwaiter().OnCompleted(() => this.resume.SetResult(default));
+                    }
                 }
                 }
 
 
-                var (index, result0, _, promise) = await ValueTaskEx.WhenAnyPooled(resumeTask, functionTask!);
-                promise.Dispose();
-                if (index == 0)
+                ResumeContext result0;
+                if (functionTask.IsCompleted || (result0 = await resumeTask).IsDead)
                 {
                 {
-                    var results = result0.Results;
+                    _ = functionTask.Result;
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     stack.PopUntil(returnBase);
                     stack.PopUntil(returnBase);
                     stack.Push(true);
                     stack.Push(true);
-                    stack.PushRange(results);
-                    return results.Length + 1;
+                    stack.PushRange(Stack.AsSpan());
+                    ReleaseCore();
+                    return stack.Count - returnBase;
                 }
                 }
                 else
                 else
                 {
                 {
-                    Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
+                    Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
+                    var results = result0.Results;
                     stack.PopUntil(returnBase);
                     stack.PopUntil(returnBase);
                     stack.Push(true);
                     stack.Push(true);
-                    stack.PushRange(Stack.AsSpan());
-                    ReleaseCore();
-                    return stack.Count - returnBase;
+                    stack.PushRange(results);
+                    return results.Length + 1;
                 }
                 }
             }
             }
             catch (Exception ex) when (ex is not OperationCanceledException)
             catch (Exception ex) when (ex is not OperationCanceledException)
@@ -247,7 +254,6 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
 
 
         resume.SetResult(new(stack, argCount));
         resume.SetResult(new(stack, argCount));
 
 
-        Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
 
 
         CancellationTokenRegistration registration = default;
         CancellationTokenRegistration registration = default;
         if (cancellationToken.CanBeCanceled)
         if (cancellationToken.CanBeCanceled)

+ 37 - 0
tests/Lua.Tests/AsyncTests.cs

@@ -0,0 +1,37 @@
+using Lua.Standard;
+
+namespace Lua.Tests;
+
+public class AsyncTests
+{
+    LuaState state = default!;
+
+    [SetUp]
+    public void SetUp()
+    {
+        state = LuaState.Create();
+        state.OpenStandardLibraries();
+        var assert = state.Environment["assert"].Read<LuaFunction>() ;
+        state.Environment["assert"] = new LuaFunction("wait",
+            async (c, ct) =>
+            {
+                await Task.Delay(1, ct);
+                return await assert.InvokeAsync(c, ct);
+            });
+    }
+    
+    [Test]
+    public  async Task Test_Async()
+    {
+        var path = FileHelper.GetAbsolutePath("tests-lua/coroutine.lua");
+        try
+        {
+            await state.DoFileAsync(path);
+        }
+        catch (LuaRuntimeException e)
+        {
+            var line = e.LuaTraceback.LastLine;
+            throw new Exception($"{path}:line {line}\n{e.InnerException}\n {e}");
+        }
+    }
+}