Skip to content

Commit

Permalink
[Feature] Video recording in SOTA examples (#2070)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 23, 2024
1 parent df749a3 commit 6c2e141
Show file tree
Hide file tree
Showing 56 changed files with 528 additions and 174 deletions.
32 changes: 0 additions & 32 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
env.backend=gymnasium \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
Expand Down Expand Up @@ -86,8 +82,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -112,7 +106,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -122,7 +115,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \
Expand All @@ -131,7 +123,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
Expand All @@ -143,22 +134,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sa
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand All @@ -185,9 +172,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
env.name=Pendulum-v1 \
logger.backend=
Expand All @@ -196,26 +180,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
collector.device=cuda:0 \
optim.device=cuda:0 \
logger.mode=offline \
logger.backend=

Expand All @@ -238,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -251,7 +227,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -262,7 +237,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
buffer.batch_size=10 \
collector.device=cuda:0 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.record_frames=4 \
Expand All @@ -274,29 +248,23 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
collector.frames_per_batch=16 \
env.train_num_envs=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
logger.mode=offline \
optim.batch_size=10 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \
collector.n_iters=2 \
Expand Down
9 changes: 9 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand Down Expand Up @@ -104,6 +105,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down
18 changes: 14 additions & 4 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand Down Expand Up @@ -89,7 +90,15 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down Expand Up @@ -178,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
i * frames_in_batch
) // cfg.logger.test_interval:
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
if prev_test_frame < cur_test_frame or final:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ logger:
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3
video: False

# Optim
optim:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ logger:
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5
video: False

# Optim
optim:
Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
TanhNormal,
ValueOperator,
)
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
Expand Down Expand Up @@ -201,6 +203,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -213,5 +220,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
16 changes: 14 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
def make_env(
env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False
):
env = GymEnv(
env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only
)
env = TransformedEnv(env)
env.append_transform(RewardSum())
env.append_transform(StepCounter())
Expand Down Expand Up @@ -125,6 +131,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -137,5 +148,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
15 changes: 13 additions & 2 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_continuous_cql_optimizer,
make_continuous_loss,
Expand Down Expand Up @@ -49,16 +50,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create env
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)
train_env, eval_env = make_environment(
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
Expand Down Expand Up @@ -144,6 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

Expand Down
Loading

0 comments on commit 6c2e141

Please sign in to comment.