forked from TRI-ML/prismatic-vlms
-
Notifications
You must be signed in to change notification settings - Fork 235
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from openvla/bridgev2_evals
Add BridgeData V2 eval script and instructions
- Loading branch information
Showing
7 changed files
with
785 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,3 +147,7 @@ dmypy.json | |
# Caches and Datasets | ||
cache/ | ||
data/ | ||
|
||
# Rollout videos and wandb logs | ||
rollouts/ | ||
wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Utils for evaluating policies in real-world BridgeData V2 environments.""" | ||
|
||
import os | ||
import sys | ||
import time | ||
|
||
import imageio | ||
import numpy as np | ||
import tensorflow as tf | ||
import torch | ||
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs | ||
|
||
sys.path.append(".") | ||
from experiments.robot.bridge.widowx_env import WidowXGym | ||
|
||
# Initialize important constants and pretty-printing mode in NumPy. | ||
ACTION_DIM = 7 | ||
BRIDGE_PROPRIO_DIM = 7 | ||
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") | ||
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | ||
np.set_printoptions(formatter={"float": lambda x: "{0:0.2f}".format(x)}) | ||
|
||
|
||
def get_widowx_env_params(cfg): | ||
"""Gets (mostly default) environment parameters for the WidowX environment.""" | ||
env_params = WidowXConfigs.DefaultEnvParams.copy() | ||
env_params["override_workspace_boundaries"] = cfg.bounds | ||
env_params["camera_topics"] = cfg.camera_topics | ||
env_params["return_full_image"] = True | ||
return env_params | ||
|
||
|
||
def get_widowx_env(cfg, model=None): | ||
"""Get WidowX control environment.""" | ||
# Set up the WidowX environment parameters | ||
env_params = get_widowx_env_params(cfg) | ||
start_state = np.concatenate([cfg.init_ee_pos, cfg.init_ee_quat]) | ||
env_params["start_state"] = list(start_state) | ||
# Set up the WidowX client | ||
widowx_client = WidowXClient(host=cfg.host_ip, port=cfg.port) | ||
widowx_client.init(env_params) | ||
env = WidowXGym( | ||
widowx_client, | ||
cfg=cfg, | ||
blocking=cfg.blocking, | ||
) | ||
return env | ||
|
||
|
||
def get_next_task_label(task_label): | ||
"""Prompt the user to input the next task.""" | ||
if task_label == "": | ||
user_input = "" | ||
while user_input == "": | ||
user_input = input("Enter the task name: ") | ||
task_label = user_input | ||
else: | ||
user_input = input("Enter the task name (or leave blank to repeat the previous task): ") | ||
if user_input == "": | ||
pass # Do nothing -> Let task_label be the same | ||
else: | ||
task_label = user_input | ||
print(f"Task: {task_label}") | ||
return task_label | ||
|
||
|
||
def save_rollout_video(rollout_images, idx): | ||
"""Saves an MP4 replay of an episode.""" | ||
os.makedirs("./rollouts", exist_ok=True) | ||
mp4_path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.mp4" | ||
video_writer = imageio.get_writer(mp4_path, fps=5) | ||
for img in rollout_images: | ||
video_writer.append_data(img) | ||
video_writer.close() | ||
print(f"Saved rollout MP4 at path {mp4_path}") | ||
|
||
|
||
def save_rollout_data(rollout_orig_images, rollout_images, rollout_states, rollout_actions, idx): | ||
""" | ||
Saves rollout data from an episode. | ||
Args: | ||
rollout_orig_images (list): Original rollout images (before preprocessing). | ||
rollout_images (list): Preprocessed images. | ||
rollout_states (list): Proprioceptive states. | ||
rollout_actions (list): Predicted actions. | ||
idx (int): Episode index. | ||
""" | ||
os.makedirs("./rollouts", exist_ok=True) | ||
path = f"./rollouts/rollout-{DATE_TIME}-{idx+1}.npz" | ||
# Convert lists to numpy arrays | ||
orig_images_array = np.array(rollout_orig_images) | ||
images_array = np.array(rollout_images) | ||
states_array = np.array(rollout_states) | ||
actions_array = np.array(rollout_actions) | ||
# Save to a single .npz file | ||
np.savez(path, orig_images=orig_images_array, images=images_array, states=states_array, actions=actions_array) | ||
print(f"Saved rollout data at path {path}") | ||
|
||
|
||
def resize_image(img, resize_size): | ||
""" | ||
Takes numpy array corresponding to a single image and returns resized image as numpy array. | ||
NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow | ||
the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training. | ||
""" | ||
assert isinstance(resize_size, tuple) | ||
# Resize to image size expected by model | ||
img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder | ||
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back | ||
img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) | ||
img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) | ||
img = img.numpy() | ||
return img | ||
|
||
|
||
def get_preprocessed_image(obs, resize_size): | ||
"""Extracts image from observations and preprocesses it.""" | ||
assert isinstance(resize_size, int) or isinstance(resize_size, tuple) | ||
if isinstance(resize_size, int): | ||
resize_size = (resize_size, resize_size) | ||
obs["full_image"] = resize_image(obs["full_image"], resize_size) | ||
return obs["full_image"] | ||
|
||
|
||
def refresh_obs(obs, env): | ||
"""Fetches new observations from the environment and updates the current observations.""" | ||
new_obs = env.get_observation() | ||
obs["full_image"] = new_obs["full_image"] | ||
obs["image_primary"] = new_obs["image_primary"] | ||
obs["proprio"] = new_obs["proprio"] | ||
return obs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
""" | ||
run_bridge_eval.py | ||
Runs a model in a real-world Bridge V2 environment. | ||
Usage: | ||
# OpenVLA: | ||
python experiments/robot/bridge/run_bridge_eval.py --model_family openvla --pretrained_checkpoint openvla/openvla-7b | ||
""" | ||
|
||
import sys | ||
import time | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Dict, List, Union | ||
|
||
import draccus | ||
|
||
# Append current directory so that interpreter can find experiments.robot | ||
sys.path.append(".") | ||
from experiments.robot.bridge.bridgev2_utils import ( | ||
get_next_task_label, | ||
get_preprocessed_image, | ||
get_widowx_env, | ||
refresh_obs, | ||
save_rollout_data, | ||
save_rollout_video, | ||
) | ||
from experiments.robot.openvla_utils import get_processor | ||
from experiments.robot.robot_utils import ( | ||
get_action, | ||
get_image_resize_size, | ||
get_model, | ||
) | ||
|
||
|
||
@dataclass | ||
class GenerateConfig: | ||
# fmt: off | ||
|
||
################################################################################################################# | ||
# Model-specific parameters | ||
################################################################################################################# | ||
model_family: str = "openvla" # Model family | ||
pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path | ||
load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization | ||
load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization | ||
|
||
center_crop: bool = False # Center crop? (if trained w/ random crop image aug) | ||
|
||
################################################################################################################# | ||
# WidowX environment-specific parameters | ||
################################################################################################################# | ||
host_ip: str = "localhost" | ||
port: int = 5556 | ||
|
||
# Note: Setting initial orientation with a 30 degree offset, which makes the robot appear more natural | ||
init_ee_pos: List[float] = field(default_factory=lambda: [0.3, -0.09, 0.26]) | ||
init_ee_quat: List[float] = field(default_factory=lambda: [0, -0.259, 0, -0.966]) | ||
bounds: List[List[float]] = field(default_factory=lambda: [ | ||
[0.1, -0.20, -0.01, -1.57, 0], | ||
[0.45, 0.25, 0.30, 1.57, 0], | ||
] | ||
) | ||
|
||
camera_topics: List[Dict[str, str]] = field(default_factory=lambda: [{"name": "/blue/image_raw"}]) | ||
|
||
blocking: bool = False # Whether to use blocking control | ||
max_episodes: int = 50 # Max number of episodes to run | ||
max_steps: int = 60 # Max number of timesteps per episode | ||
control_frequency: float = 5 # WidowX control frequency | ||
|
||
################################################################################################################# | ||
# Utils | ||
################################################################################################################# | ||
save_data: bool = False # Whether to save rollout data (images, actions, etc.) | ||
|
||
# fmt: on | ||
|
||
|
||
@draccus.wrap() | ||
def eval_model_in_bridge_env(cfg: GenerateConfig) -> None: | ||
assert cfg.pretrained_checkpoint is not None, "cfg.pretrained_checkpoint must not be None!" | ||
assert not cfg.center_crop, "`center_crop` should be disabled for Bridge evaluations!" | ||
|
||
# [OpenVLA] Set action un-normalization key | ||
cfg.unnorm_key = "bridge_orig" | ||
|
||
# Load model | ||
model = get_model(cfg) | ||
|
||
# [OpenVLA] Get Hugging Face processor | ||
processor = None | ||
if cfg.model_family == "openvla": | ||
processor = get_processor(cfg) | ||
|
||
# Initialize the WidowX environment | ||
env = get_widowx_env(cfg, model) | ||
|
||
# Get expected image dimensions | ||
resize_size = get_image_resize_size(cfg) | ||
|
||
# Start evaluation | ||
task_label = "" | ||
episode_idx = 0 | ||
while episode_idx < cfg.max_episodes: | ||
# Get task description from user | ||
task_label = get_next_task_label(task_label) | ||
|
||
# Reset environment | ||
obs, _ = env.reset() | ||
|
||
# Setup | ||
t = 0 | ||
step_duration = 1.0 / cfg.control_frequency | ||
replay_images = [] | ||
if cfg.save_data: | ||
rollout_images = [] | ||
rollout_states = [] | ||
rollout_actions = [] | ||
|
||
# Start episode | ||
input(f"Press Enter to start episode {episode_idx+1}...") | ||
print("Starting episode... Press Ctrl-C to terminate episode early!") | ||
last_tstamp = time.time() | ||
while t < cfg.max_steps: | ||
try: | ||
curr_tstamp = time.time() | ||
if curr_tstamp > last_tstamp + step_duration: | ||
print(f"t: {t}") | ||
print(f"Previous step elapsed time (sec): {curr_tstamp - last_tstamp:.2f}") | ||
last_tstamp = time.time() | ||
|
||
# Refresh the camera image and proprioceptive state | ||
obs = refresh_obs(obs, env) | ||
|
||
# Save full (not preprocessed) image for replay video | ||
replay_images.append(obs["full_image"]) | ||
|
||
# Get preprocessed image | ||
obs["full_image"] = get_preprocessed_image(obs, resize_size) | ||
|
||
# Query model to get action | ||
action = get_action( | ||
cfg, | ||
model, | ||
obs, | ||
task_label, | ||
processor=processor, | ||
) | ||
|
||
# [If saving rollout data] Save preprocessed image, robot state, and action | ||
if cfg.save_data: | ||
rollout_images.append(obs["full_image"]) | ||
rollout_states.append(obs["proprio"]) | ||
rollout_actions.append(action) | ||
|
||
# Execute action | ||
print("action:", action) | ||
obs, _, _, _, _ = env.step(action) | ||
t += 1 | ||
|
||
except (KeyboardInterrupt, Exception) as e: | ||
if isinstance(e, KeyboardInterrupt): | ||
print("\nCaught KeyboardInterrupt: Terminating episode early.") | ||
else: | ||
print(f"\nCaught exception: {e}") | ||
break | ||
|
||
# Save a replay video of the episode | ||
save_rollout_video(replay_images, episode_idx) | ||
|
||
# [If saving rollout data] Save rollout data | ||
if cfg.save_data: | ||
save_rollout_data(replay_images, rollout_images, rollout_states, rollout_actions, idx=episode_idx) | ||
|
||
# Redo episode or continue | ||
if input("Enter 'r' if you want to redo the episode, or press Enter to continue: ") != "r": | ||
episode_idx += 1 | ||
|
||
|
||
if __name__ == "__main__": | ||
eval_model_in_bridge_env() |
Oops, something went wrong.