Browse Source

Optimize: coroutine async methods work with ValueTask

AnnulusGames 1 year ago
parent
commit
8c38ccabf0

+ 13 - 0
src/Lua/Internal/CancellationExtensions.cs

@@ -0,0 +1,13 @@
+#if NETSTANDARD2_0 || NETSTANDARD2_1
+
+namespace System.Threading;
+
+internal static class CancellationTokenExtensions
+{
+    public static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action<object?> callback, object? state)
+    {
+        return cancellationToken.Register(callback, state, useSynchronizationContext: false);
+    }
+}
+
+#endif

+ 263 - 0
src/Lua/Internal/ValueTaskEx.cs

@@ -0,0 +1,263 @@
+using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
+using System.Threading.Tasks.Sources;
+
+/*
+
+ValueTaskEx based on ValueTaskSupprement
+https://github.com/Cysharp/ValueTaskSupplement
+
+MIT License
+
+Copyright (c) 2019 Cysharp, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+*/
+
+namespace Lua.Internal;
+
+internal static class ContinuationSentinel
+{
+    public static readonly Action<object?> AvailableContinuation = _ => { };
+    public static readonly Action<object?> CompletedContinuation = _ => { };
+}
+
+internal static class ValueTaskEx
+{
+    public static ValueTask<(int winArgumentIndex, T0 result0, T1 result1)> WhenAny<T0, T1>(ValueTask<T0> task0, ValueTask<T1> task1)
+    {
+        return new ValueTask<(int winArgumentIndex, T0 result0, T1 result1)>(new WhenAnyPromise<T0, T1>(task0, task1), 0);
+    }
+
+    class WhenAnyPromise<T0, T1> : IValueTaskSource<(int winArgumentIndex, T0 result0, T1 result1)>
+    {
+        static readonly ContextCallback execContextCallback = ExecutionContextCallback!;
+        static readonly SendOrPostCallback syncContextCallback = SynchronizationContextCallback!;
+
+        T0 t0 = default!;
+        T1 t1 = default!;
+        ValueTaskAwaiter<T0> awaiter0;
+        ValueTaskAwaiter<T1> awaiter1;
+
+        int completedCount = 0;
+        int winArgumentIndex = -1;
+        ExceptionDispatchInfo? exception;
+        Action<object?> continuation = ContinuationSentinel.AvailableContinuation;
+        Action<object?>? invokeContinuation;
+        object? state;
+        SynchronizationContext? syncContext;
+        ExecutionContext? execContext;
+
+        public WhenAnyPromise(ValueTask<T0> task0, ValueTask<T1> task1)
+        {
+            {
+                var awaiter = task0.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t0 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(0);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter0 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT0);
+                }
+            }
+            {
+                var awaiter = task1.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t1 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(1);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter1 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT1);
+                }
+            }
+        }
+
+        void ContinuationT0()
+        {
+            try
+            {
+                t0 = awaiter0.GetResult();
+            }
+            catch (Exception ex)
+            {
+                exception = ExceptionDispatchInfo.Capture(ex);
+                TryInvokeContinuation();
+                return;
+            }
+            TryInvokeContinuationWithIncrement(0);
+        }
+
+        void ContinuationT1()
+        {
+            try
+            {
+                t1 = awaiter1.GetResult();
+            }
+            catch (Exception ex)
+            {
+                exception = ExceptionDispatchInfo.Capture(ex);
+                TryInvokeContinuation();
+                return;
+            }
+            TryInvokeContinuationWithIncrement(1);
+        }
+
+
+        void TryInvokeContinuationWithIncrement(int index)
+        {
+            if (Interlocked.Increment(ref completedCount) == 1)
+            {
+                Volatile.Write(ref winArgumentIndex, index);
+                TryInvokeContinuation();
+            }
+        }
+
+        void TryInvokeContinuation()
+        {
+            var c = Interlocked.Exchange(ref continuation, ContinuationSentinel.CompletedContinuation);
+            if (c != ContinuationSentinel.AvailableContinuation && c != ContinuationSentinel.CompletedContinuation)
+            {
+                var spinWait = new SpinWait();
+                while (state == null) // worst case, state is not set yet so wait.
+                {
+                    spinWait.SpinOnce();
+                }
+
+                if (execContext != null)
+                {
+                    invokeContinuation = c;
+                    ExecutionContext.Run(execContext, execContextCallback, this);
+                }
+                else if (syncContext != null)
+                {
+                    invokeContinuation = c;
+                    syncContext.Post(syncContextCallback, this);
+                }
+                else
+                {
+                    c(state);
+                }
+            }
+        }
+
+        public (int winArgumentIndex, T0 result0, T1 result1) GetResult(short token)
+        {
+            if (exception != null)
+            {
+                exception.Throw();
+            }
+            var i = winArgumentIndex;
+            return (winArgumentIndex, t0, t1);
+        }
+
+        public ValueTaskSourceStatus GetStatus(short token)
+        {
+            return Volatile.Read(ref winArgumentIndex) != -1 ? ValueTaskSourceStatus.Succeeded
+                : exception != null ? exception.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : ValueTaskSourceStatus.Faulted
+                : ValueTaskSourceStatus.Pending;
+        }
+
+        public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+        {
+            var c = Interlocked.CompareExchange(ref this.continuation, continuation, ContinuationSentinel.AvailableContinuation);
+            if (c == ContinuationSentinel.CompletedContinuation)
+            {
+                continuation(state);
+                return;
+            }
+
+            if (c != ContinuationSentinel.AvailableContinuation)
+            {
+                throw new InvalidOperationException("does not allow multiple await.");
+            }
+
+            if (state == null)
+            {
+                throw new InvalidOperationException("invalid state.");
+            }
+
+            if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) == ValueTaskSourceOnCompletedFlags.FlowExecutionContext)
+            {
+                execContext = ExecutionContext.Capture();
+            }
+            if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) == ValueTaskSourceOnCompletedFlags.UseSchedulingContext)
+            {
+                syncContext = SynchronizationContext.Current;
+            }
+            this.state = state;
+
+            if (GetStatus(token) != ValueTaskSourceStatus.Pending)
+            {
+                TryInvokeContinuation();
+            }
+        }
+
+        static void ExecutionContextCallback(object state)
+        {
+            var self = (WhenAnyPromise<T0, T1>)state;
+            if (self.syncContext != null)
+            {
+                self.syncContext.Post(syncContextCallback, self);
+            }
+            else
+            {
+                var invokeContinuation = self.invokeContinuation!;
+                var invokeState = self.state;
+                self.invokeContinuation = null;
+                self.state = null;
+                invokeContinuation(invokeState);
+            }
+        }
+
+        static void SynchronizationContextCallback(object state)
+        {
+            var self = (WhenAnyPromise<T0, T1>)state;
+            var invokeContinuation = self.invokeContinuation!;
+            var invokeState = self.state;
+            self.invokeContinuation = null;
+            self.state = null;
+            invokeContinuation(invokeState);
+        }
+    }
+}

