Skip to content

Commit

Permalink
Add ckpt loading and accuracy metric to finetuning (#119)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #119

- Accuracy metric for finetuning
- Add checkpoint saving and best ckpt loading based on val accuracy
- Load pretrained ckpt by default in classification model
- make num gpus 1 in qnli.yaml

Test plan
python -m flava.finetune config=flava/configs/finetuning/qnli.yaml
(val acc : 0.8651)

Loaded model weights from checkpoint at /data/home/deankita/torchmultimodal/examples/flava-epoch=03-step=10000.ckpt
/data/home/deankita/miniconda/envs/flava/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.validate()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.
  rank_zero_warn(
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 171/171 [00:54<00:00,  3.15it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃          Validate metric           ┃            DataLoader 0            ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ validation/accuracy/classification │         0.8651315569877625         │
│  validation/losses/classification  │         0.4168359339237213         │

Test Plan: Imported from OSS

Reviewed By: ebsmothers

Differential Revision: D37444938

Pulled By: ankitade

fbshipit-source-id: b49b3dadc409f0c2e7f6567a33190f9c9c2e90ef
  • Loading branch information
ankitade authored and facebook-github-bot committed Jul 6, 2022
1 parent a57e4a2 commit 855c9ed
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 17 deletions.
4 changes: 3 additions & 1 deletion examples/flava/configs/finetuning/qnli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ training:
_target_: flava.definitions.TrainingArguments
lightning:
max_steps: 33112
gpus: -1
gpus: 1
progress_bar_refresh_rate: 50
val_check_interval: 1000
num_sanity_val_steps: 0
Expand All @@ -16,6 +16,8 @@ training:
every_n_train_steps: 1000
save_on_train_epoch_end: true
verbose: true
monitor: validation/accuracy/classification
mode: max
lightning_load_from_checkpoint: null
seed: -1
batch_size: 32
Expand Down
25 changes: 17 additions & 8 deletions examples/flava/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from flava.data.datamodules import VLDataModule
from flava.definitions import FLAVAArguments
from flava.model import FLAVAClassificationLightningModule
from flava.utils import build_config, build_datamodule_kwargs
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from utils import build_config, build_datamodule_kwargs
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

AVAIL_GPUS = 1
SEED = -1
Expand Down Expand Up @@ -55,14 +55,23 @@ def main():
**config.model,
)

callbacks = [
LearningRateMonitor(logging_interval="step"),
]

if config.training.lightning_checkpoint is not None:
callbacks.append(
ModelCheckpoint(
**OmegaConf.to_container(config.training.lightning_checkpoint)
)
)

trainer = Trainer(
**OmegaConf.to_container(config.training.lightning),
callbacks=[
LearningRateMonitor(logging_interval="step"),
],
**OmegaConf.to_container(config.training.lightning), callbacks=callbacks
)
trainer.fit(model, datamodule=datamodule)
trainer.validate(model, datamodule=datamodule)
ckpt_path = config.training.lightning_load_from_checkpoint
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
trainer.validate(datamodule=datamodule)


if __name__ == "__main__":
Expand Down
28 changes: 23 additions & 5 deletions examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy
from torchmultimodal.models.flava.flava_model import (
flava_model_for_classification,
flava_model_for_pretraining,
Expand Down Expand Up @@ -139,18 +140,33 @@ def __init__(
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.adam_betas = adam_betas
self.metrics = Accuracy()

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output, accuracy = self._step(batch, batch_idx)
self.log("train/losses/classification", output.loss, prog_bar=True, logger=True)
self.log(
"train/accuracy/classification",
accuracy,
prog_bar=True,
logger=True,
sync_dist=True,
)

return output.loss

def validation_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output, accuracy = self._step(batch, batch_idx)
self.log(
"validation/losses/classification", output.loss, prog_bar=True, logger=True
)
self.log(
"validation/accuracy/classification",
accuracy,
prog_bar=True,
logger=True,
sync_dist=True,
)

return output.loss

Expand All @@ -164,15 +180,17 @@ def _step(self, batch, batch_idx):
else:
raise RuntimeError("Batch needs to have either or both 'image' and 'text'.")

labels = batch["labels"]
output = self.model(
image=batch.get("image", None),
text=batch.get("text", None),
required_embedding=required_embedding,
labels=batch.get("labels", None),
labels=labels,
)

# TODO: Add accuracy metric to this later.
return output
accuracy = self.metrics(output.logits, labels)

return output, accuracy

def configure_optimizers(self):
return get_optimizers_for_lightning(
Expand Down
2 changes: 1 addition & 1 deletion test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):

@torch.no_grad()
def test_forward_classification(self):
flava = flava_model_for_classification(NUM_CLASSES)
flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
image = torch.rand((2, 3, 224, 224))

Expand Down
10 changes: 9 additions & 1 deletion torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def flava_model_for_classification(
classifier_activation: Callable[..., nn.Module] = nn.ReLU,
classifier_normalization: Optional[Callable[..., nn.Module]] = None,
loss_fn: Optional[Callable[..., Tensor]] = None,
pretrained_model_key: Optional[str] = "flava_full",
**flava_model_kwargs: Any,
):
model = flava_model(**flava_model_kwargs)
Expand All @@ -225,7 +226,14 @@ def flava_model_for_classification(
if loss_fn is None:
loss_fn = nn.CrossEntropyLoss()

return FLAVAForClassification(model=model, classifier=classifier, loss=loss_fn)
classification_model = FLAVAForClassification(
model=model, classifier=classifier, loss=loss_fn
)
if pretrained_model_key is not None:
classification_model.load_model(
FLAVA_FOR_PRETRAINED_MAPPING[pretrained_model_key], strict=False
)
return classification_model


def to_2tuple(x):
Expand Down
3 changes: 2 additions & 1 deletion torchmultimodal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def load_model(
pretrained_url: Optional[str],
load_state_dict: bool = True,
state_dict_key: Optional[str] = None,
strict: bool = True,
):
assert isinstance(
self, torch.nn.Module
Expand All @@ -160,7 +161,7 @@ def load_model(
state_dict = state_dict[state_dict_key]

if load_state_dict:
self.load_state_dict(state_dict)
self.load_state_dict(state_dict, strict=strict)
return state_dict


Expand Down

0 comments on commit 855c9ed

Please sign in to comment.