Skip to content

Commit

Permalink
Merge pull request #63 from openvla/bridgev2_evals
Browse files Browse the repository at this point in the history
Add BridgeData V2 eval script and instructions
  • Loading branch information
moojink authored Aug 14, 2024
2 parents ed958be + b11518c commit 7be7a1b
Show file tree
Hide file tree
Showing 7 changed files with 785 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ dmypy.json
# Caches and Datasets
cache/
data/

# Rollout videos and wandb logs
rollouts/
wandb/
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,68 @@ AttributeError: 'DLataset' object has no attribute 'traj_map'. Did you mean: 'fl

---

## Evaluating OpenVLA

### BridgeData V2 WidowX Evaluations

#### Setup

Clone the [BridgeData V2 WidowX controller repo](https://github.com/rail-berkeley/bridge_data_robot) and install the `widowx_envs` package:

```bash
git clone https://github.com/rail-berkeley/bridge_data_robot.git
cd bridge_data_robot
pip install -e widowx_envs
```

Additionally, install the [`edgeml`](https://github.com/youliangtan/edgeml) library:
```bash
git clone https://github.com/youliangtan/edgeml.git
cd edgeml
pip install -e .
```

Follow the instructions in the `bridge_data_robot` README to create the Bridge WidowX Docker container.

#### Launching BridgeData V2 Evaluations

There are multiple ways to run BridgeData V2 evaluations. We describe the server-client method below.

In one Terminal window (e.g., in tmux), start the WidowX Docker container:

```bash
cd bridge_data_robot
./generate_usb_config.sh
USB_CONNECTOR_CHART=$(pwd)/usb_connector_chart.yml docker compose up --build robonet
```

In a second Terminal window, run the WidowX robot server:

```bash
cd bridge_data_robot
docker compose exec robonet bash -lic "widowx_env_service --server"
```

In a third Terminal window, run the OpenVLA policy evaluation script:

```bash
cd openvla
python experiments/robot/bridge/run_bridgev2_eval.py \
--model_family openvla \
--pretrained_checkpoint openvla/openvla-7b
```

If you run into any problems with evaluations, please file a GitHub Issue.

---

## Repository Structure

High-level overview of repository/project file-tree:

+ `prismatic` - Package source; provides core utilities for model loading, training, data preprocessing, etc.
+ `vla-scripts/` - Core scripts for training, fine-tuning, and deploying VLAs.
+ `experiments/` - Code for evaluating OpenVLA policies in robot environments.
+ `LICENSE` - All code is made available under the MIT License; happy hacking!
+ `Makefile` - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
+ `pyproject.toml` - Full project configuration details (including dependencies), as well as tool configurations.
Expand Down
133 changes: 133 additions & 0 deletions experiments/robot/bridge/bridgev2_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Utils for evaluating policies in real-world BridgeData V2 environments."""

import os
import sys
import time

import imageio
import numpy as np
import tensorflow as tf
import torch
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs

sys.path.append(".")
from experiments.robot.bridge.widowx_env import WidowXGym

# Initialize important constants and pretty-printing mode in NumPy.
ACTION_DIM = 7
BRIDGE_PROPRIO_DIM = 7
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
np.set_printoptions(formatter={"float": lambda x: "{0:0.2f}".format(x)})


def get_widowx_env_params(cfg):
"""Gets (mostly default) environment parameters for the WidowX environment."""
env_params = WidowXConfigs.DefaultEnvParams.copy()
env_params["override_workspace_boundaries"] = cfg.bounds
env_params["camera_topics"] = cfg.camera_topics
env_params["return_full_image"] = True
return env_params


def get_widowx_env(cfg, model=None):
"""Get WidowX control environment."""
# Set up the WidowX environment parameters
env_params = get_widowx_env_params(cfg)
start_state = np.concatenate([cfg.init_ee_pos, cfg.init_ee_quat])
env_params["start_state"] = list(start_state)
# Set up the WidowX client
widowx_client = WidowXClient(host=cfg.host_ip, port=cfg.port)
widowx_client.init(env_params)
env = WidowXGym(
widowx_client,
cfg=cfg,
blocking=cfg.blocking,
)
return env


def get_next_task_label(task_label):
"""Prompt the user to input the next task."""
if task_label == "":
user_input = ""
while user_input == "":
user_input = input("Enter the task name: ")
task_label = user_input
else:
user_input = input("Enter the task name (or leave blank to repeat the previous task): ")
if user_input == "":
pass # Do nothing -> Let task_label be the same
else:
task_label = user_input
print(f"Task: {task_label}")
return task_label


def save_rollout_video(rollout_images, idx):
"""Saves an MP4 replay of an episode."""
os.makedirs("./rollouts", exist_ok=True)
mp4_path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.mp4"
video_writer = imageio.get_writer(mp4_path, fps=5)
for img in rollout_images:
video_writer.append_data(img)
video_writer.close()
print(f"Saved rollout MP4 at path {mp4_path}")


def save_rollout_data(rollout_orig_images, rollout_images, rollout_states, rollout_actions, idx):
"""
Saves rollout data from an episode.
Args:
rollout_orig_images (list): Original rollout images (before preprocessing).
rollout_images (list): Preprocessed images.
rollout_states (list): Proprioceptive states.
rollout_actions (list): Predicted actions.
idx (int): Episode index.
"""
os.makedirs("./rollouts", exist_ok=True)
path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.npz"
# Convert lists to numpy arrays
orig_images_array = np.array(rollout_orig_images)
images_array = np.array(rollout_images)
states_array = np.array(rollout_states)
actions_array = np.array(rollout_actions)
# Save to a single .npz file
np.savez(path, orig_images=orig_images_array, images=images_array, states=states_array, actions=actions_array)
print(f"Saved rollout data at path {path}")


def resize_image(img, resize_size):
"""
Takes numpy array corresponding to a single image and returns resized image as numpy array.
NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
"""
assert isinstance(resize_size, tuple)
# Resize to image size expected by model
img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back
img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
img = img.numpy()
return img


def get_preprocessed_image(obs, resize_size):
"""Extracts image from observations and preprocesses it."""
assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
if isinstance(resize_size, int):
resize_size = (resize_size, resize_size)
obs["full_image"] = resize_image(obs["full_image"], resize_size)
return obs["full_image"]


def refresh_obs(obs, env):
"""Fetches new observations from the environment and updates the current observations."""
new_obs = env.get_observation()
obs["full_image"] = new_obs["full_image"]
obs["image_primary"] = new_obs["image_primary"]
obs["proprio"] = new_obs["proprio"]
return obs
183 changes: 183 additions & 0 deletions experiments/robot/bridge/run_bridgev2_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
run_bridge_eval.py
Runs a model in a real-world Bridge V2 environment.
Usage:
# OpenVLA:
python experiments/robot/bridge/run_bridge_eval.py --model_family openvla --pretrained_checkpoint openvla/openvla-7b
"""

import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Union

import draccus

# Append current directory so that interpreter can find experiments.robot
sys.path.append(".")
from experiments.robot.bridge.bridgev2_utils import (
get_next_task_label,
get_preprocessed_image,
get_widowx_env,
refresh_obs,
save_rollout_data,
save_rollout_video,
)
from experiments.robot.openvla_utils import get_processor
from experiments.robot.robot_utils import (
get_action,
get_image_resize_size,
get_model,
)


@dataclass
class GenerateConfig:
# fmt: off

#################################################################################################################
# Model-specific parameters
#################################################################################################################
model_family: str = "openvla" # Model family
pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization

center_crop: bool = False # Center crop? (if trained w/ random crop image aug)

#################################################################################################################
# WidowX environment-specific parameters
#################################################################################################################
host_ip: str = "localhost"
port: int = 5556

# Note: Setting initial orientation with a 30 degree offset, which makes the robot appear more natural
init_ee_pos: List[float] = field(default_factory=lambda: [0.3, -0.09, 0.26])
init_ee_quat: List[float] = field(default_factory=lambda: [0, -0.259, 0, -0.966])
bounds: List[List[float]] = field(default_factory=lambda: [
[0.1, -0.20, -0.01, -1.57, 0],
[0.45, 0.25, 0.30, 1.57, 0],
]
)

camera_topics: List[Dict[str, str]] = field(default_factory=lambda: [{"name": "/blue/image_raw"}])

blocking: bool = False # Whether to use blocking control
max_episodes: int = 50 # Max number of episodes to run
max_steps: int = 60 # Max number of timesteps per episode
control_frequency: float = 5 # WidowX control frequency

#################################################################################################################
# Utils
#################################################################################################################
save_data: bool = False # Whether to save rollout data (images, actions, etc.)

# fmt: on


@draccus.wrap()
def eval_model_in_bridge_env(cfg: GenerateConfig) -> None:
assert cfg.pretrained_checkpoint is not None, "cfg.pretrained_checkpoint must not be None!"
assert not cfg.center_crop, "`center_crop` should be disabled for Bridge evaluations!"

# [OpenVLA] Set action un-normalization key
cfg.unnorm_key = "bridge_orig"

# Load model
model = get_model(cfg)

# [OpenVLA] Get Hugging Face processor
processor = None
if cfg.model_family == "openvla":
processor = get_processor(cfg)

# Initialize the WidowX environment
env = get_widowx_env(cfg, model)

# Get expected image dimensions
resize_size = get_image_resize_size(cfg)

# Start evaluation
task_label = ""
episode_idx = 0
while episode_idx < cfg.max_episodes:
# Get task description from user
task_label = get_next_task_label(task_label)

# Reset environment
obs, _ = env.reset()

# Setup
t = 0
step_duration = 1.0 / cfg.control_frequency
replay_images = []
if cfg.save_data:
rollout_images = []
rollout_states = []
rollout_actions = []

# Start episode
input(f"Press Enter to start episode {episode_idx+1}...")
print("Starting episode... Press Ctrl-C to terminate episode early!")
last_tstamp = time.time()
while t < cfg.max_steps:
try:
curr_tstamp = time.time()
if curr_tstamp > last_tstamp + step_duration:
print(f"t: {t}")
print(f"Previous step elapsed time (sec): {curr_tstamp - last_tstamp:.2f}")
last_tstamp = time.time()

# Refresh the camera image and proprioceptive state
obs = refresh_obs(obs, env)

# Save full (not preprocessed) image for replay video
replay_images.append(obs["full_image"])

# Get preprocessed image
obs["full_image"] = get_preprocessed_image(obs, resize_size)

# Query model to get action
action = get_action(
cfg,
model,
obs,
task_label,
processor=processor,
)

# [If saving rollout data] Save preprocessed image, robot state, and action
if cfg.save_data:
rollout_images.append(obs["full_image"])
rollout_states.append(obs["proprio"])
rollout_actions.append(action)

# Execute action
print("action:", action)
obs, _, _, _, _ = env.step(action)
t += 1

except (KeyboardInterrupt, Exception) as e:
if isinstance(e, KeyboardInterrupt):
print("\nCaught KeyboardInterrupt: Terminating episode early.")
else:
print(f"\nCaught exception: {e}")
break

# Save a replay video of the episode
save_rollout_video(replay_images, episode_idx)

# [If saving rollout data] Save rollout data
if cfg.save_data:
save_rollout_data(replay_images, rollout_images, rollout_states, rollout_actions, idx=episode_idx)

# Redo episode or continue
if input("Enter 'r' if you want to redo the episode, or press Enter to continue: ") != "r":
episode_idx += 1


if __name__ == "__main__":
eval_model_in_bridge_env()
Loading

0 comments on commit 7be7a1b

Please sign in to comment.