Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Video recording in SOTA examples #2070

Merged
merged 13 commits into from
Apr 23, 2024
32 changes: 0 additions & 32 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
env.backend=gymnasium \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
Expand Down Expand Up @@ -86,8 +82,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -112,7 +106,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -122,7 +115,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \
Expand All @@ -131,7 +123,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
Expand All @@ -143,22 +134,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sa
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand All @@ -185,9 +172,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
env.name=Pendulum-v1 \
logger.backend=
Expand All @@ -196,26 +180,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
collector.device=cuda:0 \
optim.device=cuda:0 \
logger.mode=offline \
logger.backend=

Expand All @@ -238,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -251,7 +227,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -262,7 +237,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
buffer.batch_size=10 \
collector.device=cuda:0 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.record_frames=4 \
Expand All @@ -274,29 +248,23 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
collector.frames_per_batch=16 \
env.train_num_envs=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
logger.mode=offline \
optim.batch_size=10 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \
collector.n_iters=2 \
Expand Down
9 changes: 9 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand Down Expand Up @@ -104,6 +105,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down
18 changes: 14 additions & 4 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand Down Expand Up @@ -89,7 +90,15 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down Expand Up @@ -178,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
i * frames_in_batch
) // cfg.logger.test_interval:
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
if prev_test_frame < cur_test_frame or final:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ logger:
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3
video: False

# Optim
optim:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ logger:
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5
video: False

# Optim
optim:
Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
TanhNormal,
ValueOperator,
)
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
Expand Down Expand Up @@ -201,6 +203,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -213,5 +220,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
16 changes: 14 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
def make_env(
env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False
):
env = GymEnv(
env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only
)
env = TransformedEnv(env)
env.append_transform(RewardSum())
env.append_transform(StepCounter())
Expand Down Expand Up @@ -125,6 +131,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -137,5 +148,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
15 changes: 13 additions & 2 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_continuous_cql_optimizer,
make_continuous_loss,
Expand Down Expand Up @@ -49,16 +50,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create env
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)
train_env, eval_env = make_environment(
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
Expand Down Expand Up @@ -144,6 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

Expand Down
Loading
Loading