[Bug] (suggested fix) mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams()
fails if there are modules present in other modes but not in forward mode='tensor'
#634
Description
Describe the bug
In models where theres modules that exist only in mode 'predict' or in 'loss' but not in 'tensor', the following code fails with a KeyError
looking through the state dict of the tensor mode model. For example, if one model has duplicates but the other doesn't.
mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_params()
#L124--L148
def traverse(module, prefix):
for name, child in module._modules.items():
if module is None:
continue
child_name = f'{prefix}{name}'
if isinstance(child, FakeQuantizeBase):
for name, param in child.named_parameters():
param_name = f'{child_name}.{name}'
src_param = src_state_dict[param_name] ## Here
if src_param.shape == param.shape:
param.data.copy_(src_param)
else:
requirs_grad = param.requires_grad
param.requires_grad = False
param.resize_(src_param.shape)
param.requires_grad = requirs_grad
param.data.copy_(src_param)
for name, buffer in child.named_buffers():
buffer_name = f'{child_name}.{name}'
src_buffer = src_state_dict[buffer_name] # here
if src_buffer.shape == buffer.shape:
buffer.data.copy_(src_buffer)
else:
buffer.resize_(src_buffer.shape)
buffer.data.copy_(src_buffer)
Additional Context
I have been trying to quantize the mmpose.TopdownPoseEstimator, applying fixes for torch 2.0.0 incompatibility suggested in mmrazor #632, a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633, and a fix on mmpose.TopdownPoseEstimator untraceable methods in mmpose #3012.
Because of a flip
input inversion test being added to the predict forward graph, not only are there duplicate modules but also duplicate loose (leaf) activation_post_process_xyz
numbered modules that make the syncing fail.
Reproduces the error - code sample
I cannot currently provide the configuration, but the executing code is this:
from mmrazor.models.algorithms.quantization.mm_architechture import MMArchitectureQuant
from mmengine import Config
cfg = Config.fromfile('qat_rtmpose-t_8xb256-420e_coco-256x192.py')
qtopdown = MMArchitectureQuant(
data_preprocessor=cfg.data_preprocessor,
architecture=cfg.architecture,
quantizer=cfg.model.quantizer,
input_shapes=cfg.model.input_shapes
)
Reproduces the problem - error message
Traceback (most recent call last):
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
self.sync_qparams('tensor')
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
.....redacted
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 143, in traverse
src_buffer = src_state_dict[buffer_name]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
KeyError: 'backbone.stem.0.conv_dup1.weight_fake_quant.fake_quant_enabled'
And while patching that:
Traceback (most recent call last):
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
self.sync_qparams('tensor')
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
.....redacted
File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 157, in traverse
raise KeyError(f"{buffer_name} in mode '{mode}' but not found in source mode '{tensor}', sync_qparams() failed.")
KeyError: "activation_post_process_123.fake_quant_enabled in mode 'predict' but not found in source mode 'tensor', sync_qparams() failed."
Post related information - suggested fix
*EDIT: while this fix allows for syncing of nodes that aren't in other modes, it causes failure in model deployment later down the line
For duplicate modules i figure one can copy the state_dict
element with a non-suffixed name, but I don't have a suggestion for non existent modules yet.
For activation post processing leaf nodes, I can ignore most of the copying since a lot of it is reset in MMArchitectureQuant.__init__()
.
mmrazor/models/algorithms/quantization/mm_architecture.py
@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
in some subtle ways, so we need to sync them here.
"""
- def traverse(module, prefix):
+ def traverse(module, prefix, mode, src_mode):
for name, child in module._modules.items():
if module is None:
continue
@@ -129,7 +129,14 @@ class MMArchitectureQuant(BaseAlgorithm):
if isinstance(child, FakeQuantizeBase):
for name, param in child.named_parameters():
param_name = f'{child_name}.{name}'
- src_param = src_state_dict[param_name]
+ src_param = src_state_dict.get(param_name)
+ if '_dup' in param_name and src_param is None:
+ param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+ src_param = src_state_dict.get(param_name)
+ if src_param is None:
+ print(src_state_dict)
+ print(child)
+ raise KeyError(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
if src_param.shape == param.shape:
param.data.copy_(src_param)
else:
@@ -138,22 +145,42 @@ class MMArchitectureQuant(BaseAlgorithm):
param.resize_(src_param.shape)
param.requires_grad = requirs_grad
param.data.copy_(src_param)
+ # These are either reset after sync_qparams() is called, or are left as default (eps)
+ # so there's no need to sync them if there's not a match
+ skip_buffer_sync = [
+ "fake_quant_enabled",
+ "observer_enabled",
+ "scale",
+ "zero_point",
+ "min_val",
+ "max_val",
+ "eps",
+ ]
for name, buffer in child.named_buffers():
buffer_name = f'{child_name}.{name}'
- src_buffer = src_state_dict[buffer_name]
+ src_buffer = src_state_dict.get(buffer_name)
+ if '_dup' in buffer_name and src_buffer is None:
+ buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+ src_buffer = src_state_dict.get(buffer_name)
+ if any([s in buffer_name for s in skip_buffer_sync]) and src_buffer is None:
+ continue
+ src_buffer = torch.tensor([1], dtype=torch.uint8)
+ if src_buffer is None:
+ print(src_state_dict)
+ print(child)
+ raise KeyError(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
if src_buffer.shape == buffer.shape:
buffer.data.copy_(src_buffer)
else:
buffer.resize_(src_buffer.shape)
buffer.data.copy_(src_buffer)
else:
- traverse(child, f'{child_name}.')
+ traverse(child, f'{child_name}.', mode, src_mode)
src_state_dict = self.qmodels[src_mode].state_dict()
for mode in self.forward_modes:
if mode == src_mode:
continue
- traverse(self.qmodels[mode], '')
+ traverse(self.qmodels[mode], '', mode, src_mode)
def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
"""Get rewriter context in mmdeploy according to the deploy related
Activity
elisa-aleman commentedon Apr 10, 2024
Added more context and suggested a fix
elisa-aleman commentedon Apr 18, 2024
After trying to deploy the quantized model, I realized the suggested fix might be unnecessary and cause further issues since the
mmdeploy/tools/deploy.py
will forcemodel.architecture.test_cfg.flip_test=False
for pose estimators, which means that there would be extra weights in the quantized state_dict and cause the model deploy to fail.I then tried:
python /tools/train.py \ ${qat_topdown_cgf} \ --cgf-options \ model.architecture.test_cfg.flip_test=False \ --work-dir /path/here/
But the model still fails to sync without my patch.
elisa-aleman commentedon Apr 22, 2024
I realized that the
sync_qparams()
is also called from the loss mode as a source mode during the training loop, so my previous fix actually removes any progress during training. I suggest this new fix that doesn't reset fake weight values if not found, although I've yet to finish deploying this model and so it's subject to changes.mmrazor.models.algorithms.mm_architecture.MMArchitectureQuant.get_deploy_model()
fails ifpredict
mode lacks nodes from themodel.quantizer.tracer.skipped_methods
configuration, but the architecturequantizer.prepare(fp32_model)
has these nodes. #642elisa-aleman commentedon Aug 11, 2024
After some fixing, the solution to this issue is to refactor the model so that all FX tracing is possible on all modes up until wrapped methods that differ in each mode. as long as the only difference in tracing is after the
.forward()
method, the syncing won't fail.