Forráskód Böngészése

Fixes Task.WaitAny with further fixes to Task.Wait

Marek Safar 14 éve
szülő
commit
a3d97ba5df

+ 38 - 12
mcs/class/corlib/System.Threading.Tasks/CompletionContainer.cs

@@ -1,5 +1,5 @@
 //
-// CompletionContainer.cs
+// TaskCompletionQueue.cs
 //
 // Authors:
 //    Jérémie Laval <jeremie dot laval at xamarin dot com>
@@ -35,18 +35,30 @@ using System.Collections.Concurrent;
 
 namespace System.Threading.Tasks
 {
-	internal struct CompletionContainer
+	struct TaskCompletionQueue
 	{
-		Task single;
-		ConcurrentQueue<Task> completed;
+		object single;
+		ConcurrentQueue<object> completed;
 
 		public void Add (Task continuation)
 		{
-			if (single == null && Interlocked.CompareExchange (ref single, continuation, null) == null)
+			AddAction (continuation);
+		}
+
+		public void Add (ManualResetEventSlim resetEvent)
+		{
+			AddAction (resetEvent);
+		}
+
+		void AddAction (object action)
+		{
+			if (single == null && Interlocked.CompareExchange (ref single, action, null) == null)
 				return;
+
 			if (completed == null)
-				Interlocked.CompareExchange (ref completed, new ConcurrentQueue<Task> (), null);
-			completed.Enqueue (continuation);
+				Interlocked.CompareExchange (ref completed, new ConcurrentQueue<object> (), null);
+
+			completed.Enqueue (action);
 		}
 
 		public bool HasElements {
@@ -55,14 +67,28 @@ namespace System.Threading.Tasks
 			}
 		}
 
-		public bool TryGetNextCompletion (out Task continuation)
+		public bool TryGetNext (out object value)
 		{
-			continuation = null;
-
-			if (single != null && (continuation = Interlocked.Exchange (ref single, null)) != null)
+			if (single != null && (value = Interlocked.Exchange (ref single, null)) != null)
 				return true;
 
-			return completed != null && completed.TryDequeue (out continuation);
+			if (completed != null)
+				return completed.TryDequeue (out value);
+
+			value = null;
+			return false;
+		}
+
+		public void TryRemove (object value)
+		{
+			if (value == null)
+				throw new ArgumentNullException ("value");
+
+			if (single != null && (Interlocked.CompareExchange (ref single, null, value) != single))
+				return;
+
+			if (completed != null)
+				completed.TryDequeue (out value);
 		}
 	}
 }

+ 74 - 94
mcs/class/corlib/System.Threading.Tasks/Task.cs

@@ -73,7 +73,7 @@ namespace System.Threading.Tasks
 		object         state;
 		AtomicBooleanValue executing;
 
-		CompletionContainer completed;
+		TaskCompletionQueue completed;
 		// If this task is a continuation, this stuff gets filled
 		CompletionSlot Slot;
 
@@ -320,7 +320,7 @@ namespace System.Threading.Tasks
 			continuation.Slot = new CompletionSlot (kind, predicate);
 
 			if (IsCompleted) {
-				CompletionExecutor (continuation);
+				CompletionTaskExecutor (continuation);
 				return;
 			}
 			
@@ -328,7 +328,7 @@ namespace System.Threading.Tasks
 			
 			// Retry in case completion was achieved but event adding was too late
 			if (IsCompleted)
-				CompletionExecutor (continuation);
+				CompletionTaskExecutor (continuation);
 		}
 
 		
@@ -509,7 +509,30 @@ namespace System.Threading.Tasks
 			}
 		}
 
-		void CompletionExecutor (Task cont)
+		void ProcessCompleteDelegates ()
+		{
+			if (!completed.HasElements)
+				return;
+
+			object value;
+			while (completed.TryGetNext (out value)) {
+				var t = value as Task;
+				if (t != null) {
+					CompletionTaskExecutor (t);
+					continue;
+				}
+
+				var mre = value as ManualResetEventSlim;
+				if (mre != null) {
+					mre.Set ();
+					continue;
+				}
+
+				throw new NotImplementedException ("Unknown completition type " + t.GetType ());
+			}
+		}
+
+		void CompletionTaskExecutor (Task cont)
 		{
 			if (cont.Slot.Predicate != null && !cont.Slot.Predicate ())
 				return;
@@ -530,16 +553,6 @@ namespace System.Threading.Tasks
 				cont.Schedule ();
 		}
 
