Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor RuntimeInfoHook
Browse files Browse the repository at this point in the history
HAOCHENYE committed Dec 10, 2022
1 parent b0d046a commit db9fae5
Showing 1 changed file with 66 additions and 76 deletions.
142 changes: 66 additions & 76 deletions tests/test_hooks/test_runtime_info_hook.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import copy
from unittest.mock import Mock

import torch.nn as nn
from torch.optim import SGD

from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
from mmengine.testing import RunnerTestCase


class TestRuntimeInfoHook(TestCase):
class TestRuntimeInfoHook(RunnerTestCase):

def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')

class ToyDataset:
class DatasetWithoutMetainfo:
...

runner = Mock()
runner.epoch = 7
runner.iter = 71
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
runner.train_dataloader.dataset = ToyDataset()
hook = RuntimeInfoHook()
def __len__(self):
return 12

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)

with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
message_hub.get_info('dataset_meta')
runner.message_hub.get_info('dataset_meta')

class ToyDatasetWithMeta:
class DatasetWithMetainfo(DatasetWithoutMetainfo):
metainfo = dict()

runner.train_dataloader.dataset = ToyDatasetWithMeta()
def __len__(self):
return 12

cfg.train_dataloader.dataset.type = DatasetWithMetainfo
runner = self.build_runner(cfg)
hook.before_train(runner)
self.assertEqual(message_hub.get_info('dataset_meta'), dict())
self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())

def test_before_train_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_epoch')
runner = Mock()
runner.epoch = 9
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
runner.train_loop._epoch = 9
hook = self._get_runtime_info_hook(runner)
hook.before_train_epoch(runner)
self.assertEqual(message_hub.get_info('epoch'), 9)
self.assertEqual(runner.message_hub.get_info('epoch'), 9)

def test_before_train_iter(self):
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
# single optimizer
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
runner.optim_wrapper = optim_wrapper1
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
lr = cfg.optim_wrapper.optimizer.lr
runner = self.build_runner(cfg)
runner.train_loop._iter = 9 # set iter and build optimizer
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
hook = self._get_runtime_info_hook(runner)
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
self.assertEqual(runner.message_hub.get_info('iter'), 9)
self.assertEqual(
runner.message_hub.get_scalar('train/lr').current(), lr)

with self.assertRaisesRegex(AssertionError,
'runner.optim_wrapper.get_lr()'):
@@ -79,49 +72,46 @@ def test_before_train_iter(self):
hook.before_train_iter(runner, batch_idx=2, data_batch=None)

# multiple optimizers
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
optimizer1 = Mock()
optimizer1.param_groups = [{'lr': 0.01}]
optimizer2 = Mock()
optimizer2.param_groups = [{'lr': 0.02}]
runner.message_hub = message_hub
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
runner.optim_wrapper = optim_wrapper_dict
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(
message_hub.get_scalar('train/key1.lr').current(), 0.01)
runner.message_hub.get_scalar('train/key1.lr').current(), 0.01)
self.assertEqual(
message_hub.get_scalar('train/key2.lr').current(), 0.02)
runner.message_hub.get_scalar('train/key2.lr').current(), 0.02)

def test_after_train_iter(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_train_iter')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_train_iter(
runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
self.assertEqual(
message_hub.get_scalar('train/loss_cls').current(), 1.111)
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)

def test_after_val_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_val_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_val_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)

def test_after_test_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_test_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_test_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)
self.assertEqual(
runner.message_hub.get_scalar('test/acc').current(), 0.8)

def _get_runtime_info_hook(self, runner):
for hook in runner.hooks:
if isinstance(hook, RuntimeInfoHook):
return hook

0 comments on commit db9fae5

Please sign in to comment.