Skip to content

Commit

Permalink
[NFC] polish _checkpoint_hook.py code style (hpcaitech#1722)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gy-Lu authored and FrankLeeeee committed Oct 19, 2022
1 parent b38efe4 commit 730f88f
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions colossalai/trainer/hooks/_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,23 @@ def after_hook_is_attached(self, trainer):
break
self.model = self.model if self.model is not None else trainer.engine.model


def after_train_iter(self, trainer, output, label, loss):
"""Saves the model after a training iter.
"""
# save by interval
if self.save_by_iter and trainer.cur_step % self.interval == 0:
save_checkpoint(self.checkpoint_dir,
trainer.cur_epoch,
self.model,
trainer.engine.optimizer,
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', ranks=[0])
self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}',
ranks=[0])
else:
pass


def after_train_epoch(self, trainer):
"""Saves the model after a training epoch.
"""
# save by interval
if trainer.cur_epoch % self.interval == 0:
save_checkpoint(self.checkpoint_dir,
trainer.cur_epoch,
self.model,
trainer.engine.optimizer,
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])

0 comments on commit 730f88f

Please sign in to comment.