Skip to content

Commit

Permalink
Merge pull request openvla#1 from siddk-tri/debug-sagemaker
Browse files Browse the repository at this point in the history
Update VLA Sagemaker Image
  • Loading branch information
siddk authored Mar 15, 2024
2 parents 20732d2 + 6f2f8e8 commit 4a8ad67
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
9 changes: 6 additions & 3 deletions scripts/sagemaker/vlm-training.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# ===
# Prismatic VLM Sagemaker Dockerfile
# => Base Image :: Python 3.10 & Pytorch 2.1.0
# => Base Image :: Python 3.10 & Pytorch 2.2.0
# ===
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker

# Sane Defaults
RUN apt-get update
Expand All @@ -29,9 +29,12 @@ RUN apt-get update && apt-get install -y \
libsdl2-2.0-0 \
python-pygame

# IMPORTANT :: Uninstall & Reinstall Torch (Sagemaker CPU Core Bug)
RUN pip install --upgrade pip
RUN pip uninstall -y torch
RUN pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121

# Install Prismatic Python Dependencies (`pip`) + Sagemaker
RUN pip install --upgrade pip
RUN pip install \
accelerate>=0.25.0 \
draccus@git+https://github.com/dlwh/draccus \
Expand Down
8 changes: 3 additions & 5 deletions vla-scripts/sagemaker/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LaunchConfig:

# OpenVLA Training Parameters
vla_type: str = ( # Unique VLA ID (specifies config)
VLARegistry.LLAVA_REPRO_MX_BRIDGE.vla_id
VLARegistry.LR_2E5_SIGLIP_224PX_ICY_MX_BRIDGE.vla_id
)

# Updated Paths for Data / Runs (on Sagemaker Volume)
Expand Down Expand Up @@ -96,7 +96,6 @@ def launch(cfg: LaunchConfig) -> None:
base_job_name=cfg.job_name,
instance_count=cfg.instance_count,
instance_type=cfg.instance_type if not cfg.debug else "local_gpu",
volume_size=100,
entry_point=cfg.entry_point,
image_uri=cfg.image_uri,
hyperparameters=hyperparameters,
Expand All @@ -109,13 +108,12 @@ def launch(cfg: LaunchConfig) -> None:
sagemaker_session=sagemaker_session,
subnets=SUBNETS,
security_group_ids=SECURITY_GROUP_IDS,
checkpoint_s3_uri=S3_LOG_PATH,
output_path=S3_LOG_PATH,
keep_alive_period_in_seconds=3600,
max_run=60 * 60 * 24 * cfg.max_days,
distribution={"torch_distributed": {"enabled": True}},
disable_profiler=True,
)
estimator.fit(inputs={"training": train_fs})
estimator.fit(inputs={"training": train_fs if not cfg.debug else "file:///mnt/fsx/"})


if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions vla-scripts/sagemaker/vla-training.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# ===
# OpenVLA Sagemaker Dockerfile
# => Base Image :: Python 3.10 & Pytorch 2.1.0
# => Base Image :: Python 3.10 & Pytorch 2.2.0
# ===
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker

# Sane Defaults
RUN apt-get update
Expand Down Expand Up @@ -30,8 +30,12 @@ RUN apt-get update && apt-get install -y \
python-pygame


# Install Prismatic + VLA Python Dependencies (`pip`) + Sagemaker
# IMPORTANT :: Uninstall & Reinstall Torch (Sagemaker CPU Core Bug)
RUN pip install --upgrade pip
RUN pip uninstall -y torch
RUN pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121

# Install Prismatic + VLA Python Dependencies (`pip`) + Sagemaker
RUN pip install \
accelerate>=0.25.0 \
draccus@git+https://github.com/dlwh/draccus \
Expand Down
2 changes: 1 addition & 1 deletion vla-scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TrainConfig:

# VLAConfig (`prismatic/conf/vla.py`); override with --vla.type `VLARegistry.<VLA>.vla_id`
vla: VLAConfig = field(
default_factory=VLAConfig.get_choice_class(VLARegistry.FREEZE_SIGLIP_224PX_MX_BRIDGE.vla_id)
default_factory=VLAConfig.get_choice_class(VLARegistry.LLAVA_REPRO_MX_BRIDGE.vla_id)
)

# Directory Paths
Expand Down

0 comments on commit 4a8ad67

Please sign in to comment.