Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update TD3 Example #1523

Merged
merged 35 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6339a07
update executable
BY571 Sep 6, 2023
9e890b3
fix objective
BY571 Sep 7, 2023
117c477
fix objective
BY571 Sep 7, 2023
d2b3ad4
Update initial frames and general structure
BY571 Sep 12, 2023
9c6c358
fixes
BY571 Sep 12, 2023
1adbff5
Merge branch 'main' into td3_benchmark
BY571 Sep 12, 2023
2422ef8
naming fix
BY571 Sep 12, 2023
0e67de2
single step td3
BY571 Sep 13, 2023
1fc0847
small fixes
BY571 Sep 14, 2023
7a02b83
fix
BY571 Sep 14, 2023
243d712
add update counter
BY571 Sep 14, 2023
af31bd9
naming fixes
BY571 Sep 14, 2023
1122808
update logging and small fixes
BY571 Sep 15, 2023
b4df32b
no eps
BY571 Sep 18, 2023
13f367a
update tests
BY571 Sep 19, 2023
72ddf7e
update objective
BY571 Sep 20, 2023
c830891
set gym backend
BY571 Sep 20, 2023
1a2f08e
Merge branch 'main' into td3_benchmark
vmoens Sep 21, 2023
4cdbb3b
update tests
BY571 Sep 21, 2023
76dcdeb
update fix max episode steps
BY571 Sep 22, 2023
68d4c26
Merge branch 'main' into td3_benchmark
BY571 Sep 26, 2023
ec8b089
fix
BY571 Sep 27, 2023
bcc3bc6
fix
BY571 Sep 27, 2023
42748e0
amend
vmoens Sep 28, 2023
0052cd9
Merge remote-tracking branch 'BY571/td3_benchmark' into td3_benchmark
vmoens Sep 28, 2023
e2c28c8
amend
vmoens Sep 28, 2023
bb496ef
update scratch_dir, frame skip, config
BY571 Sep 28, 2023
9b4704b
Merge branch 'main' into td3_benchmark
BY571 Oct 2, 2023
e622bf7
merge main
BY571 Oct 2, 2023
57bc54a
merge main
BY571 Oct 2, 2023
29977df
step counter
BY571 Oct 2, 2023
854e2a2
merge main
BY571 Oct 3, 2023
619f2ea
small fixes
BY571 Oct 3, 2023
8d36787
solve logger issue
vmoens Oct 3, 2023
a24ab8d
reset notensordict test
vmoens Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ env:
exp_name: ${env.name}_TD3
library: gym
seed: 42
max_episode_steps: 5000
max_episode_steps: 1000

# collector
collector:
Expand All @@ -32,7 +32,7 @@ optim:
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
adam_eps: 1e-8
adam_eps: 1e-4
batch_size: 256
target_update_polyak: 0.995
policy_update_delay: 2
Expand Down
8 changes: 3 additions & 5 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,17 @@ def main(cfg: "DictConfig"): # noqa: F821
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward(retain_graph=update_actor)
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
Expand Down
93 changes: 50 additions & 43 deletions examples/td3/utils.py
BY571 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
from contextlib import nullcontext

import torch

Expand All @@ -16,6 +18,7 @@
InitTracker,
ParallelEnv,
RewardSum,
StepCounter,
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
Expand All @@ -39,51 +42,50 @@
# -----------------


def env_maker(task, device="cpu", from_pixels=False, max_episode_steps=1000):
def env_maker(
task,
device="cpu",
from_pixels=False,
):
with set_gym_backend("gym"):
return GymEnv(
task,
device=device,
from_pixels=from_pixels,
max_episode_steps=max_episode_steps,
)


def apply_env_transforms(env, reward_scaling=1.0):
def apply_env_transforms(env, max_episode_steps, reward_scaling=1.0):
transformed_env = TransformedEnv(
env,
Compose(
StepCounter(max_steps=max_episode_steps),
InitTracker(),
RewardScaling(loc=0.0, scale=reward_scaling),
DoubleToFloat("observation"),
RewardSum(),
DoubleToFloat(),
),
)
if reward_scaling != 1.0:
transformed_env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))
transformed_env.append_transform(RewardSum())
return transformed_env


