SqlResolver.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Data.Linq;
  4. namespace System.Data.Linq.SqlClient {
  5. // Resolves references to columns/expressions defined in other scopes
  6. internal class SqlResolver {
  7. Visitor visitor;
  8. internal SqlResolver() {
  9. this.visitor = new Visitor();
  10. }
  11. internal SqlNode Resolve(SqlNode node) {
  12. return this.visitor.Visit(node);
  13. }
  14. private static string GetColumnName(SqlColumn c) {
  15. #if DEBUG
  16. return c.Text;
  17. #else
  18. return c.Name;
  19. #endif
  20. }
  21. class Visitor : SqlScopedVisitor {
  22. SqlBubbler bubbler;
  23. internal Visitor() {
  24. this.bubbler = new SqlBubbler();
  25. }
  26. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  27. SqlColumnRef result = this.BubbleUp(cref);
  28. if (result == null) {
  29. throw Error.ColumnReferencedIsNotInScope(GetColumnName(cref.Column));
  30. }
  31. return result;
  32. }
  33. private SqlColumnRef BubbleUp(SqlColumnRef cref) {
  34. for (Scope s = this.CurrentScope; s != null; s = s.ContainingScope) {
  35. if (s.Source != null) {
  36. SqlColumn found = this.bubbler.BubbleUp(cref.Column, s.Source);
  37. if (found != null) {
  38. if (found != cref.Column)
  39. return new SqlColumnRef(found);
  40. return cref;
  41. }
  42. }
  43. }
  44. return null;
  45. }
  46. }
  47. internal class SqlScopedVisitor : SqlVisitor {
  48. internal Scope CurrentScope;
  49. internal class Scope {
  50. SqlNode source;
  51. Scope containing;
  52. internal Scope(SqlNode source, Scope containing) {
  53. this.source = source;
  54. this.containing = containing;
  55. }
  56. internal SqlNode Source {
  57. get { return this.source; }
  58. }
  59. internal Scope ContainingScope {
  60. get { return this.containing; }
  61. }
  62. }
  63. internal SqlScopedVisitor() {
  64. this.CurrentScope = new Scope(null, null);
  65. }
  66. internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
  67. Scope save = this.CurrentScope;
  68. this.CurrentScope = new Scope(null, this.CurrentScope);
  69. base.VisitSubSelect(ss);
  70. this.CurrentScope = save;
  71. return ss;
  72. }
  73. internal override SqlSelect VisitSelect(SqlSelect select) {
  74. select.From = (SqlSource)this.Visit(select.From);
  75. Scope save = this.CurrentScope;
  76. this.CurrentScope = new Scope(select.From, this.CurrentScope.ContainingScope);
  77. select.Where = this.VisitExpression(select.Where);
  78. for (int i = 0, n = select.GroupBy.Count; i < n; i++) {
  79. select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
  80. }
  81. select.Having = this.VisitExpression(select.Having);
  82. for (int i = 0, n = select.OrderBy.Count; i < n; i++) {
  83. select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
  84. }
  85. select.Top = this.VisitExpression(select.Top);
  86. select.Row = (SqlRow)this.Visit(select.Row);
  87. // selection must be able to see its own projection
  88. this.CurrentScope = new Scope(select, this.CurrentScope.ContainingScope);
  89. select.Selection = this.VisitExpression(select.Selection);
  90. this.CurrentScope = save;
  91. return select;
  92. }
  93. internal override SqlStatement VisitInsert(SqlInsert sin) {
  94. Scope save = this.CurrentScope;
  95. this.CurrentScope = new Scope(sin, this.CurrentScope.ContainingScope);
  96. base.VisitInsert(sin);
  97. this.CurrentScope = save;
  98. return sin;
  99. }
  100. internal override SqlStatement VisitUpdate(SqlUpdate sup) {
  101. Scope save = this.CurrentScope;
  102. this.CurrentScope = new Scope(sup.Select, this.CurrentScope.ContainingScope);
  103. base.VisitUpdate(sup);
  104. this.CurrentScope = save;
  105. return sup;
  106. }
  107. internal override SqlStatement VisitDelete(SqlDelete sd) {
  108. Scope save = this.CurrentScope;
  109. this.CurrentScope = new Scope(sd, this.CurrentScope.ContainingScope);
  110. base.VisitDelete(sd);
  111. this.CurrentScope = save;
  112. return sd;
  113. }
  114. internal override SqlSource VisitJoin(SqlJoin join) {
  115. Scope save = this.CurrentScope;
  116. switch (join.JoinType) {
  117. case SqlJoinType.CrossApply:
  118. case SqlJoinType.OuterApply: {
  119. this.Visit(join.Left);
  120. Scope tmp = new Scope(join.Left, this.CurrentScope.ContainingScope);
  121. this.CurrentScope = new Scope(null, tmp);
  122. this.Visit(join.Right);
  123. Scope tmp2 = new Scope(join.Right, tmp);
  124. this.CurrentScope = new Scope(null, tmp2);
  125. this.Visit(join.Condition);
  126. break;
  127. }
  128. default: {
  129. this.Visit(join.Left);
  130. this.Visit(join.Right);
  131. this.CurrentScope = new Scope(null, new Scope(join.Right, new Scope(join.Left, this.CurrentScope.ContainingScope)));
  132. this.Visit(join.Condition);
  133. break;
  134. }
  135. }
  136. this.CurrentScope = save;
  137. return join;
  138. }
  139. }
  140. // finds location of expression definition and re-projects that value all the
  141. // way to the outermost projection
  142. internal class SqlBubbler : SqlVisitor {
  143. SqlColumn match;
  144. SqlColumn found;
  145. internal SqlBubbler() {
  146. }
  147. internal SqlColumn BubbleUp(SqlColumn col, SqlNode source) {
  148. this.match = this.GetOriginatingColumn(col);
  149. this.found = null;
  150. this.Visit(source);
  151. return this.found;
  152. }
  153. internal SqlColumn GetOriginatingColumn(SqlColumn col) {
  154. SqlColumnRef cref = col.Expression as SqlColumnRef;
  155. if (cref != null) {
  156. return this.GetOriginatingColumn(cref.Column);
  157. }
  158. return col;
  159. }
  160. internal override SqlRow VisitRow(SqlRow row) {
  161. foreach (SqlColumn c in row.Columns) {
  162. if (this.RefersToColumn(c, this.match)) {
  163. if (this.found != null) {
  164. throw Error.ColumnIsDefinedInMultiplePlaces(GetColumnName(this.match));
  165. }
  166. this.found = c;
  167. break;
  168. }
  169. }
  170. return row;
  171. }
  172. internal override SqlTable VisitTable(SqlTable tab) {
  173. foreach (SqlColumn c in tab.Columns) {
  174. if (c == this.match) {
  175. if (this.found != null)
  176. throw Error.ColumnIsDefinedInMultiplePlaces(GetColumnName(this.match));
  177. this.found = c;
  178. break;
  179. }
  180. }
  181. return tab;
  182. }
  183. internal override SqlSource VisitJoin(SqlJoin join) {
  184. switch (join.JoinType) {
  185. case SqlJoinType.CrossApply:
  186. case SqlJoinType.OuterApply: {
  187. this.Visit(join.Left);
  188. if (this.found == null) {
  189. this.Visit(join.Right);
  190. }
  191. break;
  192. }
  193. default: {
  194. this.Visit(join.Left);
  195. this.Visit(join.Right);
  196. break;
  197. }
  198. }
  199. return join;
  200. }
  201. internal override SqlExpression VisitTableValuedFunctionCall(SqlTableValuedFunctionCall fc) {
  202. foreach (SqlColumn c in fc.Columns) {
  203. if (c == this.match) {
  204. if (this.found != null)
  205. throw Error.ColumnIsDefinedInMultiplePlaces(GetColumnName(this.match));
  206. this.found = c;
  207. break;
  208. }
  209. }
  210. return fc;
  211. }
  212. private void ForceLocal(SqlRow row, string name) {
  213. bool isLocal = false;
  214. // check to see if it already exists locally
  215. foreach (SqlColumn c in row.Columns) {
  216. if (this.RefersToColumn(c, this.found)) {
  217. this.found = c;
  218. isLocal = true;
  219. break;
  220. }
  221. }
  222. if (!isLocal) {
  223. // need to put this in the local projection list to bubble it up
  224. SqlColumn c = new SqlColumn(found.ClrType, found.SqlType, name, this.found.MetaMember, new SqlColumnRef(this.found), row.SourceExpression);
  225. row.Columns.Add(c);
  226. this.found = c;
  227. }
  228. }
  229. private bool IsFoundInGroup(SqlSelect select) {
  230. // does the column happen to be listed in the group-by clause?
  231. foreach (SqlExpression exp in select.GroupBy) {
  232. if (this.RefersToColumn(exp, this.found) || this.RefersToColumn(exp, this.match)) {
  233. return true;
  234. }
  235. }
  236. return false;
  237. }
  238. internal override SqlSelect VisitSelect(SqlSelect select) {
  239. // look in this projection
  240. this.Visit(select.Row);
  241. if (this.found == null) {
  242. // look in upstream projections
  243. this.Visit(select.From);
  244. // bubble it up
  245. if (this.found != null) {
  246. if (select.IsDistinct && !match.IsConstantColumn) {
  247. throw Error.ColumnIsNotAccessibleThroughDistinct(GetColumnName(this.match));
  248. }
  249. if (select.GroupBy.Count == 0 || this.IsFoundInGroup(select)) {
  250. this.ForceLocal(select.Row, this.found.Name);
  251. }
  252. else {
  253. // found it, but its hidden behind the group-by
  254. throw Error.ColumnIsNotAccessibleThroughGroupBy(GetColumnName(this.match));
  255. }
  256. }
  257. }
  258. return select;
  259. }
  260. }
  261. }
  262. }