Skip to content

Commit

Permalink
[dynamo] support operator.attrgetter and operator.itemgetter (#14…
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Nov 21, 2024
1 parent fb529c2 commit d65f194
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,44 @@ def fn():
opt_fn = torch.compile(fn, fullgraph=True)
self.assertEqual(opt_fn(), fn())

def test_attrgetter(self):
for attrs in (
("shape",),
("data.shape",),
("device", "shape"),
("device", "shape", "data.shape"),
):
with self.subTest(attrs=attrs):

def fn(x, y):
getter = operator.attrgetter(*attrs)
return getter(x), getter(y)

opt_fn = torch.compile(fullgraph=True)(fn)

x = torch.randn(3, 4)
y = torch.randn(3, 4)
self.assertEqual(opt_fn(x, y), fn(x, y))

def test_itemgetter(self):
for items in (
(0,),
(slice(1, 3),),
(0, 1),
(slice(1, 3), 0, 1),
):
with self.subTest(items=items):

def fn(x, y):
getter = operator.itemgetter(*items)
return getter(x), getter(y)

opt_fn = torch.compile(fullgraph=True)(fn)

x = torch.randn(3, 4)
y = torch.randn(3, 4)
self.assertEqual(opt_fn(x, y), fn(x, y))

def gen_random_range_args(self):
args_count = random.randint(1, 3)
args = [random.randint(-10, 10) for _ in range(args_count)]
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/polyfills/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
builtins as builtins,
functools as functools,
itertools as itertools,
operator as operator,
os as os,
sys as sys,
)
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/polyfills/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"builtins",
"functools",
"itertools",
"operator",
"os",
"sys",
)
Expand Down
97 changes: 97 additions & 0 deletions torch/_dynamo/polyfills/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Python polyfills for operator
"""

from __future__ import annotations

import operator
from typing import Any, Callable, overload, TypeVar
from typing_extensions import TypeVarTuple, Unpack

from ..decorators import substitute_in_graph


# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
__all__ = ["attrgetter", "itemgetter"]


_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_Ts = TypeVarTuple("_Ts")
_U = TypeVar("_U")
_U1 = TypeVar("_U1")
_U2 = TypeVar("_U2")
_Us = TypeVarTuple("_Us")


@overload
def attrgetter(attr: str, /) -> Callable[[Any], _U]:
...


@overload
def attrgetter(
attr1: str, attr2: str, /, *attrs: str
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
...


# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter
@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
if len(attrs) == 0:
raise TypeError("attrgetter expected 1 argument, got 0")

if any(not isinstance(attr, str) for attr in attrs):
raise TypeError("attribute name must be a string")

def resolve_attr(obj: Any, attr: str) -> Any:
for name in attr.split("."):
obj = getattr(obj, name)
return obj

if len(attrs) == 1:
attr = attrs[0]

def getter(obj: Any) -> Any:
return resolve_attr(obj, attr)

else:

def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
return tuple(resolve_attr(obj, attr) for attr in attrs)

return getter


@overload
def itemgetter(item: _T, /) -> Callable[[Any], _U]:
...


@overload
def itemgetter(
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]:
...


# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter
@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]:
if len(items) == 0:
raise TypeError("itemgetter expected 1 argument, got 0")

if len(items) == 1:
item = items[0]

def getter(obj: Any) -> Any:
return obj[item]

else:

def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
return tuple(obj[item] for item in items)

return getter

0 comments on commit d65f194

Please sign in to comment.