瀏覽代碼

[PLinq] Add implementation of Take/TakeWhile operators (including corresponding query node and unit tests)

Jérémie Laval 15 年之前
父節點
當前提交
f00afa22f2

+ 147 - 0
mcs/class/System.Core/System.Linq.Parallel.QueryNodes/QueryHeadWorkerNode.cs

@@ -0,0 +1,147 @@
+//
+// QueryConcatNode.cs
+//
+// Author:
+//       Jérémie "Garuma" Laval <[email protected]>
+//
+// Copyright (c) 2010 Jérémie "Garuma" Laval
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+#if NET_4_0
+using System;
+using System.Threading;
+using System.Collections;
+using System.Collections.Generic;
+
+namespace System.Linq.Parallel.QueryNodes
+{
+	/* This is the QueryNode used by Take(While) operator
+	 * it symbolize operators that are preferably working on the head elements of a query and can prematurely
+	 * stop providing elements following the one they were processing is of a greater value in a specific 
+	 * order to be defined by the instance (e.g. simple numerical order when working on indexes).
+	 */
+	internal class QueryHeadWorkerNode<TSource> : QueryStreamNode<TSource, TSource>
+	{
+		/* This variable will receive an index value that represent the "stop point"
+		 * when used with GetOrderedEnumerables i.e. values that are above the indexes are discarded
+		 * (if the partitioner is ordered in a partition it even stop the processing) and value below are still tested just
+		 * in case and can still lower this gap limit.
+		 */
+		readonly int count;
+		readonly Func<TSource, int, bool> predicate;
+		
+		internal QueryHeadWorkerNode (QueryBaseNode<TSource> parent, int count)
+			: base (parent, false)
+		{
+			this.count = count;
+		}
+
+		internal QueryHeadWorkerNode (QueryBaseNode<TSource> parent, Func<TSource, int, bool> predicate, bool indexed)
+			: base (parent, indexed)
+		{
+			this.predicate = predicate;
+		}
+
+		internal int? Count {
+			get {
+				return predicate == null ? count : (int?)null;
+			}
+		}
+
+		internal override IEnumerable<TSource> GetSequential ()
+		{
+			IEnumerable<TSource> parent = Parent.GetSequential ();
+
+			return predicate == null ? parent.Take (count) : parent.TakeWhile (predicate);
+		}
+
+		public override void Visit (INodeVisitor visitor)
+		{
+			visitor.Visit (this);
+		}
+
+		internal override IList<IEnumerable<TSource>> GetEnumerablesIndexed (QueryOptions options)
+		{	
+			return Parent.GetOrderedEnumerables (options)
+				.Select ((i) => i.TakeWhile ((e) => predicate (e.Value, (int)e.Key)).Select ((e) => e.Value))
+				.ToList ();
+		}
+
+		internal override IList<IEnumerable<TSource>> GetEnumerablesNonIndexed (QueryOptions options)
+		{
+			return Parent.GetEnumerables (options)
+				.Select (GetSelector (count))
+				.ToList ();
+		}
+
+		Func<IEnumerable<TSource>, IEnumerable<TSource>> GetSelector (int c)
+		{
+			if (predicate == null)
+				return (i) => i.TakeWhile ((e) => c > 0 && Interlocked.Decrement (ref c) >= 0);
+			else
+				return (i) => i.TakeWhile ((e) => predicate (e, -1));
+		}
+
+		internal override IList<IEnumerable<KeyValuePair<long, TSource>>> GetOrderedEnumerables (QueryOptions options)
+		{
+			return Parent.GetOrderedEnumerables (options)
+				.Select ((i) => GetEnumerableInternal (i, options))
+				.ToList ();
+		}
+
+		IEnumerable<KeyValuePair<long, TSource>> GetEnumerableInternal (IEnumerable<KeyValuePair<long, TSource>> source, QueryOptions options)
+		{
+			IEnumerator<KeyValuePair<long, TSource>> current = source.GetEnumerator ();
+			long gapIndex = predicate == null ? count : long.MaxValue;
+			Func<KeyValuePair<long, TSource>, bool> cond;
+			if (predicate == null)
+				cond = (kv) => kv.Key < count;
+			else
+				cond = (kv) => predicate (kv.Value, (int)kv.Key);
+			
+			try {
+				while (current.MoveNext ()) {
+					KeyValuePair<long, TSource> kvp = current.Current;
+
+					/* When filtering is based on a predicate, this short-circuit is only valid 
+					 * if the partitioner used ensure items are ordered in each partition 
+					 * (valid w/ default partitioners)
+					 */
+					if (kvp.Key >= gapIndex && options.PartitionerSettings.Item2)
+						break;
+
+					if (!cond (kvp)) {
+						if (gapIndex > kvp.Key && predicate != null)
+							gapIndex = kvp.Key;
+
+						continue;
+					}
+					
+					yield return kvp;
+				}
+			} finally {
+				current.Dispose ();
+			}
+		}
+	}
+
+}
+
+#endif

