Skip to content

Commit

Permalink
add 10b experiment to flava native and fix checkpoint wrapper (#309)
Browse files Browse the repository at this point in the history
Summary:
Adds 10B experiment config to FLAVA native training script and fixes issues with checkpointing due to kwargs and re-entrant.

Pull Request resolved: #309

Test Plan:
`torchrun --nproc_per_node=8 -m flava.native.train config=flava/native/configs/10b.yaml`

Fixes #{issue number}

Reviewed By: ankitade

Differential Revision: D39563955

Pulled By: edward-io

fbshipit-source-id: 93d1003c6a238e5f756581ca9507501edf2aa4df
  • Loading branch information
edward-io authored and facebook-github-bot committed Sep 16, 2022
1 parent e5e9d4f commit 53ab78b
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 34 deletions.
5 changes: 5 additions & 0 deletions examples/flava/native/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
3 changes: 1 addition & 2 deletions examples/flava/native/configs/1.8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ training:
half_precision_format: "bfloat16" # or float16
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision

activation_checkpointing: False
activation_checkpointing_reentrant: False # false for non-reentrant
activation_checkpointing: True

datasets:
_target_: flava.definitions.TrainingDatasetsInfo
Expand Down
80 changes: 80 additions & 0 deletions examples/flava/native/configs/10b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
training:
strategy: fsdp # can be changed to ddp or fsdp
seed: 1337

batch_size: 8
num_workers: 4
prefetch_factor: 3

optimizer:
learning_rate: 1e-3
adam_eps: 1e-8
adam_weight_decay: 0.1
adam_betas: [0.9, 0.999]

warmup_steps: 10000
max_steps: 100000

validation_steps: 5000
log_interval: 10

enable_tf32: True
enable_amp: True
half_precision_format: "bfloat16" # or float16
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision

activation_checkpointing: True

datasets:
_target_: flava.definitions.TrainingDatasetsInfo
selected:
- image
- vl
- text
image:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: imagenet-1k
subset: default
text:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: wikitext
subset: wikitext-103-raw-v1
datamodule_extra_kwargs:
text_columns: ["text"]
vl:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: red_caps
subset: backpacking
rename_columns:
- ["caption", "text"]
val:
- _target_: flava.definitions.HFDatasetInfo
key: red_caps
subset: backpacking
rename_columns:
- ["caption", "text"]
split_key_mapping:
validation: train


model:
image_num_hidden_layers: 64
image_hidden_size: 2048
image_intermediate_size: 10240
image_num_attention_heads: 16

text_num_hidden_layers: 64
text_hidden_size: 2048
text_intermediate_size: 10240
text_num_attention_heads: 16

multimodal_num_hidden_layers: 40
multimodal_hidden_size: 2048
multimodal_intermediate_size: 10240
multimodal_num_attention_heads: 16
1 change: 0 additions & 1 deletion examples/flava/native/configs/2.7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ training:
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision

activation_checkpointing: True
activation_checkpointing_reentrant: False # false for non-reentrant

datasets:
_target_: flava.definitions.TrainingDatasetsInfo
Expand Down
79 changes: 79 additions & 0 deletions examples/flava/native/configs/4.8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
training:
strategy: fsdp # can be changed to ddp or fsdp
seed: 1337

batch_size: 12
num_workers: 4
prefetch_factor: 3

optimizer:
learning_rate: 1e-3
adam_eps: 1e-8
adam_weight_decay: 0.1
adam_betas: [0.9, 0.999]

warmup_steps: 10000
max_steps: 100000

validation_steps: 5000
log_interval: 10

enable_tf32: True
enable_amp: True
half_precision_format: "bfloat16" # or float16
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision

activation_checkpointing: True

datasets:
_target_: flava.definitions.TrainingDatasetsInfo
selected:
- image
- vl
- text
image:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: imagenet-1k
subset: default
text:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: wikitext
subset: wikitext-103-raw-v1
datamodule_extra_kwargs:
text_columns: ["text"]
vl:
_target_: flava.definitions.TrainingSingleDatasetInfo
train:
- _target_: flava.definitions.HFDatasetInfo
key: red_caps
subset: backpacking
rename_columns:
- ["caption", "text"]
val:
- _target_: flava.definitions.HFDatasetInfo
key: red_caps
subset: backpacking
rename_columns:
- ["caption", "text"]
split_key_mapping:
validation: train

model:
image_num_hidden_layers: 48
image_hidden_size: 1664
image_intermediate_size: 8192
image_num_attention_heads: 16

text_num_hidden_layers: 48
text_hidden_size: 1664
text_intermediate_size: 8192
text_num_attention_heads: 16

multimodal_num_hidden_layers: 24
multimodal_hidden_size: 1664
multimodal_intermediate_size: 8192
multimodal_num_attention_heads: 16
3 changes: 1 addition & 2 deletions examples/flava/native/configs/900m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ training:
half_precision_format: "bfloat16" # or float16
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision

activation_checkpointing: False
activation_checkpointing_reentrant: False # false for non-reentrant
activation_checkpointing: True

datasets:
_target_: flava.definitions.TrainingDatasetsInfo
Expand Down
56 changes: 28 additions & 28 deletions examples/flava/native/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.text_encoder import BERTTextEncoder
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput

Expand Down Expand Up @@ -141,11 +143,29 @@ def create_model(self) -> torch.nn.Module:
f"size: {get_model_size_gb(model):.3} GB"
)

model = model.to(self.device)
print0(f"after moving to cuda: {torch.cuda.memory_allocated()/1024**3:.3} GB")
if self.config.training.activation_checkpointing:
check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.REENTRANT,
)
apply_activation_checkpointing_wrapper(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=check_fn,
)

if strategy == "ddp":
# TODO do we have to do this in FSDP too? see https://github.com/pytorch/pytorch/issues/75478
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(self.device)

print0(
f"after moving to cuda: {torch.cuda.memory_allocated()/1024**3:.3} GB"
)

model = DDP(
model,
device_ids=[self.rank],
Expand All @@ -165,36 +185,17 @@ def create_model(self) -> torch.nn.Module:
model = FSDP(
model,
mixed_precision=mp,
device_id=self.device,
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer},
transformer_layer_cls={
TransformerEncoderLayer,
ImageTransformer,
BERTTextEncoder,
},
),
)

if self.config.training.activation_checkpointing:
# note: activation checkpointing wrapper currently is faulty
# - non-reentrant does not support kwargs in TransformerEncoderLayer
# - memory reduction from checkpointing is less than expected

check_fn = lambda submodule: isinstance(
submodule, TransformerEncoderLayer
)
if self.config.training.activation_checkpointing_reentrant:
checkpoint_impl = CheckpointImpl.REENTRANT
else:
checkpoint_impl = CheckpointImpl.NO_REENTRANT

non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=checkpoint_impl,
)
apply_activation_checkpointing_wrapper(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=check_fn,
)

print0(f"after FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")

else:
Expand Down Expand Up @@ -373,7 +374,6 @@ def validate(self):
print0(f"step {self.steps} EVAL loss: {norm_validation_loss:.4}")

def imagenet_validate(self):
# not working due to an FSDP issue
print0("imagenet validation")
with torch.no_grad():
with torch.cuda.amp.autocast(
Expand Down
1 change: 0 additions & 1 deletion examples/flava/native/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def _accuracy(output, target, topk=(1,)):
]


@rank0_only
def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs):
print0("Starting ImageNet Zero-Shot Eval")
print0("Building classifier")
Expand Down

0 comments on commit 53ab78b

Please sign in to comment.