Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Video recording in SOTA examples #2070

Merged
merged 13 commits into from
Apr 23, 2024
Next Next commit
init
  • Loading branch information
vmoens committed Apr 9, 2024
commit 4250d4518a82909ba77c89f66da2f1f72c706793
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ And it is `functorch` and `torch.compile` compatible!
```
</details>

Check our [distributed collector examples](sota-implementations/distributed/collectors) to
Check our [distributed collector examples](examples/distributed/collectors) to
learn more about ultra-fast data collection with TorchRL.

- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
Expand Down
9 changes: 9 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand Down Expand Up @@ -104,6 +105,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down
18 changes: 14 additions & 4 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand Down Expand Up @@ -89,7 +90,15 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
0,
VideoRecorder(
logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
),
)
test_env.eval()

# Main loop
Expand Down Expand Up @@ -178,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
i * frames_in_batch
) // cfg.logger.test_interval:
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
if prev_test_frame < cur_test_frame or final:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ logger:
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3
video: False

# Optim
optim:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ logger:
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5
video: False

# Optim
optim:
Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
TanhNormal,
ValueOperator,
)
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
Expand Down Expand Up @@ -201,6 +203,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -213,5 +220,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
16 changes: 14 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
def make_env(
env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False
):
env = GymEnv(
env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only
)
env = TransformedEnv(env)
env.append_transform(RewardSum())
env.append_transform(StepCounter())
Expand Down Expand Up @@ -125,6 +131,11 @@ def make_ppo_models(env_name):
# --------------------------------------------------------------------


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()


def eval_model(actor, test_env, num_episodes=3):
test_rewards = []
for _ in range(num_episodes):
Expand All @@ -137,5 +148,6 @@ def eval_model(actor, test_env, num_episodes=3):
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
test_env.apply(dump_video)
del td_test
return test_rewards.mean()
15 changes: 13 additions & 2 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_continuous_cql_optimizer,
make_continuous_loss,
Expand Down Expand Up @@ -49,16 +50,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create env
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)
train_env, eval_env = make_environment(
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
Expand Down Expand Up @@ -144,6 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

Expand Down
23 changes: 18 additions & 5 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_collector,
make_continuous_cql_optimizer,
Expand Down Expand Up @@ -54,13 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create env
train_env, eval_env = make_environment(
cfg,
cfg.env.train_num_envs,
cfg.env.eval_num_envs,
logger=logger,
)

# Create replay buffer
Expand Down Expand Up @@ -99,12 +107,12 @@ def main(cfg: "DictConfig"): # noqa: F821
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj
evaluation_interval = cfg.logger.log_interval
eval_rollout_steps = cfg.logger.eval_steps

sampling_start = time.time()
for tensordict in collector:
for i, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())
# update weights of the inference policy
Expand Down Expand Up @@ -191,7 +199,11 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:

prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
Expand All @@ -202,6 +214,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
eval_env.apply(dump_video)
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time

Expand Down
4 changes: 3 additions & 1 deletion sota-implementations/cql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ logger:
project_name: torchrl_example_cql
group_name: null
exp_name: cql_${replay_buffer.dataset}
# eval iter in gradient steps
eval_iter: 5000
eval_steps: 1000
mode: online
eval_envs: 5
video: False

# replay buffer
replay_buffer:
Expand All @@ -25,7 +27,7 @@ replay_buffer:

# optimization
optim:
device: cuda:0
device: null
actor_lr: 3e-4
critic_lr: 3e-4
weight_decay: 0.0
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ logger:
group_name: null
exp_name: cql_${env.name}
log_interval: 5000 # record interval in frames
eval_steps: 1000
mode: online
eval_iter: 1000
eval_steps: 1000
video: False

# Buffer
replay_buffer:
Expand All @@ -39,7 +39,7 @@ replay_buffer:
# Optimization
optim:
utd_ratio: 1
device: cuda:0
device: null
actor_lr: 3e-4
critic_lr: 3e-4
weight_decay: 0.0
Expand Down
30 changes: 23 additions & 7 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools

import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule, TensorDictSequential
Expand Down Expand Up @@ -37,6 +39,7 @@
ValueOperator,
)
from torchrl.objectives import CQLLoss, DiscreteCQLLoss, SoftUpdate
from torchrl.record import VideoRecorder

from torchrl.trainers.helpers.models import ACTIVATIONS

Expand All @@ -45,16 +48,17 @@
# -----------------


def env_maker(cfg, device="cpu"):
def env_maker(cfg, device="cpu", from_pixels=False):
lib = cfg.env.backend
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
cfg.env.name,
device=device,
cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
)
elif lib == "dm_control":
env = DMControlEnv(cfg.env.name, cfg.env.task)
env = DMControlEnv(
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
)
return TransformedEnv(
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
)
Expand All @@ -75,25 +79,32 @@ def apply_env_transforms(
return transformed_env


def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):
"""Make environments for training and evaluation."""
maker = functools.partial(env_maker, cfg)
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(maker),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

train_env = apply_env_transforms(parallel_env)

maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
eval_env = TransformedEnv(
ParallelEnv(
eval_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(maker),
serial_for_single=True,
),
train_env.transform.clone(),
)
eval_env.set_seed(0)
if cfg.logger.video:
eval_env = eval_env.insert_transform(
0, VideoRecorder(logger=logger, tag="rendered", in_keys=["pixels"])
)
return train_env, eval_env


Expand Down Expand Up @@ -373,3 +384,8 @@ def log_metrics(logger, metrics, step):
if logger is not None:
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step)


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
Loading
Loading