Skip to content

Commit

Permalink
[autoparallel] adapt autoparallel with new analyzer (hpcaitech#3261)
Browse files Browse the repository at this point in the history
* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
  • Loading branch information
YuliangLiu0306 authored Mar 30, 2023
1 parent e78a1e9 commit fee2af8
Show file tree
Hide file tree
Showing 36 changed files with 481 additions and 386 deletions.
5 changes: 1 addition & 4 deletions colossalai/_analyzer/_subclasses/_meta_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,7 @@ def meta_index_Tensor(self, indices):
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)

# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
Expand Down
58 changes: 53 additions & 5 deletions colossalai/_analyzer/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def _normalize_tuple(x):


def _current_device(module):
return next(module.parameters()).device
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')


@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -120,15 +123,18 @@ def _convert_meta(t: torch.Tensor):
return t.to('meta')

if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)

elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem)

else:
return elem

# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
Expand All @@ -149,7 +155,11 @@ def _convert_meta(t: torch.Tensor):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))

n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
# align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)

n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
Expand All @@ -175,10 +185,48 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st
Return
Any: The value returned by the function invocation
"""
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs)
res = self._custom_dispatch_func[target](*args, **kwargs)
else:
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res

def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args

target_method = getattr(self_obj.__class__, target)

convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return super().call_function(target, args, kwargs)
return res

def propagate(self, *args, device=None):
"""
Expand Down
114 changes: 36 additions & 78 deletions colossalai/_analyzer/fx/tracer/bias_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None):


@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv1d(input, weight, **kwargs)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))


@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv2d(input, weight, **kwargs)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))


@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv3d(input, weight, **kwargs)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))


@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose1d(input, weight, **kwargs)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))


@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose2d(input, weight, **kwargs)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))


@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose3d(input, weight, **kwargs)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))


@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
Expand Down
24 changes: 19 additions & 5 deletions colossalai/auto_parallel/meta_profiler/metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,28 @@ def target(self, target: Callable) -> None:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()

def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)

if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")

return op_data

def compute_metainfo(self):
"""
Expand Down
11 changes: 6 additions & 5 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()

def _add_hook_for_grad_communication(node, param):
def _add_hook_for_grad_communication(node, param, name=None):

comm_actions = node.best_strategy.communication_actions

def _filter_param_to_hook(node, op_data, comm_action):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
def _filter_param_to_hook(node, op_data, comm_action, name):

if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
Expand All @@ -402,7 +403,7 @@ def _filter_param_to_hook(node, op_data, comm_action):
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action):
if _filter_param_to_hook(node, operation_data, comm_action, name=name):

def wrapper(param, comm_spec, stream, overlap):

Expand Down Expand Up @@ -442,7 +443,7 @@ def _shard_param(param, target_sharding_spec):
param = _shard_param(param, target_sharding_spec)

setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param)
_add_hook_for_grad_communication(node, param, name)

sharded_buffer_dict = {}
# apply the sharding spec of buffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators

def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,6 @@ def validate(self) -> bool:
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2

if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True

def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/fx/_meta_regist_12.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor):
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
return torch.empty_like(condition + self + other, dtype=result_type)


@register_meta(aten.index.Tensor)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self

from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
Expand Down Expand Up @@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
meta_arg_names=meta_arg_names,
node_type='bias_module')

tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
Expand All @@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
# return add
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args_for_tracer.values())
# [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes)
linear_node = node_list[4]
Expand Down
Loading

0 comments on commit fee2af8

Please sign in to comment.