Skip to content

Commit

Permalink
add allow_ops_in_compiled_graph, and corresponding test
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Apr 17, 2023
1 parent 03f2806 commit 32fa76e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
20 changes: 19 additions & 1 deletion einops/_torch_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Importantly, whole lib is designed so that you can't use it
"""

import warnings
from typing import Dict, List

import torch
Expand Down Expand Up @@ -82,3 +82,21 @@ def apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, re
if len(added_axes) > 0:
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
return backend.reshape(tensor, final_shapes)


def allow_ops_in_compiled_graph():
try:
from torch._dynamo import allow_in_graph
except ImportError:
from warnings import warn
warnings.warn("allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning)

from .einops import rearrange, reduce, repeat, einsum
from .packing import pack, unpack

allow_in_graph(rearrange)
allow_in_graph(reduce)
allow_in_graph(repeat)
allow_in_graph(einsum)
allow_in_graph(pack)
allow_in_graph(unpack)
47 changes: 46 additions & 1 deletion tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy
import pytest
import torch

from parameterized import parameterized, parameterized_class

Expand All @@ -13,7 +14,7 @@
import einops.parsing
from einops._backends import AbstractBackend
from einops.einops import rearrange, parse_shape, _optimize_transformation
from . import collect_test_backends
from . import collect_test_backends, is_backend_tested

__author__ = "Alex Rogozhnikov"

Expand Down Expand Up @@ -226,3 +227,47 @@ def test_is_float_type():
if "chainer" in backend.framework_name and not is_float:
continue # chainer doesn't allow non-floating tensors
assert backend.is_float_type(input) == is_float, (dtype, backend, input.dtype)


def test_torch_compile():
"""
Test ensures that allow_ops_in_compiled_graph allows compiling in a single graph
Additionally we ensure that after compilation cache works properly
(by changing shapes and patterns)
We additionally check that pack/unpack still can be handled despite variable number of inputs/outputs
"""
if not is_backend_tested('torch'):
pytest.skip()
from torch import nn
from einops import repeat, reduce, pack, unpack, einsum
from einops._torch_specific import allow_ops_in_compiled_graph

allow_ops_in_compiled_graph()
class TestClass(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x_abc, suffix=''):
a, b, c = x_abc.shape

def suf(pattern):
parts = pattern.split()
return ' '.join([p if p[-1] not in 'acd' else p + suffix for p in parts])

x_abcd = repeat(x_abc, suf('a b c -> a b c 4'))
x_abc = reduce(x_abcd, suf('a b c d -> a b c'), 'min')
x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf('a b * c'))
x_array = unpack(rearrange(x_abdc, suf('a b d c -> (a b ) 1 c d')), ps, 'ab one1 c *')
x1 = x_array[0] + len(x_array)
x1 = rearrange(x1, suf('(a b ) 1 c -> a b c'), b=b)
addition = einsum(x_abc, x_abcd, suf('a b c , a b c d -> d'))[0]
return x1 + addition

original = TestClass()
compiled = torch.compile(original, fullgraph=True)
for size in [10, 20, 40]:
x = torch.zeros([size, size + 1, size + 2])
for suffix in ['', 'suf1', 'other_suffix']:
result1 = original(x)
result2 = compiled(x, suffix)
assert torch.allclose(result1, result2)

0 comments on commit 32fa76e

Please sign in to comment.