-
Notifications
You must be signed in to change notification settings - Fork 23.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…1122) Pull Request resolved: #141122 Approved by: https://github.com/Skylion007, https://github.com/jansel
- Loading branch information
1 parent
fb529c2
commit d65f194
Showing
4 changed files
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
"builtins", | ||
"functools", | ||
"itertools", | ||
"operator", | ||
"os", | ||
"sys", | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Python polyfills for operator | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import operator | ||
from typing import Any, Callable, overload, TypeVar | ||
from typing_extensions import TypeVarTuple, Unpack | ||
|
||
from ..decorators import substitute_in_graph | ||
|
||
|
||
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) | ||
__all__ = ["attrgetter", "itemgetter"] | ||
|
||
|
||
_T = TypeVar("_T") | ||
_T1 = TypeVar("_T1") | ||
_T2 = TypeVar("_T2") | ||
_Ts = TypeVarTuple("_Ts") | ||
_U = TypeVar("_U") | ||
_U1 = TypeVar("_U1") | ||
_U2 = TypeVar("_U2") | ||
_Us = TypeVarTuple("_Us") | ||
|
||
|
||
@overload | ||
def attrgetter(attr: str, /) -> Callable[[Any], _U]: | ||
... | ||
|
||
|
||
@overload | ||
def attrgetter( | ||
attr1: str, attr2: str, /, *attrs: str | ||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: | ||
... | ||
|
||
|
||
# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter | ||
@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] | ||
def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: | ||
if len(attrs) == 0: | ||
raise TypeError("attrgetter expected 1 argument, got 0") | ||
|
||
if any(not isinstance(attr, str) for attr in attrs): | ||
raise TypeError("attribute name must be a string") | ||
|
||
def resolve_attr(obj: Any, attr: str) -> Any: | ||
for name in attr.split("."): | ||
obj = getattr(obj, name) | ||
return obj | ||
|
||
if len(attrs) == 1: | ||
attr = attrs[0] | ||
|
||
def getter(obj: Any) -> Any: | ||
return resolve_attr(obj, attr) | ||
|
||
else: | ||
|
||
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] | ||
return tuple(resolve_attr(obj, attr) for attr in attrs) | ||
|
||
return getter | ||
|
||
|
||
@overload | ||
def itemgetter(item: _T, /) -> Callable[[Any], _U]: | ||
... | ||
|
||
|
||
@overload | ||
def itemgetter( | ||
item1: _T1, item2: _T2, /, *items: Unpack[_Ts] | ||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: | ||
... | ||
|
||
|
||
# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter | ||
@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] | ||
def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: | ||
if len(items) == 0: | ||
raise TypeError("itemgetter expected 1 argument, got 0") | ||
|
||
if len(items) == 1: | ||
item = items[0] | ||
|
||
def getter(obj: Any) -> Any: | ||
return obj[item] | ||
|
||
else: | ||
|
||
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] | ||
return tuple(obj[item] for item in items) | ||
|
||
return getter |