-
Notifications
You must be signed in to change notification settings - Fork 763
/
Copy pathpushdown_projections.py
141 lines (110 loc) · 5.24 KB
/
pushdown_projections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# Selection to use if selection list is empty
def default_selection(is_agg: bool) -> exp.Alias:
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
"""
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Args:
expression (sqlglot.Expression): expression to optimize
remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
schema = ensure_schema(schema)
source_column_alias_count = {}
referenced_columns = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
alias_count = source_column_alias_count.get(scope, 0)
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.SetOperation):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema, alias_count)
if scope.expression.is_star:
continue
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:
table_name = col.table
col_name = col.name
selects[table_name].add(col_name)
# Push the selected columns down to the next scope
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
referenced_columns[source].update(columns)
column_aliases = node.alias_column_names
if column_aliases:
source_column_alias_count[source] = len(column_aliases)
return expression
def _remove_unused_selections(scope, parent_selections, schema, alias_count):
order = scope.expression.args.get("order")
if order:
# Assume columns without a qualified table are references to output columns
order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
else:
order_refs = set()
new_selections = []
removed = False
star = False
is_agg = False
select_all = SELECT_ALL in parent_selections
for selection in scope.expression.selects:
name = selection.alias_or_name
if select_all or name in parent_selections or name in order_refs or alias_count > 0:
new_selections.append(selection)
alias_count -= 1
else:
if selection.is_star:
star = True
removed = True
if not is_agg and selection.find(exp.AggFunc):
is_agg = True
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
for name in sorted(parent_selections):
if name not in names:
new_selections.append(
alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
)
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(default_selection(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)
if removed:
scope.clear_cache()