Skip to content

Commit

Permalink
Merge branch 'fix_dataloader_save' into 'main'
Browse files Browse the repository at this point in the history
Fix dataloader save state

See merge request ADLR/megatron-lm!2580
  • Loading branch information
jaredcasper committed Jan 28, 2025
2 parents 5cf351f + ba8231f commit 3d1554d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir)

# Save dataloader state if the dataloader supports it (currently only Megatron Energon).
save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))
maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))

# Save distributed optimizer's custom parameter state.
if (
Expand Down Expand Up @@ -562,7 +562,7 @@ def remove_iter_ckpts(_iter_ckpts):
remove_iter_ckpts(rm_iter_ckpts)


def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path):
"""Saves dataloader state if the dataloader supports it.
Currently, this is only used by Megatron Energon dataloader (multimodal) to store its state at a
Expand All @@ -577,13 +577,13 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
iteration (int): Current iteration.
dataloader_save_path (str): Path where the dataloader state is saved.
"""
# If no dataloader or saving path is provided, then exit early.
if train_iterator is None or dataloader_save_path is None:
# If no dataloader or saving path is provided, exit early, otherwise, raise an error.
if train_iterator is None or dataloader_save_path is None or dataloader_save_path == "":
return

# If dataloader doesn't support saving state, exit early.
if not hasattr(train_iterator, "save_state"):
return
# If dataloader doesn't support saving state, raise an error.
if not hasattr(train_iterator.iterable, "save_state"):
raise RuntimeError(f"Could not find a save_state for the train_iterator of type {type(train_iterator)}")

# Save dataloader state for each data parallel rank only once.
first_rank = mpu.is_pipeline_first_stage(ignore_virtual=True) and mpu.get_tensor_model_parallel_rank() == 0
Expand All @@ -592,7 +592,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):

dp_rank = mpu.get_data_parallel_rank()
print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}")
train_dataloader_state_dict = train_iterator.save_state()
train_dataloader_state_dict = train_iterator.iterable.save_state()
data_state_save_path = get_checkpoint_name(
dataloader_save_path, iteration,
basename=f'train_dataloader_dprank{dp_rank:03d}.pt'
Expand Down

0 comments on commit 3d1554d

Please sign in to comment.