Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor unittest of RuntimeInfohook #810

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactor RuntimeInfoHook
  • Loading branch information
HAOCHENYE committed Dec 10, 2022
commit 27b582543a7efe76a7301349a37763911eba86cb
142 changes: 66 additions & 76 deletions tests/test_hooks/test_runtime_info_hook.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import copy
from unittest.mock import Mock

import torch.nn as nn
from torch.optim import SGD

from mmengine.hooks import RuntimeInfoHook
from mmengine.logging import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
from mmengine.testing import RunnerTestCase


class TestRuntimeInfoHook(TestCase):
class TestRuntimeInfoHook(RunnerTestCase):

def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')

class ToyDataset:
class DatasetWithoutMetainfo:
...

runner = Mock()
runner.epoch = 7
runner.iter = 71
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
runner.train_dataloader.dataset = ToyDataset()
hook = RuntimeInfoHook()
def __len__(self):
return 12

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = DatasetWithoutMetainfo
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)

with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
message_hub.get_info('dataset_meta')
runner.message_hub.get_info('dataset_meta')

class ToyDatasetWithMeta:
class DatasetWithMetainfo(DatasetWithoutMetainfo):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
metainfo = dict()

runner.train_dataloader.dataset = ToyDatasetWithMeta()
def __len__(self):
return 12

cfg.train_dataloader.dataset.type = DatasetWithMetainfo
runner = self.build_runner(cfg)
hook.before_train(runner)
self.assertEqual(message_hub.get_info('dataset_meta'), dict())
self.assertEqual(runner.message_hub.get_info('dataset_meta'), dict())

def test_before_train_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_epoch')
runner = Mock()
runner.epoch = 9
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
runner.train_loop._epoch = 9
hook = self._get_runtime_info_hook(runner)
hook.before_train_epoch(runner)
self.assertEqual(message_hub.get_info('epoch'), 9)
self.assertEqual(runner.message_hub.get_info('epoch'), 9)

def test_before_train_iter(self):
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
# single optimizer
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
runner.optim_wrapper = optim_wrapper1
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
lr = cfg.optim_wrapper.optimizer.lr
runner = self.build_runner(cfg)
runner.train_loop._iter = 9 # set iter and build optimizer
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
hook = self._get_runtime_info_hook(runner)
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(message_hub.get_scalar('train/lr').current(), 0.01)
self.assertEqual(runner.message_hub.get_info('iter'), 9)
self.assertEqual(
runner.message_hub.get_scalar('train/lr').current(), lr)

with self.assertRaisesRegex(AssertionError,
'runner.optim_wrapper.get_lr()'):
Expand All @@ -79,49 +72,46 @@ def test_before_train_iter(self):
hook.before_train_iter(runner, batch_idx=2, data_batch=None)

# multiple optimizers
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train_iter')
runner = Mock()
runner.iter = 9
optimizer1 = Mock()
optimizer1.param_groups = [{'lr': 0.01}]
optimizer2 = Mock()
optimizer2.param_groups = [{'lr': 0.02}]
runner.message_hub = message_hub
model = nn.Linear(1, 1)
optim1 = SGD(model.parameters(), lr=0.01)
optim2 = SGD(model.parameters(), lr=0.02)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
optim_wrapper1 = OptimWrapper(optim1)
optim_wrapper2 = OptimWrapper(optim2)
optim_wrapper_dict = OptimWrapperDict(
key1=optim_wrapper1, key2=optim_wrapper2)
runner.optim_wrapper = optim_wrapper_dict
hook = RuntimeInfoHook()
hook.before_train_iter(runner, batch_idx=2, data_batch=None)
self.assertEqual(message_hub.get_info('iter'), 9)
self.assertEqual(
message_hub.get_scalar('train/key1.lr').current(), 0.01)
runner.message_hub.get_scalar('train/key1.lr').current(), 0.01)
self.assertEqual(
message_hub.get_scalar('train/key2.lr').current(), 0.02)
runner.message_hub.get_scalar('train/key2.lr').current(), 0.02)

def test_after_train_iter(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_train_iter')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_train_iter(
runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
self.assertEqual(
message_hub.get_scalar('train/loss_cls').current(), 1.111)
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)

def test_after_val_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_val_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_val_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('val/acc').current(), 0.8)
self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)

def test_after_test_epoch(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_after_test_epoch')
runner = Mock()
runner.message_hub = message_hub
hook = RuntimeInfoHook()
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.after_test_epoch(runner, metrics={'acc': 0.8})
self.assertEqual(message_hub.get_scalar('test/acc').current(), 0.8)
self.assertEqual(
runner.message_hub.get_scalar('test/acc').current(), 0.8)

def _get_runtime_info_hook(self, runner):
for hook in runner.hooks:
if isinstance(hook, RuntimeInfoHook):
return hook