Skip to content

Commit

Permalink
fix (#6588)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Aug 2, 2023
1 parent caaf09d commit 38a2be0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion llm/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

from paddlenlp.transformers import PretrainedModel, ScatterOp
from paddlenlp.transformers import PretrainedModel
from paddlenlp.transformers.llama.modeling import (
LlamaConfig,
LlamaDecoderLayer,
Expand Down Expand Up @@ -94,6 +94,8 @@ def forward(self, args):
input_ids, attention_mask, position_ids = parse_args(args)
input_embeds = self.embed_tokens(input_ids)
if self.sequence_parallel:
from paddlenlp.transformers import ScatterOp

input_embeds = paddle.transpose(input_embeds, perm=[1, 0, 2])
input_embeds = ScatterOp.apply(input_embeds)

Expand Down

0 comments on commit 38a2be0

Please sign in to comment.