瀏覽代碼

Pool tasks used in coroutines

Akeit0 7 月之前
父節點
當前提交
7d2e974d1d
共有 2 個文件被更改,包括 105 次插入13 次删除
  1. 91 6
      src/Lua/Internal/ValueTaskEx.cs
  2. 14 7
      src/Lua/LuaCoroutine.cs

+ 91 - 6
src/Lua/Internal/ValueTaskEx.cs

@@ -4,7 +4,7 @@ using System.Threading.Tasks.Sources;
 
 /*
 
-ValueTaskEx based on ValueTaskSupprement
+ValueTaskEx based on ValueTaskSupprement but modified for pooling
 https://github.com/Cysharp/ValueTaskSupplement
 
 MIT License
@@ -41,16 +41,53 @@ internal static class ContinuationSentinel
 
 internal static class ValueTaskEx
 {
-    public static ValueTask<(int winArgumentIndex, T0 result0, T1 result1)> WhenAny<T0, T1>(ValueTask<T0> task0, ValueTask<T1> task1)
+    public static ValueTask<(int winArgumentIndex, T0 result0, T1 result1, IDisposable promis)> WhenAnyPooled<T0, T1>(ValueTask<T0> task0, ValueTask<T1> task1)
     {
-        return new ValueTask<(int winArgumentIndex, T0 result0, T1 result1)>(new WhenAnyPromise<T0, T1>(task0, task1), 0);
+        var promise = WhenAnyPromise<T0, T1>.Get(task0, task1);
+        return new (promise, 0);
     }
 
-    class WhenAnyPromise<T0, T1> : IValueTaskSource<(int winArgumentIndex, T0 result0, T1 result1)>
+    class WhenAnyPromise<T0, T1> : IValueTaskSource<(int winArgumentIndex, T0 result0, T1 result1, IDisposable)>, IPoolNode<WhenAnyPromise<T0, T1>>, IDisposable
     {
         static readonly ContextCallback execContextCallback = ExecutionContextCallback!;
         static readonly SendOrPostCallback syncContextCallback = SynchronizationContextCallback!;
 
+        static LinkedPool<WhenAnyPromise<T0, T1>> pool;
+        WhenAnyPromise<T0, T1>? nextNode;
+        public ref WhenAnyPromise<T0, T1>? NextNode => ref nextNode;
+
+        public static WhenAnyPromise<T0, T1> Get(ValueTask<T0> task0, ValueTask<T1> task1)
+        {
+            if (pool.TryPop(out var f))
+            {
+                f.Init(task0, task1);
+            }
+            else
+            {
+                f = new WhenAnyPromise<T0, T1>(task0, task1);
+            }
+
+            return f;
+        }
+
+        public void Dispose()
+        {
+            t0 = default!;
+            t1 = default!;
+            awaiter0 = default!;
+            awaiter1 = default!;
+
+            completedCount = default!;
+            winArgumentIndex = -1;
+            exception = default!;
+            continuation = ContinuationSentinel.AvailableContinuation;
+            invokeContinuation = default!;
+            state = default!;
+            syncContext = default!;
+            execContext = default!;
+            pool.TryPush(this);
+        }
+
         T0 t0 = default!;
         T1 t1 = default!;
         ValueTaskAwaiter<T0> awaiter0;
@@ -65,6 +102,54 @@ internal static class ValueTaskEx
         SynchronizationContext? syncContext;
         ExecutionContext? execContext;
 
+        public void Init(ValueTask<T0> task0, ValueTask<T1> task1)
+        {
+            {
+                var awaiter = task0.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t0 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(0);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter0 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT0);
+                }
+            }
+            {
+                var awaiter = task1.GetAwaiter();
+                if (awaiter.IsCompleted)
+                {
+                    try
+                    {
+                        t1 = awaiter.GetResult();
+                        TryInvokeContinuationWithIncrement(1);
+                        return;
+                    }
+                    catch (Exception ex)
+                    {
+                        exception = ExceptionDispatchInfo.Capture(ex);
+                        return;
+                    }
+                }
+                else
+                {
+                    awaiter1 = awaiter;
+                    awaiter.UnsafeOnCompleted(ContinuationT1);
+                }
+            }
+        }
+
         public WhenAnyPromise(ValueTask<T0> task0, ValueTask<T1> task1)
         {
             {
@@ -183,7 +268,7 @@ internal static class ValueTaskEx
             }
         }
 
-        public (int winArgumentIndex, T0 result0, T1 result1) GetResult(short token)
+        public (int winArgumentIndex, T0 result0, T1 result1, IDisposable) GetResult(short token)
         {
             if (exception != null)
             {
@@ -191,7 +276,7 @@ internal static class ValueTaskEx
             }
 
             var i = winArgumentIndex;
-            return (winArgumentIndex, t0, t1);
+            return (winArgumentIndex, t0, t1, this);
         }
 
         public ValueTaskSourceStatus GetStatus(short token)

+ 14 - 7
src/Lua/LuaCoroutine.cs

@@ -1,6 +1,7 @@
 using System.Threading.Tasks.Sources;
 using Lua.Internal;
 using Lua.Runtime;
+using System.Runtime.CompilerServices;
 
 namespace Lua;
 
@@ -76,7 +77,9 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
     {
         return ResumeAsync(stack, stack.Count, 0, cancellationToken);
     }
-
+#if NET6_0_OR_GREATER
+    [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
+#endif
     public async ValueTask<int> ResumeAsync(LuaStack stack, int argCount, int returnBase, CancellationToken cancellationToken = default)
     {
         if (isFirstCall)
@@ -145,8 +148,8 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                 Volatile.Write(ref isFirstCall, false);
             }
 
-            var (index, result0, result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
-
+            var (index, result0, _, promise) = await ValueTaskEx.WhenAnyPooled(resumeTask, functionTask!);
+            promise.Dispose();
             if (index == 0)
             {
                 var results = result0.Results;
@@ -188,7 +191,9 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             resume.Reset();
         }
     }
-
+#if NET6_0_OR_GREATER
+    [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
+#endif
     public override async ValueTask<int> ResumeAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
     {
         var baseThread = context.Thread;
@@ -251,8 +256,8 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
                     Volatile.Write(ref isFirstCall, false);
                 }
 
-                var (index, result0, result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
-
+                var (index, result0, _, promise) = await ValueTaskEx.WhenAnyPooled(resumeTask, functionTask!);
+                promise.Dispose();
                 if (index == 0)
                 {
                     var results = result0.Results;
@@ -292,7 +297,9 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
         }
     }
-
+#if NET6_0_OR_GREATER
+    [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
+#endif
     public override async ValueTask<int> YieldAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
     {
         if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)