Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly double save scalers/criterion #898

Merged
merged 7 commits into from
Jun 18, 2024

Conversation

KnathanM
Copy link
Member

@KnathanM KnathanM commented Jun 2, 2024

Summary

Currently some small tensors saved twice in a chemprop model file, both in hyper_parameters and in state_dict. I think there isn't a clean way to not double save them and the small tensors are small enough to not worry about, so the best thing to do is to continue double saving, suppress any warnings, and be explicit that is what we are doing.

The warning can be suppressed by adding the small nn.Modules arguments to self.save_hyperparameters(ignore=...) and then manually adding them to hparams. #832 started this for criterion and output_transform of _FFNPredictorBase.

The warning looks something like "Attribute '...' is an instance of nn.Module and is already saved during checkpointing. It is recommended to ignore them using self.save_hyperparameters(ignore=['...'])"

Background:

All of this is related to the question of "How do we reload a previous model?". Closely related is the question "How do we save a model so that it can be reloaded?". There are two parts of models that have to be saved and loaded, the architecture and the "numbers" (weights, biases, etc.).

MPNN is a lightning.pytorch.LightningModule which inherits from both lightning.pytorch.core.mixins.HyperparametersMixin and torch.nn.Module. HyperparametersMixin takes care of saving the model architecture (by saving the input arguments when creating the model) and Module takes care of saving the "numbers".

I'll discuss three representative attributes of MPNN that are each saved and loaded differently.

message_passing

architecture

HyperparametersMixin has a method save_hyperparameters which will add all the arguments passed to __init__ to self.hparams. In the case of message_passing we already include it in the self.save_hyperparameters(ignore=...) ignore list because message_passing is an object that contains all the weights that are already saved in the state_dict. save_hyperparameters would just pickle the whole object leading to the weights being saved in both hparam and in state_dict.

We want to just save the instructions for how to make a new message_passing module with the same architecture as the saved version. It will be created with random weights, but then those will be overwritten with the correct weights when the state_dict is loaded. To do this _MessagePassingBase also inherits from HyperparametersMixin and saves its own hparams. This hparams dict is then used as the value for the message_passing key in MPNN's hparams.

weights

Module has an overloaded __setattr__ which affects the behavior setting attributes. When we do self.message_passing = message_passing, the message passing object gets registered as a submodule of MPNN (because it is also a Module) and the state_dict for MPNN will include the parameters of message_passing. (I think that is done here.) message_passing will also appear in the output of print(model).

To then load an MPNN, the message_passing will first be created with random weights using the hparams and then the state_dict is loaded which overwrites the message_passing weights with the correct ones.

metrics

The metrics given to MPNN is not a Module. It is a list whose elements are Metrics which are Modules. This means that self.metrics = metrics does not register metrics as a submodule of MPNN so the numbers in the individual Metric state_dicts are not added to MPNNs state_dict. This also means that metrics does not appear in print(model) which seems reasonable to me. metrics is not included in the ignore list of save_hyperparameters so the whole thing is pickled and reloaded when saving and loading a model.

X_d_transform

We want X_d_transform to be a submodule of MPNN so that it appears in print(model) among other things. We do this by setting it as an attribute (that also includes a check if it is None) self.X_d_transform = X_d_transform if X_d_transform is not None else nn.Identity(). If X_d_transformis notNone, it's state_dict(which includes the scale and mean) will be added toMPNNs state_dict`.

The question is, do we include the whole X_d_transform object in hparams of MPNN or not? I think we must because load_from_file of MPNN doesn't know to initialize a new X_d_transform like it did with message_passing. But if we do include the whole object, the scale and mean of X_d_transform will be duplicated in both hparams and in state_dict.

Our options are

  1. Duplicate and suppress warnings. I favor this because the duplicated arrays are small.
  2. Make _ScaleTransformMixin more like _MessagePassingBase by having it inherit from HyperparametersMixin and save it's input arguments in hparams. This seems more complicated to me.
  3. Exclude scale and mean from the state_dict by not registering them as buffers. I remember that it was a concious decision to include them in the state_dict so this doesn't seem like a good solution.

These same options and questions apply to the other inputs that get this warning from lightning including: "V_d_transform", "graph_transform", "criterion", and "output_transform"

Other note

One potential issue with #832 that I changed is _FFNPredictorBase expects output_transform to be an UnscaleTransform or None, but output_transform is saved after converting Nones to nn.Identity. If we ever added some additional logic to check if output_transform is input as None, this would break that.

@KnathanM KnathanM requested a review from davidegraff June 2, 2024 03:49
@KnathanM KnathanM added this to the v2.0.x milestone Jun 4, 2024
@davidegraff
Copy link
Contributor

Perhaps a dumb question: but do we need to save it as a hyperparameter? If the attribute is registered as a module, it seems that the torch.load() mechanism will take care of everything for us because the underlying tensors in the *Scale classes are registered into the state dict.

@KnathanM
Copy link
Member Author

Perhaps a dumb question: but do we need to save it as a hyperparameter? If the attribute is registered as a module, it seems that the torch.load() mechanism will take care of everything for us because the underlying tensors in the *Scale classes are registered into the state dict.

I commented out the line self.hparams["X_d_transform"] = X_d_transform in models.model.py and as you suggested this code works:

from chemprop import data, models, nn
import torch

t = nn.transforms.ScaleTransform([0, 1, 2, 3, 4], [0, 1, 2, 3, 4])
chemprop_model = models.MPNN(
    nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN(), X_d_transform=t
    )

torch.save(chemprop_model,"mymodel.pt")
samemodel = torch.load("mymodel.pt")

I believe this works because the whole model object is saved and reloaded. Our save_model and load_model functions in models.utils.py don't do this though. They save and load a dictionary containing the hparams and state_dict to rebuild the model instead of reloading the model. Rebuilding an MPNN requires that the X_d_transform exists first. We can't rebuild X_d_transform as it doesn't have hparams, so we have to reload it by having the whole transform object saved in the dictionary.

I don't remember why we opted for saving a dictionary of hparams and state_dict instead of saving the whole model. I also don't know what the implications of changing this are. But it could certainly be something worth discussing.

@davidegraff
Copy link
Contributor

I don't remember why we opted for saving a dictionary of hparams and state_dict instead of saving the whole model. I also don't know what the implications of changing this are. But it could certainly be something worth discussing.

Saving an entire model via torch.save() is not recommended, so that's why we opted to separate learnable parameters from architecture in our serialization format. Though I'll agree that our decision still leaves something to be desired in ease-of-use and portability.

It seems that, upon inspection of the underlying *Transform sources, any initialized transform can be used to load in another's state dict. So perhaps the solution here is to:

  1. inspect if the loaded state dictionary contains a key 'X_d_transform' (same goes for the input graph transform inside the message_passing attribute
  2. if so, build a Transform with dummy input tensors because these will just be overwritten by the load_state_dict() method later on.

@KnathanM
Copy link
Member Author

I followed your suggestion and opened it as PR #915. I am unsure which approach (double saving, or checking the state_dict) is better, but we can now compare the two.

@KnathanM
Copy link
Member Author

Looks like double saving is preferred over David's suggestion here. So this PR is ready for a review and merge. I also updated the notebooks that have the warning to remove it.

Copy link
Contributor

@shihchengli shihchengli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +75 to 78
self.hparams["X_d_transform"] = X_d_transform
self.hparams.update(
{
"message_passing": message_passing.hparams,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.hparams["X_d_transform"] = X_d_transform
self.hparams.update(
{
"message_passing": message_passing.hparams,
self.hparams.update(
{
"X_d_transform": X_d_transform,
"message_passing": message_passing.hparams,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. I also thought about including X_d_transform in the update, but decided it would be better to do separately to match what is done in message passing base with V_d_transform for example. Having it separate also helps show that it doesn't have its own hparams like message passing does. I'll go ahead and set this to merge because you approved it. If you want to talk about this suggested change more, you could open another PR with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! How about not using hparams.update at all? I don't think it's necessary to separate these two parts here, as it doesn't provide any new context. If you think it's worth mentioning that the message passing has its own hparams, it would be clearer to add a comment in the code instead of using a different style to separate them.

@KnathanM KnathanM enabled auto-merge (squash) June 18, 2024 17:40
@KnathanM KnathanM merged commit dae0343 into chemprop:main Jun 18, 2024
13 checks passed
@KnathanM KnathanM deleted the stop-warning-double-save branch October 7, 2024 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants