SequenceQuery.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Collections.ObjectModel;
  5. using System.Linq.Expressions;
  6. using System.Reflection;
  7. using System.Text;
  8. // Include Silverlight's managed resources
  9. #if SILVERLIGHT
  10. using System.Core;
  11. #endif //SILVERLIGHT
  12. namespace System.Linq {
  13. // Must remain public for Silverlight
  14. public abstract class EnumerableQuery {
  15. internal abstract Expression Expression { get; }
  16. internal abstract IEnumerable Enumerable { get; }
  17. internal static IQueryable Create(Type elementType, IEnumerable sequence){
  18. Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
  19. #if SILVERLIGHT
  20. return (IQueryable) Activator.CreateInstance(seqType, sequence);
  21. #else
  22. return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {sequence}, null);
  23. #endif //SILVERLIGHT
  24. }
  25. internal static IQueryable Create(Type elementType, Expression expression) {
  26. Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
  27. #if SILVERLIGHT
  28. return (IQueryable) Activator.CreateInstance(seqType, expression);
  29. #else
  30. return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {expression}, null);
  31. #endif //SILVERLIGHT
  32. }
  33. }
  34. // Must remain public for Silverlight
  35. public class EnumerableQuery<T> : EnumerableQuery, IOrderedQueryable<T>, IQueryable, IQueryProvider, IEnumerable<T>, IEnumerable {
  36. Expression expression;
  37. IEnumerable<T> enumerable;
  38. IQueryProvider IQueryable.Provider {
  39. get{
  40. return (IQueryProvider)this;
  41. }
  42. }
  43. // Must remain public for Silverlight
  44. public EnumerableQuery(IEnumerable<T> enumerable) {
  45. this.enumerable = enumerable;
  46. this.expression = Expression.Constant(this);
  47. }
  48. // Must remain public for Silverlight
  49. public EnumerableQuery(Expression expression) {
  50. this.expression = expression;
  51. }
  52. internal override Expression Expression {
  53. get { return this.expression; }
  54. }
  55. internal override IEnumerable Enumerable {
  56. get { return this.enumerable; }
  57. }
  58. Expression IQueryable.Expression {
  59. get { return this.expression; }
  60. }
  61. Type IQueryable.ElementType {
  62. get { return typeof(T); }
  63. }
  64. IQueryable IQueryProvider.CreateQuery(Expression expression){
  65. if (expression == null)
  66. throw Error.ArgumentNull("expression");
  67. Type iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
  68. if (iqType == null)
  69. throw Error.ArgumentNotValid("expression");
  70. return EnumerableQuery.Create(iqType.GetGenericArguments()[0], expression);
  71. }
  72. IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression){
  73. if (expression == null)
  74. throw Error.ArgumentNull("expression");
  75. if (!typeof(IQueryable<S>).IsAssignableFrom(expression.Type)){
  76. throw Error.ArgumentNotValid("expression");
  77. }
  78. return new EnumerableQuery<S>(expression);
  79. }
  80. // Baselining as Safe for Mix demo so that interface can be transparent. Marking this
  81. // critical (which was the original annotation when porting to silverlight) would violate
  82. // fxcop security rules if the interface isn't also critical. However, transparent code
  83. // can't access this anyway for Mix since we're not exposing AsQueryable().
  84. // [....]: the above assertion no longer holds. Now making AsQueryable() public again
  85. // the security fallout of which will need to be re-examined.
  86. object IQueryProvider.Execute(Expression expression){
  87. if (expression == null)
  88. throw Error.ArgumentNull("expression");
  89. Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
  90. return EnumerableExecutor.Create(expression).ExecuteBoxed();
  91. }
  92. // see above
  93. S IQueryProvider.Execute<S>(Expression expression){
  94. if (expression == null)
  95. throw Error.ArgumentNull("expression");
  96. if (!typeof(S).IsAssignableFrom(expression.Type))
  97. throw Error.ArgumentNotValid("expression");
  98. return new EnumerableExecutor<S>(expression).Execute();
  99. }
  100. IEnumerator IEnumerable.GetEnumerator() {
  101. return this.GetEnumerator();
  102. }
  103. IEnumerator<T> IEnumerable<T>.GetEnumerator() {
  104. return this.GetEnumerator();
  105. }
  106. IEnumerator<T> GetEnumerator() {
  107. if (this.enumerable == null) {
  108. EnumerableRewriter rewriter = new EnumerableRewriter();
  109. Expression body = rewriter.Visit(this.expression);
  110. Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>)null);
  111. this.enumerable = f.Compile()();
  112. }
  113. return this.enumerable.GetEnumerator();
  114. }
  115. public override string ToString() {
  116. ConstantExpression c = this.expression as ConstantExpression;
  117. if (c != null && c.Value == this) {
  118. if (this.enumerable != null)
  119. return this.enumerable.ToString();
  120. return "null";
  121. }
  122. return this.expression.ToString();
  123. }
  124. }
  125. // Must remain public for Silverlight
  126. public abstract class EnumerableExecutor {
  127. internal abstract object ExecuteBoxed();
  128. internal static EnumerableExecutor Create(Expression expression) {
  129. Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
  130. #if SILVERLIGHT
  131. return (EnumerableExecutor)Activator.CreateInstance(execType, expression);
  132. #else
  133. return (EnumerableExecutor)Activator.CreateInstance(execType, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, new object[] { expression }, null);
  134. #endif //SILVERLIGHT
  135. }
  136. }
  137. // Must remain public for Silverlight
  138. public class EnumerableExecutor<T> : EnumerableExecutor{
  139. Expression expression;
  140. Func<T> func;
  141. // Must remain public for Silverlight
  142. public EnumerableExecutor(Expression expression){
  143. this.expression = expression;
  144. }
  145. internal override object ExecuteBoxed() {
  146. return this.Execute();
  147. }
  148. internal T Execute(){
  149. if (this.func == null){
  150. EnumerableRewriter rewriter = new EnumerableRewriter();
  151. Expression body = rewriter.Visit(this.expression);
  152. Expression<Func<T>> f = Expression.Lambda<Func<T>>(body, (IEnumerable<ParameterExpression>)null);
  153. this.func = f.Compile();
  154. }
  155. return this.func();
  156. }
  157. }
  158. //
  159. internal class EnumerableRewriter : OldExpressionVisitor {
  160. internal EnumerableRewriter() {
  161. }
  162. internal override Expression VisitMethodCall(MethodCallExpression m) {
  163. Expression obj = this.Visit(m.Object);
  164. ReadOnlyCollection<Expression> args = this.VisitExpressionList(m.Arguments);
  165. // check for args changed
  166. if (obj != m.Object || args != m.Arguments) {
  167. Expression[] argArray = args.ToArray();
  168. Type[] typeArgs = (m.Method.IsGenericMethod) ? m.Method.GetGenericArguments() : null;
  169. if ((m.Method.IsStatic || m.Method.DeclaringType.IsAssignableFrom(obj.Type))
  170. && ArgsMatch(m.Method, args, typeArgs)) {
  171. // current method is still valid
  172. return Expression.Call(obj, m.Method, args);
  173. }
  174. else if (m.Method.DeclaringType == typeof(Queryable)) {
  175. // convert Queryable method to Enumerable method
  176. MethodInfo seqMethod = FindEnumerableMethod(m.Method.Name, args, typeArgs);
  177. args = this.FixupQuotedArgs(seqMethod, args);
  178. return Expression.Call(obj, seqMethod, args);
  179. }
  180. else {
  181. // rebind to new method
  182. BindingFlags flags = BindingFlags.Static | (m.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic);
  183. MethodInfo method = FindMethod(m.Method.DeclaringType, m.Method.Name, args, typeArgs, flags);
  184. args = this.FixupQuotedArgs(method, args);
  185. return Expression.Call(obj, method, args);
  186. }
  187. }
  188. return m;
  189. }
  190. private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCollection<Expression> argList) {
  191. ParameterInfo[] pis = mi.GetParameters();
  192. if (pis.Length > 0) {
  193. List<Expression> newArgs = null;
  194. for (int i = 0, n = pis.Length; i < n; i++) {
  195. Expression arg = argList[i];
  196. ParameterInfo pi = pis[i];
  197. arg = FixupQuotedExpression(pi.ParameterType, arg);
  198. if (newArgs == null && arg != argList[i]) {
  199. newArgs = new List<Expression>(argList.Count);
  200. for (int j = 0; j < i; j++) {
  201. newArgs.Add(argList[j]);
  202. }
  203. }
  204. if (newArgs != null) {
  205. newArgs.Add(arg);
  206. }
  207. }
  208. if (newArgs != null)
  209. argList = newArgs.ToReadOnlyCollection();
  210. }
  211. return argList;
  212. }
  213. private Expression FixupQuotedExpression(Type type, Expression expression) {
  214. Expression expr = expression;
  215. while (true) {
  216. if (type.IsAssignableFrom(expr.Type))
  217. return expr;
  218. if (expr.NodeType != ExpressionType.Quote)
  219. break;
  220. expr = ((UnaryExpression)expr).Operand;
  221. }
  222. if (!type.IsAssignableFrom(expr.Type) && type.IsArray && expr.NodeType == ExpressionType.NewArrayInit) {
  223. Type strippedType = StripExpression(expr.Type);
  224. if (type.IsAssignableFrom(strippedType)) {
  225. Type elementType = type.GetElementType();
  226. NewArrayExpression na = (NewArrayExpression)expr;
  227. List<Expression> exprs = new List<Expression>(na.Expressions.Count);
  228. for (int i = 0, n = na.Expressions.Count; i < n; i++) {
  229. exprs.Add(this.FixupQuotedExpression(elementType, na.Expressions[i]));
  230. }
  231. expression = Expression.NewArrayInit(elementType, exprs);
  232. }
  233. }
  234. return expression;
  235. }
  236. internal override Expression VisitLambda(LambdaExpression lambda) {
  237. return lambda;
  238. }
  239. private static Type GetPublicType(Type t)
  240. {
  241. // If we create a constant explicitly typed to be a private nested type,
  242. // such as Lookup<,>.Grouping or a compiler-generated iterator class, then
  243. // we cannot use the expression tree in a context which has only execution
  244. // permissions. We should endeavour to translate constants into
  245. // new constants which have public types.
  246. if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Lookup<,>.Grouping))
  247. return typeof(IGrouping<,>).MakeGenericType(t.GetGenericArguments());
  248. if (!t.IsNestedPrivate)
  249. return t;
  250. foreach (Type iType in t.GetInterfaces())
  251. {
  252. if (iType.IsGenericType && iType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
  253. return iType;
  254. }
  255. if (typeof(IEnumerable).IsAssignableFrom(t))
  256. return typeof(IEnumerable);
  257. return t;
  258. }
  259. internal override Expression VisitConstant(ConstantExpression c) {
  260. EnumerableQuery sq = c.Value as EnumerableQuery;
  261. if (sq != null) {
  262. if (sq.Enumerable != null)
  263. {
  264. Type t = GetPublicType(sq.Enumerable.GetType());
  265. return Expression.Constant(sq.Enumerable, t);
  266. }
  267. return this.Visit(sq.Expression);
  268. }
  269. return c;
  270. }
  271. internal override Expression VisitParameter(ParameterExpression p) {
  272. return p;
  273. }
  274. private static volatile ILookup<string, MethodInfo> _seqMethods;
  275. static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs) {
  276. if (_seqMethods == null) {
  277. _seqMethods = typeof(Enumerable).GetMethods(BindingFlags.Static|BindingFlags.Public).ToLookup(m => m.Name);
  278. }
  279. MethodInfo mi = _seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  280. if (mi == null)
  281. throw Error.NoMethodOnTypeMatchingArguments(name, typeof(Enumerable));
  282. if (typeArgs != null)
  283. return mi.MakeGenericMethod(typeArgs);
  284. return mi;
  285. }
  286. internal static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags) {
  287. MethodInfo[] methods = type.GetMethods(flags).Where(m => m.Name == name).ToArray();
  288. if (methods.Length == 0)
  289. throw Error.NoMethodOnType(name, type);
  290. MethodInfo mi = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
  291. if (mi == null)
  292. throw Error.NoMethodOnTypeMatchingArguments(name, type);
  293. if (typeArgs != null)
  294. return mi.MakeGenericMethod(typeArgs);
  295. return mi;
  296. }
  297. private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[] typeArgs) {
  298. ParameterInfo[] mParams = m.GetParameters();
  299. if (mParams.Length != args.Count)
  300. return false;
  301. if (!m.IsGenericMethod && typeArgs != null && typeArgs.Length > 0) {
  302. return false;
  303. }
  304. if (!m.IsGenericMethodDefinition && m.IsGenericMethod && m.ContainsGenericParameters) {
  305. m = m.GetGenericMethodDefinition();
  306. }
  307. if (m.IsGenericMethodDefinition) {
  308. if (typeArgs == null || typeArgs.Length == 0)
  309. return false;
  310. if (m.GetGenericArguments().Length != typeArgs.Length)
  311. return false;
  312. m = m.MakeGenericMethod(typeArgs);
  313. mParams = m.GetParameters();
  314. }
  315. for (int i = 0, n = args.Count; i < n; i++) {
  316. Type parameterType = mParams[i].ParameterType;
  317. if (parameterType == null)
  318. return false;
  319. if (parameterType.IsByRef)
  320. parameterType = parameterType.GetElementType();
  321. Expression arg = args[i];
  322. if (!parameterType.IsAssignableFrom(arg.Type)) {
  323. if (arg.NodeType == ExpressionType.Quote) {
  324. arg = ((UnaryExpression)arg).Operand;
  325. }
  326. if (!parameterType.IsAssignableFrom(arg.Type) &&
  327. !parameterType.IsAssignableFrom(StripExpression(arg.Type))) {
  328. return false;
  329. }
  330. }
  331. }
  332. return true;
  333. }
  334. private static Type StripExpression(Type type) {
  335. bool isArray = type.IsArray;
  336. Type tmp = isArray ? type.GetElementType() : type;
  337. Type eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
  338. if (eType != null)
  339. tmp = eType.GetGenericArguments()[0];
  340. if (isArray) {
  341. int rank = type.GetArrayRank();
  342. return (rank == 1) ? tmp.MakeArrayType() : tmp.MakeArrayType(rank);
  343. }
  344. return type;
  345. }
  346. }
  347. }