Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Video recording in SOTA examples #2070

Merged
merged 13 commits into from
Apr 23, 2024
Prev Previous commit
Next Next commit
ppo
  • Loading branch information
vmoens committed Apr 22, 2024
commit a68cad63a51cabf1cd469bf38f95c4af0a521af6
1 change: 1 addition & 0 deletions sota-implementations/ppo/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ logger:
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3
video: False

# Optim
optim:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ logger:
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5
video: False

# Optim
optim:
Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand Down Expand Up @@ -104,9 +105,16 @@ def main(cfg: "DictConfig"): # noqa: F821
"group": cfg.logger.group_name,
},
)
logger_video = cfg.logger.video
else:
logger_video = False

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"])
)
test_env.eval()

# Main loop
Expand Down
10 changes: 9 additions & 1 deletion sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand Down Expand Up @@ -96,9 +97,16 @@ def main(cfg: "DictConfig"): # noqa: F821
"group": cfg.logger.group_name,
},
)
logger_video = cfg.logger.video
else:
logger_video = False

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()

# Main loop
Expand Down
12 changes: 11 additions & 1 deletion sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GymEnv,
NoopResetEnv,
ParallelEnv,
RenameTransform,
Resize,
RewardSum,
SignTransform,
Expand All @@ -35,6 +36,8 @@
TanhNormal,
ValueOperator,
)
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
Expand Down Expand Up @@ -64,7 +67,8 @@ def make_parallel_env(env_name, num_envs, device, is_test=False):
device=device,
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
env.append_transform(RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"]))
env.append_transform(ToTensorImage(in_keys=["pixels_int"], out_keys=["pixels"]))
env.append_transform(GrayScale())
env.append_transform(Resize(84, 84))
env.append_transform(CatFrames(N=4, dim=-3))
Expand Down Expand Up @@ -198,6 +202,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
dump_video.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -208,6 +217,7 @@ def eval_model(actor, test_env, num_episodes=3):
break_when_any_done=True,
max_steps=10_000_000,
)
test_env.apply(dump_video)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards.append(reward.cpu())
del td_test
Expand Down
12 changes: 10 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False):
env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
env = TransformedEnv(env)
env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
Expand Down Expand Up @@ -126,6 +128,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
dump_video.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -138,5 +145,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards.append(reward.cpu())
test_env.apply(dump_video)
del td_test
return torch.cat(test_rewards, 0).mean()
Loading