Skip to content

Commit

Permalink
Don't use equality to narrow when value is IntEnum/StrEnum (#17866)
Browse files Browse the repository at this point in the history
IntEnum/StrEnum values compare equal to the corresponding int/str
values, which breaks the logic we use for narrowing based on equality to
a literal value. Special case IntEnum/StrEnum to avoid the incorrect
behavior.

Fix #17860.
  • Loading branch information
JukkaL authored Oct 2, 2024
1 parent 329e38e commit aa7733a
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,14 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
if (
typ.type.is_enum
and name in ("__eq__", "__ne__")
and any(base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in typ.type.mro)
):
# IntEnum and StrEnum values have non-straightfoward equality, so treat them
# as if they had custom __eq__ and __ne__
return True
method = typ.type.get(name)
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
Expand Down
76 changes: 76 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2130,3 +2130,79 @@ else:

[typing fixtures/typing-medium.pyi]
[builtins fixtures/ops.pyi]

[case testNarrowingWithIntEnum]
# mypy: strict-equality
from __future__ import annotations
from typing import Any
from enum import IntEnum, StrEnum

class IE(IntEnum):
X = 1
Y = 2

def f1(x: int) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "builtins.int"
if x != IE.X:
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "builtins.int"

def f2(x: IE) -> None:
if x == 1:
reveal_type(x) # N: Revealed type is "__main__.IE"
else:
reveal_type(x) # N: Revealed type is "__main__.IE"

def f3(x: object) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "builtins.object"
else:
reveal_type(x) # N: Revealed type is "builtins.object"

def f4(x: int | Any) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]"
else:
reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]"

def f5(x: int) -> None:
if x is IE.X:
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
else:
reveal_type(x) # N: Revealed type is "builtins.int"
if x is not IE.X:
reveal_type(x) # N: Revealed type is "builtins.int"
else:
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithStrEnum]
# mypy: strict-equality
from enum import StrEnum

class SE(StrEnum):
A = 'a'
B = 'b'

def f1(x: str) -> None:
if x == SE.A:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "builtins.str"

def f2(x: SE) -> None:
if x == 'a':
reveal_type(x) # N: Revealed type is "__main__.SE"
else:
reveal_type(x) # N: Revealed type is "__main__.SE"

def f3(x: object) -> None:
if x == SE.A:
reveal_type(x) # N: Revealed type is "builtins.object"
else:
reveal_type(x) # N: Revealed type is "builtins.object"
[builtins fixtures/primitives.pyi]

0 comments on commit aa7733a

Please sign in to comment.