-
Notifications
You must be signed in to change notification settings - Fork 599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicitly double save scalers/criterion #898
Conversation
Perhaps a dumb question: but do we need to save it as a hyperparameter? If the attribute is registered as a module, it seems that the |
I commented out the line
I believe this works because the whole model object is saved and reloaded. Our I don't remember why we opted for saving a dictionary of hparams and state_dict instead of saving the whole model. I also don't know what the implications of changing this are. But it could certainly be something worth discussing. |
Saving an entire model via It seems that, upon inspection of the underlying
|
I followed your suggestion and opened it as PR #915. I am unsure which approach (double saving, or checking the state_dict) is better, but we can now compare the two. |
Looks like double saving is preferred over David's suggestion here. So this PR is ready for a review and merge. I also updated the notebooks that have the warning to remove it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
self.hparams["X_d_transform"] = X_d_transform | ||
self.hparams.update( | ||
{ | ||
"message_passing": message_passing.hparams, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.hparams["X_d_transform"] = X_d_transform | |
self.hparams.update( | |
{ | |
"message_passing": message_passing.hparams, | |
self.hparams.update( | |
{ | |
"X_d_transform": X_d_transform, | |
"message_passing": message_passing.hparams, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion. I also thought about including X_d_transform in the update, but decided it would be better to do separately to match what is done in message passing base with V_d_transform for example. Having it separate also helps show that it doesn't have its own hparams like message passing does. I'll go ahead and set this to merge because you approved it. If you want to talk about this suggested change more, you could open another PR with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! How about not using hparams.update
at all? I don't think it's necessary to separate these two parts here, as it doesn't provide any new context. If you think it's worth mentioning that the message passing has its own hparams
, it would be clearer to add a comment in the code instead of using a different style to separate them.
Summary
Currently some small tensors saved twice in a chemprop model file, both in hyper_parameters and in state_dict. I think there isn't a clean way to not double save them and the small tensors are small enough to not worry about, so the best thing to do is to continue double saving, suppress any warnings, and be explicit that is what we are doing.
The warning can be suppressed by adding the small
nn.Modules
arguments toself.save_hyperparameters(ignore=...)
and then manually adding them tohparams
. #832 started this forcriterion
andoutput_transform
of_FFNPredictorBase
.The warning looks something like "Attribute '...' is an instance of
nn.Module
and is already saved during checkpointing. It is recommended to ignore them usingself.save_hyperparameters(ignore=['...'])
"Background:
All of this is related to the question of "How do we reload a previous model?". Closely related is the question "How do we save a model so that it can be reloaded?". There are two parts of models that have to be saved and loaded, the architecture and the "numbers" (weights, biases, etc.).
MPNN is a
lightning.pytorch.LightningModule
which inherits from bothlightning.pytorch.core.mixins.HyperparametersMixin
andtorch.nn.Module
.HyperparametersMixin
takes care of saving the model architecture (by saving the input arguments when creating the model) andModule
takes care of saving the "numbers".I'll discuss three representative attributes of
MPNN
that are each saved and loaded differently.message_passing
architecture
HyperparametersMixin
has a methodsave_hyperparameters
which will add all the arguments passed to__init__
toself.hparams
. In the case ofmessage_passing
we already include it in theself.save_hyperparameters(ignore=...)
ignore list becausemessage_passing
is an object that contains all the weights that are already saved in thestate_dict
.save_hyperparameters
would just pickle the whole object leading to the weights being saved in bothhparam
and instate_dict
.We want to just save the instructions for how to make a new
message_passing
module with the same architecture as the saved version. It will be created with random weights, but then those will be overwritten with the correct weights when the state_dict is loaded. To do this_MessagePassingBase
also inherits fromHyperparametersMixin
and saves its ownhparams
. Thishparams
dict is then used as the value for themessage_passing
key in MPNN'shparams
.weights
Module
has an overloaded__setattr__
which affects the behavior setting attributes. When we doself.message_passing = message_passing
, the message passing object gets registered as a submodule of MPNN (because it is also aModule
) and the state_dict for MPNN will include the parameters ofmessage_passing
. (I think that is done here.)message_passing
will also appear in the output ofprint(model)
.To then load an MPNN, the
message_passing
will first be created with random weights using thehparams
and then the state_dict is loaded which overwrites themessage_passing
weights with the correct ones.metrics
The metrics given to
MPNN
is not aModule
. It is alist
whose elements areMetric
s which areModule
s. This means thatself.metrics = metrics
does not register metrics as a submodule of MPNN so the numbers in the individualMetric
state_dict
s are not added toMPNN
sstate_dict
. This also means thatmetrics
does not appear inprint(model)
which seems reasonable to me.metrics
is not included in the ignore list ofsave_hyperparameters
so the whole thing is pickled and reloaded when saving and loading a model.X_d_transform
We want
X_d_transform
to be a submodule ofMPNN
so that it appears inprint(model)
among other things. We do this by setting it as an attribute (that also includes a check if it isNone
)self.X_d_transform = X_d_transform if X_d_transform is not None else nn.Identity()
. If X_d_transformis not
None, it's
state_dict(which includes the scale and mean) will be added to
MPNNs
state_dict`.The question is, do we include the whole
X_d_transform
object inhparams
ofMPNN
or not? I think we must becauseload_from_file
ofMPNN
doesn't know to initialize a newX_d_transform
like it did withmessage_passing
. But if we do include the whole object, the scale and mean ofX_d_transform
will be duplicated in bothhparams
and instate_dict
.Our options are
_ScaleTransformMixin
more like_MessagePassingBase
by having it inherit fromHyperparametersMixin
and save it's input arguments inhparams
. This seems more complicated to me.scale
andmean
from thestate_dict
by not registering them as buffers. I remember that it was a concious decision to include them in thestate_dict
so this doesn't seem like a good solution.These same options and questions apply to the other inputs that get this warning from lightning including: "V_d_transform", "graph_transform", "criterion", and "output_transform"
Other note
One potential issue with #832 that I changed is
_FFNPredictorBase
expectsoutput_transform
to be anUnscaleTransform
orNone
, butoutput_transform
is saved after convertingNone
s tonn.Identity
. If we ever added some additional logic to check ifoutput_transform
is input asNone
, this would break that.