| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- using System;
- using System.Collections.Generic;
- using System.Data.Linq;
- using System.Data.Linq.Provider;
- using System.Diagnostics.CodeAnalysis;
- namespace System.Data.Linq.SqlClient {
- /// <summary>
- /// SQL with CASE statements is harder to read. This visitor attempts to reduce CASE
- /// statements to equivalent (but easier to read) logic.
- /// </summary>
- internal class SqlCaseSimplifier {
- internal static SqlNode Simplify(SqlNode node, SqlFactory sql) {
- return new Visitor(sql).Visit(node);
- }
- class Visitor : SqlVisitor {
- SqlFactory sql;
- internal Visitor(SqlFactory sql) {
- this.sql = sql;
- }
- /// <summary>
- /// Replace equals and not equals:
- ///
- /// | CASE XXX | CASE XXX CASE XXX
- /// | WHEN AAA THEN MMMM | != RRRR ===> WHEN AAA THEN (MMMM != RRRR) ==> WHEN AAA THEN true
- /// | WHEN BBB THEN NNNN | WHEN BBB THEN (NNNN != RRRR) WHEN BBB THEN false
- /// | etc. | etc. etc.
- /// | ELSE OOOO | ELSE (OOOO != RRRR) ELSE true
- /// | END END END
- ///
- /// Where MMMM, NNNN and RRRR are constants.
- /// </summary>
- internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
- switch (bo.NodeType) {
- case SqlNodeType.EQ:
- case SqlNodeType.NE:
- case SqlNodeType.EQ2V:
- case SqlNodeType.NE2V:
- if (bo.Left.NodeType == SqlNodeType.SimpleCase &&
- bo.Right.NodeType == SqlNodeType.Value &&
- AreCaseWhenValuesConstant((SqlSimpleCase)bo.Left)) {
- return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Left, bo.Right);
- }
- else if (bo.Right.NodeType == SqlNodeType.SimpleCase &&
- bo.Left.NodeType==SqlNodeType.Value &&
- AreCaseWhenValuesConstant((SqlSimpleCase)bo.Right)) {
- return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Right, bo.Left);
- }
- break;
- }
- return base.VisitBinaryOperator(bo);
- }
- /// <summary>
- /// Checks to see if all SqlSimpleCase when values are of Value type.
- /// </summary>
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- internal bool AreCaseWhenValuesConstant(SqlSimpleCase sc) {
- foreach (SqlWhen when in sc.Whens) {
- if (when.Value.NodeType != SqlNodeType.Value) {
- return false;
- }
- }
- return true;
- }
- /// <summary>
- /// Helper for VisitBinaryOperator. Builds the new case with distributed valueds.
- /// </summary>
- private SqlExpression DistributeOperatorIntoCase(SqlNodeType nt, SqlSimpleCase sc, SqlExpression expr) {
- if (nt!=SqlNodeType.EQ && nt!=SqlNodeType.NE && nt!=SqlNodeType.EQ2V && nt!=SqlNodeType.NE2V)
- throw Error.ArgumentOutOfRange("nt");
- object val = Eval(expr);
- List<SqlExpression> values = new List<SqlExpression>();
- List<SqlExpression> matches = new List<SqlExpression>();
- foreach(SqlWhen when in sc.Whens) {
- matches.Add(when.Match);
- object whenVal = Eval(when.Value);
- bool eq = when.Value.SqlType.AreValuesEqual(whenVal, val);
- values.Add(sql.ValueFromObject((nt==SqlNodeType.EQ || nt==SqlNodeType.EQ2V) == eq, false, sc.SourceExpression));
- }
- return this.VisitExpression(sql.Case(typeof(bool), sc.Expression, matches, values, sc.SourceExpression));
- }
- internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
- c.Expression = this.VisitExpression(c.Expression);
- int compareWhen = 0;
- // Find the ELSE if it exists.
- for (int i = 0, n = c.Whens.Count; i < n; i++) {
- if (c.Whens[i].Match == null) {
- compareWhen = i;
- break;
- }
- }
- c.Whens[compareWhen].Match = VisitExpression(c.Whens[compareWhen].Match);
- c.Whens[compareWhen].Value = VisitExpression(c.Whens[compareWhen].Value);
- // Compare each other when value to the compare when
- List<SqlWhen> newWhens = new List<SqlWhen>();
- bool allValuesLiteral = true;
- for (int i = 0, n = c.Whens.Count; i < n; i++) {
- if (compareWhen != i) {
- SqlWhen when = c.Whens[i];
- when.Match = this.VisitExpression(when.Match);
- when.Value = this.VisitExpression(when.Value);
- if (!SqlComparer.AreEqual(c.Whens[compareWhen].Value, when.Value)) {
- newWhens.Add(when);
- }
- allValuesLiteral = allValuesLiteral && when.Value.NodeType == SqlNodeType.Value;
- }
- }
- newWhens.Add(c.Whens[compareWhen]);
- // Did everything reduce to a single CASE?
- SqlExpression rewrite = TryToConsolidateAllValueExpressions(newWhens.Count, c.Whens[compareWhen].Value);
- if (rewrite != null)
- return rewrite;
- // Can it be a conjuction (or disjunction) of clauses?
- rewrite = TryToWriteAsSimpleBooleanExpression(c.ClrType, c.Expression, newWhens, allValuesLiteral);
- if (rewrite != null)
- return rewrite;
- // Can any WHEN clauses be reduced to fall into the ELSE clause?
- rewrite = TryToWriteAsReducedCase(c.ClrType, c.Expression, newWhens, c.Whens[compareWhen].Match, c.Whens.Count);
- if (rewrite != null)
- return rewrite;
- return c;
- }
- /// <summary>
- /// When there is exactly one when clause in the CASE:
- ///
- /// CASE XXX
- /// WHEN AAA THEN YYY ===> YYY
- /// END
- ///
- /// Then, just reduce it to the value.
- /// </summary>
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- private SqlExpression TryToConsolidateAllValueExpressions(int valueCount, SqlExpression value) {
- if (valueCount == 1) {
- return value;
- }
- return null;
- }
- /// <summary>
- /// For CASE statements which represent boolean values:
- ///
- /// CASE XXX
- /// WHEN AAA THEN true ===> (XXX==AAA) || (XXX==BBB)
- /// WHEN BBB THEN true
- /// ELSE false
- /// etc.
- /// END
- ///
- /// Also,
- ///
- /// CASE XXX
- /// WHEN AAA THEN false ===> (XXX!=AAA) && (XXX!=BBB)
- /// WHEN BBB THEN false
- /// ELSE true
- /// etc.
- /// END
- ///
- /// The reduce to a conjunction or disjunction of equality or inequality.
- /// The possibility of NULL in XXX is taken into account.
- /// </summary>
- private SqlExpression TryToWriteAsSimpleBooleanExpression(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, bool allValuesLiteral) {
- SqlExpression rewrite = null;
- if (caseType == typeof(bool) && allValuesLiteral) {
- bool? holdsNull = SqlExpressionNullability.CanBeNull(discriminator);
- // The discriminator can't hold a NULL.
- // In this case, we don't need the special fallback that CASE-ELSE gives.
- // We can just construct a boolean operation.
- bool? whenValue = null;
- for (int i = 0; i < newWhens.Count; ++i) {
- SqlValue lit = (SqlValue)newWhens[i].Value; // Must be SqlValue because of allValuesLiteral.
- bool value = (bool)lit.Value; // Must be bool because of caseType==typeof(bool).
- if (newWhens[i].Match != null) { // Skip the ELSE
- if (value) {
- rewrite = sql.OrAccumulate(rewrite, sql.Binary(SqlNodeType.EQ, discriminator, newWhens[i].Match));
- }
- else {
- rewrite = sql.AndAccumulate(rewrite, sql.Binary(SqlNodeType.NE, discriminator, newWhens[i].Match));
- }
- }
- else {
- whenValue = value;
- }
- }
- // If it could possibly hold null values.
- if (holdsNull != false && whenValue != null) {
- if (whenValue == true) {
- rewrite = sql.OrAccumulate(rewrite, sql.Unary(SqlNodeType.IsNull, discriminator, discriminator.SourceExpression));
- }
- else {
- rewrite = sql.AndAccumulate(rewrite, sql.Unary(SqlNodeType.IsNotNull, discriminator, discriminator.SourceExpression));
- }
- }
- }
- return rewrite;
- }
- /// <summary>
- /// Remove any WHEN clauses which have the same value as ELSE.
- ///
- /// CASE XXX CASE XXX
- /// WHEN AAA THEN YYY ===> WHEN AAA THEN YYY
- /// WHEN BBB THEN ZZZ WHEN CCC THEN YYY
- /// WHEN CCC THEN YYY ELSE ZZZ
- /// ELSE ZZZ END
- /// END
- ///
- /// </summary>
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- private SqlExpression TryToWriteAsReducedCase(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, SqlExpression elseCandidate, int originalWhenCount) {
- if (newWhens.Count != originalWhenCount) {
- // Some whens were the same as the comparand.
- if (elseCandidate == null) {
- // -and- the comparand is ELSE (value == null).
- // In this case, simplify the CASE to eliminate everything equivalent to ELSE.
- return new SqlSimpleCase(caseType, discriminator, newWhens, discriminator.SourceExpression);
- }
- }
- return null;
- }
- }
- }
- }
|