SqlCaseSimplifier.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Data.Linq;
  4. using System.Data.Linq.Provider;
  5. using System.Diagnostics.CodeAnalysis;
  6. namespace System.Data.Linq.SqlClient {
  7. /// <summary>
  8. /// SQL with CASE statements is harder to read. This visitor attempts to reduce CASE
  9. /// statements to equivalent (but easier to read) logic.
  10. /// </summary>
  11. internal class SqlCaseSimplifier {
  12. internal static SqlNode Simplify(SqlNode node, SqlFactory sql) {
  13. return new Visitor(sql).Visit(node);
  14. }
  15. class Visitor : SqlVisitor {
  16. SqlFactory sql;
  17. internal Visitor(SqlFactory sql) {
  18. this.sql = sql;
  19. }
  20. /// <summary>
  21. /// Replace equals and not equals:
  22. ///
  23. /// | CASE XXX | CASE XXX CASE XXX
  24. /// | WHEN AAA THEN MMMM | != RRRR ===> WHEN AAA THEN (MMMM != RRRR) ==> WHEN AAA THEN true
  25. /// | WHEN BBB THEN NNNN | WHEN BBB THEN (NNNN != RRRR) WHEN BBB THEN false
  26. /// | etc. | etc. etc.
  27. /// | ELSE OOOO | ELSE (OOOO != RRRR) ELSE true
  28. /// | END END END
  29. ///
  30. /// Where MMMM, NNNN and RRRR are constants.
  31. /// </summary>
  32. internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
  33. switch (bo.NodeType) {
  34. case SqlNodeType.EQ:
  35. case SqlNodeType.NE:
  36. case SqlNodeType.EQ2V:
  37. case SqlNodeType.NE2V:
  38. if (bo.Left.NodeType == SqlNodeType.SimpleCase &&
  39. bo.Right.NodeType == SqlNodeType.Value &&
  40. AreCaseWhenValuesConstant((SqlSimpleCase)bo.Left)) {
  41. return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Left, bo.Right);
  42. }
  43. else if (bo.Right.NodeType == SqlNodeType.SimpleCase &&
  44. bo.Left.NodeType==SqlNodeType.Value &&
  45. AreCaseWhenValuesConstant((SqlSimpleCase)bo.Right)) {
  46. return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Right, bo.Left);
  47. }
  48. break;
  49. }
  50. return base.VisitBinaryOperator(bo);
  51. }
  52. /// <summary>
  53. /// Checks to see if all SqlSimpleCase when values are of Value type.
  54. /// </summary>
  55. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  56. internal bool AreCaseWhenValuesConstant(SqlSimpleCase sc) {
  57. foreach (SqlWhen when in sc.Whens) {
  58. if (when.Value.NodeType != SqlNodeType.Value) {
  59. return false;
  60. }
  61. }
  62. return true;
  63. }
  64. /// <summary>
  65. /// Helper for VisitBinaryOperator. Builds the new case with distributed valueds.
  66. /// </summary>
  67. private SqlExpression DistributeOperatorIntoCase(SqlNodeType nt, SqlSimpleCase sc, SqlExpression expr) {
  68. if (nt!=SqlNodeType.EQ && nt!=SqlNodeType.NE && nt!=SqlNodeType.EQ2V && nt!=SqlNodeType.NE2V)
  69. throw Error.ArgumentOutOfRange("nt");
  70. object val = Eval(expr);
  71. List<SqlExpression> values = new List<SqlExpression>();
  72. List<SqlExpression> matches = new List<SqlExpression>();
  73. foreach(SqlWhen when in sc.Whens) {
  74. matches.Add(when.Match);
  75. object whenVal = Eval(when.Value);
  76. bool eq = when.Value.SqlType.AreValuesEqual(whenVal, val);
  77. values.Add(sql.ValueFromObject((nt==SqlNodeType.EQ || nt==SqlNodeType.EQ2V) == eq, false, sc.SourceExpression));
  78. }
  79. return this.VisitExpression(sql.Case(typeof(bool), sc.Expression, matches, values, sc.SourceExpression));
  80. }
  81. internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
  82. c.Expression = this.VisitExpression(c.Expression);
  83. int compareWhen = 0;
  84. // Find the ELSE if it exists.
  85. for (int i = 0, n = c.Whens.Count; i < n; i++) {
  86. if (c.Whens[i].Match == null) {
  87. compareWhen = i;
  88. break;
  89. }
  90. }
  91. c.Whens[compareWhen].Match = VisitExpression(c.Whens[compareWhen].Match);
  92. c.Whens[compareWhen].Value = VisitExpression(c.Whens[compareWhen].Value);
  93. // Compare each other when value to the compare when
  94. List<SqlWhen> newWhens = new List<SqlWhen>();
  95. bool allValuesLiteral = true;
  96. for (int i = 0, n = c.Whens.Count; i < n; i++) {
  97. if (compareWhen != i) {
  98. SqlWhen when = c.Whens[i];
  99. when.Match = this.VisitExpression(when.Match);
  100. when.Value = this.VisitExpression(when.Value);
  101. if (!SqlComparer.AreEqual(c.Whens[compareWhen].Value, when.Value)) {
  102. newWhens.Add(when);
  103. }
  104. allValuesLiteral = allValuesLiteral && when.Value.NodeType == SqlNodeType.Value;
  105. }
  106. }
  107. newWhens.Add(c.Whens[compareWhen]);
  108. // Did everything reduce to a single CASE?
  109. SqlExpression rewrite = TryToConsolidateAllValueExpressions(newWhens.Count, c.Whens[compareWhen].Value);
  110. if (rewrite != null)
  111. return rewrite;
  112. // Can it be a conjuction (or disjunction) of clauses?
  113. rewrite = TryToWriteAsSimpleBooleanExpression(c.ClrType, c.Expression, newWhens, allValuesLiteral);
  114. if (rewrite != null)
  115. return rewrite;
  116. // Can any WHEN clauses be reduced to fall into the ELSE clause?
  117. rewrite = TryToWriteAsReducedCase(c.ClrType, c.Expression, newWhens, c.Whens[compareWhen].Match, c.Whens.Count);
  118. if (rewrite != null)
  119. return rewrite;
  120. return c;
  121. }
  122. /// <summary>
  123. /// When there is exactly one when clause in the CASE:
  124. ///
  125. /// CASE XXX
  126. /// WHEN AAA THEN YYY ===> YYY
  127. /// END
  128. ///
  129. /// Then, just reduce it to the value.
  130. /// </summary>
  131. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  132. private SqlExpression TryToConsolidateAllValueExpressions(int valueCount, SqlExpression value) {
  133. if (valueCount == 1) {
  134. return value;
  135. }
  136. return null;
  137. }
  138. /// <summary>
  139. /// For CASE statements which represent boolean values:
  140. ///
  141. /// CASE XXX
  142. /// WHEN AAA THEN true ===> (XXX==AAA) || (XXX==BBB)
  143. /// WHEN BBB THEN true
  144. /// ELSE false
  145. /// etc.
  146. /// END
  147. ///
  148. /// Also,
  149. ///
  150. /// CASE XXX
  151. /// WHEN AAA THEN false ===> (XXX!=AAA) && (XXX!=BBB)
  152. /// WHEN BBB THEN false
  153. /// ELSE true
  154. /// etc.
  155. /// END
  156. ///
  157. /// The reduce to a conjunction or disjunction of equality or inequality.
  158. /// The possibility of NULL in XXX is taken into account.
  159. /// </summary>
  160. private SqlExpression TryToWriteAsSimpleBooleanExpression(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, bool allValuesLiteral) {
  161. SqlExpression rewrite = null;
  162. if (caseType == typeof(bool) && allValuesLiteral) {
  163. bool? holdsNull = SqlExpressionNullability.CanBeNull(discriminator);
  164. // The discriminator can't hold a NULL.
  165. // In this case, we don't need the special fallback that CASE-ELSE gives.
  166. // We can just construct a boolean operation.
  167. bool? whenValue = null;
  168. for (int i = 0; i < newWhens.Count; ++i) {
  169. SqlValue lit = (SqlValue)newWhens[i].Value; // Must be SqlValue because of allValuesLiteral.
  170. bool value = (bool)lit.Value; // Must be bool because of caseType==typeof(bool).
  171. if (newWhens[i].Match != null) { // Skip the ELSE
  172. if (value) {
  173. rewrite = sql.OrAccumulate(rewrite, sql.Binary(SqlNodeType.EQ, discriminator, newWhens[i].Match));
  174. }
  175. else {
  176. rewrite = sql.AndAccumulate(rewrite, sql.Binary(SqlNodeType.NE, discriminator, newWhens[i].Match));
  177. }
  178. }
  179. else {
  180. whenValue = value;
  181. }
  182. }
  183. // If it could possibly hold null values.
  184. if (holdsNull != false && whenValue != null) {
  185. if (whenValue == true) {
  186. rewrite = sql.OrAccumulate(rewrite, sql.Unary(SqlNodeType.IsNull, discriminator, discriminator.SourceExpression));
  187. }
  188. else {
  189. rewrite = sql.AndAccumulate(rewrite, sql.Unary(SqlNodeType.IsNotNull, discriminator, discriminator.SourceExpression));
  190. }
  191. }
  192. }
  193. return rewrite;
  194. }
  195. /// <summary>
  196. /// Remove any WHEN clauses which have the same value as ELSE.
  197. ///
  198. /// CASE XXX CASE XXX
  199. /// WHEN AAA THEN YYY ===> WHEN AAA THEN YYY
  200. /// WHEN BBB THEN ZZZ WHEN CCC THEN YYY
  201. /// WHEN CCC THEN YYY ELSE ZZZ
  202. /// ELSE ZZZ END
  203. /// END
  204. ///
  205. /// </summary>
  206. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  207. private SqlExpression TryToWriteAsReducedCase(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, SqlExpression elseCandidate, int originalWhenCount) {
  208. if (newWhens.Count != originalWhenCount) {
  209. // Some whens were the same as the comparand.
  210. if (elseCandidate == null) {
  211. // -and- the comparand is ELSE (value == null).
  212. // In this case, simplify the CASE to eliminate everything equivalent to ELSE.
  213. return new SqlSimpleCase(caseType, discriminator, newWhens, discriminator.SourceExpression);
  214. }
  215. }
  216. return null;
  217. }
  218. }
  219. }
  220. }