-
Notifications
You must be signed in to change notification settings - Fork 763
/
Copy pathpushdown_predicates.py
209 lines (168 loc) · 7.6 KB
/
pushdown_predicates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression, dialect=None):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
if root:
scope_ref_count = root.ref_count()
for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
join_index = {
join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
}
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.alias_or_name
if name in scope.selected_sources:
pushdown(
join.args.get("on"),
{name: scope.selected_sources[name]},
scope_ref_count,
dialect,
)
return expression
def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
if not condition:
return
condition = condition.replace(simplify(condition, dialect=dialect))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
else:
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
# Don't push the predicate if it references tables that appear in later joins
this_index = join_index[name]
if all(join_index.get(table, -1) < this_index for table in predicate_tables):
predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
"""
# find all the tables that can be pushdown too
# these are tables that are referenced in all blocks of a DNF
# (a.x AND b.x) OR (a.y AND c.y)
# only table a can be push down
pushdown_tables = set()
for a in predicates:
a_tables = exp.column_table_names(a)
for b in predicates:
a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
conditions = {}
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
if table not in nodes:
continue
conditions[table] = (
exp.or_(conditions[table], predicate) if table in conditions else predicate
)
for name, node in nodes.items():
if name not in conditions:
continue
predicate = conditions[name]
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
# we want to find the root join or from statement
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression
if isinstance(node, exp.Join):
if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
has_window_expression = any(
select for select in node.selects if select.find(exp.Window)
)
# we can't push down predicates to select statements if they are referenced in
# multiple places.
if (
not node.args.get("group")
and scope_ref_count[id(source)] < 2
and not has_window_expression
):
nodes[table] = node
return nodes
def replace_aliases(source, predicate):
aliases = {}
for select in source.selects:
if isinstance(select, exp.Alias):
aliases[select.alias] = select.this
else:
aliases[select.name] = select
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)