Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning example #1564

Merged
merged 7 commits into from
May 2, 2020

Conversation

jbschiratti
Copy link
Contributor

What does this PR do?

Addresses issue #514. Following up on this discussion, this PR proposes to add a self-contained example which shows how a pretrained network (such as ResNet50) can be fine tuned within a LightningModule.

PR review

Anyone in the community is free to review the PR 🙂

@mergify mergify bot requested a review from a team April 22, 2020 17:15
@Borda Borda added the example label Apr 22, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add argparse to be able to run with diff params
pls use Napoleon docs style https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html

@mergify mergify bot requested review from a team April 22, 2020 21:04
@hcjghr
Copy link

hcjghr commented Apr 23, 2020

Hi @jbschiratti

Thanks for such a nice example. I'm rather new in the field so hopefully my question will not be too off base. I was just going over the code and I noticed that in your example the BatchNorm layers will always remain in training mode (as train_bn is always set to self.hparams.train_bn when calling the freeze function). That is even when performing validation or evaluation. I understand the code potentially allows for the BN layers to be set to eval (if the train_bn=False) but I am just wondering if there is a specific reason why do you always leave BN in train mode? Why not have them train in the training stage and eval in the validation/testing?

Just to clerify I'm not arguing it should be different, I'm just asking for the reasoning behind it.

@Borda
Copy link
Member

Borda commented Apr 23, 2020

Thanks for such a nice example. I'm rather new in the field so hopefully my question will not be too off base. I was just going over the code and I noticed ...

Thank you for your interest and help with this addition, may you please use review tab of this PR to write your comments directly to the sections you are talking about... it will make the discussion clearer and a bit more concrete :]

@jbschiratti
Copy link
Contributor Author

@hcjghr Thank you for spotting this. It's was a bug! The way I see it, in the evaluation loop when model.eval() is called, the BatchNorm layers (as well as the other layers) should be in eval mode (training=False). This is what I had in mind and it is now fixed.

@Borda I fixed the docstrings and added argparse. Thank you for the comments!

@Borda
Copy link
Member

Borda commented Apr 23, 2020

pls add note to changelog 🐰

@Borda Borda added the waiting on author Waiting on user action, correction, or update label Apr 25, 2020
@mergify
Copy link
Contributor

mergify bot commented Apr 26, 2020

This pull request is now in conflict... :(

@codecov
Copy link

codecov bot commented Apr 27, 2020

Codecov Report

Merging #1564 into master will not change coverage.
The diff coverage is n/a.

@@          Coverage Diff           @@
##           master   #1564   +/-   ##
======================================
  Coverage      88%     88%           
======================================
  Files          69      69           
  Lines        4133    4133           
======================================
  Hits         3656    3656           
  Misses        477     477           

@mergify mergify bot requested a review from a team April 28, 2020 23:13
@jbschiratti
Copy link
Contributor Author

Thanks @awaelchli for the review and the comments!

@mergify mergify bot requested a review from a team April 29, 2020 11:38
@awaelchli
Copy link
Contributor

awaelchli commented Apr 29, 2020

I noticed that the downloaded dataset is not ignored in version control. Could we maybe redirect it to a subfolder datasets and add a .gitignore in domain templates folder?

@jbschiratti
Copy link
Contributor Author

This is strange because the context manager

with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir:
    ...

should delete the temporary folder in which the data is downloaded.

@awaelchli
Copy link
Contributor

ah ok, so it is also supposed to do that on keyboard interrupt? Maybe it's because I'm on Windows currently.

@jbschiratti
Copy link
Contributor Author

I tried to stop the script with CTRL+C during the 1st epoch and the temporary folder was deleted (on Linux). But I cannot guarantee this always works.

@Borda Borda requested a review from awaelchli April 29, 2020 11:56
@awaelchli
Copy link
Contributor

Can now also confirm it works fine on Linux, so it's just a Windows thing, so I guess we can keep it like that.

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, minimal and clean. I like it very much.

@mergify mergify bot requested a review from a team April 29, 2020 13:25
def loss(self, labels, logits):
return self.loss_func(input=logits, target=labels)

def train(self, mode=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the mode for? could it be more descriptive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/pytorch/pytorch/blob/d37a4861b8a5eed3d9a1340484d1efb0f48aa59e/torch/nn/modules/module.py#L1067. This line overrides the train method of the Pytorch module. I will add a docstring specifying what mode does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, we may rename it...
doe you have suggestion about a better name? @PyTorchLightning/core-contributors

Copy link
Contributor Author

@jbschiratti jbschiratti Apr 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we can rename it. In the evaluation loop (L330), model.train() is called and here, model refers (if I am not mistaken) to the LightningModule. We want to override this train method to ensure that, at the end of this evaluation loop when model.train() is called, some parameters (in specific layers) remain frozen (that is, with requires_grad=False) if needed.

@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument('--backbone',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use add_argparse_args so we limit duplication and add just the new/needed for a model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "limit code duplication", you want me to remove this line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean remove lines which are generated from Trainer arguments... does it make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to a temporary directory.
"""

with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we want to keep the output folder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The folder in which the data was downloaded is deleted after the experiment. If you think we should leave the data untouched after the example has run, I can make another PR to fix this :-)

@mergify mergify bot requested a review from a team April 29, 2020 15:28
@williamFalcon
Copy link
Contributor

@jbschiratti this is super cool.
Why don't we move this to https://github.com/PyTorchLightning/pytorch-lightning-bolts?

@Borda
Copy link
Member

Borda commented Apr 30, 2020

I would keep it here in examples....

@mergify
Copy link
Contributor

mergify bot commented May 1, 2020

This pull request is now in conflict... :(

@Borda Borda force-pushed the fine_tuning_example branch from 7a1e5e3 to 493296e Compare May 1, 2020 19:05
@williamFalcon williamFalcon merged commit fafe5d6 into Lightning-AI:master May 2, 2020
@jbschiratti
Copy link
Contributor Author

Thank you @williamFalcon 👍

@Borda Borda added this to the 0.7.6 milestone May 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
example waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants