Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Video recording in SOTA examples #2070

Merged
merged 13 commits into from
Apr 23, 2024
Prev Previous commit
Next Next commit
sac
  • Loading branch information
vmoens committed Apr 23, 2024
commit de42c426964bb606a7d713de1a149e681b73a9cf
2 changes: 1 addition & 1 deletion sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def make_ppo_models(env_name):

def dump_video(module):
if isinstance(module, VideoRecorder):
dump_video.dump()
module.dump()


def eval_model(actor, test_env, num_episodes=3):
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def make_ppo_models(env_name):

def dump_video(module):
if isinstance(module, VideoRecorder):
dump_video.dump()
module.dump()


def eval_model(actor, test_env, num_episodes=3):
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ logger:
project_name: torchrl_example_redq
group_name: null
exp_name: cheetah
record_video: 0
record_interval: 10
record_frames: 10000
mode: online
recorder_log_keys:
video: False

optim:
optimizer: adam
Expand Down
32 changes: 17 additions & 15 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,20 @@ def main(cfg: "DictConfig"): # noqa: F821
]
)

logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="redq_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
video_tag = exp_name if cfg.logger.record_video else ""
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="redq_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)
else:
logger = ""

key, init_env_steps, stats = None, None, None
if not cfg.env.vecnorm and cfg.env.norm_stats:
Expand Down Expand Up @@ -146,7 +148,7 @@ def main(cfg: "DictConfig"): # noqa: F821

recorder = transformed_env_constructor(
cfg,
video_tag=video_tag,
video_tag="rendering/test",
norm_obs_only=True,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
Expand All @@ -162,8 +164,8 @@ def main(cfg: "DictConfig"): # noqa: F821
recorder.transform = create_env_fn.transform.clone()
else:
raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
if logger is not None and video_tag:
recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag))
if logger is not None and cfg.logger.video:
recorder.insert_transform(0, VideoRecorder(logger=logger, tag="rendering/test"))

# reset reward scaling
for t in recorder.transform:
Expand Down
3 changes: 1 addition & 2 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def make_trainer(
rb_trainer = ReplayBufferTrainer(
replay_buffer,
cfg.buffer.batch_size,
flatten_tensordicts=False,
flatten_tensordicts=True,
memmap=False,
device=device,
)
Expand Down Expand Up @@ -1044,7 +1044,6 @@ def make_replay_buffer(
storage=LazyMemmapStorage(
cfg.buffer.size,
scratch_dir=cfg.buffer.scratch_dir,
# device=device, # when using prefetch, this can overload the GPU memory
),
sampler=sampler,
pin_memory=device != torch.device("cpu"),
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ network:
activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:

# logging
logger:
Expand All @@ -50,3 +50,4 @@ logger:
exp_name: ${env.name}_SAC
mode: online
eval_iter: 25000
video: False
12 changes: 10 additions & 2 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
log_metrics,
make_collector,
make_environment,
Expand All @@ -36,7 +37,13 @@

@hydra.main(version_base="1.1", config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)
device = cfg.network.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
device = torch.device(device)

# Create logger
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
Expand All @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg)
train_env, eval_env = make_environment(cfg, logger=logger)

# Create agent
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)
Expand Down Expand Up @@ -198,6 +205,7 @@ def main(cfg: "DictConfig"): # noqa: F821
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
Expand Down
35 changes: 25 additions & 10 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools

import torch
from tensordict.nn import InteractionType, TensorDictModule
Expand All @@ -26,23 +27,28 @@
from torchrl.modules.distributions import TanhNormal
from torchrl.objectives import SoftUpdate
from torchrl.objectives.sac import SACLoss
from torchrl.record import VideoRecorder


# ====================================================================
# Environment utils
# -----------------


def env_maker(cfg, device="cpu"):
def env_maker(cfg, device="cpu", from_pixels=False):
lib = cfg.env.library
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
cfg.env.name,
device=device,
from_pixels=from_pixels,
pixels_only=False,
)
elif lib == "dm_control":
env = DMControlEnv(cfg.env.name, cfg.env.task)
env = DMControlEnv(
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
)
return TransformedEnv(
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
)
Expand All @@ -63,24 +69,31 @@ def apply_env_transforms(env, max_episode_steps=1000):
return transformed_env


def make_environment(cfg):
def make_environment(cfg, logger=None):
"""Make environments for training and evaluation."""
partial = functools.partial(env_maker, cfg=cfg)
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(partial),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)

partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video)
trsf_clone = train_env.transform.clone()
if cfg.logger.video:
trsf_clone.insert(
0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(partial),
serial_for_single=True,
),
train_env.transform.clone(),
trsf_clone,
)
return train_env, eval_env

Expand Down Expand Up @@ -211,13 +224,10 @@ def make_sac_agent(cfg, train_env, eval_env, device):

# init nets
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = eval_env.reset()
td = eval_env.fake_tensordict()
td = td.to(device)
for net in model:
net(td)
del td
eval_env.close()

return model, model[0]


Expand Down Expand Up @@ -298,3 +308,8 @@ def get_activation(cfg):
return nn.LeakyReLU
else:
raise NotImplementedError


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
Loading