Skip to content

Commit

Permalink
[model checkpoint] updated checkpoint hook (hpcaitech#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurisusnowdeng authored Apr 1, 2022
1 parent 77ad24b commit 28b515d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 92 deletions.
8 changes: 4 additions & 4 deletions colossalai/trainer/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import LoadCheckpointHook, SaveCheckpointHook
from ._checkpoint_hook import SaveCheckpointHook
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
TensorboardHook)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook

__all__ = [
'BaseHook', 'MetricHook', 'LoadCheckpointHook', 'SaveCheckpointHook', 'LossHook', 'AccuracyHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook',
'ThroughputHook', 'LogMetricByStepHook'
'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook',
'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook',
'SaveCheckpointHook'
]
98 changes: 10 additions & 88 deletions colossalai/trainer/hooks/_checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os.path as osp
from colossalai.logging import get_dist_logger

from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook


Expand All @@ -17,9 +14,8 @@ class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
Args:
interval (int, optional): Saving interval, defaults to 1.
checkpoint_dir (str, optional): Directory of saving checkpoint, defaults to None.
suffix (str, optional): Saving suffix of the file, defaults to ''.
interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1.
checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
Expand All @@ -28,19 +24,17 @@ class SaveCheckpointHook(BaseHook):
def __init__(self,
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 10):
super().__init__(priority=priority)
self.interval = interval
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
self.logger = get_dist_logger()

# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None

def after_hook_is_attached(self, trainer):
# check if lr scheduler is present in LRSchedulerHook
# get lr scheduler if exists
for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
Expand All @@ -51,82 +45,10 @@ def after_train_epoch(self, trainer):
"""
# save by interval
if trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
save_path = get_checkpoint_path(self.checkpoint_dir,
trainer.cur_epoch,
suffix=self.suffix)

save_checkpoint(save_path,
trainer.cur_epoch,
trainer.engine.model,
trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])


@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
"""Loads the model before training process.
Args:
checkpoint_dir (str, optional): Directory of saving checkpoint, defaults to None.
epoch (str, optional): Loading checkpoint of setting epoch numbers, defaults to -1.
Epoch equals to -1 means choosing the latest checkpoint.
finetune (bool, optional): Whether allows to load a part of the model, defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint
match the names of parameters and buffers in model, defaults to False.
suffix (str, optional): Suffix of checkpoint file path, defaults to ''.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 0. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""

def __init__(self,
checkpoint_dir: str = None,
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
suffix: str = '',
priority: int = 0) -> None:
super().__init__(priority=priority)
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.suffix = suffix
self.strict = strict
self.logger = get_dist_logger()

def before_train(self, trainer):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break

# use latest checkpoint if epoch = -1
if self.epoch == -1:
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
else:
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)

if osp.exists(path):
last_epoch, _ = load_checkpoint(path,
trainer.engine.model,
trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
trainer.cur_epoch = 0
else:
trainer.cur_epoch = last_epoch

save_checkpoint(self.checkpoint_dir,
trainer.cur_epoch,
trainer.engine.model,
trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'loaded checkpoint from {path}', ranks=[0])
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')
f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])

0 comments on commit 28b515d

Please sign in to comment.