Skip to content

Commit 72341e6

Browse files
authoredJan 20, 2023
[auto-chunk] support extramsa (#3) (#2504)
1 parent 0f02b8c commit 72341e6

8 files changed

+283
-54
lines changed
 

‎colossalai/autochunk/estimate_memory.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66

77
from colossalai.fx.profiler import activation_size, parameter_size
88

9-
from .utils import (
10-
delete_free_var_from_last_use,
11-
find_idx_by_name,
12-
get_node_shape,
13-
is_non_compute_node_except_placeholder,
14-
)
9+
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
1510

1611

1712
class EstimateMemory(object):
@@ -240,7 +235,7 @@ def estimate_chunk_inference_mem(
240235
elif node.op == "output":
241236
continue
242237
# no change for non compute node
243-
elif is_non_compute_node_except_placeholder(node):
238+
elif is_non_memory_node(node):
244239
act_memory_peak_log.append(act_memory)
245240
# node is a compute op
246241
# calculate tmp, output node and delete node memory

‎colossalai/autochunk/trace_flow.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,34 @@ def check_index_duplicate(self, chunk_infos, return_dim=False):
118118

119119
def _assgin_single_node_flow(
120120
self,
121-
arg_node,
122-
start_idx,
123-
end_idx,
124-
cur_node_dim,
125-
cur_node_compute,
126-
cur_node_source,
127-
cur_node_fix_dim,
128-
all_node_info,
129-
next_node_list,
130-
):
121+
arg_node: Node,
122+
start_idx: int,
123+
end_idx: int,
124+
cur_node_dim: int,
125+
cur_node_compute: Dict,
126+
cur_node_source: Dict,
127+
cur_node_fix_dim: List,
128+
all_node_info: Dict,
129+
next_node_list: List,
130+
) -> bool:
131+
"""
132+
Given the current node and one of its arg node,
133+
this function finds out arg node's chunk dim and fix dim
134+
135+
Args:
136+
arg_node (Node): input node
137+
start_idx (int): chunk region start
138+
end_idx (int): chunk region end
139+
cur_node_dim (int): current node chunk dim
140+
cur_node_compute (Dict): current node compute dict
141+
cur_node_source (Dict): current node source dict
142+
cur_node_fix_dim (List): current node fix dim
143+
all_node_info (Dict): all node chunk info in the chunk region
144+
next_node_list (List)
145+
146+
Returns:
147+
bool: True if this node can be added to the flow, vice versa.
148+
"""
131149
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
132150
# arg in chunk range or be inputs
133151
if not (start_idx <= arg_idx < end_idx):
@@ -142,6 +160,9 @@ def _assgin_single_node_flow(
142160
arg_dim = None
143161
else:
144162
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
163+
# chunk dim should be None if shape size is 1
164+
if get_node_shape(arg_node)[arg_dim] == 1:
165+
arg_dim = None
145166
else:
146167
arg_dim = None
147168

@@ -184,7 +205,7 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
184205

185206
# get all valid args
186207
arg_list = []
187-
for arg in cur_node.args:
208+
for arg in cur_node.all_input_nodes:
188209
if type(arg) != type(cur_node):
189210
continue
190211
if is_non_compute_node(arg):

‎colossalai/autochunk/trace_indice.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,38 @@ def _assign_ones_like_indice(self, node: Node, node_idx: int):
432432
"""
433433
self._assign_all_indice(node, node_idx)
434434

435+
def _assign_cat_indice(self, node: Node, node_idx: int):
436+
"""
437+
Assign indice for cat op.
438+
439+
Args:
440+
node (node)
441+
node_idx (int)
442+
"""
443+
nodes_in = flat_list(node.args[0])
444+
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
445+
for n in nodes_in[1:]:
446+
self._mark_computation_from_node(n, node)
447+
cat_dim = node.kwargs["dim"]
448+
self._del_dim(node_idx, cat_dim)
449+
self._add_dim(node_idx, cat_dim)
450+
451+
def _assign_sum_indice(self, node: Node, node_idx: int):
452+
"""
453+
Assign indice for sum op.
454+
455+
Args:
456+
node (node)
457+
node_idx (int)
458+
"""
459+
nodes_in = flat_list(node.args[0])
460+
self._add_dim(node_idx, 0)
461+
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
462+
for n in nodes_in[1:]:
463+
self._mark_computation_from_node(n, node)
464+
cat_dim = node.kwargs["dim"]
465+
self._del_dim(node_idx, cat_dim)
466+
435467
def _assign_getitem_indice(self, node: Node, node_idx: int):
436468
"""
437469
Assign indice for getitem.
@@ -442,7 +474,16 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
442474
node_idx (int)
443475
"""
444476
node_args = flat_list(node.args[1:])
445-
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
477+
flag = False
478+
for node_arg in node_args:
479+
node_arg_str = str(node_arg)
480+
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
481+
flag = True
482+
break
483+
if "slice" in node_arg_str:
484+
flag = True
485+
break
486+
if flag == False:
446487
return
447488

448489
# node args should be like [Ellipsis, slice(start, step, end), None]
@@ -461,8 +502,11 @@ def _assign_getitem_indice(self, node: Node, node_idx: int):
461502
shape_gap = len(node_shape) - len(node_args) + 1
462503
origin_idx_count += shape_gap
463504
new_idx_count += shape_gap
464-
# slice(None, None, None) means all indexes, doesn't support other slice
465-
elif "slice(None, None, None)" == node_arg_str:
505+
# slice(None, None, None) means all indexes
506+
elif "slice" in node_arg_str:
507+
if "slice(None, None, None)" != node_arg_str:
508+
self._del_dim(node_idx, new_idx_count)
509+
self._add_dim(node_idx, new_idx_count)
466510
origin_idx_count += 1
467511
new_idx_count += 1
468512
# None means a new dim
@@ -565,7 +609,7 @@ def trace_indice(self):
565609
self._assign_view_reshape_indice(node, idx)
566610
elif "unsqueeze" in node.name:
567611
self._assign_unsqueeze_indice(node, idx)
568-
elif any(i in node.name for i in ["to", "contiguous"]):
612+
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
569613
self._assgin_no_change_indice(node, idx)
570614
elif "new_ones" in node.name:
571615
self._assign_ones_like_indice(node, idx)
@@ -574,6 +618,8 @@ def trace_indice(self):
574618
elif node.op == "call_function":
575619
if "linear" in node.name:
576620
self._assign_linear_indice(node, idx)
621+
elif "cat" in node.name:
622+
self._assign_cat_indice(node, idx)
577623
elif "matmul" in node.name:
578624
self._assign_matmul_indice(node, idx)
579625
elif "softmax" in node.name:
@@ -586,6 +632,8 @@ def trace_indice(self):
586632
self._assign_dropout_indice(node, idx)
587633
elif "einsum" in node.name:
588634
self._assign_einsum_indice(node, idx)
635+
elif "sum" in node.name:
636+
self._assign_sum_indice(node, idx)
589637
elif "layer_norm" in node.name:
590638
self._assign_layernorm_indice(node, idx)
591639
elif "getitem" in node.name:

‎colossalai/autochunk/utils.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from torch.fx.node import Node
44

55

6-
def flat_list(inputs):
6+
def flat_list(inputs: Any) -> List:
77
"""
88
flat a list by recursion
99
"""
10+
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
11+
return [inputs]
1012
res = []
1113
for i in inputs:
1214
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
@@ -16,7 +18,7 @@ def flat_list(inputs):
1618
return res
1719

1820

19-
def find_first_tensor_arg(node):
21+
def find_first_tensor_arg(node: Node) -> Node:
2022
"""
2123
Find the first input tensor arg for a node
2224
"""
@@ -26,24 +28,34 @@ def find_first_tensor_arg(node):
2628
raise RuntimeError()
2729

2830

29-
def is_non_compute_node(node):
31+
def is_non_compute_node(node: Node) -> bool:
3032
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
3133
return True
3234
if "getitem" in node.name:
3335
node_args = flat_list(node.args[1:])
3436
for node_arg in node_args:
3537
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
3638
return False
39+
if "slice" in str(node_arg):
40+
return False
3741
return True
3842
return False
3943

4044

41-
def get_node_shape(node):
45+
def get_node_shape(node: Node) -> List:
4246
if hasattr(node.meta["tensor_meta"], "shape"):
4347
return node.meta["tensor_meta"].shape
4448
return None
4549

4650

51+
def is_non_memory_node(node: Node) -> bool:
52+
if "getitem" in node.name:
53+
return True
54+
if "output" in node.op:
55+
return True
56+
return is_non_compute_node(node)
57+
58+
4759
def is_non_compute_node_except_placeholder(node):
4860
if "placeholder" in node.op:
4961
return False

‎tests/test_autochunk/test_evoformer_codegen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
130130
},
131131
)
132132
graph.set_codegen(codegen)
133-
gm = ColoGraphModule(model, graph)
133+
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
134134
gm.recompile()
135135

136136
# assert we have inserted chunk

0 commit comments

Comments
 (0)