Skip to content

[Bug] mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys, causing mmengine.hooks.checkpoint_hook._save_best_checkpoint() to fail with KeyError for the save_best config #637

Open
@elisa-aleman

Description

Describe the bug

During QAT training of models with config files publicly available, such as RTMPose-tiny, I ran into this issue where the original file has the config:

default_hooks.checkpoint.save_best='coco/AP'

This works normally in non-quantized training.

However when inheriting the _base_ in the QAT config, mmrazor.engine.runner.quantization_loops.QATValLoop calls after_val_epoch hook twice with different keys as seen here

    def run(self) -> dict:
        """Launch validation."""
        self.runner.call_hook('before_val')
        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch, self.runner.model)

        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
        qat_metrics = dict()
        for key, value in metrics.items():
            qat_key = 'qat.' + key
            ori_key = 'original.' + key
            qat_metrics[qat_key] = value
            self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None)

        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)

        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch, self.architecture)

        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
        qat_metrics = dict()
        for key, value in metrics.items():
            qat_key = 'qat.' + key
            ori_key = 'original.' + key
            qat_metrics[ori_key] = value
            self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None)

        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)

        self.runner.call_hook('after_val')
        return qat_metrics

causing mmengine.hooks.checkpoint_hook._savebest_checkpoint() to fail unless save_best is overwritten with 'auto'.

edit: 'auto' still causes only the first occurrence to be set as key_indicator from this line.

However, one might need a different setting besides 'auto' in some fringe circumstances. I suggest the code only attempts to save the best_checkpoint once, or that the prefix 'qat.' and 'original.' each be added temporarily to the key_indicators set key and then removed each time the hook is called.

Reproduces the error - error message

Traceback (most recent call last):
...
  File ".../site-packages/mmrazor/engine/runner/quantization_loops.py" line 266, in run
    self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
  File ".../site-packages/mmengine/runner/runner.py" line 1839, in call_hook
    getattr(hook, fn_name)(self, *kwargs)
  File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 361, in after_val_epoch
    self._save_best_checkpoint(runner, metrics)
  File ".../site-packages/mmengine/hooks/checkpoint_hook.py" line 505, in _save_best_checkpoint
    key_score = metrics[key_indicator]
                ~~~~~~~^^^^^^^^^^^^^^^

Suggested fix

Edit: I changed the fix to prefer the qat metrics to save the best checkpoint rather than the architecture metrics.

@@ -255,12 +255,12 @@ class QATValLoop(ValLoop):
              self.run_iter(idx, data_batch, self.runner.model)
          # compute metrics
+         self.runner.logger.info(f'Evaluating QAT model')
          metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
          qat_metrics = dict()
          for key, value in metrics.items():
-             qat_key = 'qat.' + key
              ori_key = 'original.' + key
-             qat_metrics[qat_key] = value
+             qat_metrics[key] = value
              self.runner.message_hub.log_scalars.pop(f'val/{ori_key}, None)
 
          self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
@@ -271,15 +271,10 @@ class QATValLoop(ValLoop):
              self.run_iter(idx, data_batch, self.architecture)

          # compute metrics
+         self.runner.logger.info(f'Evaluating original model architecture')
          metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
-         qat_metrics = dict()
          for key, value in metrics.items():
-             qat_key = 'qat.' + key
-             ori_key = 'original.' + key
-             qat_metrics[ori_key] = value
-             self.runner.message_hub.log_scalars.pop(f'val/{qat_key}, None)
-
-        self.runner.call_hook('after_val_epoch', metrics=qat_metrics)
+             self.runner.message_hub.log_scalars.pop(f'val/{key}, None)

          self.runner.call_hook('after_val')
          return qat_metrics            

The only issue is that the after_val_epoch hook doesn't get called for the QAT architecture model and some of the logs are incomplete because of that.

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions