Skip to content

Commit

Permalink
Add configs/transforms for modified LIBERO datasets
Browse files Browse the repository at this point in the history
Also add quick manual check for "_no_noops" unnorm_key in run_libero_eval.py.
  • Loading branch information
moojink committed Sep 6, 2024
1 parent d00667d commit 7a359de
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 0 deletions.
4 changes: 4 additions & 0 deletions experiments/robot/libero/run_libero_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def eval_libero(cfg: GenerateConfig) -> None:

# [OpenVLA] Check that the model contains the action un-normalization key
if cfg.model_family == "openvla":
# In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
# with the suffix "_no_noops" in the dataset name)
if cfg.unnorm_key not in model.norm_stats and f"{cfg.unnorm_key}_no_noops" in model.norm_stats:
cfg.unnorm_key = f"{cfg.unnorm_key}_no_noops"
assert cfg.unnorm_key in model.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"

# [OpenVLA] Get Hugging Face processor
Expand Down
29 changes: 29 additions & 0 deletions prismatic/vla/datasets/rlds/oxe/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,4 +641,33 @@ class ActionEncoding(IntEnum):
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"libero_object_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"libero_goal_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"libero_10_no_noops": {
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
}
14 changes: 14 additions & 0 deletions prismatic/vla/datasets/rlds/oxe/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,19 @@
"droid_wipe": [
("droid_wipe", 1.0),
],

# === LIBERO Datasets (Modified Versions) ===
"libero_spatial_no_noops": [
("libero_spatial_no_noops", 1.0),
],
"libero_object_no_noops": [
("libero_object_no_noops", 1.0),
],
"libero_goal_no_noops": [
("libero_goal_no_noops", 1.0),
],
"libero_10_no_noops": [
("libero_10_no_noops", 1.0),
],
}
# fmt: on
22 changes: 22 additions & 0 deletions prismatic/vla/datasets/rlds/oxe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,23 @@ def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory


def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))

trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
gripper_action,
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
return trajectory


# === Registry ===
OXE_STANDARDIZATION_TRANSFORMS = {
"bridge_oxe": bridge_oxe_dataset_transform,
Expand Down Expand Up @@ -897,4 +914,9 @@ def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
### DROID Finetuning datasets
"droid_wipe": droid_finetuning_transform,
### LIBERO datasets (modified versions)
"libero_spatial_no_noops": libero_dataset_transform,
"libero_object_no_noops": libero_dataset_transform,
"libero_goal_no_noops": libero_dataset_transform,
"libero_10_no_noops": libero_dataset_transform,
}

0 comments on commit 7a359de

Please sign in to comment.