Skip to content

Commit

Permalink
Support hybrid_parallel_topo_order for auto parallel Llama (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#8011)

* Support hybrid_parallel_topo_order for auto parallel Llama

* Set order in hybrid_configs

* Update get_mesh_with_dim

* Update loss for CI baseline

* Fix CI errors

* Update loss

* Update loss

* Update loss

* Update loss

* Update loss
  • Loading branch information
From00 authored Mar 8, 2024
1 parent 93aa4bc commit b504a73
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 45 deletions.
5 changes: 5 additions & 0 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,18 @@ def init_seed(seed: int = 1234, args=None):
else:
assert not args.use_hybrid_parallel and args.enable_auto_parallel
if dist.get_world_size() > 1:
if args.hybrid_parallel_topo_order is None or args.hybrid_parallel_topo_order == "pp_first":
order = ["pp", "dp", "sharding", "mp", "sep"]
elif args.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "sharding", "pp", "mp", "sep"]
topo = Topology(
dist.get_rank(),
dist.get_world_size(),
dp_degree=args.data_parallel_degree,
pp_degree=args.pipeline_parallel_degree,
mp_degree=args.tensor_parallel_degree,
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
order=order,
)

global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo)
Expand Down
5 changes: 5 additions & 0 deletions llm/llama/auto_parallel/run_pretrain_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,18 @@ def init_seed(seed: int = 1234, args=None):
else:
assert not args.use_hybrid_parallel and args.enable_auto_parallel
if dist.get_world_size() > 1:
if args.hybrid_parallel_topo_order is None or args.hybrid_parallel_topo_order == "pp_first":
order = ["pp", "dp", "sharding", "mp", "sep"]
elif args.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "sharding", "pp", "mp", "sep"]
topo = Topology(
dist.get_rank(),
dist.get_world_size(),
dp_degree=args.data_parallel_degree,
pp_degree=args.pipeline_parallel_degree,
mp_degree=args.tensor_parallel_degree,
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
order=order,
)

global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo)
Expand Down
71 changes: 42 additions & 29 deletions paddlenlp/ops/distributed/utils/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,55 @@

class Topology:
def __init__(
self, device_rank, world_size, dp_degree=None, pp_degree=1, sharding_degree=1, mp_degree=1, sep_degree=1
self,
device_rank,
world_size,
dp_degree=None,
pp_degree=1,
sharding_degree=1,
mp_degree=1,
sep_degree=1,
order=["dp", "pp", "sharding", "mp", "sep"],
):
arr = np.arange(0, dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree).reshape(
[dp_degree, pp_degree, sharding_degree, mp_degree, sep_degree]
)

dp_rank, pp_rank, sharding_rank, mp_rank, sep_rank = np.where(arr == device_rank)
dp_rank = dp_rank[0]
pp_rank = pp_rank[0]
sharding_rank = sharding_rank[0]
mp_rank = mp_rank[0]
sep_rank = sep_rank[0]

self.world = GroupInfo(size=world_size, rank=device_rank, world=list(range(0, world_size)))
assert set(order) == {"dp", "pp", "sharding", "mp", "sep"}, f"Illegal order : {order}"
self.order = order

sep_world = arr[dp_rank, pp_rank, sharding_rank, mp_rank, :]
self.sep_info = GroupInfo(size=len(sep_world), rank=sep_rank, world=sep_world.tolist())
degree_map = {
"dp": dp_degree,
"pp": pp_degree,
"sharding": sharding_degree,
"mp": mp_degree,
"sep": sep_degree,
}
shape = [degree_map[key] for key in self.order]

mp_world = arr[dp_rank, pp_rank, sharding_rank, :, sep_rank]
self.mp_info = GroupInfo(size=len(mp_world), rank=mp_rank, world=mp_world.tolist())
arr = np.arange(0, dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree).reshape(shape)
ranks = [rank[0] for rank in np.where(arr == device_rank)]

sharding_world = arr[dp_rank, pp_rank, :, mp_rank, sep_rank]
self.sharding_info = GroupInfo(size=len(sharding_world), rank=sharding_rank, world=sharding_world.tolist())

pp_world = arr[dp_rank, :, sharding_rank, mp_rank, sep_rank]
self.pp_info = GroupInfo(size=len(pp_world), rank=pp_rank, world=pp_world.tolist())

