Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Failed to remove the previous best checkpoints #1086

Merged
merged 5 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[Fix] Only reserve one best checkpoint
  • Loading branch information
HAOCHENYE committed Apr 19, 2023
commit 5ac94dabb82a715dbec4f983b17f9a967f8ff5ed
8 changes: 4 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,9 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
runner.message_hub.update_info(best_score_key, best_score)

if best_ckpt_path and \
self.file_client.isfile(best_ckpt_path) and \
self.file_backend.isfile(best_ckpt_path) and \
is_main_process():
self.file_client.remove(best_ckpt_path)
self.file_backend.remove(best_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} '
'is removed')
Expand All @@ -490,13 +490,13 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
# Replace illegal characters for filename with `_`
best_ckpt_name = best_ckpt_name.replace('/', '_')
if len(self.key_indicators) == 1:
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(runtime_best_ckpt_key,
self.best_ckpt_path)
else:
self.best_ckpt_path_dict[
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(
runtime_best_ckpt_key,
Expand Down
2 changes: 1 addition & 1 deletion mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def save_checkpoint(checkpoint,
file_backend = get_file_backend(
filename, backend_args=backend_args)
else:
file_backend = file_client
file_backend = file_client.client

with io.BytesIO() as f:
torch.save(checkpoint, f)
Expand Down
8 changes: 7 additions & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,7 +2191,13 @@ def save_checkpoint(
checkpoint['param_schedulers'].append(state_dict)

self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath)
if file_client_args is not None:
save_checkpoint(
checkpoint, filepath, file_client_args=file_client_args)
elif backend_args is not None:
save_checkpoint(checkpoint, filepath, backend_args=backend_args)
else:
save_checkpoint(checkpoint, filepath)

@master_only
def dump_config(self) -> None:
Expand Down
9 changes: 9 additions & 0 deletions mmengine/testing/runner_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import shutil
import tempfile
import time
from unittest import TestCase
Expand Down Expand Up @@ -184,3 +185,11 @@ def setup_dist_env(self):
os.environ['RANK'] = self.dist_cfg['RANK']
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']

def clear_work_dir(self):
for filename in os.listdir(self.temp_dir.name):
filepath = os.path.join(self.temp_dir.name, filename)
if os.path.isfile(filepath):
os.remove(filepath)
else:
shutil.rmtree(filepath)
51 changes: 51 additions & 0 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import os.path as osp
import re
import sys
from unittest.mock import MagicMock, patch

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -312,6 +314,55 @@ def test_after_val_epoch(self):
self.assertFalse(
osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))

# There should only one best checkpoint be reserved
# dist backend
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
self.clear_work_dir()
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=by_epoch, save_best='acc')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
all_files = os.listdir(runner.work_dir)
best_ckpts = [
file for file in all_files if file.startswith('best')
]
self.assertTrue(len(best_ckpts) == 1)

# petrel backend
# TODO use real petrel oss bucket to test
petrel_client = MagicMock()
sys.modules['petrel_client'] = petrel_client
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
isfile = MagicMock(return_value=True)
self.clear_work_dir()
with patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \
patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \
patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
metrics = dict(acc=0.5)
petrel_client.client.Client = MagicMock(
return_value=petrel_client)
checkpoint_hook = CheckpointHook(
interval=2,
by_epoch=by_epoch,
save_best='acc',
backend_args=dict(backend='petrel'))
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
put_mock.assert_called_once()
metrics['acc'] += 0.1
runner.train_loop._epoch += 1
runner.train_loop._iter += 1
checkpoint_hook.after_val_epoch(runner, metrics)
isfile.assert_called_once()
remove_mock.assert_called_once()
sys.modules.pop('petrel_client')

def test_after_train_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand Down