Skip to content

Commit

Permalink
Merge branch 'skierat/finalize_local_checkpointing' into 'main'
Browse files Browse the repository at this point in the history
Finalize local checkpointing support

See merge request ADLR/megatron-lm!2217
  • Loading branch information
ericharper committed Jan 30, 2025
2 parents fb591c7 + cef5154 commit 2b4b680
Show file tree
Hide file tree
Showing 21 changed files with 1,135 additions and 492 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.ci.lts
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ RUN --mount=type=secret,id=JET_INDEX_URLS \
JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \
pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS
ENV PATH="$PATH:/opt/jet/bin"
###
###
148 changes: 71 additions & 77 deletions megatron/core/dist_checkpointing/exchange_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast

import numpy as np
Expand All @@ -15,7 +14,7 @@
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId
from .utils import _sharded_tensor_shard_id, _ShardId, debug_time

# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
Expand Down Expand Up @@ -52,7 +51,6 @@ class ShardDistribution(NamedTuple):
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""

main_rank_for_shard: Dict[_ShardId, int]
Expand Down Expand Up @@ -237,6 +235,7 @@ def determine_main_replica_uniform_distribution(


@torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
Expand Down Expand Up @@ -276,76 +275,75 @@ def exchange_loaded_tensors_gather_rounds(
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):

start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)

# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)

# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten

round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device

torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)

# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten

round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device

torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)

del round_tensors # remove tensor references
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)

end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
del round_tensors # remove tensor references

return all_loaded_tensors

Expand Down Expand Up @@ -397,6 +395,7 @@ def exchange_loaded_tensors_gather_object(


@torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
Expand Down Expand Up @@ -427,8 +426,6 @@ def exchange_loaded_tensors_broadcast(

all_loaded_tensors = dict(loaded_tensors)

start = time()

for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
Expand Down Expand Up @@ -475,17 +472,13 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten

end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')

return all_loaded_tensors


def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
Expand All @@ -508,6 +501,7 @@ def exchange_by_distribution(
previously loaded tensors (from `loaded_tensors` input)
"""

assert shard_distribution is not None, 'Expecting distribution to perform exchange'
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
StateDict,
apply_factory_merges,
)
from .state_dict_transformation import load_preprocess, save_preprocess
from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
Expand Down
Loading

0 comments on commit 2b4b680

Please sign in to comment.