-		void ProcessCompleteDelegates ()
-		{
-			if (!completed.HasElements)
-				return;
-
-			Task continuation;
-			while (completed.TryGetNextCompletion (out continuation))
-				CompletionExecutor (continuation);
-		}
-
 		void ProcessChildExceptions ()
 		{
 			if (childExceptions == null)
@@ -603,29 +616,32 @@ namespace System.Threading.Tasks
 
 			bool result = IsCompleted;
 			if (!result) {
-				if (scheduler == null) {
-					Watch watch = Watch.StartNew ();
-
-					schedWait.Wait (millisecondsTimeout, cancellationToken);
-					millisecondsTimeout = ComputeTimeout (millisecondsTimeout, watch);
-				}
-
-				var wait_event = new ManualResetEventSlim (false);
 				CancellationTokenRegistration? registration = null;
+				var completed_event =  new ManualResetEventSlim (false);
 
 				try {
 					if (cancellationToken.CanBeCanceled) {
-						registration = cancellationToken.Register (wait_event.Set);
+						registration = cancellationToken.Register (completed_event.Set);
 					}
 
-					// FIXME: The implementation is wrong and slow
-					// It adds a continuation to the task which is then
-					// returned to parent causing all sort of problems when
-					// timeout is reached before task is finished
-					result = !scheduler.ParticipateUntil (this, wait_event, millisecondsTimeout);
+					completed.Add (completed_event);
+
+					// Task could complete while we were setting things up
+					if (IsCompleted) {
+						// Don't bother removing completed_event, GC can handle it
+						result = true;
+					} else {
+						result = completed_event.Wait (millisecondsTimeout);
+					}
 				} finally {
 					if (registration.HasValue)
 						registration.Value.Dispose ();
+
+					// Try to remove completition event when timeout expired
+					if (!result)
+						completed.TryRemove (completed_event);
+
+					completed_event.Dispose ();
 				}
 			}
 
@@ -662,6 +678,9 @@ namespace System.Threading.Tasks
 		{
 			if (tasks == null)
 				throw new ArgumentNullException ("tasks");
+
+			if (millisecondsTimeout < -1)
+				throw new ArgumentOutOfRangeException ("millisecondsTimeout");
 			
 			bool result = true;
 			bool simple_run = millisecondsTimeout == Timeout.Infinite || tasks.Length == 1;
@@ -705,7 +724,7 @@ namespace System.Threading.Tasks
 		
 		public static int WaitAny (params Task[] tasks)
 		{
-			return WaitAny (tasks, -1, CancellationToken.None);
+			return WaitAny (tasks, Timeout.Infinite, CancellationToken.None);
 		}
 
 		public static int WaitAny (Task[] tasks, TimeSpan timeout)
@@ -715,90 +734,51 @@ namespace System.Threading.Tasks
 		
 		public static int WaitAny (Task[] tasks, int millisecondsTimeout)
 		{
-			if (millisecondsTimeout < -1)
-				throw new ArgumentOutOfRangeException ("millisecondsTimeout");
-
-			if (millisecondsTimeout == -1)
-				return WaitAny (tasks);
-
 			return WaitAny (tasks, millisecondsTimeout, CancellationToken.None);
 		}
 
 		public static int WaitAny (Task[] tasks, CancellationToken cancellationToken)
 		{
-			return WaitAny (tasks, -1, cancellationToken);
+			return WaitAny (tasks, Timeout.Infinite, cancellationToken);
 		}
 
 		public static int WaitAny (Task[] tasks, int millisecondsTimeout, CancellationToken cancellationToken)
 		{
 			if (tasks == null)
 				throw new ArgumentNullException ("tasks");
-			if (tasks.Length == 0)
-				throw new ArgumentException ("tasks is empty", "tasks");
-			if (tasks.Length == 1) {
-				tasks[0].Wait (millisecondsTimeout, cancellationToken);
-				return 0;
-			}
-			
-			int numFinished = 0;
-			int indexFirstFinished = -1;
-			int index = 0;
-			TaskScheduler sched = null;
-			Task task = null;
-			Watch watch = Watch.StartNew ();
-			ManualResetEventSlim predicateEvt = new ManualResetEventSlim (false);
-
-			foreach (Task t in tasks) {
-				int indexResult = index++;
-				t.ContinueWith (delegate {
-					if (numFinished >= 1)
-						return;
-					int result = Interlocked.Increment (ref numFinished);
-
-					// Check if we are the first to have finished
-					if (result == 1)
-						indexFirstFinished = indexResult;
-
-					// Stop waiting
-					predicateEvt.Set ();
-				}, TaskContinuationOptions.ExecuteSynchronously);
-
-				if (sched == null && t.scheduler != null) {
-					task = t;
-					sched = t.scheduler;
-				}
-			}
 
-			// If none of task have a scheduler we are forced to wait for at least one to start
-			if (sched == null) {
-				var handles = Array.ConvertAll (tasks, t => t.schedWait.WaitHandle);
-				int shandle = -1;
-				if ((shandle = WaitHandle.WaitAny (handles, millisecondsTimeout)) == WaitHandle.WaitTimeout)
-					return -1;
-				sched = tasks[shandle].scheduler;
-				task = tasks[shandle];
-				millisecondsTimeout = ComputeTimeout (millisecondsTimeout, watch);
-			}
+			int first_finished = -1;
+			for (int i = 0; i < tasks.Length; ++i) {
+				var t = tasks [i];
 
-			// One task already finished
-			if (indexFirstFinished != -1)
-				return indexFirstFinished;
+				if (t == null)
+					throw new ArgumentNullException ("tasks", "the tasks argument contains a null element");
 
-			if (cancellationToken != CancellationToken.None) {
-				cancellationToken.Register (predicateEvt.Set);
-				cancellationToken.ThrowIfCancellationRequested ();
+				if (first_finished < 0 && t.IsCompleted)
+					first_finished = i;
 			}
 
-			sched.ParticipateUntil (task, predicateEvt, millisecondsTimeout);
+			if (first_finished >= 0 || tasks.Length == 0)
+				return first_finished;
+
+			using (var completed_event = new ManualResetEventSlim (false)) {
 
-			// Index update is still not done
-			if (indexFirstFinished == -1) {
-				SpinWait wait = new SpinWait ();
-				while (indexFirstFinished == -1)
-					wait.SpinOnce ();
+				foreach (var t in tasks) {
+					t.completed.Add (completed_event);
+				}
+
+				completed_event.Wait (millisecondsTimeout, cancellationToken);
+
+				for (int i = 0; i < tasks.Length; ++i) {
+					var t = tasks[i];
+					if (first_finished < 0 && t.IsCompleted)
+						first_finished = i;
+
+					t.completed.TryRemove (completed_event);
+				}
 			}
 
-			return indexFirstFinished;
+			return first_finished;
 		}
 
 		static int CheckTimeout (TimeSpan timeout)

+ 109 - 8
mcs/class/corlib/Test/System.Threading.Tasks/TaskTest.cs

@@ -52,7 +52,7 @@ namespace MonoTests.System.Threading.Tasks
 			}
 		}
 		
-		[TestAttribute]
+		[Test]
 		public void WaitAnyTest()
 		{
 			ParallelTestHelper.Repeat (delegate {
@@ -69,15 +69,94 @@ namespace MonoTests.System.Threading.Tasks
 					}
 				});
 				
-				int index = Task.WaitAny(tasks);
+				int index = Task.WaitAny(tasks, 1000);
 				
 				Assert.AreNotEqual (-1, index, "#3");
 				Assert.AreEqual (1, flag, "#1");
 				Assert.AreEqual (1, finished, "#2");
-				
-				Task.WaitAll (tasks);
 			});
 		}
