浏览代码

[PLinq] Include Cancellation checks deeper in the pipeline processing

Jérémie Laval 15 年之前
父节点
当前提交
0da162b9f2

+ 1 - 1
mcs/class/System.Core/System.Linq.Parallel.QueryNodes/QueryGroupByNode.cs

@@ -76,7 +76,7 @@ namespace System.Linq.Parallel.QueryNodes
 		{
 			var store = new ConcurrentDictionary<TKey, ConcurrentQueue<TElement>> ();
 
-			ParallelExecuter.ProcessAndBlock (Parent, (e) => {
+			ParallelExecuter.ProcessAndBlock (Parent, (e, c) => {
 					ConcurrentQueue<TElement> queue = store.GetOrAdd (keySelector (e), (_) => new ConcurrentQueue<TElement> ());
 					queue.Enqueue (elementSelector (e));
 				});

+ 1 - 1
mcs/class/System.Core/System.Linq.Parallel/OrderingEnumerator.cs

@@ -191,7 +191,7 @@ namespace System.Linq.Parallel
 			}
 		}
 		
-		public void Add (KeyValuePair<long, T> value)
+		public void Add (KeyValuePair<long, T> value, CancellationToken token)
 		{
 			slotBucket.Add (value);
 		}

+ 14 - 22
mcs/class/System.Core/System.Linq.Parallel/ParallelExecuter.cs

@@ -79,34 +79,26 @@ namespace System.Linq.Parallel
 			return blocking ? Environment.ProcessorCount + 1 : Environment.ProcessorCount;
 		}
 
-		internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node, Action<TElement> call,
+		internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node,
+		                                                   Action<TElement, CancellationToken> call,
 		                                                   Func<QueryBaseNode<TSource>, QueryOptions, IList<IEnumerable<TElement>>> acquisitionFunc,
 		                                                   QueryOptions options)
 		{
 			return Process<TSource, TElement> (node, call, acquisitionFunc, null, options);
 		}
-		
-		internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node, Action<TElement> call,
-		                                                   Func<QueryBaseNode<TSource>, QueryOptions, IList<IEnumerable<TElement>>> acquisitionFunc,
-		                                                   Action endAction,
-		                                                   QueryOptions options)
-		{
-			return Process<TSource, TElement> (node,
-			                                   (e, i) => call (e),
-			                                   acquisitionFunc,
-			                                   endAction == null ? ((Action<int>)null) : (i) => endAction (),
-			                                   options);
-		}
 
-		internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node, Action<TElement, int> call,
+		internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node,
+		                                                   Action<TElement, CancellationToken> call,
 		                                                   Func<QueryBaseNode<TSource>, QueryOptions, IList<IEnumerable<TElement>>> acquisitionFunc,
-		                                                   Action<int> endAction,
+		                                                   Action endAction,
 		                                                   QueryOptions options)
 		{
 			IList<IEnumerable<TElement>> enumerables = acquisitionFunc (node, options);
 
 			Task[] tasks = new Task[enumerables.Count];
-			
+			CancellationTokenSource src
+				= CancellationTokenSource.CreateLinkedTokenSource (options.ImplementerToken, options.Token);
+
 			for (int i = 0; i < tasks.Length; i++) {
 				int index = i;
 				tasks[i] = Task.Factory.StartNew (() => {
@@ -117,17 +109,17 @@ namespace System.Linq.Parallel
 						if (options.Token.IsCancellationRequested)
 							throw new OperationCanceledException (options.Token);
 
-						call (item, index);
+						call (item, src.Token);
 					}
 					if (endAction != null)
-						endAction (index);
+						endAction ();
 				  }, options.Token);
 			}
 
 			return tasks;
 		}
 
-		internal static void ProcessAndBlock<T> (QueryBaseNode<T> node, Action<T> call)
+		internal static void ProcessAndBlock<T> (QueryBaseNode<T> node, Action<T, CancellationToken> call)
 		{
 			QueryOptions options = CheckQuery (node, true);
 
@@ -135,7 +127,7 @@ namespace System.Linq.Parallel
 			Task.WaitAll (tasks, options.Token);
 		}
 
