Skip to content

Commit

Permalink
[FLAVA]Use boolean pretrained flag to load ckpts
Browse files Browse the repository at this point in the history
ghstack-source-id: 9ea01321467e8e3bc8191969f7b09a4d1fd3cd1d
Pull Request resolved: #365
  • Loading branch information
ankitade committed Nov 8, 2022
1 parent f4eacdd commit f3fc18f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/flava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ python train.py config=flava/configs/pretraining/debug.yaml training.lightning.m
Similarly, let's say you want to use a pretrained model for your pretraining/finetuning.

```
python -m flava.train config=configs/pretraining/debug.yaml model.pretrained_model_key=flava_full
python -m flava.train config=configs/pretraining/debug.yaml model.pretrained=True
```

### Full Pretraining
Expand All @@ -64,7 +64,7 @@ python -m flava.train config=configs/pretraining/debug.yaml model.pretrained_mod
Similarly to pretraining, finetuning can be launched by following command:

```
python finetune.py config=configs/finetuning/qnli.yaml model.pretrained_model_key=flava_full
python finetune.py config=configs/finetuning/qnli.yaml model.pretrained=True
```

### Linear Probe
Expand Down
3 changes: 1 addition & 2 deletions examples/flava/coco_zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def transform(image, target):

def collator(batch):
texts = []
print(batch[0][0]["image"])
images = torch.stack([x[0]["image"] for x in batch], dim=0)
texts = torch.cat([torch.LongTensor(x[1]["input_ids"]) for x in batch], dim=0)
return images, texts
Expand All @@ -61,7 +60,7 @@ def main():
dataset = CocoCaptions(
root=args.data_root, annFile=args.annotations, transforms=transform
)
flava = flava_model(pretrained_model_key="flava_full")
flava = flava_model(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
flava = flava.to(device)
Expand Down
3 changes: 2 additions & 1 deletion examples/flava/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class TrainingArguments:

@dataclass
class ModelArguments:
pretrained_model_key: Optional[str] = None
# pretrained_model_key: Optional[str] = None
pretrained: bool = False


@dataclass
Expand Down
8 changes: 3 additions & 5 deletions tests/models/flava/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def inputs_model(self, image_input, text_input):
@pytest.fixture
def classification_model(self):
def get_model():
flava = flava_model_for_classification(
num_classes=3, pretrained_model_key="flava_full"
)
flava = flava_model_for_classification(num_classes=3, pretrained=True)
flava.eval()
return flava

Expand All @@ -82,7 +80,7 @@ def get_model():
@pytest.fixture
def pretraining_model(self):
def get_model():
flava = flava_model_for_pretraining(pretrained_model_key="flava_full")
flava = flava_model_for_pretraining(pretrained=True)
flava.eval()
return flava

Expand All @@ -91,7 +89,7 @@ def get_model():
@pytest.fixture
def model(self):
def get_model():
flava = flava_model(pretrained_model_key="flava_full")
flava = flava_model(pretrained=True)
flava.eval()
return flava

Expand Down
2 changes: 1 addition & 1 deletion tests/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def pretraining_inputs(self):
def test_forward_classification(self, classification_inputs):
text, image, labels = classification_inputs

flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
flava = flava_model_for_classification(NUM_CLASSES, pretrained=False)
flava.eval()

# Test multimodal scenario
Expand Down
19 changes: 10 additions & 9 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def flava_model(
multimodal_layer_norm_eps: float = 1e-12,
# projection
text_and_image_proj_size: int = 768,
pretrained_model_key: Optional[str] = None,
pretrained: bool = False,
**kwargs: Any,
) -> FLAVAModel:
image_encoder = flava_image_encoder(
Expand Down Expand Up @@ -505,15 +505,15 @@ def flava_model(
image_projection=image_projection,
)

if pretrained_model_key is not None:
flava.load_model(FLAVA_MODEL_MAPPING[pretrained_model_key])
if pretrained:
flava.load_model(FLAVA_MODEL_MAPPING["flava_full"])

return flava


def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
pretrained: bool = False,
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
Expand All @@ -528,8 +528,8 @@ def flava_model_for_pretraining(
loss=losses,
)

if pretrained_model_key is not None:
flava.load_model(FLAVA_FOR_PRETRAINED_MAPPING[pretrained_model_key])
if pretrained:
flava.load_model(FLAVA_FOR_PRETRAINED_MAPPING["flava_full"])

return flava

Expand All @@ -542,7 +542,7 @@ def flava_model_for_classification(
classifier_activation: Callable[..., nn.Module] = nn.ReLU,
classifier_normalization: Optional[Callable[..., nn.Module]] = None,
loss_fn: Optional[Callable[..., Tensor]] = None,
pretrained_model_key: Optional[str] = "flava_full",
pretrained: bool = True,
**flava_model_kwargs: Any,
) -> FLAVAForClassification:

Expand All @@ -561,9 +561,10 @@ def flava_model_for_classification(
classification_model = FLAVAForClassification(
model=model, classifier=classifier, loss=loss_fn
)
if pretrained_model_key is not None:

if pretrained:
classification_model.load_model(
FLAVA_FOR_PRETRAINED_MAPPING[pretrained_model_key], strict=False
FLAVA_FOR_PRETRAINED_MAPPING["flava_full"], strict=False
)
return classification_model

Expand Down

0 comments on commit f3fc18f

Please sign in to comment.