From 1e49c9d86859a04d9e017ec722d595011ace9c49 Mon Sep 17 00:00:00 2001 From: Helen Ngo Date: Tue, 17 Dec 2024 17:40:20 -0800 Subject: [PATCH] ADLR/megatron-lm!2478 - Fix accidental inference pipelining when it should be disabled --- megatron/inference/text_generation/forward_step.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/megatron/inference/text_generation/forward_step.py b/megatron/inference/text_generation/forward_step.py index 0a89936ed2..aaa518fad4 100644 --- a/megatron/inference/text_generation/forward_step.py +++ b/megatron/inference/text_generation/forward_step.py @@ -6,11 +6,10 @@ import torch +from megatron.core import InferenceParams, mpu from megatron.training import get_args -from megatron.core import mpu, InferenceParams -from .communication import ( - send_to_next_pipeline_rank, - recv_from_prev_pipeline_rank_) + +from .communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank class ForwardStep: @@ -46,7 +45,7 @@ def __call__(self, tokens, position_ids, attention_mask, recv_buffer_seq_length= # This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold # and requires setting args.pipeline_model_parallel > 1. The batch will be split into # smaller microbatches to be pipelined through the stages. - if self.pipeline_size_larger_than_one: + if self.pipeline_size_larger_than_one and self.pipelining_batch_x_seqlen != -1: seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length current_batch_x_seqlen = tokens.size(0) * seq_len if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: