Skip to content

Commit

Permalink
ADLR/megatron-lm!2547 - Standardize NCCL option passing in Megatron Core
Browse files Browse the repository at this point in the history
  • Loading branch information
afarjallah-nv authored and ko3n1g committed Jan 31, 2025
1 parent f8887ce commit b4076c7
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def get_nccl_options(pg_name, nccl_comm_cfgs):
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
if 'net_name' in nccl_comm_cfgs[pg_name]:
nccl_options.config.net_name = nccl_comm_cfgs[pg_name].get('net_name')
# verify net_name value
if nccl_options.config.net_name.lower() not in ['ib', 'socket']:
raise RuntimeError(
f"net_name ({nccl_options.config.net_name}) is not supported."
f"Accepted values: 'IB' or 'socket'."
)
return nccl_options
else:
return None
Expand Down Expand Up @@ -799,7 +807,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
intra_partial_data_parallel_group_with_cp = create_group(
intra_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
pg_options=get_nccl_options('intra_dp_cp', nccl_comm_cfgs),
group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
)
intra_partial_data_parallel_group_with_cp_gloo = create_group(
Expand All @@ -825,7 +833,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
inter_partial_data_parallel_group_with_cp = create_group(
inter_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
pg_options=get_nccl_options('inter_dp_cp', nccl_comm_cfgs),
group_desc='INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
)

Expand Down Expand Up @@ -878,7 +886,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
ranks,
context_parallel_size,
hierarchical_context_parallel_sizes,
get_nccl_options('cp', nccl_comm_cfgs),
get_nccl_options('hcp', nccl_comm_cfgs),
)

# Build the model-parallel groups.
Expand Down Expand Up @@ -959,7 +967,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
group = create_group(
position_embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('embd', nccl_comm_cfgs),
pg_options=get_nccl_options('pos_embd', nccl_comm_cfgs),
group_desc='POSITION_EMBEDDING_GROUP',
)
if rank in position_embedding_ranks:
Expand Down Expand Up @@ -1012,7 +1020,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
for ranks in generator_wrapper('ep', is_expert=True):
group = create_group(
ranks,
pg_options=get_nccl_options('exp', nccl_comm_cfgs),
pg_options=get_nccl_options('ep', nccl_comm_cfgs),
group_desc='EXPERT_MODEL_PARALLEL_GROUP',
)
if rank in ranks:
Expand All @@ -1027,7 +1035,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp', nccl_comm_cfgs),
pg_options=get_nccl_options('ep_tp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_PARALLEL_GROUP',
)
if rank in ranks:
Expand All @@ -1042,7 +1050,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs),
pg_options=get_nccl_options('tp_ep_mp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP',
)
if rank in ranks:
Expand All @@ -1057,7 +1065,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('mp', nccl_comm_cfgs),
pg_options=get_nccl_options('tp_ep_pp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP',
)
if rank in ranks:
Expand All @@ -1073,7 +1081,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('dp', nccl_comm_cfgs),
pg_options=get_nccl_options('ep_dp', nccl_comm_cfgs),
group_desc='EXPERT_DATA_PARALLEL_GROUP',
)
group_gloo = create_group(
Expand Down

0 comments on commit b4076c7

Please sign in to comment.