Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update TD3 Example #1523

Merged
merged 35 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6339a07
update executable
BY571 Sep 6, 2023
9e890b3
fix objective
BY571 Sep 7, 2023
117c477
fix objective
BY571 Sep 7, 2023
d2b3ad4
Update initial frames and general structure
BY571 Sep 12, 2023
9c6c358
fixes
BY571 Sep 12, 2023
1adbff5
Merge branch 'main' into td3_benchmark
BY571 Sep 12, 2023
2422ef8
naming fix
BY571 Sep 12, 2023
0e67de2
single step td3
BY571 Sep 13, 2023
1fc0847
small fixes
BY571 Sep 14, 2023
7a02b83
fix
BY571 Sep 14, 2023
243d712
add update counter
BY571 Sep 14, 2023
af31bd9
naming fixes
BY571 Sep 14, 2023
1122808
update logging and small fixes
BY571 Sep 15, 2023
b4df32b
no eps
BY571 Sep 18, 2023
13f367a
update tests
BY571 Sep 19, 2023
72ddf7e
update objective
BY571 Sep 20, 2023
c830891
set gym backend
BY571 Sep 20, 2023
1a2f08e
Merge branch 'main' into td3_benchmark
vmoens Sep 21, 2023
4cdbb3b
update tests
BY571 Sep 21, 2023
76dcdeb
update fix max episode steps
BY571 Sep 22, 2023
68d4c26
Merge branch 'main' into td3_benchmark
BY571 Sep 26, 2023
ec8b089
fix
BY571 Sep 27, 2023
bcc3bc6
fix
BY571 Sep 27, 2023
42748e0
amend
vmoens Sep 28, 2023
0052cd9
Merge remote-tracking branch 'BY571/td3_benchmark' into td3_benchmark
vmoens Sep 28, 2023
e2c28c8
amend
vmoens Sep 28, 2023
bb496ef
update scratch_dir, frame skip, config
BY571 Sep 28, 2023
9b4704b
Merge branch 'main' into td3_benchmark
BY571 Oct 2, 2023
e622bf7
merge main
BY571 Oct 2, 2023
57bc54a
merge main
BY571 Oct 2, 2023
29977df
step counter
BY571 Oct 2, 2023
854e2a2
merge main
BY571 Oct 3, 2023
619f2ea
small fixes
BY571 Oct 3, 2023
8d36787
solve logger issue
vmoens Oct 3, 2023
a24ab8d
reset notensordict test
vmoens Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ env:
library: gym
frame_skip: 1
seed: 42
max_episode_steps: 1_000_000

# collector
collector:
Expand All @@ -14,7 +15,6 @@ collector:
init_env_steps: 1000
frames_per_batch: 1000
max_frames_per_traj: 1000
async_collection: 1
collector_device: cpu
env_per_collector: 1
num_workers: 1
Expand Down
11 changes: 6 additions & 5 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,17 @@ def main(cfg: "DictConfig"): # noqa: F821
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
BY571 marked this conversation as resolved.
Show resolved Hide resolved
episode_rewards = tensordict["next", "episode_reward"][
episode_end = (
tensordict["next", "done"]
]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
Expand Down
24 changes: 20 additions & 4 deletions examples/td3/utils.py
BY571 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@
# -----------------


def env_maker(task, device="cpu"):
def env_maker(
task, device="cpu", max_episode_steps=1000
):
with set_gym_backend("gym"):
return GymEnv(task, device=device)
return GymEnv(
task,
device=device,
max_episode_steps=max_episode_steps,
)


def apply_env_transforms(env, reward_scaling=1.0):
Expand All @@ -63,7 +69,11 @@ def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda task=cfg.env.name: env_maker(task=task)),
EnvCreator(
lambda task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps: env_maker(
task=task, max_episode_steps=max_episode_steps
)
),
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -72,7 +82,13 @@ def make_environment(cfg):
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda task=cfg.env.name: env_maker(task=task)),
EnvCreator(
lambda
task=cfg.env.name,
max_episode_steps=cfg.env.max_episode_steps: env_maker(
task=task, max_episode_steps=max_episode_steps
)
),
),
train_env.transform.clone(),
)
Expand Down
6 changes: 4 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,8 +2370,10 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key):
loss_val_td = loss(td)
torch.manual_seed(0)
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
for i in loss_val:
assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}"
# for i, key in enumerate(loss_val_td.keys()):
# torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dangerous as keys in the tensordict get ordered by name but output tuple loss_val doesn't. So for now im just checking if all values in the loss_val tuple are also in the loss_val_td.

# actor metadata
metadata = {
            "state_action_value_actor": state_action_value_actor.mean().detach(),
        }
# value metadata
        metadata = {
            "td_error": td_error,
            "next_state_value": next_target_qvalue.mean().detach(),
            "pred_value": current_qvalue.mean().detach(),
            "target_value": target_value.mean().detach(),
        }
# out tensordict
        td_out = TensorDict(
            source={
                "loss_actor": loss_actor,
                "loss_qvalue": loss_qval,
                **metadata_actor,
                **metadata_value,
            },
            batch_size=[],
        )

loss_vals will be in that order (loss_actor, loss_qvalue, state_action_value_actor, next_state_value, pred_value, target_value)
However, as the items are getting ordered in the TD by the keys the output tensordict has actually this order:
(loss_actor, loss_qvalue, next_state_value, pred_value, state_action_value_actor, target_value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch returns the keys in the order of out_keys.
So this is predictable, we can just do

            for i, key in enumerate(loss.out_keys):
                torch.testing.assert_close(loss_val_td.get(key), loss_val[i])

does that solve the problem?

# test select
loss.select_out_keys("loss_actor", "loss_qvalue")
torch.manual_seed(0)
Expand Down
6 changes: 3 additions & 3 deletions torchrl/objectives/td3.py
BY571 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def actor_loss(self, tensordict):
state_action_value_actor = (
self._vmap_qvalue_network00(
actor_loss_td,
self.qvalue_network_params,
self._cached_detach_qvalue_network_params,
)
.get(self.tensor_keys.state_action_value)
.squeeze(-1)
Expand Down Expand Up @@ -446,8 +446,8 @@ def value_loss(self, tensordict):
)
metadata = {
"td_error": td_error,
"pred_value": current_qvalue.mean().detach(),
"next_state_value": next_target_qvalue.mean().detach(),
"pred_value": current_qvalue.mean().detach(),
"target_value": target_value.mean().detach(),
}

Expand All @@ -456,8 +456,8 @@ def value_loss(self, tensordict):
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_save = tensordict
loss_qval, metadata_value = self.value_loss(tensordict)
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_qval, metadata_value = self.value_loss(tensordict_save)
tensordict_save.set(
self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
)
Expand Down