Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add reshape handler v2 and fix some previous bug #1683

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[autoparallel] add reshape handler v2 and fix some previous bug
  • Loading branch information
YuliangLiu0306 committed Oct 10, 2022
commit 9c86e39616ab0601ec7cc949d6c52d85e80d9eaf
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
import operator

__all__ = ['ReshapeHandler']


@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.permute)
class ReshapeHandler(NodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "output": physical_output}

return mapping
7 changes: 4 additions & 3 deletions colossalai/auto_parallel/solver/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .layer_norm_generator import LayerNormGenerator
from .reshape_generator import ReshapeGenerator

__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator'
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator'
]
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def validate(self) -> bool:
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_compute_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the computation cost per device with this specific strategy.

Expand All @@ -62,9 +62,9 @@ def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy_V2):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def validate(self) -> bool:
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_compute_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the computation cost per device with this specific strategy.

Expand Down Expand Up @@ -67,9 +67,9 @@ def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
total_compute_cost = forward_compute_cost + backward_compute_cost

compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
def update_memory_cost(self, strategy: ShardingStrategy_V2):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def has_bias(self):
def validate(self) -> bool:
return super().validate()

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
return TrainCycleItem(fwd=10, bwd=10, total=20)
def update_compute_cost(self, strategy: ShardingStrategy_V2):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
Expand Down Expand Up @@ -59,7 +60,6 @@ def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
return super().update_memory_cost(strategy)


class TensorStrategyGenerator(GetItemStrategyGenerator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def has_bias(self):
def validate(self) -> bool:
return super().validate()

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_compute_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the computation cost per device with this specific strategy.

Expand Down Expand Up @@ -52,9 +52,9 @@ def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
Expand Down Expand Up @@ -103,6 +103,9 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}

other_comm_spec = self.get_communication_spec(
Expand Down
100 changes: 100 additions & 0 deletions colossalai/auto_parallel/solver/strategy/reshape_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
import copy

__all__ = ['ReshapeGenerator']


class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
"""

def validate(self) -> bool:
return super().validate()

def update_compute_cost(self, strategy: ShardingStrategy_V2):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}

backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)

# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)

# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost

def generate(self):
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
if isinstance(self.op_data["output"].data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'

total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]

input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["input"] = input_comm_spec
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)

for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)

return strategy_list
13 changes: 10 additions & 3 deletions colossalai/auto_parallel/solver/strategy/strategy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,16 @@ def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
sharding_spec = []
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape,
dim_partition_dict=dim_partition_dict_element)
else:
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
results[op_data_name] = sharding_spec
return results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
def validate(self) -> bool:
return super().validate()

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
return TrainCycleItem(fwd=10, bwd=10, total=20)
def update_compute_cost(self, strategy: ShardingStrategy_V2):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
Expand Down Expand Up @@ -49,7 +50,6 @@ def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
return super().update_memory_cost(strategy)

def generate(self):
strategy_list = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh


class ReshapeModel(nn.Module):

def __init__(self):
super().__init__()

def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other)
reshape_node = conv_node.view(2, -1)
return reshape_node


def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'),
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)

mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3]
reshape_strategies_vector = StrategiesVector(reshape_node)
conv_strategies_vector = StrategiesVector(conv_mod_node)

# build handler
conv_handler = ConvFunctionHandler(node=conv_mod_node,
device_mesh=device_mesh,
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy()
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
reshape_handler = ReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)

reshape_handler.register_strategy()

# check operation data mapping
mapping = reshape_handler.get_operation_data_mapping()

for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None

assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])

assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 30752])
assert mapping['output'].type == OperationDataType.OUTPUT

# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(reshape_strategies_vector) == len(conv_strategies_vector)


if __name__ == '__main__':
test_reshape_handler()