-
Notifications
You must be signed in to change notification settings - Fork 463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Fixed load_pretrained_params in PyTorch when ignoring keys #902
Conversation
# test pretrained model with different num_classes | ||
model = classification.__dict__[arch_name](pretrained=True, num_classes=108).eval() | ||
_test_classification(model, input_shape, output_size=(108,)) | ||
# Check that you can pretrained everything up until the last layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better: Test whether a pretrained model can be initialized down to the last layer ? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand?
Compared to previous features, my idea is that we can already assume:
- models are built correctly
- pretrained with default output classes is working
What needs testing here is that the checkpoint can be loaded on a model with different num_classes. So perhaps we could check that the bias of a given layer is indeed the one of the state_dict, but I don't see much more :)
(and I try to keep unittests from running slower)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frgfm sorry I just mean to change this comment :
# Check that you can pretrained everything up until the last layer
maybe better / clearer:
# Test whether a pretrained model can be initialized down to the last layer
😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frgfm and about the unitests i think my tests from: https://github.com/mindee/doctr/pull/892/files are currently the biggest slowdown but i think we need to test each model
@frgfm Thanks for refactoring this now it is more dynamically ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the refacto! Only a small missing import it seems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Following up on #874, this PR introduces the following modifications:
load_pretrained_params
to avoid non-strict loading of wrongly sized state_dictignore_keys
mechanism (the keys to ignore were hardcoded in the factory function, while this function is used to build models of different size, and thus with Linear layers being named differently)ignore_keys
Any feedback is welcome!