Skip to content

Commit

Permalink
Merge pull request openvla#1 from siddk/update-batch-size
Browse files Browse the repository at this point in the history
Fix Batch Size (to 32)
  • Loading branch information
siddk authored Feb 27, 2024
2 parents 7ad1de9 + 0a639c9 commit 1a40225
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
14 changes: 9 additions & 5 deletions prismatic/conf/vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class VLAConfig(ChoiceRegistry):
# === [8 GPU] Base VLA =>> LLaVa (Reproduction) + Bridge ===
@dataclass
class Exp_LLaVa15_Bridge(VLAConfig):
vla_id: str = "llava-repro+bridge+7b"
vla_id: str = "reproduction-llava-v15+mx-bridge"
base_vlm: Union[str, Path] = "reproduction-llava-v15+7b"
freeze_vision_backbone: bool = False

Expand All @@ -71,8 +71,8 @@ class Exp_LLaVa15_Bridge(VLAConfig):
max_steps: Optional[int] = None

expected_world_size: int = 8
global_batch_size: int = 128
per_device_batch_size: int = 16
global_batch_size: int = 256
per_device_batch_size: int = 32

learning_rate: float = 5e-6
weight_decay: float = 0.0
Expand All @@ -96,6 +96,10 @@ class Exp_DINOSigLIP_384px_Bridge(Exp_LLaVa15_Bridge):
vla_id: str = "prism-dinosiglip+mx-bridge"
base_vlm: Union[str, Path] = "prism-dinosiglip+7b"

# Note =>> Unfrozen DINOSigLIP OOMs w/ Per-Device Batch Size of 32!
global_batch_size: int = 192
per_device_batch_size: int = 24


# === [8 GPU] Frozen Vision Backbone =>> DINO-SigLIP @ 384px + Bridge ===
@dataclass
Expand All @@ -114,7 +118,7 @@ class Exp_SigLIP_224px_Bridge_RT1(Exp_LLaVa15_Bridge):
data_mix: str = "bridge_rt_1"

expected_world_size: int = 16
global_batch_size: int = 256
global_batch_size: int = 512


# === [32 GPU] Bridge + RT-1 =>> Frozen DINO-SigLIP @ 384px + [Bridge, RT-1] ===
Expand All @@ -127,7 +131,7 @@ class Exp_FreezeVIT_DINOSigLIP_384px_Bridge_RT1(Exp_LLaVa15_Bridge):
data_mix: str = "bridge_rt_1"

expected_world_size: int = 32
global_batch_size: int = 512
global_batch_size: int = 1024


# === Define a VLA Registry Enum for Reference & Validation ===
Expand Down
4 changes: 2 additions & 2 deletions vla-scripts/sagemaker/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
class LaunchConfig:
# fmt: off
job_name: str = "sk-openvla" # Base Name for Job in Sagemaker Dashboard
instance_count: int = 4 # Number of Nodes for Multi-Node Training
instance_count: int = 1 # Number of Nodes for Multi-Node Training
instance_type: str = "ml.p4de.24xlarge" # Instance Type (default: p4de.24xlarge)
instance_n_gpus: int = 8 # Number of GPUs per Instance

# OpenVLA Training Parameters
vla_type: str = ( # Unique VLA ID (specifies config)
VLARegistry.FREEZE_DINOSIGLIP_384PX_MX_BRIDGE_RT1.vla_id
VLARegistry.LLAVA_REPRO_MX_BRIDGE.vla_id
)

# Updated Paths for Data / Runs (on Sagemaker Volume)
Expand Down
6 changes: 5 additions & 1 deletion vla-scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def train(cfg: TrainConfig) -> None:

# Configure Unique Run Name & Save Directory
vla_id = cfg.vla.vla_id
cfg.run_id = f"{vla_id}+x{cfg.seed}" if cfg.run_id is None else cfg.run_id
cfg.run_id = (
f"{vla_id}+n{cfg.vla.expected_world_size // 8}+b{cfg.per_device_batch_size}+x{cfg.seed}"
if cfg.run_id is None
else cfg.run_id
)

# Start =>> Build Directories and Set Randomness
overwatch.info('"Do or do not; there is no try."', ctx_level=1)
Expand Down

0 comments on commit 1a40225

Please sign in to comment.