Skip to content

Commit

Permalink
Merge branch 'helenn-pipeline-parallel-fix-flash-decode' into 'main'
Browse files Browse the repository at this point in the history
Fix accidental inference pipelining when it should be disabled

See merge request ADLR/megatron-lm!2478
  • Loading branch information
jaredcasper committed Dec 18, 2024
2 parents d995e9c + 1e49c9d commit 584e4f9
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions megatron/inference/text_generation/forward_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

import torch

from megatron.core import InferenceParams, mpu
from megatron.training import get_args
from megatron.core import mpu, InferenceParams
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)

from .communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank


class ForwardStep:
Expand Down Expand Up @@ -46,7 +45,7 @@ def __call__(self, tokens, position_ids, attention_mask, recv_buffer_seq_length=
# This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold
# and requires setting args.pipeline_model_parallel > 1. The batch will be split into
# smaller microbatches to be pipelined through the stages.
if self.pipeline_size_larger_than_one:
if self.pipeline_size_larger_than_one and self.pipelining_batch_x_seqlen != -1:
seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length
current_batch_x_seqlen = tokens.size(0) * seq_len
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
Expand Down

0 comments on commit 584e4f9

Please sign in to comment.