From d65f194ab91055d1a104163dce93d5580fec2d28 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 21 Nov 2024 00:58:28 +0800 Subject: [PATCH] [dynamo] support `operator.attrgetter` and `operator.itemgetter` (#141122) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141122 Approved by: https://github.com/Skylion007, https://github.com/jansel --- test/dynamo/test_functions.py | 38 +++++++++++ torch/_dynamo/polyfills/__init__.py | 1 + torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/operator.py | 97 +++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+) create mode 100644 torch/_dynamo/polyfills/operator.py diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 54342ac50e8914..984c5fc1d853ed 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3071,6 +3071,44 @@ def fn(): opt_fn = torch.compile(fn, fullgraph=True) self.assertEqual(opt_fn(), fn()) + def test_attrgetter(self): + for attrs in ( + ("shape",), + ("data.shape",), + ("device", "shape"), + ("device", "shape", "data.shape"), + ): + with self.subTest(attrs=attrs): + + def fn(x, y): + getter = operator.attrgetter(*attrs) + return getter(x), getter(y) + + opt_fn = torch.compile(fullgraph=True)(fn) + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + self.assertEqual(opt_fn(x, y), fn(x, y)) + + def test_itemgetter(self): + for items in ( + (0,), + (slice(1, 3),), + (0, 1), + (slice(1, 3), 0, 1), + ): + with self.subTest(items=items): + + def fn(x, y): + getter = operator.itemgetter(*items) + return getter(x), getter(y) + + opt_fn = torch.compile(fullgraph=True)(fn) + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + self.assertEqual(opt_fn(x, y), fn(x, y)) + def gen_random_range_args(self): args_count = random.randint(1, 3) args = [random.randint(-10, 10) for _ in range(args_count)] diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5abd52c17640e6..fc30d5a759ed6d 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -21,6 +21,7 @@ builtins as builtins, functools as functools, itertools as itertools, + operator as operator, os as os, sys as sys, ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 24478e1b5a0f96..c67a5d907cfeb9 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -16,6 +16,7 @@ "builtins", "functools", "itertools", + "operator", "os", "sys", ) diff --git a/torch/_dynamo/polyfills/operator.py b/torch/_dynamo/polyfills/operator.py new file mode 100644 index 00000000000000..bf84895bdd013d --- /dev/null +++ b/torch/_dynamo/polyfills/operator.py @@ -0,0 +1,97 @@ +""" +Python polyfills for operator +""" + +from __future__ import annotations + +import operator +from typing import Any, Callable, overload, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +from ..decorators import substitute_in_graph + + +# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) +__all__ = ["attrgetter", "itemgetter"] + + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_Ts = TypeVarTuple("_Ts") +_U = TypeVar("_U") +_U1 = TypeVar("_U1") +_U2 = TypeVar("_U2") +_Us = TypeVarTuple("_Us") + + +@overload +def attrgetter(attr: str, /) -> Callable[[Any], _U]: + ... + + +@overload +def attrgetter( + attr1: str, attr2: str, /, *attrs: str +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: + ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter +@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(attrs) == 0: + raise TypeError("attrgetter expected 1 argument, got 0") + + if any(not isinstance(attr, str) for attr in attrs): + raise TypeError("attribute name must be a string") + + def resolve_attr(obj: Any, attr: str) -> Any: + for name in attr.split("."): + obj = getattr(obj, name) + return obj + + if len(attrs) == 1: + attr = attrs[0] + + def getter(obj: Any) -> Any: + return resolve_attr(obj, attr) + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(resolve_attr(obj, attr) for attr in attrs) + + return getter + + +@overload +def itemgetter(item: _T, /) -> Callable[[Any], _U]: + ... + + +@overload +def itemgetter( + item1: _T1, item2: _T2, /, *items: Unpack[_Ts] +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: + ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter +@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(items) == 0: + raise TypeError("itemgetter expected 1 argument, got 0") + + if len(items) == 1: + item = items[0] + + def getter(obj: Any) -> Any: + return obj[item] + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(obj[item] for item in items) + + return getter