+
+		[Test]
+		public void WaitAny_Empty ()
+		{
+			Assert.AreEqual (-1, Task.WaitAny (new Task[0]));
+		}
+
+		[Test]
+		public void WaitAny_Zero ()
+		{
+			Assert.AreEqual (-1, Task.WaitAny (new Task[1] { new Task (delegate { })}, 0), "#1");
+			Assert.AreEqual (-1, Task.WaitAny (new Task[1] { new Task (delegate { }) }, 20), "#1");
+		}
+
+		[Test]
+		public void WaitAny_WithNull ()
+		{
+			var tasks = new [] {
+				Task.FromResult (2),
+				null
+			};
+
+			try {
+				Task.WaitAny (tasks);
+				Assert.Fail ();
+			} catch (ArgumentException) {
+			}
+		}
+
+		[Test]
+		public void WaitAny_Cancelled ()
+		{
+			var cancelation = new CancellationTokenSource ();
+			var tasks = new Task[] {
+				new Task (delegate { }),
+				new Task (delegate { }, cancelation.Token)
+			};
+
+			cancelation.Cancel ();
+
+			Assert.AreEqual (1, Task.WaitAny (tasks, 1000), "#1");
+			Assert.IsTrue (tasks[1].IsCompleted, "#2");
+			Assert.IsTrue (tasks[1].IsCanceled, "#3");
+		}
+
+		[Test]
+		public void WaitAny_CancelledWithoutExecution ()
+		{
+			var cancelation = new CancellationTokenSource ();
+			var tasks = new Task[] {
+				new Task (delegate { }),
+				new Task (delegate { })
+			};
+
+			int res = 0;
+			var mre = new ManualResetEventSlim (false);
+			ThreadPool.QueueUserWorkItem (delegate {
+				res = Task.WaitAny (tasks, 20);
+				mre.Set ();
+			});
+
+			cancelation.Cancel ();
+			Assert.IsTrue (mre.Wait (1000), "#1");
+			Assert.AreEqual (-1, res);
+		}
+
+		[Test]
+		public void WaitAny_OneException ()
+		{
+			var mre = new ManualResetEventSlim (false);
+			var tasks = new Task[] {
+				Task.Factory.StartNew (delegate { mre.Wait (1000); }),
+				Task.Factory.StartNew (delegate { throw new ApplicationException (); })
+			};
+
+			Assert.AreEqual (1, Task.WaitAny (tasks, 1000), "#1");
+			Assert.IsFalse (tasks[0].IsCompleted, "#2");
+			Assert.IsTrue (tasks[1].IsFaulted, "#3");
+
+			mre.Set ();
+		}
 		
 		[Test]
 		public void WaitAllTest()