+ 1 - 1
src/Lua/LuaFunction.cs

@@ -33,7 +33,7 @@ public abstract partial class LuaFunction
         {
         {
             return await InvokeAsyncCore(context, buffer, cancellationToken);
             return await InvokeAsyncCore(context, buffer, cancellationToken);
         }
         }
-        catch (Exception ex) when (ex is not LuaException)
+        catch (Exception ex) when (ex is not (LuaException or OperationCanceledException))
         {
         {
             throw new LuaRuntimeException(state.GetTracebacks(), ex.Message);
             throw new LuaRuntimeException(state.GetTracebacks(), ex.Message);
         }
         }

+ 104 - 46
src/Lua/LuaThread.cs

@@ -1,13 +1,25 @@
+using System.Threading.Tasks.Sources;
+using Lua.Internal;
+
 namespace Lua;
 namespace Lua;
 
 
-public sealed class LuaThread
+public sealed class LuaThread : IValueTaskSource<LuaThread.YieldContext>, IValueTaskSource<LuaThread.ResumeContext>
 {
 {
+    struct YieldContext
+    {
+    }
+
+    struct ResumeContext
+    {
+        public LuaValue[] Results;
+    }
+
     LuaThreadStatus status;
     LuaThreadStatus status;
     LuaState threadState;
     LuaState threadState;
-    Task<int>? functionTask;
+    ValueTask<int> functionTask;
 
 
-    TaskCompletionSource<LuaValue[]> resume = new();
-    TaskCompletionSource<object?> yield = new();
+    ManualResetValueTaskSourceCore<ResumeContext> resume;
+    ManualResetValueTaskSourceCore<YieldContext> yield;
 
 
     public LuaThreadStatus Status => status;
     public LuaThreadStatus Status => status;
     public bool IsProtectedMode { get; }
     public bool IsProtectedMode { get; }
@@ -21,7 +33,7 @@ public sealed class LuaThread
         function.SetCurrentThread(this);
         function.SetCurrentThread(this);
     }
     }
 
 