dp_world = arr[:, pp_rank, sharding_rank, mp_rank, sep_rank]
self.dp_info = GroupInfo(size=len(dp_world), rank=dp_rank, world=dp_world.tolist())
self.world = GroupInfo(size=world_size, rank=device_rank, world=list(range(0, world_size)))
worlds = []
for i in range(len(ranks)):
indexs = tuple(ranks[:i] + [slice(None)] + ranks[(i + 1) :])
worlds.append(arr[indexs])

for i, key in enumerate(self.order):
if key == "dp":
self.dp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
elif key == "pp":
self.pp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
elif key == "sharding":
self.sharding_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
elif key == "mp":
self.mp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
elif key == "sep":
self.sep_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())

self.is_last = self.pp_info.rank == self.pp_info.size - 1

data_arr = np.arange(0, dp_degree * sharding_degree).reshape([dp_degree, sharding_degree])
data_arr = np.expand_dims(data_arr, axis=1).repeat(pp_degree, axis=1)
data_arr = np.expand_dims(data_arr, axis=3).repeat(mp_degree, axis=3)
data_arr = np.expand_dims(data_arr, axis=4).repeat(sep_degree, axis=4)
for i, key in enumerate(self.order):
if key != "dp" and key != "sharding":
data_arr = np.expand_dims(data_arr, axis=i).repeat(degree_map[key], axis=i)

self.data_info = GroupInfo(
size=int(self.dp_info.size * self.sharding_info.size),
Expand All @@ -68,4 +81,4 @@ def __init__(
self.data_inner_times = self.world.size // self.data_info.size

def __repr__(self):
return f"dp_info:\n\t {self.dp_info}, \npp_info:\n\t {self.pp_info}, \nsharding_info:\n\t {self.sharding_info}, \nmp_info:\n\t {self.mp_info}, \nsep_info:\n\t {self.sep_info}\ndata_info:\n\t {self.data_info}"
return f"dp_info:\n\t {self.dp_info}, \npp_info:\n\t {self.pp_info}, \nsharding_info:\n\t {self.sharding_info}, \nmp_info:\n\t {self.mp_info}, \nsep_info:\n\t {self.sep_info}, \ndata_info:\n\t {self.data_info}, \norder:\n\t {self.order}"
22 changes: 16 additions & 6 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,10 @@ def __post_init__(self):
self.pipeline_parallel_degree = -1
self.sep_parallel_degree = -1

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]

if self.use_hybrid_parallel and self.enable_auto_parallel:
self.use_hybrid_parallel = False

Expand Down Expand Up @@ -1058,10 +1062,6 @@ def __post_init__(self):
"by current version of Paddle. Please try latest develop Paddle."
)

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]

def is_segment_parallel_supported():
import inspect

Expand Down Expand Up @@ -1317,17 +1317,27 @@ def is_segment_parallel_supported():
recompute.refined_ops_patterns.append(eval(pattern))

self.strategy = strategy
order = ["dp", "pp", "mp"]
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
if self.hybrid_parallel_topo_order == "pp_first":
order = ["pp", "dp", "mp"]
degree = [self.pipeline_parallel_degree, self.data_parallel_degree, self.tensor_parallel_degree]
elif self.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "pp", "mp"]
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
mesh_dims = list(zip(order, degree))
fleet.auto.create_mesh(mesh_dims)

# init hcg for communication in trainer
if self.hybrid_parallel_topo_order == "pp_first":
order = ["pp", "dp", "sharding", "sep", "mp"]
elif self.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "sharding", "pp", "sep", "mp"]

strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"mp_degree": self.tensor_parallel_degree,
"pp_degree": self.pipeline_parallel_degree,
"order": order,
}
fleet.init(is_collective=True, strategy=strategy)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def swiglu(x, y=None):
def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp")[pp_idx]
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh


Expand Down
18 changes: 9 additions & 9 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.42011845
loss_base=9.42011833
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1031,7 +1031,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.44299495
loss_base=9.44299471
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1093,7 +1093,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
--data_impl "mmap" \
--enable_auto_parallel 1 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.2 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
Expand Down Expand Up @@ -1161,7 +1161,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
--data_impl "mmap" \
--enable_auto_parallel 1 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
Expand Down Expand Up @@ -1230,7 +1230,7 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
--data_impl "mmap" \
--enable_auto_parallel 1 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
Expand Down Expand Up @@ -1301,7 +1301,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.52781677
loss_base=9.53389835
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1368,7 +1368,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.40659046
loss_base=9.39066124
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1435,7 +1435,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38319206
loss_base=9.38235474
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1503,7 +1503,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38341904
loss_base=9.38257694
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down

0 comments on commit b504a73

Please sign in to comment.