diff --git a/mypy/typeops.py b/mypy/typeops.py index 7f530d13d4e2..221976ce02b3 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -1022,6 +1022,14 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool """ typ = get_proper_type(typ) if isinstance(typ, Instance): + if ( + typ.type.is_enum + and name in ("__eq__", "__ne__") + and any(base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in typ.type.mro) + ): + # IntEnum and StrEnum values have non-straightfoward equality, so treat them + # as if they had custom __eq__ and __ne__ + return True method = typ.type.get(name) if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): if method.node.info: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 21a1523580c2..169523e61be7 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2130,3 +2130,79 @@ else: [typing fixtures/typing-medium.pyi] [builtins fixtures/ops.pyi] + +[case testNarrowingWithIntEnum] +# mypy: strict-equality +from __future__ import annotations +from typing import Any +from enum import IntEnum, StrEnum + +class IE(IntEnum): + X = 1 + Y = 2 + +def f1(x: int) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "builtins.int" + else: + reveal_type(x) # N: Revealed type is "builtins.int" + if x != IE.X: + reveal_type(x) # N: Revealed type is "builtins.int" + else: + reveal_type(x) # N: Revealed type is "builtins.int" + +def f2(x: IE) -> None: + if x == 1: + reveal_type(x) # N: Revealed type is "__main__.IE" + else: + reveal_type(x) # N: Revealed type is "__main__.IE" + +def f3(x: object) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "builtins.object" + else: + reveal_type(x) # N: Revealed type is "builtins.object" + +def f4(x: int | Any) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]" + else: + reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]" + +def f5(x: int) -> None: + if x is IE.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" + else: + reveal_type(x) # N: Revealed type is "builtins.int" + if x is not IE.X: + reveal_type(x) # N: Revealed type is "builtins.int" + else: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingWithStrEnum] +# mypy: strict-equality +from enum import StrEnum + +class SE(StrEnum): + A = 'a' + B = 'b' + +def f1(x: str) -> None: + if x == SE.A: + reveal_type(x) # N: Revealed type is "builtins.str" + else: + reveal_type(x) # N: Revealed type is "builtins.str" + +def f2(x: SE) -> None: + if x == 'a': + reveal_type(x) # N: Revealed type is "__main__.SE" + else: + reveal_type(x) # N: Revealed type is "__main__.SE" + +def f3(x: object) -> None: + if x == SE.A: + reveal_type(x) # N: Revealed type is "builtins.object" + else: + reveal_type(x) # N: Revealed type is "builtins.object" +[builtins fixtures/primitives.pyi]