-    public async Task<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
+    public async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
     {
         if (status is LuaThreadStatus.Dead)
         if (status is LuaThreadStatus.Dead)
         {
         {
@@ -57,90 +69,136 @@ public sealed class LuaThread
                 ArgumentCount = context.ArgumentCount - 1,
                 ArgumentCount = context.ArgumentCount - 1,
                 ChunkName = Function.Name,
                 ChunkName = Function.Name,
                 RootChunkName = context.RootChunkName,
                 RootChunkName = context.RootChunkName,
-            }, buffer[1..], cancellationToken).AsTask();
+            }, buffer[1..], cancellationToken).Preserve();
         }
         }
         else
         else
         {
         {
             status = LuaThreadStatus.Running;
             status = LuaThreadStatus.Running;
+            yield.SetResult(new());
+        }
+
+        var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
+
+        CancellationTokenRegistration registration = default;
+        if (cancellationToken.CanBeCanceled)
+        {
+            registration = cancellationToken.UnsafeRegister(static x =>
+            {
+                var thread = (LuaThread)x!;
+                thread.yield.SetException(new OperationCanceledException());
+            }, this);
+        }
+
+        try
+        {
+            (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
 
 
-            if (cancellationToken.IsCancellationRequested)
+            if (index == 0)
             {
             {
-                yield.TrySetCanceled();
+                var results = result0.Results;
+
+                buffer.Span[0] = true;
+                for (int i = 0; i < results.Length; i++)
+                {
+                    buffer.Span[i + 1] = results[i];
+                }
+
+                return results.Length + 1;
             }
             }
             else
             else
             {
             {
-                yield.TrySetResult(null);
+                status = LuaThreadStatus.Dead;
+                buffer.Span[0] = true;
+                return 1 + functionTask!.Result;
             }
             }
         }
         }
-
-        var resumeTask = resume.Task;
-        var completedTask = await Task.WhenAny(resumeTask, functionTask!);
-
-        if (!completedTask.IsCompletedSuccessfully)
+        catch (Exception ex) when (ex is not OperationCanceledException)
         {
         {
             if (IsProtectedMode)
             if (IsProtectedMode)
             {
             {
                 status = LuaThreadStatus.Dead;
                 status = LuaThreadStatus.Dead;
                 buffer.Span[0] = false;
                 buffer.Span[0] = false;
-                buffer.Span[1] = completedTask.Exception.InnerException.Message;
+                buffer.Span[1] = ex.Message;
                 return 2;
                 return 2;
             }
             }
             else
             else
             {
             {
-                throw completedTask.Exception.InnerException;
-            }
-        }
-
-        if (completedTask == resumeTask)
-        {
-            resume = new();
-            var results = resumeTask.Result;
-
-            buffer.Span[0] = true;
-            for (int i = 0; i < results.Length; i++)
-            {
-                buffer.Span[i + 1] = results[i];
+                throw;
             }
             }
-
-            return results.Length + 1;
         }
         }
-        else
+        finally
         {
         {
-            status = LuaThreadStatus.Dead;
-            buffer.Span[0] = true;
-            return 1 + functionTask!.Result;
+            registration.Dispose();
+            resume.Reset();
         }
         }
     }
     }
 
 
-    public async Task Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
+    public async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
     {
     {
         if (status is not LuaThreadStatus.Running)
         if (status is not LuaThreadStatus.Running)
         {
         {
             throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
             throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
         }
         }
 
 
-        if (cancellationToken.IsCancellationRequested)
+        resume.SetResult(new()
         {
         {
-            resume.TrySetCanceled();
-        }
-        else
-        {
-            resume.TrySetResult(context.Arguments.ToArray());
-        }
+            Results = context.Arguments.ToArray(),
+        });
 
 
         status = LuaThreadStatus.Suspended;
         status = LuaThreadStatus.Suspended;
 
 
-RETRY:
+        CancellationTokenRegistration registration = default;
+        if (cancellationToken.CanBeCanceled)
+        {
+            registration = cancellationToken.UnsafeRegister(static x =>
+            {
+                var thread = (LuaThread)x!;
+                thread.yield.SetException(new OperationCanceledException());
+            }, this);
+        }
+
+    RETRY:
         try
         try
         {
         {
-            await yield.Task;
+            await new ValueTask<YieldContext>(this, yield.Version);
         }
         }
         catch (Exception ex) when (ex is not OperationCanceledException)
         catch (Exception ex) when (ex is not OperationCanceledException)
         {
         {
-            yield = new();
+            yield.Reset();
             goto RETRY;
             goto RETRY;
         }
         }
 
 
-        yield = new();
+        registration.Dispose();
+        yield.Reset();
+    }
+
+    YieldContext IValueTaskSource<YieldContext>.GetResult(short token)
+    {
+        return yield.GetResult(token);
+    }
+
+    ValueTaskSourceStatus IValueTaskSource<YieldContext>.GetStatus(short token)
+    {
+        return yield.GetStatus(token);
+    }
+
+    void IValueTaskSource<YieldContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+    {
+        yield.OnCompleted(continuation, state, token, flags);
+    }
+
+    ResumeContext IValueTaskSource<ResumeContext>.GetResult(short token)
+    {
+        return resume.GetResult(token);
+    }
+
+    ValueTaskSourceStatus IValueTaskSource<ResumeContext>.GetStatus(short token)
+    {
+        return resume.GetStatus(token);
+    }
+
+    void IValueTaskSource<ResumeContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+    {
+        resume.OnCompleted(continuation, state, token, flags);
     }
     }
 }
 }