Skip to content

Commit

Permalink
Support loading distributed checkpoints for FSDP in Trainer (Lightnin…
Browse files Browse the repository at this point in the history
…g-AI#18358)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Aug 23, 2023
1 parent 1867bd7 commit fc6f43f
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 70 deletions.
12 changes: 11 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added validation of user input for `devices` and `num_nodes` when running with `SLURM` or `TorchElastic` ([#18292](https://github.com/Lightning-AI/lightning/pull/18292))


- Added support for saving sharded checkpoints with FSDP via `FSDPStrategy(state_dict_type="sharded")` ([#18364](https://github.com/Lightning-AI/lightning/pull/18364))
- Added support for saving checkpoints with either full state-dict or sharded state dict via `FSDPStrategy(state_dict_type="full"|"sharded")` ([#18364](https://github.com/Lightning-AI/lightning/pull/18364))


- Added support for loading sharded/distributed checkpoints in FSDP ([#18358](https://github.com/Lightning-AI/lightning/pull/18358))


### Changed
Expand Down Expand Up @@ -152,6 +155,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device ([#18275](https://github.com/Lightning-AI/lightning/pull/18275))


- The `FSDPStrategy` now loads checkpoints after the `configure_model`/`configure_sharded_model` hook ([#18358](https://github.com/Lightning-AI/lightning/pull/18358))


- The `FSDPStrategy.load_optimizer_state_dict` and `FSDPStrategy.load_model_state_dict` are a no-op now ([#18358](https://github.com/Lightning-AI/lightning/pull/18358))


### Deprecated

- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None:
load(self.lightning_module, prefix="")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
# Override to do nothing, the deepspeed engine already loaded the states in `load_checkpoint()`
pass

@classmethod
Expand Down
139 changes: 104 additions & 35 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
_get_sharded_state_dict_context,
_init_cpu_offload,
_init_sharding_strategy,
_is_full_checkpoint,
_is_sharded_checkpoint,
_load_raw_module_state,
_METADATA_FILENAME,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
Expand Down Expand Up @@ -221,6 +224,14 @@ def mixed_precision_config(self) -> Optional["MixedPrecision"]:
def distributed_sampler_kwargs(self) -> Dict:
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}

@property
def restore_checkpoint_after_setup(self) -> bool:
return True

@property
def lightning_restore_optimizer(self) -> bool:
return False

def setup_environment(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
Expand Down Expand Up @@ -451,6 +462,10 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:
with state_dict_ctx:
return self.model.state_dict()

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# Override to do nothing, FSDP already loaded the states in `load_checkpoint()`
pass

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
if not _TORCH_GREATER_EQUAL_2_0:
rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support saving the optimizer state.")
Expand Down Expand Up @@ -478,40 +493,8 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if not _TORCH_GREATER_EQUAL_2_0:
rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support loading the optimizer state.")
return

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType

optimizer_states = checkpoint.get("optimizer_states")

# If the optimizer states are not present, we don't need to do anything (backward compatibility)
if optimizer_states is None:
return

if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"
" of optimizers or edit the checkpoint manually to remove states."
)

assert self.model is not None

# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(self.model, world_size=self.world_size, rank0_only=False):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
# convert the optimizer state to the format expected by FSDP
opt_state = FSDP.rekey_optim_state_dict(opt_state, OptimStateKeyType.PARAM_NAME, self.model)

opt_state = FSDP.optim_state_dict_to_load(
optim_state_dict=opt_state,
model=self.model,
optim=optimizer,
)
optimizer.load_state_dict(opt_state)
# Override to do nothing, the FSDP already loaded the states in `load_checkpoint()`
pass

def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
Expand Down Expand Up @@ -542,5 +525,91 @@ def save_checkpoint(

if self.global_rank == 0:
torch.save(checkpoint, path / _METADATA_FILENAME)
else:
elif self._state_dict_type == "full":
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(checkpoint_path))

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

assert self.model is not None
assert self.lightning_module is not None

if _is_sharded_checkpoint(path):
from torch.distributed.checkpoint import FileSystemReader, load_state_dict
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict

state_dict_ctx = _get_sharded_state_dict_context(self.model)
reader = FileSystemReader(path=path)

with state_dict_ctx:
module_state = {"model": self.model.state_dict()}
load_state_dict(module_state, reader)
self.model.load_state_dict(module_state["model"])

if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
# the optimizer states must be loaded separately
for idx, optim in enumerate(self.optimizers):
optim_key = f"optimizer_{idx}"
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state["model"],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=self.model,
optim=optim,
)
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
return metadata

if _is_full_checkpoint(path):
# TODO: Support lazy-loading here (see Fabric)
checkpoint = torch.load(path, map_location="cpu")
_load_raw_module_state(checkpoint["state_dict"], world_size=self.world_size, module=self.model)

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType

optimizer_states = checkpoint.get("optimizer_states")
if optimizer_states is None or self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
# If the optimizer states are not present, we don't need to do anything (backward compatibility)
return checkpoint
if not _TORCH_GREATER_EQUAL_2_0:
rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support loading the optimizer state.")
return checkpoint
if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"
" of optimizers or edit the checkpoint manually to remove states."
)

# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(self.model, world_size=self.world_size, rank0_only=False):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
if isinstance(list(opt_state["state"].keys())[0], int):
# Handling the case where the optimizer state is saved from a normal optimizer
opt_state = FSDP.rekey_optim_state_dict(opt_state, OptimStateKeyType.PARAM_NAME, self.model)

opt_state = FSDP.optim_state_dict_to_load(
optim_state_dict=opt_state,
model=self.model,
optim=optimizer,
)
optimizer.load_state_dict(opt_state)

return checkpoint

raise ValueError(
f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"
" directory with FSDP checkpoint shards, or a single file with a full checkpoint."
)
Loading

0 comments on commit fc6f43f

Please sign in to comment.