+ 3 - 4
mcs/class/System.Core/System.Linq/ParallelEnumerable.cs

@@ -906,13 +906,12 @@ namespace System.Linq
 		#endregion
 
 		#region Take
-		// TODO : introduce some early break up here, use ImplementerToken
 		public static ParallelQuery<TSource> Take<TSource> (this ParallelQuery<TSource> source, int count)
 		{
 			if (source == null)
 				throw new ArgumentNullException ("source");
 
-			return source.Where ((e, i) => i < count);
+			return new ParallelQuery<TSource> (new QueryHeadWorkerNode<TSource> (source.Node, count));
 		}
 
 		public static ParallelQuery<TSource> TakeWhile<TSource> (this ParallelQuery<TSource> source,
@@ -923,7 +922,7 @@ namespace System.Linq
 			if (predicate == null)
 				throw new ArgumentNullException ("predicate");
 
-			return source.Where ((e) => predicate (e));
+			return new ParallelQuery<TSource> (new QueryHeadWorkerNode<TSource> (source.Node, (e, _) => predicate (e), false));
 		}
 
 		public static ParallelQuery<TSource> TakeWhile<TSource> (this ParallelQuery<TSource> source,
@@ -934,7 +933,7 @@ namespace System.Linq
 			if (predicate == null)
 				throw new ArgumentNullException ("predicate");
 
-			return source.Where ((e, i) => predicate (e, i));
+			return new ParallelQuery<TSource> (new QueryHeadWorkerNode<TSource> (source.Node, predicate, true));
 		}
 		#endregion
 

+ 29 - 13
mcs/class/System.Core/Test/System.Linq/ParallelEnumerableTests.cs

@@ -313,7 +313,7 @@ namespace MonoTests.System.Linq
 			AssertAreSame (result, data.AsParallel ().AsOrdered ().SkipWhile (i => i < 3));
 		}
 
-		[Test, Ignore]
+		[Test]
 		public void TestTake ()
 		{
 			int [] data = {0, 1, 2, 3, 4, 5};
@@ -322,7 +322,7 @@ namespace MonoTests.System.Linq
 			AssertAreSame (result, data.AsParallel ().AsOrdered ().Take (3));
 		}
 
-		[Test, Ignore]
+		[Test]
 		public void TestTakeWhile ()
 		{
 			int [] data = {0, 1, 2, 3, 4, 5};
@@ -527,20 +527,36 @@ namespace MonoTests.System.Linq
 			});
 		}
 		
-		[TestAttribute, Ignore]
+		[TestAttribute]
 		public void TakeTestCase()
 		{
 			ParallelTestHelper.Repeat (() => {
-				ParallelQuery<int> async = baseEnumerable.AsParallel().Take(2000);
-				IEnumerable<int> sync = baseEnumerable.Take(2000);
-				
-				AreEquivalent(sync, async, 1);
-				
-				async = baseEnumerable.AsParallel().Take(100);
-				sync = baseEnumerable.Take(100);
-			
-				AreEquivalent(sync, async, 2);
-			}, 20);
+					ParallelQuery<int> async = baseEnumerable.AsParallel().AsOrdered ().Take(2000);
+					IEnumerable<int> sync = baseEnumerable.Take(2000);
+
+					AreEquivalent(sync, async, 1);
+
+					async = baseEnumerable.AsParallel().AsOrdered ().Take(100);
+					sync = baseEnumerable.Take(100);
+
+					AreEquivalent(sync, async, 2);
+				}, 20);
+		}
+
+		[TestAttribute]
+		public void UnorderedTakeTestCase()
+		{
+			ParallelTestHelper.Repeat (() => {
+					ParallelQuery<int> async = baseEnumerable.AsParallel().Take(2000);
+					IEnumerable<int> sync = baseEnumerable.Take (2000);
+
+					Assert.AreEqual (sync.Count (), async.Count (), "#1");
+
+					async = baseEnumerable.AsParallel().Take(100);
+					sync = baseEnumerable.Take(100);
+
+					Assert.AreEqual (sync.Count (), async.Count (), "#2");
+				}, 20);
 		}
 		
 		[Test]

+ 1 - 0
mcs/class/System.Core/net_4_0_System.Core.dll.sources

@@ -247,6 +247,7 @@ System.Linq.Parallel.QueryNodes/QuerySetNode.cs
 System.Linq.Parallel.QueryNodes/QueryReverseNode.cs
 System.Linq.Parallel.QueryNodes/SetInclusion.cs
 System.Linq.Parallel.QueryNodes/WrapHelper.cs
+System.Linq.Parallel.QueryNodes/QueryHeadWorkerNode.cs
 System.Linq.Parallel/INodeVisitor.cs
 System.Linq.Parallel/IVisitableNode.cs
 System.Linq.Parallel/QueryCheckerVisitor.cs