def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(
lambda: env_maker(
task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps
)
),
EnvCreator(lambda task=cfg.env.name: env_maker(task=task)),
)
parallel_env.set_seed(cfg.env.seed)

train_env = apply_env_transforms(parallel_env)
train_env = apply_env_transforms(
parallel_env, max_episode_steps=cfg.env.max_episode_steps
)

eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(
lambda: env_maker(
task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps
)
),
EnvCreator(lambda task=cfg.env.name: env_maker(task=task)),
),
train_env.transform.clone(),
)
Expand Down Expand Up @@ -115,35 +117,40 @@ def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir="/tmp/",
buffer_scratch_dir=None,
device="cpu",
prefetch=3,
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
return replay_buffer
with (
tempfile.TemporaryDirectory()
if buffer_scratch_dir is None
else nullcontext(buffer_scratch_dir)
) as scratch_dir:
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
)
return replay_buffer


# ====================================================================
Expand Down
58 changes: 36 additions & 22 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import _pickle
import abc
import inspect
Expand Down Expand Up @@ -391,12 +393,12 @@ class SyncDataCollector(DataCollectorBase):
If the environment wraps multiple environments together, the number
of steps is tracked for each environment independently. Negative
values are allowed, in which case this argument is ignored.
Defaults to ``-1`` (i.e. no maximum number of steps).
Defaults to ``None`` (i.e. no maximum number of steps).
init_random_frames (int, optional): Number of frames for which the
policy is ignored before it is called. This feature is mainly
intended to be used in offline/model-based settings, where a
batch of random trajectories can be used to initialize training.
Defaults to ``-1`` (i.e. no random frames).
Defaults to ``None`` (i.e. no random frames).
reset_at_each_iter (bool, optional): Whether environments should be reset
at the beginning of a batch collection.
Defaults to ``False``.
Expand Down Expand Up @@ -498,12 +500,12 @@ def __init__(
total_frames: int,
device: DEVICE_TYPING = None,
storing_device: DEVICE_TYPING = None,
create_env_kwargs: Optional[dict] = None,
max_frames_per_traj: int = -1,
init_random_frames: int = -1,
create_env_kwargs: dict | None = None,
max_frames_per_traj: int | None = None,
init_random_frames: int | None = None,
reset_at_each_iter: bool = False,
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
split_trajs: bool | None = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
exploration_mode=None,
return_same_td: bool = False,
Expand Down Expand Up @@ -567,7 +569,7 @@ def __init__(

self.env: EnvBase = self.env.to(self.device)
self.max_frames_per_traj = max_frames_per_traj
if self.max_frames_per_traj > 0:
if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
# let's check that there is no StepCounter yet
for key in self.env.output_spec.keys(True, True):
if isinstance(key, str):
Expand Down Expand Up @@ -823,7 +825,10 @@ def rollout(self) -> TensorDictBase:
tensordicts = []
with set_exploration_type(self.exploration_type):
for t in range(self.frames_per_batch):
if self._frames < self.init_random_frames:
if (
self.init_random_frames is not None
and self._frames < self.init_random_frames
):
self.env.rand_step(self._tensordict)
else:
self.policy(self._tensordict)
Expand Down Expand Up @@ -1016,12 +1021,12 @@ class _MultiDataCollector(DataCollectorBase):
If the environment wraps multiple environments together, the number
of steps is tracked for each environment independently. Negative
values are allowed, in which case this argument is ignored.
Defaults to ``-1`` (i.e. no maximum number of steps).
Defaults to ``None`` (i.e. no maximum number of steps).
init_random_frames (int, optional): Number of frames for which the
policy is ignored before it is called. This feature is mainly
intended to be used in offline/model-based settings, where a
batch of random trajectories can be used to initialize training.
Defaults to ``-1`` (i.e. no random frames).
Defaults to ``None`` (i.e. no random frames).
reset_at_each_iter (bool, optional): Whether environments should be reset
at the beginning of a batch collection.
Defaults to ``False``.
Expand Down Expand Up @@ -1077,8 +1082,8 @@ def __init__(
device: DEVICE_TYPING = None,
storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None,
create_env_kwargs: Optional[Sequence[dict]] = None,
max_frames_per_traj: int = -1,
init_random_frames: int = -1,
max_frames_per_traj: int | None = None,
init_random_frames: int | None = None,
reset_at_each_iter: bool = False,
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
Expand Down Expand Up @@ -1633,7 +1638,10 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.update_policy_weights_()

for idx in range(self.num_workers):
if self._frames < self.init_random_frames:
if (
self.init_random_frames is not None
and self._frames < self.init_random_frames
):
msg = "continue_random"
else:
msg = "continue"
Expand Down Expand Up @@ -1869,7 +1877,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.update_policy_weights_()

for i in range(self.num_workers):
if self.init_random_frames > 0:
if self.init_random_frames is not None and self.init_random_frames > 0:
self.pipes[i].send((None, "continue_random"))
else:
self.pipes[i].send((None, "continue"))
Expand All @@ -1891,7 +1899,10 @@ def iterator(self) -> Iterator[TensorDictBase]:

# the function blocks here until the next item is asked, hence we send the message to the
# worker to keep on working in the meantime before the yield statement
if self._frames < self.init_random_frames:
if (
self.init_random_frames is not None
and self._frames < self.init_random_frames
):
msg = "continue_random"
else:
msg = "continue"
Expand All @@ -1918,7 +1929,10 @@ def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None:
raise Exception("self.queue_out is full")
if self.running:
for idx in range(self.num_workers):
if self._frames < self.init_random_frames:
if (
self.init_random_frames is not None
and self._frames < self.init_random_frames
):
self.pipes[idx].send((idx, "continue_random"))
else:
self.pipes[idx].send((idx, "continue"))
Expand Down Expand Up @@ -1952,14 +1966,14 @@ class aSyncDataCollector(MultiaSyncDataCollector):
environment wraps multiple environments together, the number of
steps is tracked for each environment independently. Negative
values are allowed, in which case this argument is ignored.
Default is -1 (i.e. no maximum number of steps)
Defaults to ``None`` (i.e. no maximum number of steps)
frames_per_batch (int): Time-length of a batch.
reset_at_each_iter and frames_per_batch == n_steps are equivalent configurations.
default: 200
Defaults to ``200``
init_random_frames (int): Number of frames for which the policy is ignored before it is called.
This feature is mainly intended to be used in offline/model-based settings, where a batch of random
trajectories can be used to initialize training.
default=-1 (i.e. no random frames)
Defaults to ``None`` (i.e. no random frames)
reset_at_each_iter (bool): Whether or not environments should be reset for each batch.
default=False.
postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a
Expand Down Expand Up @@ -1994,9 +2008,9 @@ def __init__(
] = None,
total_frames: Optional[int] = -1,
create_env_kwargs: Optional[dict] = None,
max_frames_per_traj: int = -1,
max_frames_per_traj: int | None = None,
frames_per_batch: int = 200,
init_random_frames: int = -1,
init_random_frames: int | None = None,
reset_at_each_iter: bool = False,
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
Expand Down
7 changes: 7 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,11 +2864,18 @@ def _set_in_keys(self):
self._keys_inv_unset = False
self._container.empty_cache()

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._keys_unset:
self._set_in_keys()
return super().reset(tensordict)

@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform."""
if self._keys_unset:
self._set_in_keys()
# if still no update
if self._keys_unset:
for in_key, data in tensordict.items(True, True):
if data.dtype == self.dtype_in:
out_key = in_key
Expand Down