[Bug] mmrazor.engine.runner.quantization_loops.QATValLoop
calls after_val_epoch
hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint()
to fail with KeyError
for the save_best
config #637
Description
Describe the bug
During QAT training of models with config files publicly available, such as RTMPose-tiny, I ran into this issue where the original file has the config:
default_hooks.checkpoint.save_best='coco/AP'
This works normally in non-quantized training.
However when inheriting the _base_
in the QAT config, mmrazor.engine.runner.quantization_loops.QATValLoop
calls after_val_epoch
hook twice with different keys as seen here
def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch, self.runner.model)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
qat_metrics = dict()
for key, value in metrics.items():
qat_key = 'qat.' + key
ori_key = 'original.' + key
qat_metrics[qat_key] = value
self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None)
self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch, self.architecture)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
qat_metrics = dict()
for key, value in metrics.items():
qat_key = 'qat.' + key
ori_key = 'original.' + key
qat_metrics[ori_key] = value
self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None)
self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
self.runner.call_hook('after_val')
return qat_metrics
causing mmengine.hooks.checkpoint_hook._savebest_checkpoint()
to fail unless save_best
is overwritten with 'auto'.
edit: 'auto' still causes only the first occurrence to be set as key_indicator
from this line.
However, one might need a different setting besides 'auto' in some fringe circumstances. I suggest the code only attempts to save the best_checkpoint once, or that the prefix 'qat.' and 'original.' each be added temporarily to the key_indicators
set key and then removed each time the hook is called.
Reproduces the error - error message
Traceback (most recent call last):
...
File ".../site-packages/mmrazor/engine/runner/quantization_loops.py" line 266, in run
self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
File ".../site-packages/mmengine/runner/runner.py" line 1839, in call_hook
getattr(hook, fn_name)(self, *kwargs)
File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 361, in after_val_epoch
self._save_best_checkpoint(runner, metrics)
File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 505, in _save_best_checkpoint
key_score = metrics[key_indicator]
~~~~~~~^^^^^^^^^^^^^^^
Suggested fix
Edit: I changed the fix to prefer the qat metrics to save the best checkpoint rather than the architecture metrics.
@@ -255,12 +255,12 @@ class QATValLoop(ValLoop):
self.run_iter(idx, data_batch, self.runner.model)
# compute metrics
+ self.runner.logger.info(f'Evaluating QAT model')
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
qat_metrics = dict()
for key, value in metrics.items():
- qat_key = 'qat.' + key
ori_key = 'original.' + key
- qat_metrics[qat_key] = value
+ qat_metrics[key] = value
self.runner.message_hub.log_scalars.pop(f'val/{ori_key}, None)
self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
@@ -271,15 +271,10 @@ class QATValLoop(ValLoop):
self.run_iter(idx, data_batch, self.architecture)
# compute metrics
+ self.runner.logger.info(f'Evaluating original model architecture')
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
- qat_metrics = dict()
for key, value in metrics.items():
- qat_key = 'qat.' + key
- ori_key = 'original.' + key
- qat_metrics[ori_key] = value
- self.runner.message_hub.log_scalars.pop(f'val/{qat_key}, None)
-
- self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
+ self.runner.message_hub.log_scalars.pop(f'val/{key}, None)
self.runner.call_hook('after_val')
return qat_metrics
The only issue is that the after_val_epoch hook doesn't get called for the QAT architecture model and some of the logs are incomplete because of that.