Skip to content

Commit

Permalink
[fx] support meta tracing for aten level computation graphs like func…
Browse files Browse the repository at this point in the history
…torch. (#1536)

* [fx] support meta tracing for aten level computation graphs like functorch.

* [fx] support meta tracing for aten level computation graphs like functorch.

* [fx] remove redundant import.

* [fx] add docstring.
  • Loading branch information
super-dainiu authored Sep 5, 2022
1 parent 521078f commit 7012960
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 7 deletions.
7 changes: 6 additions & 1 deletion colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

__version__ = '0.0.1'
__version__ = '0.1.9'
File renamed without changes.
2 changes: 1 addition & 1 deletion colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracer import ColoTracer
from .tracer import ColoTracer, meta_trace
from .graph_module import ColoGraphModule
5 changes: 0 additions & 5 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
Expand Down
1 change: 1 addition & 0 deletions colossalai/fx/tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .tracer import ColoTracer
from ._meta_trace import meta_trace
99 changes: 99 additions & 0 deletions colossalai/fx/tracer/_meta_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map


def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor
Args:
module (torch.nn.Module): The target module for tracing.
Returns:
graph (torch.fx.Graph): The computation graph.
Usage:
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
>>> graph.print_tabular()
"""
graph = Graph()
namespace = _Namespace()

class MetaProxy(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""

_tensor: torch.Tensor
_node: Node

__slots__ = ['_tensor', '_node']

@staticmethod
def __new__(cls, tensor, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device='cpu',
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
return r

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaProxy(x)
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x

def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node

args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)

# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x

def set_node(x):
x._node = node

out = tree_map(wrap, out)
tree_map(set_node, out)

return out

def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x

args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)

module(*args, **kwargs).sum().backward()
return graph

0 comments on commit 7012960

Please sign in to comment.