Skip to content

Commit

Permalink
read training strategy from config (#66)
Browse files Browse the repository at this point in the history
Summary:
* Read training strategy from config file
* Add `strategy: ddp` to all existing config files
* Can also pass in strategy from the terminal by `python train.py config=configs/pretraining/debug.yaml training.lightning.strategy=ddp`
* Priority: terminal argument > config

Pull Request resolved: #66

Test Plan:
logged `self.trainer.strategy` in `FLAVAPreTrainingLightningModule`

test 1
* command: `python train.py config=configs/pretraining/debug.yaml`
* config: `strategy: ddp`
* result: `DDPStrategy`

test 2
* command: `python train.py config=configs/pretraining/debug.yaml training.lightning.strategy=deepspeed`
* config: `strategy: ddp`
* result: `DeepSpeedStrategy`

Reviewed By: ebsmothers

Differential Revision: D37124420

Pulled By: katrina433

fbshipit-source-id: a02cafba7ad784963a0c27ad543256fd5daaa417
  • Loading branch information
katrina433 authored and facebook-github-bot committed Jun 13, 2022
1 parent 27a0d2e commit 3beffd9
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/flava/configs/finetuning/qnli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 1000
num_sanity_val_steps: 0
strategy: ddp
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
Expand Down
1 change: 1 addition & 0 deletions examples/flava/configs/finetuning/rendered_sst2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 100
num_sanity_val_steps: 0
strategy: ddp
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
Expand Down
1 change: 1 addition & 0 deletions examples/flava/configs/pretraining/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ training:
progress_bar_refresh_rate: 50
val_check_interval: 10000
num_sanity_val_steps: 0
strategy: ddp
lightning_checkpoint:
dirpath: "."
filename: flava-{epoch:02d}-{step}
Expand Down
1 change: 0 additions & 1 deletion examples/flava/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main():
trainer = Trainer(
**OmegaConf.to_container(config.training.lightning),
callbacks=callbacks,
strategy="ddp",
)
ckpt_path = config.training.lightning_load_from_checkpoint

Expand Down

0 comments on commit 3beffd9

Please sign in to comment.