-		internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<T> call,
+		internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<T, CancellationToken> call,
 		                                              Action callback, QueryOptions options)
 		{
 			Task[] tasks = Process (node, call, (n, o) => n.GetEnumerables (o), options);
@@ -145,8 +137,8 @@ namespace System.Linq.Parallel
 			return () => Task.WaitAll (tasks, options.Token);
 		}
 
-		internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<KeyValuePair<long, T>, int> call,
-		                                              Action<int> endAction,
+		internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<KeyValuePair<long, T>, CancellationToken> call,
+		                                              Action endAction,
 		                                              Action callback, QueryOptions options)
 		{
 			Task[] tasks = Process (node, call, (n, o) => n.GetOrderedEnumerables (o), endAction, options);

+ 4 - 4
mcs/class/System.Core/System.Linq.Parallel/ParallelQueryEnumerator.cs

@@ -55,8 +55,8 @@ namespace System.Linq.Parallel
 			// Launch adding to the buffer asynchronously via Tasks
 			if (options.BehindOrderGuard.Value) {
 				waitAction = ParallelExecuter.ProcessAndCallback (node,
-				                                                  (e, i) => ordEnumerator.Add (e),
-				                                                  (i) => ordEnumerator.EndParticipation (),
+				                                                  ordEnumerator.Add,
+				                                                  ordEnumerator.EndParticipation,
 				                                                  ordEnumerator.Stop,
 				                                                  options);
 			} else {
@@ -80,8 +80,8 @@ namespace System.Linq.Parallel
 					buffer = new BlockingCollection<T> (DefaultBufferSize);
 				}
 
-				IEnumerable<T> source = buffer.GetConsumingEnumerable (options.Token);
-
+				var src = CancellationTokenSource.CreateLinkedTokenSource (options.Token, options.ImplementerToken);
+				IEnumerable<T> source = buffer.GetConsumingEnumerable (src.Token);
 				loader = source.GetEnumerator ();
 			} else {
 				loader = ordEnumerator = new OrderingEnumerator<T> (options.PartitionCount);

+ 22 - 11
mcs/class/System.Core/System.Linq/ParallelEnumerable.cs

@@ -355,7 +355,7 @@ namespace System.Linq
 			if (action == null)
 				throw new ArgumentNullException ("action");
 
-			ParallelExecuter.ProcessAndBlock (source.Node, action);
+			ParallelExecuter.ProcessAndBlock (source.Node, (e, c) => action (e));
 		}
 		#endregion
 
@@ -463,12 +463,17 @@ namespace System.Linq
 			ParallelQuery<TSource> innerQuery = source.WithImplementerToken (src);
 
 			bool result = true;
-			innerQuery.ForAll ((e) => {
-				if (!predicate (e)) {
-					result = false;
-					src.Cancel ();
-				}
-			});
+			try {
+				innerQuery.ForAll ((e) => {
+						if (!predicate (e)) {
+							result = false;
+							src.Cancel ();
+						}
+					});
+			} catch (OperationCanceledException e) {
+				if (e.CancellationToken != src.Token)
+					throw e;
+			}
 
 			return result;
 		}
@@ -537,10 +542,15 @@ namespace System.Linq
 
 			bool result = true;
 
-			innerQuery.ForAll ((value) => {
-				result = false;
-				source.Cancel ();
-			});
+			try {
+				innerQuery.ForAll ((value) => {
+						result = false;
+						source.Cancel ();
+					});
+			} catch (OperationCanceledException e) {
+				if (e.CancellationToken != source.Token)
+					throw e;
+			}
 
 			return result;
 		}
@@ -2060,6 +2070,7 @@ namespace System.Linq
 			
 			TSource result = enumerator.Current;
 			src.Cancel ();
+			enumerator.Dispose ();
 			
 			return result;
 		}