Skip to content

Commit

Permalink
Merge 79c8d73 into d6f41bc
Browse files Browse the repository at this point in the history
  • Loading branch information
songyuc authored Nov 18, 2022
2 parents d6f41bc + 79c8d73 commit d1b3614
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
34 changes: 34 additions & 0 deletions mmengine/hooks/hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union

from mmengine import is_method_overridden

DATA_BATCH = Optional[Union[dict, tuple, list]]


Expand All @@ -11,6 +13,13 @@ class Hook:
"""

priority = 'NORMAL'
stages = ('before_run', 'after_load_checkpoint', 'before_train',
'before_train_epoch', 'before_train_iter', 'after_train_iter',
'after_train_epoch', 'before_val', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_val', 'before_save_checkpoint', 'after_train',
'before_test', 'before_test_epoch', 'before_test_iter',
'after_test_iter', 'after_test_epoch', 'after_test', 'after_run')

def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
Expand Down Expand Up @@ -416,3 +425,28 @@ def is_last_train_iter(self, runner) -> bool:
bool: Whether current iteration is the last train iteration.
"""
return runner.iter + 1 == runner.max_iters

def get_triggered_stages(self) -> list:
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)

# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'_before_epoch':
['before_train_epoch', 'before_val_epoch', 'before_test_epoch'],
'_after_epoch':
['after_train_epoch', 'after_val_epoch', 'after_test_epoch'],
'_before_iter':
['before_train_iter', 'before_val_iter', 'before_test_iter'],
'_after_iter':
['after_train_iter', 'after_val_iter', 'after_test_iter'],
}

for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)

return list(trigger_stages)
26 changes: 26 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def __init__(
self._hooks: List[Hook] = []
# register hooks to `self._hooks`
self.register_hooks(default_hooks, custom_hooks)
# log hooks information
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')

# dump `cfg` to `work_dir`
self.dump_config()
Expand Down Expand Up @@ -1576,6 +1579,29 @@ def build_log_processor(

return log_processor # type: ignore

def get_hooks_info(self) -> str:
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)

stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)

def load_or_resume(self) -> None:
"""load or resume checkpoint."""
if self._has_loaded:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,3 +2306,19 @@ def test_build_runner(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_runner2'
assert isinstance(RUNNERS.build(cfg), Runner)

def test_get_hooks_info(self):
# test get_hooks_info() function
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_get_hooks_info_from_test_runner_py'
cfg.runner_type = 'Runner'
runner = RUNNERS.build(cfg)
self.assertIsInstance(runner, Runner)
target_str = ('after_train_iter:\n'
'(VERY_HIGH ) RuntimeInfoHook \n'
'(NORMAL ) IterTimerHook \n'
'(BELOW_NORMAL) LoggerHook \n'
'(LOW ) ParamSchedulerHook \n'
'(VERY_LOW ) CheckpointHook \n')
self.assertIn(target_str, runner.get_hooks_info(),
'target string is not in logged hooks information.')

0 comments on commit d1b3614

Please sign in to comment.