MethodCallConverter.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using System.Linq.Expressions;
  5. using System.Reflection;
  6. using System.Data.Linq;
  7. using System.Data.Linq.Provider;
  8. using System.Data.Linq.Mapping;
  9. namespace System.Data.Linq.SqlClient {
  10. // convert special method calls and member accesses into known sql nodes
  11. internal static class PreBindDotNetConverter {
  12. internal static SqlNode Convert(SqlNode node, SqlFactory sql, MetaModel model) {
  13. return new Visitor(sql, model).Visit(node);
  14. }
  15. internal static bool CanConvert(SqlNode node) {
  16. SqlBinary bo = node as SqlBinary;
  17. if (bo != null && (IsCompareToValue(bo) || IsVbCompareStringEqualsValue(bo))) {
  18. return true;
  19. }
  20. SqlMember sm = node as SqlMember;
  21. if (sm != null && IsSupportedMember(sm)) {
  22. return true;
  23. }
  24. SqlMethodCall mc = node as SqlMethodCall;
  25. if (mc != null && (IsSupportedMethod(mc) || IsSupportedVbHelperMethod(mc))) {
  26. return true;
  27. }
  28. return false;
  29. }
  30. private static bool IsCompareToValue(SqlBinary bo) {
  31. if (IsComparison(bo.NodeType)
  32. && bo.Left.NodeType == SqlNodeType.MethodCall
  33. && bo.Right.NodeType == SqlNodeType.Value) {
  34. SqlMethodCall call = (SqlMethodCall)bo.Left;
  35. return IsCompareToMethod(call) || IsCompareMethod(call);
  36. }
  37. return false;
  38. }
  39. private static bool IsCompareToMethod(SqlMethodCall call) {
  40. return !call.Method.IsStatic && call.Method.Name == "CompareTo" && call.Arguments.Count == 1 && call.Method.ReturnType == typeof(int);
  41. }
  42. private static bool IsCompareMethod(SqlMethodCall call) {
  43. return call.Method.IsStatic && call.Method.Name == "Compare" && call.Arguments.Count > 1 && call.Method.ReturnType == typeof(int);
  44. }
  45. private static bool IsComparison(SqlNodeType nodeType) {
  46. switch (nodeType) {
  47. case SqlNodeType.EQ:
  48. case SqlNodeType.NE:
  49. case SqlNodeType.LT:
  50. case SqlNodeType.LE:
  51. case SqlNodeType.GT:
  52. case SqlNodeType.GE:
  53. case SqlNodeType.EQ2V:
  54. case SqlNodeType.NE2V:
  55. return true;
  56. default:
  57. return false;
  58. }
  59. }
  60. private static bool IsVbCompareStringEqualsValue(SqlBinary bo) {
  61. return IsComparison(bo.NodeType)
  62. && bo.Left.NodeType == SqlNodeType.MethodCall
  63. && bo.Right.NodeType == SqlNodeType.Value
  64. && IsVbCompareString((SqlMethodCall)bo.Left);
  65. }
  66. private static bool IsVbCompareString(SqlMethodCall call) {
  67. return call.Method.IsStatic &&
  68. call.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators" &&
  69. call.Method.Name == "CompareString";
  70. }
  71. private static bool IsSupportedVbHelperMethod(SqlMethodCall mc) {
  72. return IsVbIIF(mc);
  73. }
  74. private static bool IsVbIIF(SqlMethodCall mc) {
  75. return mc.Method.IsStatic &&
  76. mc.Method.DeclaringType.FullName == "Microsoft.VisualBasic.Interaction" && mc.Method.Name == "IIf";
  77. }
  78. private static bool IsSupportedMember(SqlMember m) {
  79. return IsNullableHasValue(m) || IsNullableHasValue(m);
  80. }
  81. private static bool IsNullableValue(SqlMember m) {
  82. return TypeSystem.IsNullableType(m.Expression.ClrType) && m.Member.Name == "Value";
  83. }
  84. private static bool IsNullableHasValue(SqlMember m) {
  85. return TypeSystem.IsNullableType(m.Expression.ClrType) && m.Member.Name == "HasValue";
  86. }
  87. private static bool IsSupportedMethod(SqlMethodCall mc) {
  88. if (mc.Method.IsStatic) {
  89. switch (mc.Method.Name) {
  90. case "op_Equality":
  91. case "op_Inequality":
  92. case "op_LessThan":
  93. case "op_LessThanOrEqual":
  94. case "op_GreaterThan":
  95. case "op_GreaterThanOrEqual":
  96. case "op_Multiply":
  97. case "op_Division":
  98. case "op_Subtraction":
  99. case "op_Addition":
  100. case "op_Modulus":
  101. case "op_BitwiseAnd":
  102. case "op_BitwiseOr":
  103. case "op_ExclusiveOr":
  104. case "op_UnaryNegation":
  105. case "op_OnesComplement":
  106. case "op_False":
  107. return true;
  108. case "Equals":
  109. return mc.Arguments.Count == 2;
  110. case "Concat":
  111. return mc.Method.DeclaringType == typeof(string);
  112. }
  113. }
  114. else {
  115. return mc.Method.Name == "Equals" && mc.Arguments.Count == 1 ||
  116. mc.Method.Name == "GetType" && mc.Arguments.Count == 0;
  117. }
  118. return false;
  119. }
  120. private class Visitor : SqlVisitor {
  121. SqlFactory sql;
  122. MetaModel model;
  123. internal Visitor(SqlFactory sql, MetaModel model) {
  124. this.sql = sql;
  125. this.model = model;
  126. }
  127. internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
  128. if (IsCompareToValue(bo)) {
  129. SqlMethodCall call = (SqlMethodCall)bo.Left;
  130. if (IsCompareToMethod(call)) {
  131. int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
  132. bo = this.MakeCompareTo(call.Object, call.Arguments[0], bo.NodeType, iValue) ?? bo;
  133. }
  134. else if (IsCompareMethod(call)) {
  135. int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
  136. bo = this.MakeCompareTo(call.Arguments[0], call.Arguments[1], bo.NodeType, iValue) ?? bo;
  137. }
  138. }
  139. else if (IsVbCompareStringEqualsValue(bo)) {
  140. SqlMethodCall call = (SqlMethodCall)bo.Left;
  141. int iValue = System.Convert.ToInt32(this.Eval(bo.Right), Globalization.CultureInfo.InvariantCulture);
  142. //in VB, comparing a string with Nothing means comparing with ""
  143. SqlValue strValue = call.Arguments[1] as SqlValue;
  144. if (strValue != null && strValue.Value == null) {
  145. SqlValue emptyStr = new SqlValue(strValue.ClrType, strValue.SqlType, String.Empty, strValue.IsClientSpecified, strValue.SourceExpression);
  146. bo = this.MakeCompareTo(call.Arguments[0], emptyStr, bo.NodeType, iValue) ?? bo;
  147. }
  148. else {
  149. bo = this.MakeCompareTo(call.Arguments[0], call.Arguments[1], bo.NodeType, iValue) ?? bo;
  150. }
  151. }
  152. return base.VisitBinaryOperator(bo);
  153. }
  154. private SqlBinary MakeCompareTo(SqlExpression left, SqlExpression right, SqlNodeType op, int iValue) {
  155. if (iValue == 0) {
  156. return sql.Binary(op, left, right);
  157. }
  158. else if (op == SqlNodeType.EQ || op == SqlNodeType.EQ2V) {
  159. switch (iValue) {
  160. case -1:
  161. return sql.Binary(SqlNodeType.LT, left, right);
  162. case 1:
  163. return sql.Binary(SqlNodeType.GT, left, right);
  164. }
  165. }
  166. return null;
  167. }
  168. private SqlExpression CreateComparison(SqlExpression a, SqlExpression b, Expression source) {
  169. SqlExpression lower = sql.Binary(SqlNodeType.LT, a, b);
  170. SqlExpression equal = sql.Binary(SqlNodeType.EQ2V, a, b);
  171. return sql.SearchedCase(
  172. new SqlWhen[] {
  173. new SqlWhen(lower, sql.ValueFromObject(-1, false, source)),
  174. new SqlWhen(equal, sql.ValueFromObject(0, false, source)),
  175. },
  176. sql.ValueFromObject(1, false, source), source
  177. );
  178. }
  179. internal override SqlNode VisitMember(SqlMember m) {
  180. m.Expression = this.VisitExpression(m.Expression);
  181. if (IsNullableValue(m)) {
  182. return sql.UnaryValueOf(m.Expression, m.SourceExpression);
  183. }
  184. else if (IsNullableHasValue(m)) {
  185. return sql.Unary(SqlNodeType.IsNotNull, m.Expression, m.SourceExpression);
  186. }
  187. return m;
  188. }
  189. internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
  190. mc.Object = this.VisitExpression(mc.Object);
  191. for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
  192. mc.Arguments[i] = this.VisitExpression(mc.Arguments[i]);
  193. }
  194. if (mc.Method.IsStatic) {
  195. if (mc.Method.Name == "Equals" && mc.Arguments.Count == 2) {
  196. return sql.Binary(SqlNodeType.EQ2V, mc.Arguments[0], mc.Arguments[1], mc.Method);
  197. }
  198. else if (mc.Method.DeclaringType == typeof(string) && mc.Method.Name == "Concat") {
  199. SqlClientArray arr = mc.Arguments[0] as SqlClientArray;
  200. List<SqlExpression> exprs = null;
  201. if (arr != null) {
  202. exprs = arr.Expressions;
  203. }
  204. else {
  205. exprs = mc.Arguments;
  206. }
  207. if (exprs.Count == 0) {
  208. return sql.ValueFromObject("", false, mc.SourceExpression);
  209. }
  210. else {
  211. SqlExpression sum;
  212. if (exprs[0].SqlType.IsString || exprs[0].SqlType.IsChar) {
  213. sum = exprs[0];
  214. }
  215. else {
  216. sum = sql.ConvertTo(typeof(string), exprs[0]);
  217. }
  218. for (int i = 1; i < exprs.Count; i++) {
  219. if (exprs[i].SqlType.IsString || exprs[i].SqlType.IsChar) {
  220. sum = sql.Concat(sum, exprs[i]);
  221. }
  222. else {
  223. sum = sql.Concat(sum, sql.ConvertTo(typeof(string), exprs[i]));
  224. }
  225. }
  226. return sum;
  227. }
  228. }
  229. else if (IsVbIIF(mc)) {
  230. return TranslateVbIIF(mc);
  231. }
  232. else {
  233. switch (mc.Method.Name) {
  234. case "op_Equality":
  235. return sql.Binary(SqlNodeType.EQ, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  236. case "op_Inequality":
  237. return sql.Binary(SqlNodeType.NE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  238. case "op_LessThan":
  239. return sql.Binary(SqlNodeType.LT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  240. case "op_LessThanOrEqual":
  241. return sql.Binary(SqlNodeType.LE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  242. case "op_GreaterThan":
  243. return sql.Binary(SqlNodeType.GT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  244. case "op_GreaterThanOrEqual":
  245. return sql.Binary(SqlNodeType.GE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  246. case "op_Multiply":
  247. return sql.Binary(SqlNodeType.Mul, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  248. case "op_Division":
  249. return sql.Binary(SqlNodeType.Div, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  250. case "op_Subtraction":
  251. return sql.Binary(SqlNodeType.Sub, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  252. case "op_Addition":
  253. return sql.Binary(SqlNodeType.Add, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  254. case "op_Modulus":
  255. return sql.Binary(SqlNodeType.Mod, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  256. case "op_BitwiseAnd":
  257. return sql.Binary(SqlNodeType.BitAnd, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  258. case "op_BitwiseOr":
  259. return sql.Binary(SqlNodeType.BitOr, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  260. case "op_ExclusiveOr":
  261. return sql.Binary(SqlNodeType.BitXor, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType);
  262. case "op_UnaryNegation":
  263. return sql.Unary(SqlNodeType.Negate, mc.Arguments[0], mc.Method, mc.SourceExpression);
  264. case "op_OnesComplement":
  265. return sql.Unary(SqlNodeType.BitNot, mc.Arguments[0], mc.Method, mc.SourceExpression);
  266. case "op_False":
  267. return sql.Unary(SqlNodeType.Not, mc.Arguments[0], mc.Method, mc.SourceExpression);
  268. }
  269. }
  270. }
  271. else {
  272. if (mc.Method.Name == "Equals" && mc.Arguments.Count == 1) {
  273. return sql.Binary(SqlNodeType.EQ, mc.Object, mc.Arguments[0]);
  274. }
  275. else if (mc.Method.Name == "GetType" && mc.Arguments.Count == 0) {
  276. MetaType mt = TypeSource.GetSourceMetaType(mc.Object, this.model);
  277. if (mt.HasInheritance) {
  278. Type discriminatorType = mt.Discriminator.Type;
  279. SqlDiscriminatorOf discriminatorOf = new SqlDiscriminatorOf(mc.Object, discriminatorType, this.sql.TypeProvider.From(discriminatorType), mc.SourceExpression);
  280. return this.VisitExpression(sql.DiscriminatedType(discriminatorOf, mt));
  281. }
  282. return this.VisitExpression(sql.StaticType(mt, mc.SourceExpression));
  283. }
  284. }
  285. return mc;
  286. }
  287. private SqlExpression TranslateVbIIF(SqlMethodCall mc) {
  288. //Check to see if the types can be implicitly converted from one to another.
  289. if (mc.Arguments[1].ClrType == mc.Arguments[2].ClrType) {
  290. List<SqlWhen> whens = new List<SqlWhen>(1);
  291. whens.Add(new SqlWhen(mc.Arguments[0], mc.Arguments[1]));
  292. SqlExpression @else = mc.Arguments[2];
  293. while (@else.NodeType == SqlNodeType.SearchedCase) {
  294. SqlSearchedCase sc = (SqlSearchedCase)@else;
  295. whens.AddRange(sc.Whens);
  296. @else = sc.Else;
  297. }
  298. return sql.SearchedCase(whens.ToArray(), @else, mc.SourceExpression);
  299. }
  300. else {
  301. throw Error.IifReturnTypesMustBeEqual(mc.Arguments[1].ClrType.Name, mc.Arguments[2].ClrType.Name);
  302. }
  303. }
  304. internal override SqlExpression VisitTreat(SqlUnary t) {
  305. t.Operand = this.VisitExpression(t.Operand);
  306. Type treatType = t.ClrType;
  307. Type originalType = model.GetMetaType(t.Operand.ClrType).InheritanceRoot.Type;
  308. // .NET nullability rules are that typeof(int)==typeof(int?). Let's be consistent with that:
  309. treatType = TypeSystem.GetNonNullableType(treatType);
  310. originalType = TypeSystem.GetNonNullableType(originalType);
  311. if (treatType == originalType) {
  312. return t.Operand;
  313. }
  314. else if (treatType.IsAssignableFrom(originalType)) {
  315. t.Operand.SetClrType(treatType);
  316. return t.Operand;
  317. }
  318. else if (!treatType.IsAssignableFrom(originalType) && !originalType.IsAssignableFrom(treatType)) {
  319. if (!treatType.IsInterface && !originalType.IsInterface) { // You can't tell when there's an interface involved.
  320. // We statically know the TREAT will result in NULL.
  321. return sql.TypedLiteralNull(treatType, t.SourceExpression);
  322. }
  323. }
  324. //return base.VisitTreat(t);
  325. return t;
  326. }
  327. }
  328. }
  329. }