SqlUnionizer.cs 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using System.Diagnostics.CodeAnalysis;
  5. namespace System.Data.Linq.SqlClient {
  6. internal class SqlUnionizer {
  7. internal static SqlNode Unionize(SqlNode node) {
  8. return new Visitor().Visit(node);
  9. }
  10. class Visitor : SqlVisitor {
  11. internal override SqlSelect VisitSelect(SqlSelect select) {
  12. base.VisitSelect(select);
  13. // enforce exact ordering of columns in union selects
  14. SqlUnion union = this.GetUnion(select.From);
  15. if (union != null) {
  16. SqlSelect sleft = union.Left as SqlSelect;
  17. SqlSelect sright = union.Right as SqlSelect;
  18. if (sleft != null & sright != null) {
  19. // preset ordinals to high values (so any unreachable column definition is ordered last)
  20. for (int i = 0, n = sleft.Row.Columns.Count; i < n; i++) {
  21. sleft.Row.Columns[i].Ordinal = select.Row.Columns.Count + i;
  22. }
  23. for (int i = 0, n = sright.Row.Columns.Count; i < n; i++) {
  24. sright.Row.Columns[i].Ordinal = select.Row.Columns.Count + i;
  25. }
  26. // next assign ordinals to all direct columns in subselects
  27. for (int i = 0, n = select.Row.Columns.Count; i < n; i++) {
  28. SqlExprSet es = select.Row.Columns[i].Expression as SqlExprSet;
  29. if (es != null) {
  30. for (int e = 0, en = es.Expressions.Count; e < en; e++) {
  31. SqlColumnRef cr = es.Expressions[e] as SqlColumnRef;
  32. if (cr != null && e >= select.Row.Columns.Count) {
  33. cr.Column.Ordinal = i;
  34. }
  35. }
  36. }
  37. }
  38. // next sort columns in left & right subselects
  39. Comparison<SqlColumn> comp = (x,y) => x.Ordinal - y.Ordinal;
  40. sleft.Row.Columns.Sort(comp);
  41. sright.Row.Columns.Sort(comp);
  42. }
  43. }
  44. return select;
  45. }
  46. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  47. private SqlUnion GetUnion(SqlSource source) {
  48. SqlAlias alias = source as SqlAlias;
  49. if (alias != null) {
  50. SqlUnion union = alias.Node as SqlUnion;
  51. if (union != null)
  52. return union;
  53. }
  54. return null;
  55. }
  56. }
  57. }
  58. }