Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Video recording in SOTA examples #2070

Merged
merged 13 commits into from
Apr 23, 2024
Prev Previous commit
Next Next commit
iql
  • Loading branch information
vmoens committed Apr 9, 2024
commit 573b1ed7475a9096cf169369c831c8364caa53b1
11 changes: 10 additions & 1 deletion sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_collector,
make_discrete_iql_model,
Expand Down Expand Up @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create environments
train_env, eval_env = make_environment(
cfg,
cfg.env.train_num_envs,
cfg.env.eval_num_envs,
logger=logger,
)

# Create replay buffer
Expand Down Expand Up @@ -186,6 +194,7 @@ def main(cfg: "DictConfig"): # noqa: F821
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ logger:
eval_steps: 200
mode: online
eval_iter: 1000
video: False

# replay buffer
replay_buffer:
Expand All @@ -38,7 +39,7 @@ replay_buffer:
# optimization
optim:
utd_ratio: 1
device: cuda:0
device: null
lr: 3e-4
weight_decay: 0.0
batch_size: 256
Expand Down
16 changes: 14 additions & 2 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_environment,
make_iql_model,
Expand Down Expand Up @@ -54,10 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Creante env
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)
train_env, eval_env = make_environment(
cfg,
cfg.logger.eval_envs,
logger=logger,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -123,6 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_collector,
make_environment,
Expand Down Expand Up @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)
device = cfg.optim.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Create environments
train_env, eval_env = make_environment(
cfg,
cfg.env.train_num_envs,
cfg.env.eval_num_envs,
logger=logger,
)

# Create replay buffer
Expand Down Expand Up @@ -184,6 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/iql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ logger:
eval_steps: 1000
mode: online
eval_envs: 5
video: False

# replay buffer
replay_buffer:
Expand All @@ -25,7 +26,7 @@ replay_buffer:

# optimization
optim:
device: cuda:0
device: null
lr: 3e-4
weight_decay: 0.0
gradient_steps: 50000
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/iql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ logger:
eval_steps: 200
mode: online
eval_iter: 1000
video: False

# replay buffer
replay_buffer:
Expand All @@ -38,7 +39,7 @@ replay_buffer:
# optimization
optim:
utd_ratio: 1
device: cuda:0
device: null
lr: 3e-4
weight_decay: 0.0
batch_size: 256
Expand Down
29 changes: 22 additions & 7 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools

import torch.nn
import torch.optim
from tensordict.nn import InteractionType, TensorDictModule
Expand Down Expand Up @@ -39,6 +41,7 @@
ValueOperator,
)
from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate
from torchrl.record import VideoRecorder

from torchrl.trainers.helpers.models import ACTIVATIONS

Expand All @@ -48,16 +51,17 @@
# -----------------


def env_maker(cfg, device="cpu"):
def env_maker(cfg, device="cpu", from_pixels=False):
lib = cfg.env.backend
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
cfg.env.name,
device=device,
cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
)
elif lib == "dm_control":
env = DMControlEnv(cfg.env.name, cfg.env.task)
env = DMControlEnv(
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
)
return TransformedEnv(
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
)
Expand All @@ -79,25 +83,31 @@ def apply_env_transforms(
return transformed_env


def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):
"""Make environments for training and evaluation."""
maker = functools.partial(env_maker, cfg)
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda: env_maker(cfg)),
EnvCreator(maker),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

train_env = apply_env_transforms(parallel_env)

maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
eval_env = TransformedEnv(
ParallelEnv(
eval_num_envs,
EnvCreator(lambda: env_maker(cfg)),
EnvCreator(maker),
serial_for_single=True,
),
train_env.transform.clone(),
)
if cfg.logger.video:
eval_env.insert_transform(
0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
)
return train_env, eval_env


Expand Down Expand Up @@ -417,3 +427,8 @@ def log_metrics(logger, metrics, step):
if logger is not None:
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step)


def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
Loading