Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't save Modules in hparams #915

Closed
wants to merge 1 commit into from

Conversation

KnathanM
Copy link
Member

Description

Related to #898, this PR address the fact that we double save nn.Modules in both the hparams and in the state_dict. This PR follows David's idea here.

The main downside to this PR is a lot more code is required to implement this versus continuing to double save. The main upside is that model files will be a little smaller.

Other notes

metrics in MPNN is a list of Modules so the whole list is pickled and unpickled to save and load. This is different than all the other Modules like criterion which are rebuilt each time a model is loaded.

The default task_weights for criterion is an array, so I don't need to check if there are task_weights in the state_dict. For the other cases, the default is None/Identity so in those cases, I need to check if their are corresponding parameters in the state_dict and if so build a Module with zeros.

While working on this PR, I found that for multicomponent message passing, if a single block is shared between two components, the weights for that block appear twice in the state_dict. I don't know if the weights just have two key values, or if they are truly saved twice in the model file. Not sure if this is something we should try to change.

Comment on lines +243 to +244
.squeeze(0)
.numpy()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ScaleTransformMixin unsqueezes and converts to a torch tensor. If I don't include .numpy() here, I get this warning: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). scale = torch.cat([torch.ones(pad), torch.tensor(scale, dtype=torch.float)])

@davidegraff
Copy link
Contributor

woof- on second thought, i don't think my idea was a good one. I don't think it's that much of a problem to double-save like 32 floating point numbers

@KnathanM
Copy link
Member Author

Okay, sounds like we will go with #898 for saving the scalers.

Questions:
I added some unit tests here. Are those worth still merging in? They test if the scalers are saved and loaded correctly, which is independent of the mechanics how we save and load them.

I reorganized some of load_from_checkpoint and load_from_file by having the hyperparameters and state dict loaded in load_submodules. This allowed removing a separate load_from_file in MulticomponentMPNN. Is this worth merging in?

@KnathanM KnathanM mentioned this pull request Jul 5, 2024
@KnathanM
Copy link
Member Author

KnathanM commented Jul 5, 2024

I moved the tests to #955

I decided to not move the changes to load_from_file to another PR because I think having MPNN and MulticomponentMPNN each have their own load_from_file may help people debug in the future.

@KnathanM KnathanM closed this Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants