Skip to content

Commit

Permalink
fix loading model with kwargs (#2387)
Browse files Browse the repository at this point in the history
* test

* fix

* fix
  • Loading branch information
Borda authored Jun 27, 2020
1 parent e82d9cd commit 51711c2
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Fixes # (issue)

- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you create a separate PR for every change.
- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
- [ ] Did you verify new and existing tests pass locally with your changes?
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))

- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def load_from_checkpoint(
return model

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
# pass in the values we saved automatically
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
Expand All @@ -184,19 +184,23 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
cls_spec = inspect.getfullargspec(cls.__init__)
kwargs_identifier = cls_spec.varkw
cls_init_args_name = inspect.signature(cls).parameters.keys()

if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
kwargs.update(**cls_kwargs)
# in case the class cannot take any extra argument filter only the possible
if not kwargs_identifier:
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
cls_kwargs.update(**model_args)
elif args_name:
if args_name in init_args_name:
kwargs.update({args_name: model_args})
if args_name in cls_init_args_name:
cls_kwargs.update({args_name: model_args})
else:
args = (model_args, ) + args
cls_args = (model_args,) + cls_args

# load the state_dict on the model automatically
model = cls(*args, **kwargs)
model = cls(*cls_args, **cls_kwargs)
model.load_state_dict(checkpoint['state_dict'])

# give model a chance to load something
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint = {
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
'pytorch-ligthning_version': pytorch_lightning.__version__,
'pytorch-lightning_version': pytorch_lightning.__version__,
}

if not weights_only:
Expand Down
26 changes: 13 additions & 13 deletions tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ class EvalModelTemplate(
>>> model = EvalModelTemplate()
"""

def __init__(self,
*args,
drop_prob: float = 0.2,
batch_size: int = 32,
in_features: int = 28 * 28,
learning_rate: float = 0.001 * 8,
optimizer_name: str = 'adam',
data_root: str = PATH_DATASETS,
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999,
**kwargs) -> object:
def __init__(
self,
drop_prob: float = 0.2,
batch_size: int = 32,
in_features: int = 28 * 28,
learning_rate: float = 0.001 * 8,
optimizer_name: str = 'adam',
data_root: str = PATH_DATASETS,
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999
):
# init superclass
super().__init__()
self.save_hyperparameters()
Expand Down
19 changes: 9 additions & 10 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,15 @@ def test_collect_init_arguments(tmpdir, cls):
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179

# verify that model loads correctly
# TODO: uncomment and get it pass
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.batch_size == 179
#
# if isinstance(model, AggSubClassEvalModel):
# assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
#
# if isinstance(model, DictConfSubClassEvalModel):
# assert isinstance(model.hparams.dict_conf, Container)
# assert model.hparams.dict_conf == 'anything'
model = cls.load_from_checkpoint(raw_checkpoint_path)
assert model.hparams.batch_size == 179

if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)

if isinstance(model, DictConfSubClassEvalModel):
assert isinstance(model.hparams.dict_conf, Container)
assert model.hparams.dict_conf['my_param'] == 'anything'

# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ def test_end(self, outputs):


def test_tbd_remove_in_v1_0_0_model_hooks():
hparams = EvalModelTemplate.get_default_hparams()

model = ModelVer0_6(hparams)
model = ModelVer0_6()

with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
Expand All @@ -143,7 +142,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.6)}

model = ModelVer0_7(hparams)
model = ModelVer0_7()

with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_suggestion_with_non_finite_values(tmpdir):
""" Test that non-finite values does not alter results """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
model = EvalModelTemplate(**hparams)

# logger file to get meta
trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def configure_optimizers(self):
return config

hparams = EvalModelTemplate.get_default_hparams()
model = CurrentModel(hparams)
model = CurrentModel(**hparams)

# fit model
trainer = Trainer(
Expand Down

0 comments on commit 51711c2

Please sign in to comment.