Skip to content

Commit

Permalink
Config for lightning checkpoints (#75)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #75

Test Plan: Imported from OSS

Reviewed By: ankitade

Differential Revision: D37008469

Pulled By: IvanKobzarev

fbshipit-source-id: 28d9bbefc899124302c391b8edf82b1bcca521ca
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Jun 10, 2022
1 parent 63427b6 commit f028280
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 6 deletions.
8 changes: 8 additions & 0 deletions examples/flava/configs/finetuning/qnli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 1000
num_sanity_val_steps: 0
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
save_last: true
every_n_train_steps: 1000
save_on_train_epoch_end: true
verbose: true
lightning_load_from_checkpoint: null
seed: -1
batch_size: 32
num_workers: 4
Expand Down
8 changes: 8 additions & 0 deletions examples/flava/configs/finetuning/rendered_sst2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 100
num_sanity_val_steps: 0
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
save_last: true
every_n_train_steps: 1000
save_on_train_epoch_end: true
verbose: true
lightning_load_from_checkpoint: null
seed: -1
batch_size: 32
num_workers: 4
Expand Down
8 changes: 8 additions & 0 deletions examples/flava/configs/pretraining/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 10000
num_sanity_val_steps: 0
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
save_last: true
every_n_train_steps: 1000
save_on_train_epoch_end: true
verbose: true
lightning_load_from_checkpoint: null
seed: -1
batch_size: 8
num_workers: 4
Expand Down
2 changes: 2 additions & 0 deletions examples/flava/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class TrainingDatasetsInfo:
class TrainingArguments:
# Any lightning args to be pushed here
lightning: Dict[str, Any] = field(default=dict)
lightning_checkpoint: Optional[Dict[str, Any]] = None
lightning_load_from_checkpoint: Optional[str] = None
seed: int = -1
batch_size: int = 8
num_workers: int = 4
Expand Down
23 changes: 17 additions & 6 deletions examples/flava/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from model import FLAVAPreTrainingLightningModule
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from utils import build_config, build_datamodule_kwargs


Expand Down Expand Up @@ -53,15 +53,26 @@ def main():
**config.model,
)

callbacks = [
LearningRateMonitor(logging_interval="step"),
MultimodalEvalCallback(imagenet_datamodule=imagenet_datamodule),
]

if config.training.lightning_checkpoint is not None:
callbacks.append(
ModelCheckpoint(
**OmegaConf.to_container(config.training.lightning_checkpoint)
)
)

trainer = Trainer(
**OmegaConf.to_container(config.training.lightning),
callbacks=[
LearningRateMonitor(logging_interval="step"),
MultimodalEvalCallback(imagenet_datamodule=imagenet_datamodule),
],
callbacks=callbacks,
strategy="ddp",
)
trainer.fit(model, datamodule=datamodule)
ckpt_path = config.training.lightning_load_from_checkpoint

trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
trainer.validate(model, datamodule=datamodule)


Expand Down

0 comments on commit f028280

Please sign in to comment.