Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Saving & Reloading #84

Closed
mariusmerkle opened this issue Feb 12, 2021 · 3 comments
Closed

Model Saving & Reloading #84

mariusmerkle opened this issue Feb 12, 2021 · 3 comments

Comments

@mariusmerkle
Copy link

Hi,

I am studying how transfer learning can enhance the training of physics-informed neural networks. The NeuroDiffEq sparked my interest and I was wondering whether it is possible to

  1. save a trained model, i.e. the parameters of the network and its architecture

  2. reload the saved model and continue training from that non-random state.

@shuheng-liu
Copy link
Member

shuheng-liu commented Feb 12, 2021

Hi Marius. Yes, it's easy to do that as long as you know basic usage of PyTorch. For saving and loading model, here's a useful link. In the context of neurodiffeq, you want to perform the following procedures:

  1. Make sure you have the model class MyNetwork saved in some file model.py. Note that MyNetwork must be a subclass of torch.nn.Module
  2. Create one or more (depending on the number of functions you are solving for) MyNetwork instances:
my_nets = [MyNetwork(...), MyNetwork(...), ...]
  1. Instantiate your solver and pass your model(s)
solver = Solver1D(
    ...
    nets=my_nets,
)
  1. Do the training and get your networks. Currently, neurodiffeq doesn't make a copy of the networks passed to it, so solver.nets is the same object as my_nets created earlier
solver.fit(max_epochs=xxx, ...)
my_nets = solver.nets # you can skip this step if you still have access to `my_nets` created earlier
  1. Save your model
torch.save({f'net_{i}': net.state_dict() for i, net in enumerate(nets)}, YOUR_MODEL_PATH)
  1. In another script, instantiate your model using exactly the same architecture and load the weights
loaded_nets = [MyNetwork(...), MyNetwork(...), ...]
checkpoint = torch.load(YOUR_MODEL_PATH)
for i, net in enumerate(loaded_nets):
    net.load_state_dict(checkpoint[f'net_{i}'])
  1. Redo step 3~4, but change my_nets to loaded_nets

@mariusmerkle
Copy link
Author

That sounds great! And it is possible to use both Adam optimiser and L-BFGS/L-BFGS-B, right?

@shuheng-liu
Copy link
Member

Most optimizers are currently supported, except LBFGS, which is a little tricky (see #83). Luckily, we seem to have a solution proposed just now. Yet, we still need to run the tests.

I'm not familiar with L-BFGS-B, but it appears that this optimizer has not been implemented in PyTorch (see here). So currently, you can't use L-BFGS-B without implementing it yourself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants