Skip to content

Commit

Permalink
【New Feature】Support ulysses sequence parallism (#187)
Browse files Browse the repository at this point in the history
Add deepseed-ulysses sequence parallelism. We can enable both
`context-parallel` and `ulysses-parallel` or one of them to train long
sequence model.

you can set `ulysses_sp_parallel_size` in config file.
```
system:
  tensor_model_parallel_size: 1
  pipeline_model_parallel_size: 1
  ulysses_sp_parallel_size: 2
```

referece:
Jiarui Fang and Shangchun Zhao. 2024. USP: A Unified Sequence
Parallelism Approach for Long Context Generative AI.
https://doi.org/10.48550/arXiv.2405.07719

---------

Co-authored-by: lizhiyu01 <zyli@baai.ac.cn>
  • Loading branch information
heavyrain-lzy and lizhiyu01 authored Aug 22, 2024
1 parent 57cbd0a commit f98ee1e
Show file tree
Hide file tree
Showing 15 changed files with 608 additions and 85 deletions.
3 changes: 2 additions & 1 deletion flagscale/train/hetero/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
usp=1,
order=order,
)

Expand Down Expand Up @@ -465,7 +466,7 @@ def build_all_process_meshes(self):
expert_model_parallel_size=ep,
nccl_communicator_config_path=self._args.nccl_communicator_config_path,
distributed_timeout_minutes=self._args.distributed_timeout_minutes,
order='tp-cp-ep-dp-pp' if not self._args.use_tp_pp_dp_mapping else 'tp-pp-dp',
order='tp-usp-cp-ep-dp-pp' if not self._args.use_tp_pp_dp_mapping else 'tp-pp-dp',
offset=accumulated_world_size,
rank_mapper=self._rank_mapper,
)
Expand Down
11 changes: 9 additions & 2 deletions flagscale/train/train_aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_ulysses_sp_rank,
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
Expand Down Expand Up @@ -114,7 +115,10 @@ def get_batch(data_iterator):

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)


# slice batch along sequence dimension for ulysses sequence parallelism
batch = get_batch_on_this_ulysses_sp_rank(batch)

return batch.values()


Expand All @@ -138,6 +142,9 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

if args.ulysses_sp_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_ulysses_sp_parallel_group())

if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())

Expand All @@ -155,7 +162,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):

local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0] * args.context_parallel_size,
loss[0] * args.context_parallel_size * args.ulysses_sp_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
Expand Down
245 changes: 245 additions & 0 deletions flagscale/train/transformer/ulysses_sp_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import copy
from typing import Union, Any, Tuple
from dataclasses import dataclass

import torch
import torch.distributed

from megatron.core import parallel_state
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb


def post_all2all(input, scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
head_dim).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
return output


def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = parallel_state.get_ulysses_sp_parallel_world_size()
if batch_dim_idx == 0:
# b, s, hc, h
if scatter_idx < 2: # all_to_all for output or backward
bs, global_seq_len, num_local_head, head_dim = input.shape
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
head_dim]).contiguous()
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
else:
# s, b, hc, h
if scatter_idx < 2: # all_to_all for output or backward
global_seq_len, bs, num_local_head, head_dim = input.shape
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
head_dim]).contiguous()
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()

output = torch.empty_like(input_t)
torch.distributed.all_to_all_single(output, input_t, group=group)

if scatter_idx < 2:
res = post_all2all(output, scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
head_dim)
else:
res = post_all2all(output, scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
head_dim)
return res


class _SeqAllToAll(torch.autograd.Function):

@staticmethod
def forward(ctx: Any,
group,
input: torch.Tensor,
scatter_idx: int = 0,
gather_idx: int = 2,
batch_dim_idx: int = 1) -> torch.Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
ctx.batch_dim_idx = batch_dim_idx
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group)
return res

@staticmethod
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None, None]:

return (None,
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx),
None,None,None)



@dataclass
class USPSelfAttentionSubmodules:
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None


class USPSelfAttention(SelfAttention):
def __init__(
self,
config: TransformerConfig,
submodules: USPSelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
self.usp_size = parallel_state.get_ulysses_sp_parallel_world_size()
self.usp_group = parallel_state.get_ulysses_sp_parallel_group()
te_attn_config = copy.deepcopy(config)
assert config.num_attention_heads % self.usp_size == 0, \
f"num_attention_heads[{config.num_attention_heads}] can't be divisived by usp_size[{self.usp_size}]"
assert config.num_attention_heads % self.usp_size == 0, \
f"num_query_groups[{config.num_query_groups}] can't be divisived by usp_size[{self.usp_size}]"
te_attn_config.num_attention_heads = config.num_attention_heads // self.usp_size
te_attn_config.num_query_groups = config.num_query_groups // self.usp_size

self.core_attention = build_module(
submodules.core_attention,
config=te_attn_config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
)

def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
):
# hidden_states: [sq, b, h]

# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2

# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)

# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)

if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)

# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb

if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
)
key = apply_rotary_pos_emb(
key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
)

query = _SeqAllToAll.apply(self.usp_group, query, 2, 0)
key = _SeqAllToAll.apply(self.usp_group, key, 2, 0)
value = _SeqAllToAll.apply(self.usp_group, value, 2, 0)

# ==================================
# core attention computation
# ==================================

if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)

# ================================================
# scatter out along the sequence dimension(0) and gather along the head dimension(2)
# ================================================

core_attn_out = core_attn_out.view(query.shape)
core_attn_out = _SeqAllToAll.apply(self.usp_group, core_attn_out, 0, 2)
core_attn_out = core_attn_out.view(*core_attn_out.shape[:2], -1)

# =================
# Output. [sq, b, h]
# =================

if packed_seq_params is not None:
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)

output, bias = self.linear_proj(core_attn_out)

return output, bias
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def allocate_buffers_for_parameters(
# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(with_context_parallel=True),
parallel_state.get_data_parallel_group(with_context_parallel=True, with_ulysses_sp_parallel=True),
gradient_scaling_factor=gradient_scaling_factor,
)

Expand Down
3 changes: 3 additions & 0 deletions megatron/megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class ModelParallelConfig:

context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""

ulysses_sp_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks using deepspeed-ulysses method."""

expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb

def get_pos_emb_on_this_usp_rank(pos_emb, seq_dim):
'''
Ulysses sequence
'''
usp_size = parallel_state.get_ulysses_sp_parallel_world_size()
usp_rank = parallel_state.get_ulysses_sp_parallel_rank()
usp_idx = torch.tensor(
[usp_rank], device="cpu", pin_memory=True
).cuda(non_blocking=True)
pos_emb = pos_emb.view(
*pos_emb.shape[:seq_dim], usp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
)
pos_emb = pos_emb.index_select(seq_dim, usp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb


class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model.
Expand Down Expand Up @@ -113,6 +129,8 @@ def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
if parallel_state.get_context_parallel_world_size() > 1:
# slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
if parallel_state.get_ulysses_sp_parallel_world_size() > 1:
emb = get_pos_emb_on_this_usp_rank(emb, 0)
return emb

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
Expand Down Expand Up @@ -148,7 +166,7 @@ def get_rotary_seq_len(
if transformer_config.sequence_parallel:
rotary_seq_len *= parallel_state.get_tensor_model_parallel_world_size()

rotary_seq_len *= transformer_config.context_parallel_size
rotary_seq_len *= transformer_config.context_parallel_size * transformer_config.ulysses_sp_parallel_size

return rotary_seq_len

Expand Down
7 changes: 5 additions & 2 deletions megatron/megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from megatron.core import parallel_state
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
Expand All @@ -12,6 +13,8 @@
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

from flagscale.train.transformer.ulysses_sp_attention import USPSelfAttention

try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelGroupedLinear,
Expand Down Expand Up @@ -53,7 +56,7 @@ def get_gpt_layer_with_transformer_engine_spec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
module=SelfAttention if parallel_state.get_ulysses_sp_parallel_world_size() <= 1 else USPSelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
Expand Down Expand Up @@ -85,7 +88,7 @@ def get_gpt_layer_local_spec(
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
module=SelfAttention if parallel_state.get_ulysses_sp_parallel_world_size() <= 1 else USPSelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
Expand Down
Loading

0 comments on commit f98ee1e

Please sign in to comment.