Skip to content

Commit

Permalink
Ensure subqueries used as CASE targets are not aliased.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Jun 17, 2024
1 parent 415384d commit 4c4b788
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
15 changes: 13 additions & 2 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,16 @@ def __sql__(self, ctx):
return ctx.literal(self.window._alias or 'w')


class _InFunction(Node):
def __init__(self, node, in_function=True):
self.node = node
self.in_function = in_function

def __sql__(self, ctx):
with ctx(in_function=self.in_function):
return ctx.sql(self.node)


class Case(ColumnBase):
def __init__(self, predicate, expression_tuples, default=None):
self.predicate = predicate
Expand All @@ -1838,9 +1848,10 @@ def __sql__(self, ctx):
if self.predicate is not None:
clauses.append(self.predicate)
for expr, value in self.expression_tuples:
clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value))
clauses.extend((SQL('WHEN'), expr,
SQL('THEN'), _InFunction(value)))
if self.default is not None:
clauses.extend((SQL('ELSE'), self.default))
clauses.extend((SQL('ELSE'), _InFunction(self.default)))
clauses.append(SQL('END'))
with ctx(in_function=False):
return ctx.sql(NodeList(clauses))
Expand Down
14 changes: 14 additions & 0 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,20 @@ def test_case_subquery(self):
'CASE WHEN ("t1"."id" IN (SELECT "t1"."id" FROM "n" AS "t1")) '
'THEN ? ELSE ? END) FROM "n" AS "t1"'), [1, 0])

case = Case(None, [
(Name.id < 5, Name.select(fn.SUM(Name.id))),
(Name.id > 5, Name.select(fn.COUNT(Name.name)).distinct())],
Name.select(fn.MAX(Name.id)))
q = Name.select(Name.name, case.alias('magic'))
self.assertSQL(q, (
'SELECT "t1"."name", CASE '
'WHEN ("t1"."id" < ?) '
'THEN (SELECT SUM("t1"."id") FROM "n" AS "t1") '
'WHEN ("t1"."id" > ?) '
'THEN (SELECT DISTINCT COUNT("t1"."name") FROM "n" AS "t1") '
'ELSE (SELECT MAX("t1"."id") FROM "n" AS "t1") END AS "magic" '
'FROM "n" AS "t1"'), [5, 5])



class TestSelectFeatures(BaseTestCase):
Expand Down

0 comments on commit 4c4b788

Please sign in to comment.