SqlDeflator.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics.CodeAnalysis;
  4. namespace System.Data.Linq.SqlClient {
  5. internal class SqlDeflator {
  6. SqlValueDeflator vDeflator;
  7. SqlColumnDeflator cDeflator;
  8. SqlAliasDeflator aDeflator;
  9. SqlTopSelectDeflator tsDeflator;
  10. SqlDuplicateColumnDeflator dupColumnDeflator;
  11. internal SqlDeflator() {
  12. this.vDeflator = new SqlValueDeflator();
  13. this.cDeflator = new SqlColumnDeflator();
  14. this.aDeflator = new SqlAliasDeflator();
  15. this.tsDeflator = new SqlTopSelectDeflator();
  16. this.dupColumnDeflator = new SqlDuplicateColumnDeflator();
  17. }
  18. internal SqlNode Deflate(SqlNode node) {
  19. node = this.vDeflator.Visit(node);
  20. node = this.cDeflator.Visit(node);
  21. node = this.aDeflator.Visit(node);
  22. node = this.tsDeflator.Visit(node);
  23. node = this.dupColumnDeflator.Visit(node);
  24. return node;
  25. }
  26. // remove references to literal values
  27. class SqlValueDeflator : SqlVisitor {
  28. SelectionDeflator sDeflator;
  29. bool isTopLevel = true;
  30. internal SqlValueDeflator() {
  31. this.sDeflator = new SelectionDeflator();
  32. }
  33. internal override SqlSelect VisitSelect(SqlSelect select) {
  34. if (this.isTopLevel) {
  35. select.Selection = sDeflator.VisitExpression(select.Selection);
  36. }
  37. return select;
  38. }
  39. internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
  40. bool saveIsTopLevel = this.isTopLevel;
  41. try {
  42. return base.VisitSubSelect(ss);
  43. }
  44. finally {
  45. this.isTopLevel = saveIsTopLevel;
  46. }
  47. }
  48. class SelectionDeflator : SqlVisitor {
  49. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  50. SqlExpression literal = this.GetLiteralValue(cref);
  51. if (literal != null) {
  52. return literal;
  53. }
  54. return cref;
  55. }
  56. private SqlValue GetLiteralValue(SqlExpression expr) {
  57. while (expr != null && expr.NodeType == SqlNodeType.ColumnRef) {
  58. expr = ((SqlColumnRef)expr).Column.Expression;
  59. }
  60. return expr as SqlValue;
  61. }
  62. }
  63. }
  64. // remove unreferenced items in projection list
  65. class SqlColumnDeflator : SqlVisitor {
  66. Dictionary<SqlNode, SqlNode> referenceMap;
  67. bool isTopLevel;
  68. bool forceReferenceAll;
  69. SqlAggregateChecker aggregateChecker;
  70. internal SqlColumnDeflator() {
  71. this.referenceMap = new Dictionary<SqlNode, SqlNode>();
  72. this.aggregateChecker = new SqlAggregateChecker();
  73. this.isTopLevel = true;
  74. }
  75. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  76. this.referenceMap[cref.Column] = cref.Column;
  77. return cref;
  78. }
  79. internal override SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
  80. bool saveIsTopLevel = this.isTopLevel;
  81. this.isTopLevel = false;
  82. bool saveForceReferenceAll = this.forceReferenceAll;
  83. this.forceReferenceAll = true;
  84. try {
  85. return base.VisitScalarSubSelect(ss);
  86. }
  87. finally {
  88. this.isTopLevel = saveIsTopLevel;
  89. this.forceReferenceAll = saveForceReferenceAll;
  90. }
  91. }
  92. internal override SqlExpression VisitExists(SqlSubSelect ss) {
  93. bool saveIsTopLevel = this.isTopLevel;
  94. this.isTopLevel = false;
  95. try {
  96. return base.VisitExists(ss);
  97. }
  98. finally {
  99. this.isTopLevel = saveIsTopLevel;
  100. }
  101. }
  102. internal override SqlNode VisitUnion(SqlUnion su) {
  103. bool saveForceReferenceAll = this.forceReferenceAll;
  104. this.forceReferenceAll = true;
  105. su.Left = this.Visit(su.Left);
  106. su.Right = this.Visit(su.Right);
  107. this.forceReferenceAll = saveForceReferenceAll;
  108. return su;
  109. }
  110. internal override SqlSelect VisitSelect(SqlSelect select) {
  111. bool saveForceReferenceAll = this.forceReferenceAll;
  112. this.forceReferenceAll = false;
  113. bool saveIsTopLevel = this.isTopLevel;
  114. try {
  115. if (this.isTopLevel) {
  116. // top-level projection references columns!
  117. select.Selection = this.VisitExpression(select.Selection);
  118. }
  119. this.isTopLevel = false;
  120. for (int i = select.Row.Columns.Count - 1; i >= 0; i--) {
  121. SqlColumn c = select.Row.Columns[i];
  122. bool safeToRemove =
  123. !saveForceReferenceAll
  124. && !this.referenceMap.ContainsKey(c)
  125. // don't remove anything from a distinct select (except maybe a literal value) since it would change the meaning of the comparison
  126. && !select.IsDistinct
  127. // don't remove an aggregate expression that may be the only expression that forces the grouping (since it would change the cardinality of the results)
  128. && !(select.GroupBy.Count == 0 && this.aggregateChecker.HasAggregates(c.Expression));
  129. if (safeToRemove) {
  130. select.Row.Columns.RemoveAt(i);
  131. }
  132. else {
  133. this.VisitExpression(c.Expression);
  134. }
  135. }
  136. select.Top = this.VisitExpression(select.Top);
  137. for (int i = select.OrderBy.Count - 1; i >= 0; i--) {
  138. select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
  139. }
  140. select.Having = this.VisitExpression(select.Having);
  141. for (int i = select.GroupBy.Count - 1; i >= 0; i--) {
  142. select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
  143. }
  144. select.Where = this.VisitExpression(select.Where);
  145. select.From = this.VisitSource(select.From);
  146. }
  147. finally {
  148. this.isTopLevel = saveIsTopLevel;
  149. this.forceReferenceAll = saveForceReferenceAll;
  150. }
  151. return select;
  152. }
  153. internal override SqlSource VisitJoin(SqlJoin join) {
  154. join.Condition = this.VisitExpression(join.Condition);
  155. join.Right = this.VisitSource(join.Right);
  156. join.Left = this.VisitSource(join.Left);
  157. return join;
  158. }
  159. internal override SqlNode VisitLink(SqlLink link) {
  160. // don't visit expansion...
  161. for (int i = 0, n = link.KeyExpressions.Count; i < n; i++) {
  162. link.KeyExpressions[i] = this.VisitExpression(link.KeyExpressions[i]);
  163. }
  164. return link;
  165. }
  166. }
  167. class SqlColumnEqualizer : SqlVisitor {
  168. Dictionary<SqlColumn, SqlColumn> map;
  169. internal SqlColumnEqualizer() {
  170. }
  171. internal void BuildEqivalenceMap(SqlSource scope) {
  172. this.map = new Dictionary<SqlColumn, SqlColumn>();
  173. this.Visit(scope);
  174. }
  175. internal bool AreEquivalent(SqlExpression e1, SqlExpression e2) {
  176. if (SqlComparer.AreEqual(e1, e2))
  177. return true;
  178. SqlColumnRef cr1 = e1 as SqlColumnRef;
  179. SqlColumnRef cr2 = e2 as SqlColumnRef;
  180. if (cr1 != null && cr2 != null) {
  181. SqlColumn c1 = cr1.GetRootColumn();
  182. SqlColumn c2 = cr2.GetRootColumn();
  183. SqlColumn r;
  184. return this.map.TryGetValue(c1, out r) && r == c2;
  185. }
  186. return false;
  187. }
  188. internal override SqlSource VisitJoin(SqlJoin join) {
  189. base.VisitJoin(join);
  190. if (join.Condition != null) {
  191. this.CheckJoinCondition(join.Condition);
  192. }
  193. return join;
  194. }
  195. internal override SqlSelect VisitSelect(SqlSelect select) {
  196. base.VisitSelect(select);
  197. if (select.Where != null) {
  198. this.CheckJoinCondition(select.Where);
  199. }
  200. return select;
  201. }
  202. [SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification="[....]: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
  203. private void CheckJoinCondition(SqlExpression expr) {
  204. switch (expr.NodeType) {
  205. case SqlNodeType.And: {
  206. SqlBinary b = (SqlBinary)expr;
  207. CheckJoinCondition(b.Left);
  208. CheckJoinCondition(b.Right);
  209. break;
  210. }
  211. case SqlNodeType.EQ:
  212. case SqlNodeType.EQ2V: {
  213. SqlBinary b = (SqlBinary)expr;
  214. SqlColumnRef crLeft = b.Left as SqlColumnRef;
  215. SqlColumnRef crRight = b.Right as SqlColumnRef;
  216. if (crLeft != null && crRight != null) {
  217. SqlColumn cLeft = crLeft.GetRootColumn();
  218. SqlColumn cRight = crRight.GetRootColumn();
  219. this.map[cLeft] = cRight;
  220. this.map[cRight] = cLeft;
  221. }
  222. break;
  223. }
  224. }
  225. }
  226. internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
  227. return ss;
  228. }
  229. }
  230. // remove redundant/trivial aliases
  231. class SqlAliasDeflator : SqlVisitor {
  232. Dictionary<SqlAlias, SqlAlias> removedMap;
  233. internal SqlAliasDeflator() {
  234. this.removedMap = new Dictionary<SqlAlias, SqlAlias>();
  235. }
  236. internal override SqlExpression VisitAliasRef(SqlAliasRef aref) {
  237. SqlAlias alias = aref.Alias;
  238. SqlAlias value;
  239. if (this.removedMap.TryGetValue(alias, out value)) {
  240. throw Error.InvalidReferenceToRemovedAliasDuringDeflation();
  241. }
  242. return aref;
  243. }
  244. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  245. if (cref.Column.Alias != null && this.removedMap.ContainsKey(cref.Column.Alias)) {
  246. SqlColumnRef c = cref.Column.Expression as SqlColumnRef;
  247. if (c != null) {
  248. //The following code checks for cases where there are differences between the type returned
  249. //by a ColumnRef and the column that refers to it. This situation can occur when conversions
  250. //are optimized out of the SQL node tree. As mentioned in the SetClrType comments this is not
  251. //an operation that can have adverse effects and should only be used in limited cases, such as
  252. //this one.
  253. if (c.ClrType != cref.ClrType) {
  254. c.SetClrType(cref.ClrType);
  255. return this.VisitColumnRef(c);
  256. }
  257. }
  258. return c;
  259. }
  260. return cref;
  261. }
  262. internal override SqlSource VisitSource(SqlSource node) {
  263. node = (SqlSource)this.Visit(node);
  264. SqlAlias alias = node as SqlAlias;
  265. if (alias != null) {
  266. SqlSelect sel = alias.Node as SqlSelect;
  267. if (sel != null && this.IsTrivialSelect(sel)) {
  268. this.removedMap[alias] = alias;
  269. node = sel.From;
  270. }
  271. }
  272. return node;
  273. }
  274. internal override SqlSource VisitJoin(SqlJoin join) {
  275. base.VisitJoin(join);
  276. switch (join.JoinType) {
  277. case SqlJoinType.Cross:
  278. case SqlJoinType.Inner:
  279. // reducing either side would effect cardinality of results
  280. break;
  281. case SqlJoinType.LeftOuter:
  282. case SqlJoinType.CrossApply:
  283. case SqlJoinType.OuterApply:
  284. // may reduce to left if no references to the right
  285. if (this.HasEmptySource(join.Right)) {
  286. SqlAlias a = (SqlAlias)join.Right;
  287. this.removedMap[a] = a;
  288. return join.Left;
  289. }
  290. break;
  291. }
  292. return join;
  293. }
  294. private bool IsTrivialSelect(SqlSelect select) {
  295. if (select.OrderBy.Count != 0 ||
  296. select.GroupBy.Count != 0 ||
  297. select.Having != null ||
  298. select.Top != null ||
  299. select.IsDistinct ||
  300. select.Where != null)
  301. return false;
  302. return this.HasTrivialSource(select.From) && this.HasTrivialProjection(select);
  303. }
  304. private bool HasTrivialSource(SqlSource node) {
  305. SqlJoin join = node as SqlJoin;
  306. if (join != null) {
  307. return this.HasTrivialSource(join.Left) &&
  308. this.HasTrivialSource(join.Right);
  309. }
  310. return node is SqlAlias;
  311. }
  312. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  313. private bool HasTrivialProjection(SqlSelect select) {
  314. foreach (SqlColumn c in select.Row.Columns) {
  315. if (c.Expression != null && c.Expression.NodeType != SqlNodeType.ColumnRef) {
  316. return false;
  317. }
  318. }
  319. return true;
  320. }
  321. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  322. private bool HasEmptySource(SqlSource node) {
  323. SqlAlias alias = node as SqlAlias;
  324. if (alias == null) return false;
  325. SqlSelect sel = alias.Node as SqlSelect;
  326. if (sel == null) return false;
  327. return sel.Row.Columns.Count == 0 &&
  328. sel.From == null &&
  329. sel.Where == null &&
  330. sel.GroupBy.Count == 0 &&
  331. sel.Having == null &&
  332. sel.OrderBy.Count == 0;
  333. }
  334. }
  335. // remove duplicate columns from order by and group by lists
  336. class SqlDuplicateColumnDeflator : SqlVisitor
  337. {
  338. SqlColumnEqualizer equalizer = new SqlColumnEqualizer();
  339. internal override SqlSelect VisitSelect(SqlSelect select) {
  340. select.From = this.VisitSource(select.From);
  341. select.Where = this.VisitExpression(select.Where);
  342. for (int i = 0, n = select.GroupBy.Count; i < n; i++)
  343. {
  344. select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
  345. }
  346. // remove duplicate group expressions
  347. for (int i = select.GroupBy.Count - 1; i >= 0; i--)
  348. {
  349. for (int j = i - 1; j >= 0; j--)
  350. {
  351. if (SqlComparer.AreEqual(select.GroupBy[i], select.GroupBy[j]))
  352. {
  353. select.GroupBy.RemoveAt(i);
  354. break;
  355. }
  356. }
  357. }
  358. select.Having = this.VisitExpression(select.Having);
  359. for (int i = 0, n = select.OrderBy.Count; i < n; i++)
  360. {
  361. select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
  362. }
  363. // remove duplicate order expressions
  364. if (select.OrderBy.Count > 0)
  365. {
  366. this.equalizer.BuildEqivalenceMap(select.From);
  367. for (int i = select.OrderBy.Count - 1; i >= 0; i--)
  368. {
  369. for (int j = i - 1; j >= 0; j--)
  370. {
  371. if (this.equalizer.AreEquivalent(select.OrderBy[i].Expression, select.OrderBy[j].Expression))
  372. {
  373. select.OrderBy.RemoveAt(i);
  374. break;
  375. }
  376. }
  377. }
  378. }
  379. select.Top = this.VisitExpression(select.Top);
  380. select.Row = (SqlRow)this.Visit(select.Row);
  381. select.Selection = this.VisitExpression(select.Selection);
  382. return select;
  383. }
  384. }
  385. // if the top level select is simply a reprojection of the subquery, then remove it,
  386. // pushing any distinct names down
  387. class SqlTopSelectDeflator : SqlVisitor {
  388. internal override SqlSelect VisitSelect(SqlSelect select) {
  389. if (IsTrivialSelect(select)) {
  390. SqlSelect aselect = (SqlSelect)((SqlAlias)select.From).Node;
  391. // build up a column map, so we can rewrite the top-level selection expression
  392. Dictionary<SqlColumn, SqlColumnRef> map = new Dictionary<SqlColumn, SqlColumnRef>();
  393. foreach (SqlColumn c in select.Row.Columns) {
  394. SqlColumnRef cref = (SqlColumnRef)c.Expression;
  395. map.Add(c, cref);
  396. // push the interesting column names down (non null)
  397. if (!string.IsNullOrEmpty(c.Name)) {
  398. cref.Column.Name = c.Name;
  399. }
  400. }
  401. aselect.Selection = new ColumnMapper(map).VisitExpression(select.Selection);
  402. return aselect;
  403. }
  404. return select;
  405. }
  406. private bool IsTrivialSelect(SqlSelect select) {
  407. if (select.OrderBy.Count != 0 ||
  408. select.GroupBy.Count != 0 ||
  409. select.Having != null ||
  410. select.Top != null ||
  411. select.IsDistinct ||
  412. select.Where != null)
  413. return false;
  414. return this.HasTrivialSource(select.From) && this.HasTrivialProjection(select);
  415. }
  416. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  417. private bool HasTrivialSource(SqlSource node) {
  418. SqlAlias alias = node as SqlAlias;
  419. if (alias == null) return false;
  420. return alias.Node is SqlSelect;
  421. }
  422. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  423. private bool HasTrivialProjection(SqlSelect select) {
  424. foreach (SqlColumn c in select.Row.Columns) {
  425. if (c.Expression != null && c.Expression.NodeType != SqlNodeType.ColumnRef) {
  426. return false;
  427. }
  428. }
  429. return true;
  430. }
  431. class ColumnMapper : SqlVisitor {
  432. Dictionary<SqlColumn, SqlColumnRef> map;
  433. internal ColumnMapper(Dictionary<SqlColumn, SqlColumnRef> map) {
  434. this.map = map;
  435. }
  436. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  437. SqlColumnRef mapped;
  438. if (this.map.TryGetValue(cref.Column, out mapped)) {
  439. return mapped;
  440. }
  441. return cref;
  442. }
  443. }
  444. }
  445. }
  446. }