فهرست منبع

Allow Task.WaitAny completion to run synchronously (dotnet/coreclr#21245)

Signed-off-by: dotnet-bot <[email protected]>
Ben Adams 7 سال پیش
والد
کامیت
bce27a80de

+ 2 - 2
netcore/System.Private.CoreLib/shared/System/Threading/Tasks/Task.cs

@@ -4394,7 +4394,7 @@ namespace System.Threading.Tasks
             AddCompletionAction(action, addBeforeOthers: false);
         }
 
-        private void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOthers)
+        internal void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOthers)
         {
             if (!AddTaskContinuation(action, addBeforeOthers))
                 action.Invoke(this); // run the action directly if we failed to queue the continuation (i.e., the task completed)
@@ -5127,7 +5127,7 @@ namespace System.Threading.Tasks
 
             if (signaledTaskIndex == -1 && tasks.Length != 0)
             {
-                Task<Task> firstCompleted = TaskFactory.CommonCWAnyLogic(tasks);
+                Task<Task> firstCompleted = TaskFactory.CommonCWAnyLogic(tasks, isSyncBlocking: true);
                 bool waitCompleted = firstCompleted.Wait(millisecondsTimeout, cancellationToken);
                 if (waitCompleted)
                 {

+ 21 - 8
netcore/System.Private.CoreLib/shared/System/Threading/Tasks/TaskFactory.cs

@@ -2281,14 +2281,23 @@ namespace System.Threading.Tasks
         // Used in TaskFactory.CommonCWAnyLogic(), below.
         internal sealed class CompleteOnInvokePromise : Task<Task>, ITaskCompletionAction
         {
+            private const int CompletedFlag = 0b_01;
+            private const int SyncBlockingFlag = 0b_10;
+
             private IList<Task> _tasks; // must track this for cleanup
-            private int m_firstTaskAlreadyCompleted;
+            private int _stateFlags;
 
-            public CompleteOnInvokePromise(IList<Task> tasks) : base()
+            public CompleteOnInvokePromise(IList<Task> tasks, bool isSyncBlocking) : base()
             {
                 Debug.Assert(tasks != null, "Expected non-null collection of tasks");
                 _tasks = tasks;
 
+                if (isSyncBlocking)
+                {
+                    // Not completed, but blocking thread, set second bit
+                    _stateFlags = SyncBlockingFlag;
+                }
+
                 if (AsyncCausalityTracer.LoggingOn)
                     AsyncCausalityTracer.TraceOperationCreation(this, "TaskFactory.ContinueWhenAny");
 
@@ -2298,8 +2307,12 @@ namespace System.Threading.Tasks
 
             public void Invoke(Task completingTask)
             {
-                if (m_firstTaskAlreadyCompleted == 0 &&
-                    Interlocked.Exchange(ref m_firstTaskAlreadyCompleted, 1) == 0)
+                int flags = _stateFlags;
+                int isSyncBlockingFlag = flags & SyncBlockingFlag;
+                int isCompleted = flags & CompletedFlag;
+
+                if (isCompleted == 0 &&
+                    Interlocked.Exchange(ref _stateFlags, isSyncBlockingFlag | CompletedFlag) == isSyncBlockingFlag)
                 {
                     if (AsyncCausalityTracer.LoggingOn)
                     {
@@ -2330,20 +2343,20 @@ namespace System.Threading.Tasks
                 }
             }
 
-            public bool InvokeMayRunArbitraryCode { get { return true; } }
+            public bool InvokeMayRunArbitraryCode => (_stateFlags & SyncBlockingFlag) == 0;
         }
         // Common ContinueWhenAny logic
         // If the tasks list is not an array, it must be an internal defensive copy so that 
         // we don't need to be concerned about concurrent modifications to the list.  If the task list
         // is an array, it should be a defensive copy if this functionality is being used
         // asynchronously (e.g. WhenAny) rather than synchronously (e.g. WaitAny).
-        internal static Task<Task> CommonCWAnyLogic(IList<Task> tasks)
+        internal static Task<Task> CommonCWAnyLogic(IList<Task> tasks, bool isSyncBlocking = false)
         {
             Debug.Assert(tasks != null);
 
             // Create a promise task to be returned to the user.
             // (If this logic ever changes, also update CommonCWAnyLogicCleanup.)
-            var promise = new CompleteOnInvokePromise(tasks);
+            var promise = new CompleteOnInvokePromise(tasks, isSyncBlocking);
 
             // At the completion of any of the tasks, complete the promise.
 
@@ -2370,7 +2383,7 @@ namespace System.Threading.Tasks
                 // Otherwise, add the completion action and keep going.
                 else
                 {
-                    task.AddCompletionAction(promise);
+                    task.AddCompletionAction(promise, addBeforeOthers: isSyncBlocking);
                     if (promise.IsCompleted)
                     {
                         // One of the previous tasks that already had its continuation registered may have