Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed load_pretrained_params in PyTorch when ignoring keys #902

Merged
merged 11 commits into from
Apr 28, 2022
Prev Previous commit
Next Next commit
refactor: Improved checks
  • Loading branch information
FG Fernandez committed Apr 27, 2022
commit 1bdedd2d34cd405a38ba77774157568b90e7ff7f
4 changes: 2 additions & 2 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def load_pretrained_params(
for key in ignore_keys:
state_dict.pop(key)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) != len(ignore_keys) or len(unexpected_keys) > 0:
raise AssertionError("unable to load state_dict")
if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0:
raise ValueError("unable to load state_dict, due to non-matching keys.")
else:
# Load weights
model.load_state_dict(state_dict)
Expand Down