Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor hooks unittest #946

Merged
merged 7 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions mmengine/hooks/ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def __init__(self,
assert not (begin_iter != 0 and begin_epoch != 0), (
'`begin_iter` and `begin_epoch` should not be both set.')
assert begin_iter >= 0, (
f'begin_iter must larger than 0, but got begin: {begin_iter}')
'`begin_iter` must larger than or equal to 0, '
f'but got begin_iter: {begin_iter}')
assert begin_epoch >= 0, (
f'begin_epoch must larger than 0, but got begin: {begin_epoch}')
'`begin_epoch` must larger than or equal to 0, '
f'but got begin_epoch: {begin_epoch}')
self.begin_iter = begin_iter
self.begin_epoch = begin_epoch
# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
Expand Down Expand Up @@ -80,12 +82,14 @@ def before_train(self, runner) -> None:
"""
if self.enabled_by_epoch:
assert self.begin_epoch <= runner.max_epochs, (
'self.begin_epoch should be smaller than runner.max_epochs: '
f'{runner.max_epochs}, but got begin: {self.begin_epoch}')
'self.begin_epoch should be smaller than or equal to '
f'runner.max_epochs: {runner.max_epochs}, but got '
f'begin_epoch: {self.begin_epoch}')
else:
assert self.begin_iter <= runner.max_iters, (
'self.begin_iter should be smaller than runner.max_iters: '
f'{runner.max_iters}, but got begin: {self.begin_iter}')
'self.begin_iter should be smaller than or equal to '
f'runner.max_iters: {runner.max_iters}, but got '
f'begin_iter: {self.begin_iter}')

def after_train_iter(self,
runner,
Expand Down
19 changes: 5 additions & 14 deletions mmengine/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def before_run(self, runner) -> None:
runner (Runner): The runner of the training, validation or testing
process.
"""
pass

def after_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -39,7 +38,6 @@ def after_run(self, runner) -> None:
runner (Runner): The runner of the training, validation or testing
process.
"""
pass

def before_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -48,7 +46,6 @@ def before_train(self, runner) -> None:
Args:
runner (Runner): The runner of the training process.
"""
pass

def after_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -57,7 +54,6 @@ def after_train(self, runner) -> None:
Args:
runner (Runner): The runner of the training process.
"""
pass

def before_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -66,7 +62,6 @@ def before_val(self, runner) -> None:
Args:
runner (Runner): The runner of the validation process.
"""
pass

def after_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -75,7 +70,6 @@ def after_val(self, runner) -> None:
Args:
runner (Runner): The runner of the validation process.
"""
pass

def before_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -84,7 +78,6 @@ def before_test(self, runner) -> None:
Args:
runner (Runner): The runner of the testing process.
"""
pass

def after_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -93,7 +86,6 @@ def after_test(self, runner) -> None:
Args:
runner (Runner): The runner of the testing process.
"""
pass

def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -104,7 +96,6 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
process.
checkpoint (dict): Model's checkpoint.
"""
pass

def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -115,7 +106,6 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
process.
checkpoint (dict): Model's checkpoint.
"""
pass

def before_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand Down Expand Up @@ -300,7 +290,6 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass

def _after_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
Expand All @@ -311,7 +300,6 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass

def _before_iter(self,
runner,
Expand All @@ -328,7 +316,6 @@ def _before_iter(self,
data_batch (dict or tuple or list, optional): Data from dataloader.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass

def _after_iter(self,
runner,
Expand All @@ -347,7 +334,6 @@ def _after_iter(self,
outputs (dict or Sequence, optional): Outputs from model.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass

def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether current epoch can be evenly divided by n.
Expand Down Expand Up @@ -427,6 +413,11 @@ def is_last_train_iter(self, runner) -> bool:
return runner.iter + 1 == runner.max_iters

def get_triggered_stages(self) -> list:
"""Get all triggered stages with method name of the hook.

Returns:
list: List of triggered stages.
"""
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
Expand Down
Loading