Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] support meta tracing for aten level computation graphs like functorch. #1536

Merged
merged 5 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

__version__ = '0.0.1'
__version__ = '0.1.9'
YuliangLiu0306 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracer import ColoTracer
from .tracer import ColoTracer, meta_trace
from .graph_module import ColoGraphModule
5 changes: 0 additions & 5 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
Expand Down
1 change: 1 addition & 0 deletions colossalai/fx/tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .tracer import ColoTracer
from ._meta_trace import meta_trace
93 changes: 93 additions & 0 deletions colossalai/fx/tracer/_meta_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map


def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor

Args:
module (torch.nn.Module): The target module for tracing.

Returns:
graph (torch.fx.Graph): The computation graph.
"""
graph = Graph()
namespace = _Namespace()

class MetaProxy(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""

_tensor: torch.Tensor
_node: Node

__slots__ = ['_tensor', '_node']

@staticmethod
def __new__(cls, tensor, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device='cpu',
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
return r

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaProxy(x)
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x

def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node

args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)

# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x

def set_node(x):
x._node = node

out = tree_map(wrap, out)
tree_map(set_node, out)

return out

def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x

args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)

module(*args, **kwargs).sum().backward()
return graph