@@ -90,6 +169,13 @@ namespace MonoTests.System.Threading.Tasks
 			});
 		}
 
+		[Test]
+		public void WaitAll_Zero ()
+		{
+			Assert.IsFalse (Task.WaitAll (new Task[1] { new Task (delegate { }) }, 0), "#0");
+			Assert.IsFalse (Task.WaitAll (new Task[1] { new Task (delegate { }) }, 10), "#1");
+		}
+
 		[Test]
 		public void WaitAllWithExceptions ()
 		{
@@ -152,6 +238,20 @@ namespace MonoTests.System.Threading.Tasks
 			Assert.IsTrue (tasks[1].IsCanceled, "#5");
 		}
 
+		[Test]
+		public void WaitAll_StartedUnderWait ()
+		{
+			var task1 = new Task (delegate { });
+
+			ThreadPool.QueueUserWorkItem (delegate {
+				// Sleep little to let task to start and hit internal wait
+				Thread.Sleep (20);
+				task1.Start ();
+			});
+
+			Assert.IsTrue (Task.WaitAll (new [] { task1 }, 1000), "#1");
+		}
+
 		[Test]
 		public void CancelBeforeStart ()
 		{
@@ -169,7 +269,7 @@ namespace MonoTests.System.Threading.Tasks
 		}
 
 		[Test]
-		public void CancelBeforeWait ()
+		public void Wait_CancelledTask ()
 		{
 			var src = new CancellationTokenSource ();
 
@@ -346,19 +446,20 @@ namespace MonoTests.System.Threading.Tasks
 		{
 			ParallelTestHelper.Repeat (delegate {
 				bool r1 = false, r2 = false, r3 = false;
+				var mre = new ManualResetEvent (false);
 				
 				Task t = Task.Factory.StartNew(delegate {
 					Task.Factory.StartNew(delegate {
-						Thread.Sleep(50);
 						r1 = true;
+						mre.Set ();
 					}, TaskCreationOptions.AttachedToParent);
 					Task.Factory.StartNew(delegate {
-						Thread.Sleep(300);
+						Assert.IsTrue (mre.WaitOne (1000), "#0");
 						
 						r2 = true;
 					}, TaskCreationOptions.AttachedToParent);
 					Task.Factory.StartNew(delegate {
-						Thread.Sleep(150);
+						Assert.IsTrue (mre.WaitOne (1000), "#0");
 						
 						r3 = true;
 					}, TaskCreationOptions.AttachedToParent);