From 106368ffe627aada70ebb33ecf545d0f193c7757 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Oct 2023 14:47:02 +0100 Subject: [PATCH] [Feature] Make advantages compatible with Terminated, Truncated, Done (#1581) Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> --- examples/multiagent/iql.py | 11 +- examples/multiagent/maddpg_iddpg.py | 11 +- examples/multiagent/mappo_ippo.py | 16 +- examples/multiagent/sac.py | 13 +- examples/multiagent/utils/utils.py | 41 ++ test/test_cost.py | 764 +++++++++++++++++++----- torchrl/objectives/a2c.py | 12 +- torchrl/objectives/cql.py | 9 +- torchrl/objectives/ddpg.py | 14 +- torchrl/objectives/deprecated.py | 7 + torchrl/objectives/dqn.py | 23 +- torchrl/objectives/dreamer.py | 5 + torchrl/objectives/iql.py | 12 +- torchrl/objectives/multiagent/qmixer.py | 8 + torchrl/objectives/ppo.py | 14 +- torchrl/objectives/redq.py | 11 +- torchrl/objectives/reinforce.py | 11 +- torchrl/objectives/sac.py | 23 +- torchrl/objectives/td3.py | 11 +- torchrl/objectives/value/advantages.py | 138 +++-- torchrl/objectives/value/functional.py | 376 ++++++++---- torchrl/objectives/value/utils.py | 10 +- 22 files changed, 1203 insertions(+), 337 deletions(-) create mode 100644 examples/multiagent/utils/utils.py diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 4d36614f199..351f5c3730e 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -21,6 +21,7 @@ from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform def rendering_callback(env, td): @@ -111,6 +112,7 @@ def train(cfg: "DictConfig"): # noqa: F821 storing_device=cfg.train.device, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), ) replay_buffer = TensorDictReplayBuffer( @@ -125,6 +127,8 @@ def train(cfg: "DictConfig"): # noqa: F821 action=env.action_key, value=("agents", "chosen_action_value"), reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), ) loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) @@ -144,13 +148,6 @@ def train(cfg: "DictConfig"): # noqa: F821 sampling_time = time.time() - sampling_start - tensordict_data.set( - ("next", "done"), - tensordict_data.get(("next", "done")) - .unsqueeze(-1) - .expand(tensordict_data.get(("next", env.reward_key)).shape), - ) # We need to expand the done to match the reward shape - current_frames = tensordict_data.numel() total_frames += current_frames data_view = tensordict_data.reshape(-1) diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index 0b1cb4079e8..9301f8a63f2 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -26,6 +26,7 @@ from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform def rendering_callback(env, td): @@ -133,6 +134,7 @@ def train(cfg: "DictConfig"): # noqa: F821 storing_device=cfg.train.device, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), ) replay_buffer = TensorDictReplayBuffer( @@ -147,6 +149,8 @@ def train(cfg: "DictConfig"): # noqa: F821 loss_module.set_keys( state_action_value=("agents", "state_action_value"), reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), ) loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau) @@ -170,13 +174,6 @@ def train(cfg: "DictConfig"): # noqa: F821 sampling_time = time.time() - sampling_start - tensordict_data.set( - ("next", "done"), - tensordict_data.get(("next", "done")) - .unsqueeze(-1) - .expand(tensordict_data.get(("next", env.reward_key)).shape), - ) # We need to expand the done to match the reward shape - current_frames = tensordict_data.numel() total_frames += current_frames data_view = tensordict_data.reshape(-1) diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index 6be5240935f..c2e46174e92 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -22,6 +22,7 @@ from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import ClipPPOLoss, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform def rendering_callback(env, td): @@ -126,6 +127,7 @@ def train(cfg: "DictConfig"): # noqa: F821 storing_device=cfg.train.device, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), ) replay_buffer = TensorDictReplayBuffer( @@ -142,7 +144,12 @@ def train(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_eps, normalize_advantage=False, ) - loss_module.set_keys(reward=env.reward_key, action=env.action_key) + loss_module.set_keys( + reward=env.reward_key, + action=env.action_key, + done=("agents", "done"), + terminated=("agents", "terminated"), + ) loss_module.make_value_estimator( ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda ) @@ -165,13 +172,6 @@ def train(cfg: "DictConfig"): # noqa: F821 sampling_time = time.time() - sampling_start - tensordict_data.set( - ("next", "done"), - tensordict_data.get(("next", "done")) - .unsqueeze(-1) - .expand(tensordict_data.get(("next", env.reward_key)).shape), - ) # We need to expand the done to match the reward shape - with torch.no_grad(): loss_module.value_estimator( tensordict_data, diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index e9aea20e282..6fc063c2411 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -23,6 +23,7 @@ from torchrl.modules.models.multiagent import MultiAgentMLP from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators from utils.logging import init_logging, log_evaluation, log_training +from utils.utils import DoneTransform def rendering_callback(env, td): @@ -179,6 +180,7 @@ def train(cfg: "DictConfig"): # noqa: F821 storing_device=cfg.train.device, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, + postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys), ) replay_buffer = TensorDictReplayBuffer( @@ -198,6 +200,8 @@ def train(cfg: "DictConfig"): # noqa: F821 state_action_value=("agents", "state_action_value"), action=env.action_key, reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), ) else: loss_module = DiscreteSACLoss( @@ -211,6 +215,8 @@ def train(cfg: "DictConfig"): # noqa: F821 action_value=("agents", "action_value"), action=env.action_key, reward=env.reward_key, + done=("agents", "done"), + terminated=("agents", "terminated"), ) loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma) @@ -235,13 +241,6 @@ def train(cfg: "DictConfig"): # noqa: F821 sampling_time = time.time() - sampling_start - tensordict_data.set( - ("next", "done"), - tensordict_data.get(("next", "done")) - .unsqueeze(-1) - .expand(tensordict_data.get(("next", env.reward_key)).shape), - ) # We need to expand the done to match the reward shape - current_frames = tensordict_data.numel() total_frames += current_frames data_view = tensordict_data.reshape(-1) diff --git a/examples/multiagent/utils/utils.py b/examples/multiagent/utils/utils.py new file mode 100644 index 00000000000..d21bafdf691 --- /dev/null +++ b/examples/multiagent/utils/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from tensordict import unravel_key +from torchrl.envs import Transform + + +def swap_last(source, dest): + source = unravel_key(source) + dest = unravel_key(dest) + if isinstance(source, str): + if isinstance(dest, str): + return dest + return dest[-1] + if isinstance(dest, str): + return source[:-1] + (dest,) + return source[:-1] + (dest[-1],) + + +class DoneTransform(Transform): + """Expands the 'done' entries (incl. terminated) to match the reward shape. + + Can be appended to a replay buffer or a collector. + """ + + def __init__(self, reward_key, done_keys): + super().__init__() + self.reward_key = reward_key + self.done_keys = done_keys + + def forward(self, tensordict): + for done_key in self.done_keys: + new_name = swap_last(self.reward_key, done_key) + tensordict.set( + ("next", new_name), + tensordict.get(("next", done_key)) + .unsqueeze(-1) + .expand(tensordict.get(("next", self.reward_key)).shape), + ) + return tensordict diff --git a/test/test_cost.py b/test/test_cost.py index f7333f4a0ba..c3d4e0b8086 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -351,6 +351,7 @@ def _create_mock_data_dqn( action = torch.argmax(action, -1, keepdim=False) reward = torch.randn(batch, 1) done = torch.zeros(batch, 1, dtype=torch.bool) + terminated = torch.zeros(batch, 1, dtype=torch.bool) td = TensorDict( batch_size=(batch,), source={ @@ -358,6 +359,7 @@ def _create_mock_data_dqn( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, action_key: action, @@ -395,6 +397,7 @@ def _create_seq_mock_data_dqn( # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] @@ -409,6 +412,7 @@ def _create_seq_mock_data_dqn( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -555,6 +559,7 @@ def test_dqn_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -565,6 +570,7 @@ def test_dqn_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -671,7 +677,10 @@ def test_distributional_dqn( @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_dqn_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_dqn_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): n_obs = 3 n_action = 4 action_spec = OneHotDiscreteTensorSpec(n_action) @@ -683,18 +692,20 @@ def test_dqn_notensordict(self, observation_key, reward_key, done_key): in_keys=[observation_key], ) dqn_loss = DQNLoss(actor) - dqn_loss.set_keys(reward=reward_key, done=done_key) + dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) next_observation = torch.randn(n_obs) action = action_spec.rand() next_reward = torch.randn(1) next_done = torch.zeros(1, dtype=torch.bool) + next_terminated = torch.zeros(1, dtype=torch.bool) kwargs = { observation_key: observation, f"next_{observation_key}": next_observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") @@ -719,6 +730,7 @@ def test_distributional_dqn_tensordict_keys(self): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", "steps_to_next_obs": "steps_to_next_obs", } @@ -851,6 +863,7 @@ def _create_mock_data_dqn( reward = torch.randn(*batch, 1, device=device) done = torch.zeros(*batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(*batch, 1, dtype=torch.bool, device=device) td = TensorDict( { "agents": TensorDict( @@ -872,6 +885,7 @@ def _create_mock_data_dqn( "state": next_state, "reward": reward, "done": done, + "terminated": terminated, }, batch_size=batch, device=device, @@ -1050,6 +1064,7 @@ def test_qmix_tensordict_keys(self, td_est): "action": ("agents", "action"), "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -1060,6 +1075,7 @@ def test_qmix_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1153,6 +1169,7 @@ def test_mixer_keys( "state": torch.zeros(32, 64, 64, 3), "reward": torch.zeros(32, 1), "done": torch.zeros(32, 1, dtype=torch.bool), + "terminated": torch.zeros(32, 1, dtype=torch.bool), }, [32], ), @@ -1255,10 +1272,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1283,6 +1302,7 @@ def _create_mock_data_ddpg( device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1294,6 +1314,7 @@ def _create_mock_data_ddpg( reward = torch.randn(batch, 1, device=device) state = torch.randn(batch, state_dim, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -1303,6 +1324,7 @@ def _create_mock_data_ddpg( "observation": next_obs, "state": state, done_key: done, + terminated_key: terminated, reward_key: reward, }, "action": action, @@ -1322,6 +1344,7 @@ def _create_seq_mock_data_ddpg( device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -1339,6 +1362,7 @@ def _create_seq_mock_data_ddpg( reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -1349,6 +1373,7 @@ def _create_seq_mock_data_ddpg( "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "state": next_state.masked_fill_(~mask.unsqueeze(-1), 0.0), done_key: done, + terminated_key: terminated, reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -1656,6 +1681,7 @@ def test_ddpg_tensordict_keys(self, td_est): default_keys = { "reward": "reward", "done": "done", + "terminated": "terminated", "state_action_value": "state_action_value", "priority": "td_error", } @@ -1676,6 +1702,7 @@ def test_ddpg_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward2"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1691,12 +1718,15 @@ def test_ddpg_tensordict_run(self, td_est): "priority": "td_error_test", "reward": "reward_test", "done": ("done", "test"), + "terminated": ("terminated", "test"), } actor = self._create_mock_actor() value = self._create_mock_value(out_keys=[tensor_keys["state_action_value"]]) td = self._create_mock_data_ddpg( - reward_key="reward_test", done_key=("done", "test") + reward_key="reward_test", + done_key=("done", "test"), + terminated_key=("terminated", "test"), ) loss_fn = DDPGLoss( actor, @@ -1724,6 +1754,7 @@ def test_ddpg_notensordict(self): "observation": td.get("observation"), "next_reward": td.get(("next", "reward")), "next_done": td.get(("next", "done")), + "next_terminated": td.get(("next", "terminated")), "next_observation": td.get(("next", "observation")), "action": td.get("action"), "state": td.get("state"), @@ -1835,10 +1866,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1872,6 +1905,7 @@ def _create_mock_data_td3( observation_key="observation", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1882,6 +1916,7 @@ def _create_mock_data_td3( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -1889,6 +1924,7 @@ def _create_mock_data_td3( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -1912,6 +1948,7 @@ def _create_seq_mock_data_td3( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -1921,6 +1958,7 @@ def _create_seq_mock_data_td3( "observation": next_obs * mask.to(obs.dtype), "reward": reward * mask.to(obs.dtype), "done": done, + "terminated": terminated, }, "collector": {"mask": mask}, "action": action * mask.to(obs.dtype), @@ -2293,6 +2331,7 @@ def test_td3_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -2311,6 +2350,7 @@ def test_td3_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -2344,22 +2384,29 @@ def test_constructor(self, spec, bounds): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_td3_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(in_keys=[observation_key]) qvalue = self._create_mock_value( observation_key=observation_key, out_keys=["state_action_value"] ) td = self._create_mock_data_td3( - observation_key=observation_key, reward_key=reward_key, done_key=done_key + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, ) loss = TD3Loss(actor, qvalue, action_spec=actor.spec) - loss.set_keys(reward=reward_key, done=done_key) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) kwargs = { observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), "action": td.get("action"), } @@ -2492,10 +2539,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -2532,6 +2581,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -2543,6 +2593,7 @@ def _create_mock_data_sac( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -2550,6 +2601,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -2573,6 +2625,7 @@ def _create_seq_mock_data_sac( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -2581,6 +2634,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3090,6 +3144,7 @@ def test_sac_tensordict_keys(self, td_est, version): "log_prob": "_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -3109,6 +3164,7 @@ def test_sac_tensordict_keys(self, td_est, version): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3116,8 +3172,9 @@ def test_sac_tensordict_keys(self, td_est, version): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_sac_notensordict( - self, action_key, observation_key, reward_key, done_key, version + self, action_key, observation_key, reward_key, done_key, terminated_key, version ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3125,6 +3182,7 @@ def test_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -3145,13 +3203,19 @@ def test_sac_notensordict( qvalue_network=qvalue, value_network=value, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -3259,6 +3323,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -3276,6 +3341,7 @@ def _create_mock_data_sac( action = (action_value == action_value.max(-1, True)[0]).to(torch.long) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3283,6 +3349,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3311,6 +3378,7 @@ def _create_seq_mock_data_sac( reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3319,6 +3387,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3632,6 +3701,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -3651,6 +3721,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3658,8 +3729,9 @@ def test_discrete_sac_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_discrete_sac_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3667,6 +3739,7 @@ def test_discrete_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -3681,13 +3754,19 @@ def test_discrete_sac_notensordict( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -3801,10 +3880,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -3881,6 +3962,7 @@ def _create_mock_data_redq( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -3891,6 +3973,7 @@ def _create_mock_data_redq( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3898,6 +3981,7 @@ def _create_mock_data_redq( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3921,6 +4005,7 @@ def _create_seq_mock_data_redq( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3929,6 +4014,7 @@ def _create_seq_mock_data_redq( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -4497,6 +4583,7 @@ def test_redq_tensordict_keys(self, td_est): "state_action_value": "state_action_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -4515,6 +4602,7 @@ def test_redq_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -4522,9 +4610,10 @@ def test_redq_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize("deprec", [True, False]) def test_redq_notensordict( - self, action_key, observation_key, reward_key, done_key, deprec + self, action_key, observation_key, reward_key, done_key, terminated_key, deprec ): torch.manual_seed(self.seed) td = self._create_mock_data_redq( @@ -4532,6 +4621,7 @@ def test_redq_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -4552,13 +4642,19 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -4657,6 +4753,7 @@ def _create_mock_data_cql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -4664,6 +4761,7 @@ def _create_mock_data_cql( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, "action": action, @@ -4687,6 +4785,7 @@ def _create_seq_mock_data_cql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -4695,6 +4794,7 @@ def _create_seq_mock_data_cql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -5129,6 +5229,7 @@ def _create_mock_data_ppo( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", sample_log_prob_key="sample_log_prob", ): # create a tensordict @@ -5140,6 +5241,7 @@ def _create_mock_data_ppo( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -5147,6 +5249,7 @@ def _create_mock_data_ppo( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -5179,6 +5282,7 @@ def _create_seq_mock_data_ppo( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 @@ -5189,6 +5293,7 @@ def _create_seq_mock_data_ppo( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -5522,6 +5627,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -5540,6 +5646,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "value": ("value", value_key), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -5644,6 +5751,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_ppo_notensordict( self, loss_class, @@ -5652,6 +5760,7 @@ def test_ppo_notensordict( observation_key, reward_key, done_key, + terminated_key, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -5660,6 +5769,7 @@ def test_ppo_notensordict( sample_log_prob_key=sample_log_prob_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -5670,6 +5780,7 @@ def test_ppo_notensordict( action=action_key, reward=reward_key, done=done_key, + terminated=terminated_key, sample_log_prob=sample_log_prob_key, ) @@ -5679,6 +5790,7 @@ def test_ppo_notensordict( sample_log_prob_key: td.get(sample_log_prob_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -5781,10 +5893,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -5820,6 +5934,7 @@ def _create_seq_mock_data_a2c( observation_key="observation", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -5833,6 +5948,7 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 @@ -5843,6 +5959,7 @@ def _create_seq_mock_data_a2c( "next": { observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), done_key: done, + terminated_key: terminated, reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -6080,6 +6197,7 @@ def test_a2c_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -6098,6 +6216,7 @@ def test_a2c_tensordict_keys(self, td_est): "value": ("value", "value_state_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -6112,12 +6231,14 @@ def test_a2c_tensordict_keys_run(self, device): action_key = "action_test" reward_key = "reward_test" done_key = ("done", "test") + terminated_key = ("terminated", "test") td = self._create_seq_mock_data_a2c( device=device, action_key=action_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(device=device) @@ -6134,6 +6255,7 @@ def test_a2c_tensordict_keys_run(self, device): value=value_key, reward=reward_key, done=done_key, + terminated=terminated_key, ) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( @@ -6143,6 +6265,7 @@ def test_a2c_tensordict_keys_run(self, device): action=action_key, reward=reward_key, done=done_key, + terminated=done_key, ) advantage(td) @@ -6179,7 +6302,10 @@ def test_a2c_tensordict_keys_run(self, device): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_a2c_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(observation_key=observation_key) @@ -6189,16 +6315,23 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) loss = A2CLoss(actor, value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { observation_key: td.get(observation_key), f"next_{observation_key}": td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), action_key: td.get(action_key), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -6289,6 +6422,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est "observation": torch.randn(batch, n_obs), "reward": torch.randn(batch, 1), "done": torch.zeros(batch, 1, dtype=torch.bool), + "terminated": torch.zeros(batch, 1, dtype=torch.bool), }, "action": torch.randn(batch, n_act), }, @@ -6372,6 +6506,7 @@ def test_reinforce_tensordict_keys(self, td_est): "sample_log_prob": "sample_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -6395,6 +6530,7 @@ def test_reinforce_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -6427,10 +6563,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -6527,8 +6665,9 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_reinforce_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) n_obs = 3 @@ -6547,19 +6686,26 @@ def test_reinforce_notensordict( spec=UnboundedContinuousTensorSpec(n_act), ) loss = ReinforceLoss(actor=actor_net, critic=value_net) - loss.set_keys(reward=reward_key, done=done_key, action=action_key) + loss.set_keys( + reward=reward_key, + done=done_key, + action=action_key, + terminated=terminated_key, + ) observation = torch.randn(batch, n_obs) action = torch.randn(batch, n_act) next_reward = torch.randn(batch, 1) next_observation = torch.randn(batch, n_obs) next_done = torch.zeros(batch, 1, dtype=torch.bool) + next_terminated = torch.zeros(batch, 1, dtype=torch.bool) kwargs = { action_key: action, observation_key: observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, f"next_{observation_key}": next_observation, } td = TensorDict(kwargs, [batch]).unflatten_keys("_") @@ -6600,6 +6746,9 @@ def _create_world_model_data( ), "reward": torch.randn(batch_size, temporal_length, 1), "done": torch.zeros(batch_size, temporal_length, dtype=torch.bool), + "terminated": torch.zeros( + batch_size, temporal_length, dtype=torch.bool + ), }, "action": torch.randn(batch_size, temporal_length, 64), }, @@ -7024,6 +7173,7 @@ def test_dreamer_actor_tensordict_keys(self, td_est, device): "reward": "reward", "value": "state_value", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -7529,10 +7679,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -7579,6 +7731,7 @@ def _create_mock_data_iql( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -7590,6 +7743,7 @@ def _create_mock_data_iql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -7597,6 +7751,7 @@ def _create_mock_data_iql( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -7620,6 +7775,7 @@ def _create_seq_mock_data_iql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -7628,6 +7784,7 @@ def _create_seq_mock_data_iql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -8085,6 +8242,7 @@ def test_iql_tensordict_keys(self, td_est): "value": "state_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -8104,6 +8262,7 @@ def test_iql_tensordict_keys(self, td_est): key_mapping = { "value": ("value", "value_test"), "done": ("done", "done_test"), + "terminated": ("terminated", "terminated_test"), "reward": ("reward", ("reward", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -8112,13 +8271,17 @@ def test_iql_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_iql_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_iql_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) td = self._create_mock_data_iql( action_key=action_key, observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -8130,13 +8293,19 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke value = self._create_mock_value(observation_key=observation_key) loss = IQLLoss(actor_network=actor, qvalue_network=qvalue, value_network=value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -8454,24 +8623,77 @@ class TestValues: @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) # @pytest.mark.parametrize("random_gamma,rolling_gamma", [[True, False], [True, True], [False, None]]) @pytest.mark.parametrize("random_gamma,rolling_gamma", [[False, None]]) def test_tdlambda(self, device, gamma, lmbda, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) - next_state_value = torch.randn(*N, T, 1, device=device) if random_gamma: gamma = torch.rand_like(reward) * gamma + next_state_value = torch.cat( + [state_value[..., 1:, :], torch.randn_like(state_value[..., -1:, :])], -2 + ) r1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r3, *_ = vec_generalized_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + torch.testing.assert_close(r3, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) + + # test when v' is not v from next step (not working with gae) + next_state_value = torch.randn_like(next_state_value) + r1 = vec_td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r2 = td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8488,7 +8710,9 @@ def test_tdlambda_multi( torch.manual_seed(0) D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -8501,8 +8725,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td_lambda_advantage_estimate( @@ -8511,8 +8736,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -8524,8 +8750,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8541,8 +8768,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8559,8 +8787,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8575,8 +8804,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8596,7 +8826,9 @@ def test_tdlambda_multi( def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8604,10 +8836,22 @@ def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): gamma = torch.rand_like(reward) * gamma r1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8624,7 +8868,9 @@ def test_td1_multi( D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -8636,8 +8882,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td1_advantage_estimate( @@ -8645,8 +8892,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -8657,8 +8905,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8673,8 +8922,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8690,8 +8940,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8705,8 +8956,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8724,22 +8976,36 @@ def test_td1_multi( @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - @pytest.mark.parametrize("has_done", [True, False]) + @pytest.mark.parametrize("has_done", [False, True]) def test_gae(self, device, gamma, lmbda, N, T, dtype, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) r1 = vec_generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8764,8 +9030,10 @@ def test_gae_param_as_tensor( T = 200 done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) @@ -8785,10 +9053,22 @@ def test_gae_param_as_tensor( lmbda_vec = lmbda r1 = vec_generalized_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8809,8 +9089,10 @@ def test_gae_multidim( torch.manual_seed(0) done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device, dtype=dtype) state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) @@ -8821,7 +9103,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) r2 = generalized_advantage_estimate( @@ -8830,7 +9113,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) if len(D) == 2: @@ -8841,7 +9125,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -8854,7 +9139,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -8868,7 +9154,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -8880,7 +9167,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -8901,7 +9189,7 @@ def test_gae_multidim( @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("has_done", [True, False]) def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): """Tests vec_td_lambda_advantage_estimate against itself with @@ -8911,32 +9199,61 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - + # if len(N) == 2: + # print(terminated[4, 0, -10:]) + # print(done[4, 0, -10:]) v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -8962,8 +9279,10 @@ def test_tdlambda_tensor_gamma_single_element( torch.manual_seed(0) done = torch.zeros(*N, T, F, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, F, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, F, device=device) state_value = torch.randn(*N, T, F, device=device) next_state_value = torch.randn(*N, T, F, device=device) @@ -8981,22 +9300,47 @@ def test_tdlambda_tensor_gamma_single_element( lmbda_vec = lmbda v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9014,8 +9358,10 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9023,23 +9369,44 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9061,8 +9428,10 @@ def test_vectdlambda_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9070,23 +9439,48 @@ def test_vectdlambda_tensor_gamma( gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9107,28 +9501,55 @@ def test_vectd1_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9150,8 +9571,10 @@ def test_vectdlambda_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9165,8 +9588,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): # if a not-done follows a done, then rolling_gamma=False cannot be used @@ -9179,8 +9603,24 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match=r"The vectorized version of TD" + ): + vec_td_lambda_advantage_estimate( + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -9189,8 +9629,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9210,8 +9651,10 @@ def test_vectd1_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9224,10 +9667,14 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) - if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): + if ( + rolling_gamma is False + and not terminated[..., 1:, :][terminated[..., :-1, :]].all() + ): # if a not-done follows a done, then rolling_gamma=False cannot be used with pytest.raises( NotImplementedError, match="When using rolling_gamma=False" @@ -9237,8 +9684,23 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match="The vectorized version of TD" + ): + vec_td1_advantage_estimate( + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td1_advantage_estimate( @@ -9246,8 +9708,9 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9299,8 +9762,10 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9315,8 +9780,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v1a = td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -9324,8 +9790,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v1b = td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -9333,8 +9800,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9348,8 +9816,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -9358,8 +9827,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v2a = vec_td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -9367,8 +9837,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v2b = vec_td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -9376,8 +9847,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9389,22 +9861,17 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("N", [(3,), (3, 7)]) @pytest.mark.parametrize("T", [3, 5, 200]) - def test_successive_traj_tdadv( - self, - device, - N, - T, - ): + def test_successive_traj_tdadv(self, device, N, T): """Tests td_lambda_advantage_estimate against vec_td_lambda_advantage_estimate with gamma being a random tensor """ torch.manual_seed(0) - lmbda = torch.rand([]).item() - + # for td0, a done that is not terminated has no effect done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) done[..., T // 2 - 1, :] = 1 + terminated = done.clone() reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9418,21 +9885,24 @@ def test_successive_traj_tdadv( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, ) v1a = td0_advantage_estimate( gamma_tensor[..., : T // 2, :], state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], ) v1b = td0_advantage_estimate( gamma_tensor[..., T // 2 :, :], state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9453,8 +9923,10 @@ def test_successive_traj_gae( lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9469,7 +9941,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v1a = generalized_advantage_estimate( gamma_tensor, @@ -9477,7 +9950,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v1b = generalized_advantage_estimate( gamma_tensor, @@ -9485,7 +9959,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9495,7 +9970,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v2a = vec_generalized_advantage_estimate( gamma_tensor, @@ -9503,7 +9979,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v2b = vec_generalized_advantage_estimate( gamma_tensor, @@ -9511,7 +9988,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) torch.testing.assert_close(v2, torch.cat([v2a, v2b], -2), rtol=1e-4, atol=1e-4) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ea2c715d927..bb7b9014f0d 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -97,6 +97,7 @@ class A2CLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -114,7 +115,7 @@ class A2CLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic. The return value is a tuple of tensors in the following order: ``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None @@ -148,6 +149,7 @@ class A2CLoss(LossModule): ... observation = torch.randn(*batch, n_obs), ... action = spec.rand(batch), ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated = torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() @@ -161,6 +163,7 @@ class A2CLoss(LossModule): ... observation = torch.randn(*batch, n_obs), ... action = spec.rand(batch), ... next_done = torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated = torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() @@ -187,6 +190,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -195,6 +201,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE @@ -251,6 +258,7 @@ def in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor.in_keys, *[("next", key) for key in self.actor.in_keys], ] @@ -282,6 +290,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) def reset(self) -> None: @@ -389,5 +398,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index b24d4498106..249166a6bd2 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -126,6 +126,7 @@ class CQLLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -145,7 +146,7 @@ class CQLLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "loss_alpha_prime", "alpha", "entropy"]``. @@ -184,6 +185,7 @@ class CQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -197,6 +199,7 @@ class CQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -229,6 +232,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -392,6 +396,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): @@ -431,6 +436,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -448,6 +454,7 @@ def in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d67db859713..d72afb09f7b 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -69,6 +69,7 @@ class DDPGLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -88,7 +89,7 @@ class DDPGLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["next_reward", "next_done"]`` + in_keys of the actor_network and value_network. + ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]`` @@ -117,6 +118,7 @@ class DDPGLoss(LossModule): ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), + ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() @@ -130,6 +132,7 @@ class DDPGLoss(LossModule): ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), + ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() @@ -154,6 +157,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ @@ -161,6 +167,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.TD0 @@ -232,6 +239,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.state_action_value, reward=self._tensor_keys.reward, done=self._tensor_keys.done, + terminated=self._tensor_keys.terminated, ) self._set_in_keys() @@ -239,6 +247,7 @@ def _set_in_keys(self): in_keys = { unravel_key(("next", self.tensor_keys.reward)), unravel_key(("next", self.tensor_keys.done)), + unravel_key(("next", self.tensor_keys.terminated)), *self.actor_in_keys, *[unravel_key(("next", key)) for key in self.actor_in_keys], *self.value_network.in_keys, @@ -264,7 +273,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a priority to items in the tensordict. Args: - tensordict (TensorDictBase): a tensordict with keys ["done", "reward"] and the in_keys of the actor + tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor and value networks. Returns: @@ -360,6 +369,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.state_action_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 02b82ff430c..696efbdc650 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -109,6 +109,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -118,6 +121,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() delay_actor: bool = False @@ -248,6 +252,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -264,6 +269,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -434,6 +440,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 527af5bf481..225d5d553bd 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -71,6 +71,7 @@ class DQNLoss(LossModule): ... "action": spec.rand(batch), ... ("next", "observation"): torch.randn(*batch, n_obs), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1) ... }, batch) >>> loss(data) @@ -84,7 +85,7 @@ class DQNLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["observation", "next_observation", "action", "next_reward", "next_done"]``, + ``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``, and a single loss value is returned. Examples: @@ -103,11 +104,13 @@ class DQNLoss(LossModule): >>> action = action_spec.rand() >>> next_reward = torch.randn(1) >>> next_done = torch.zeros(1, dtype=torch.bool) + >>> next_terminated = torch.zeros(1, dtype=torch.bool) >>> loss_val = dqn_loss( ... observation=observation, ... next_observation=next_observation, ... next_reward=next_reward, ... next_done=next_done, + ... next_terminated=next_terminated, ... action=action) """ @@ -137,6 +140,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -147,6 +153,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -212,6 +219,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -220,6 +228,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.value_network.in_keys, *[("next", key) for key in self.value_network.in_keys], ] @@ -260,6 +269,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -272,7 +282,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: Args: tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of - the value network (observations, "done", "reward" in a "next" tensordict). + the value network (observations, "done", "terminated", "reward" in a "next" tensordict). Returns: a tensor containing the DQN loss. @@ -363,6 +373,8 @@ class _AcceptedKeys: Defaults to ``"reward"``. done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. Defaults to ``"done"``. + terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + Defaults to ``"terminated"``. steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected. Defaults to ``"steps_to_next_obs"``. """ @@ -372,6 +384,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" default_keys = _AcceptedKeys() @@ -442,6 +455,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: action = tensordict.get(self.tensor_keys.action) reward = tensordict.get(("next", self.tensor_keys.reward)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, 1) discount = self.gamma**steps_to_next_obs @@ -489,12 +503,13 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: # Tz = R^n + (γ^n)z (accounting for terminal states) if isinstance(discount, torch.Tensor): discount = discount.to("cpu") - done = done.to("cpu") + # done = done.to("cpu") + terminated = terminated.to("cpu") reward = reward.to("cpu") support = support.to("cpu") pns_a = pns_a.to("cpu") - Tz = reward + (1 - done.to(reward.dtype)) * discount * support + Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support if Tz.shape != torch.Size([batch_size, atoms]): raise RuntimeError( "Tz shape must be torch.Size([batch_size, atoms]), " diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index c4e02ff4a5a..7bdfde573fa 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -216,12 +216,15 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"state_value"``. done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected ("next", done). Defaults to ``"done"``. + terminated (NestedKey): The input tensordict key where the flag if a + trajectory is terminated is expected ("next", terminated). Defaults to ``"terminated"``. """ belief: NestedKey = "belief" reward: NestedKey = "reward" value: NestedKey = "state_value" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TDLambda @@ -297,11 +300,13 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) + terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) input_tensordict = TensorDict( { ("next", self.tensor_keys.reward): reward, ("next", self.tensor_keys.value): value, ("next", self.tensor_keys.done): done, + ("next", self.tensor_keys.terminated): terminated, }, [], ) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 6ffff97c66a..966550e21e5 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -103,6 +103,7 @@ class IQLLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -120,7 +121,7 @@ class IQLLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``. @@ -163,6 +164,7 @@ class IQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -177,6 +179,7 @@ class IQLLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -206,6 +209,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ value: NestedKey = "state_value" @@ -215,6 +221,7 @@ class _AcceptedKeys: state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -307,6 +314,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -336,6 +344,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -490,5 +499,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f09fa3c0e03..00106571744 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -120,6 +120,7 @@ class QMixerLoss(LossModule): ... "state": torch.zeros(32, 64, 64, 3), ... "reward": torch.zeros(32, 1), ... "done": torch.zeros(32, 1, dtype=torch.bool), + ... "terminated": torch.zeros(32, 1, dtype=torch.bool), ... }, ... [32], ... ), @@ -162,6 +163,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -173,6 +177,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -260,6 +265,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.global_value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -268,6 +274,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.global_value_network.in_keys, *[("next", key) for key in self.global_value_network.in_keys], ] @@ -312,6 +319,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.global_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 06ccea7ff30..63d59e8210c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -144,11 +144,11 @@ class PPOLoss(LossModule): >>> loss = PPOLoss(actor, value) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> data = TensorDict({ - ... "observation": torch.randn(*batch, n_obs), + >>> data = TensorDict({"observation": torch.randn(*batch, n_obs), ... "action": action, ... "sample_log_prob": torch.randn_like(action[..., 1]), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -166,7 +166,7 @@ class PPOLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "sample_log_prob", "next_reward", "next_done"]`` + in_keys of the actor and value network. + ``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network. The return value is a tuple of tensors in the following order: ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not None. @@ -204,6 +204,7 @@ class PPOLoss(LossModule): ... action=action, ... sampleLogProb=torch.randn_like(action[..., 1]) / 10, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() @@ -233,6 +234,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -242,6 +246,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -304,6 +309,7 @@ def _set_in_keys(self): self.tensor_keys.sample_log_prob, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor.in_keys, *[("next", key) for key in self.actor.in_keys], *self.critic.in_keys, @@ -343,6 +349,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -471,6 +478,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index afafcbfd446..dd64a4bc033 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -123,6 +123,7 @@ class REDQLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -145,7 +146,7 @@ class REDQLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``. @@ -186,6 +187,7 @@ class REDQLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward() @@ -214,6 +216,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -223,6 +228,7 @@ class _AcceptedKeys: state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() delay_actor: bool = False @@ -377,6 +383,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -393,6 +400,7 @@ def _set_in_keys(self): self.tensor_keys.sample_log_prob, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -612,5 +620,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 7c314bace36..93910f1eebf 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -98,6 +98,7 @@ class ReinforceLoss(LossModule): ... "observation": torch.randn(batch, n_obs), ... "reward": torch.randn(batch, 1), ... "done": torch.zeros(batch, 1, dtype=torch.bool), + ... "terminated": torch.zeros(batch, 1, dtype=torch.bool), ... }, ... "action": torch.randn(batch, n_act), ... }, [batch]) @@ -113,7 +114,7 @@ class ReinforceLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value"]``. Examples: @@ -141,6 +142,7 @@ class ReinforceLoss(LossModule): ... next_observation=torch.randn(batch, n_obs), ... next_reward=torch.randn(batch, 1), ... next_done=torch.zeros(batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(batch, 1, dtype=torch.bool), ... action=torch.randn(batch, n_act),) >>> loss_actor.backward() @@ -169,6 +171,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ advantage: NestedKey = "advantage" @@ -178,6 +183,7 @@ class _AcceptedKeys: action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -241,6 +247,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -249,6 +256,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic.in_keys, @@ -341,5 +349,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": self.tensor_keys.value_target, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index de4908d1335..09ca452fa19 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -137,6 +137,7 @@ class SACLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -156,7 +157,7 @@ class SACLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + ``"loss_value"`` if version one is used. @@ -199,6 +200,7 @@ class SACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -212,6 +214,7 @@ class SACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -240,6 +243,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -249,6 +255,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -426,6 +433,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self.tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -471,6 +479,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) @@ -487,6 +496,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -831,6 +841,7 @@ class DiscreteSACLoss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -850,7 +861,7 @@ class DiscreteSACLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` @@ -894,6 +905,7 @@ class DiscreteSACLoss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() @@ -918,6 +930,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -926,6 +941,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" log_prob: NestedKey = "log_prob" default_keys = _AcceptedKeys() @@ -1046,6 +1062,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -1054,6 +1071,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -1269,5 +1287,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value_target": "value_target", "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 68d63fbaa47..72ffa64a4f2 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -109,6 +109,7 @@ class TD3Loss(LossModule): ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) @@ -128,7 +129,7 @@ class TD3Loss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. @@ -162,6 +163,7 @@ class TD3Loss(LossModule): ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward() @@ -187,6 +189,9 @@ class _AcceptedKeys: done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. """ action: NestedKey = "action" @@ -194,6 +199,7 @@ class _AcceptedKeys: priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -313,6 +319,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.state_action_value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, ) self._set_in_keys() @@ -321,6 +328,7 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, @@ -504,5 +512,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "value": self.tensor_keys.state_action_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 31c8c291c5b..db056d5ac4d 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -171,12 +171,14 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"advantage"``. value_target (NestedKey): The input tensordict key where the target state value is written to. Will be used for the underlying value estimator Defaults to ``"value_target"``. - value_key (NestedKey): The input tensordict key where the state value is expected. + value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - reward_key (NestedKey): The input tensordict key where the reward is written to. + reward (NestedKey): The input tensordict key where the reward is written to. Defaults to ``"reward"``. - done_key (NestedKey): The key in the input TensorDict that indicates + done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Defaults to ``"terminated"``. steps_to_next_obs_key (NestedKey): The key in the input tensordict that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. @@ -187,6 +189,7 @@ class _AcceptedKeys: value: NestedKey = "state_value" reward: NestedKey = "reward" done: NestedKey = "done" + terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" default_keys = _AcceptedKeys() @@ -212,6 +215,10 @@ def reward_key(self): def done_key(self): return self.tensor_keys.done + @property + def terminated_key(self): + return self.tensor_keys.terminated + @property def steps_to_next_obs_key(self): return self.tensor_keys.steps_to_next_obs @@ -230,10 +237,14 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are - the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as + :obj:`[*B, T, *F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the + feature dimension(s). The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -302,6 +313,7 @@ def in_keys(self): + [ ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), ] + [("next", in_key) for in_key in self.value_network.in_keys] ) @@ -483,10 +495,14 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are - the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, and ``"next"`` + tensordict state as returned by the environment) necessary to + compute the value estimates and the TDEstimate. + The data passed to this module should be structured as + :obj:`[*B, T, *F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the + feature dimension(s). The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -507,7 +523,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -524,7 +541,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) """ if tensordict.batch_dims < 1: @@ -587,8 +605,13 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) value_target = td0_return_estimate( - gamma=gamma, next_state_value=next_value, reward=reward, done=done + gamma=gamma, + next_state_value=next_value, + reward=reward, + done=done, + terminated=terminated, ) return value_target @@ -674,10 +697,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -698,7 +724,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -715,7 +742,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) """ if tensordict.batch_dims < 1: @@ -779,8 +807,14 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) value_target = vec_td1_return_estimate( - gamma, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return value_target @@ -873,10 +907,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", ("next", "reward"), ("next", "done") and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the TDLambdaEstimate. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the TDLambdaEstimate. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -898,7 +935,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward}}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -916,7 +954,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) """ if tensordict.batch_dims < 1: @@ -980,13 +1019,26 @@ def value_estimate( next_value = self._next_value(tensordict, target_params, kwargs=kwargs) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.done), default=done) if self.vectorized: val = vec_td_lambda_return_estimate( - gamma, lmbda, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) else: val = td_lambda_return_estimate( - gamma, lmbda, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return val @@ -1096,10 +1148,13 @@ def forward( Args: tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", "reward", "done" and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the GAE. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + (an observation key, ``"action"``, ``("next", "reward")``, + ``("next", "done")``, ``("next", "terminated")``, + and ``"next"`` tensordict state as returned by the environment) + necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + The tensordict must have shape ``[*B, T]``. params (TensorDictBase, optional): A nested TensorDict containing the params to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the @@ -1122,7 +1177,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward, "terminated": terminated}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() @@ -1141,7 +1197,8 @@ def forward( >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) """ if tensordict.batch_dims < 1: @@ -1178,6 +1235,7 @@ def forward( next_value = tensordict.get(("next", self.tensor_keys.value)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.done), default=done) if self.vectorized: adv, value_target = vec_generalized_advantage_estimate( gamma, @@ -1185,7 +1243,8 @@ def forward( value, next_value, reward, - done, + done=done, + terminated=done, time_dim=tensordict.ndim - 1, ) else: @@ -1195,7 +1254,8 @@ def forward( value, next_value, reward, - done, + done=done, + terminated=terminated, time_dim=tensordict.ndim - 1, ) @@ -1254,8 +1314,16 @@ def value_estimate( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) _, value_target = vec_generalized_advantage_estimate( - gamma, lmbda, value, next_value, reward, done, time_dim=tensordict.ndim - 1 + gamma, + lmbda, + value, + next_value, + reward, + done=done, + terminated=terminated, + time_dim=tensordict.ndim - 1, ) return value_target diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ccd0966bbf6..318ba09d02c 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -2,8 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import math + +import warnings from functools import wraps from typing import Optional, Tuple, Union @@ -51,7 +54,9 @@ def _transpose_time(fun): ) @wraps(fun) - def transposed_fun(*args, time_dim=-2, **kwargs): + def transposed_fun(*args, **kwargs): + time_dim = kwargs.pop("time_dim", -2) + def transpose_tensor(tensor): if ( not isinstance(tensor, (torch.Tensor, MemmapTensor)) @@ -77,7 +82,7 @@ def transpose_tensor(tensor): if time_dim != -2: args, single_dim = zip(*(transpose_tensor(arg) for arg in args)) single_dim = any(single_dim) - for k, item in kwargs.items(): + for k, item in list(kwargs.items()): item, sd = transpose_tensor(item) single_dim = single_dim or sd kwargs[k] = item @@ -116,6 +121,7 @@ def generalized_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Generalized advantage estimate of a trajectory. @@ -129,27 +135,37 @@ def generalized_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) dtype = next_state_value.dtype device = state_value.device - not_done = (~done).int() + not_terminated = (~terminated).int() *batch_size, time_steps, lastdim = not_done.shape advantage = torch.empty( *batch_size, time_steps, lastdim, device=device, dtype=dtype ) prev_advantage = 0 - gnotdone = gamma * not_done - delta = reward + (gnotdone * next_state_value) - state_value - discount = lmbda * gnotdone + g_not_terminated = gamma * not_terminated + delta = reward + (g_not_terminated * next_state_value) - state_value + discount = lmbda * gamma * not_done for t in reversed(range(time_steps)): prev_advantage = advantage[..., t, :] = delta[..., t, :] + ( prev_advantage * discount[..., t, :] @@ -187,6 +203,7 @@ def _fast_vec_gae( state_value: torch.Tensor, next_state_value: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor, gamma: float, lmbda: float, thr: float = 1e-7, @@ -200,7 +217,8 @@ def _fast_vec_gae( reward (torch.Tensor): a [*B, T, F] tensor containing rewards state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function) next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function) - done (torch.Tensor): a [B, T] boolean tensor containing the done states + done (torch.Tensor): a [B, T] boolean tensor containing the done states. + terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states. gamma (scalar): the gamma decay (trajectory discount) lmbda (scalar): the lambda decay (exponential mean discount) thr (float): threshold for the filter. Below this limit, components will ignored. @@ -213,13 +231,14 @@ def _fast_vec_gae( # _gen_num_per_traj and _split_and_pad_sequence need # time dimension at last position done = done.transpose(-2, -1) + terminated = terminated.transpose(-2, -1) reward = reward.transpose(-2, -1) state_value = state_value.transpose(-2, -1) next_state_value = next_state_value.transpose(-2, -1) gammalmbda = gamma * lmbda - not_done = (~done).int() - td0 = reward + not_done * gamma * next_state_value - state_value + not_terminated = (~terminated).int() + td0 = reward + not_terminated * gamma * next_state_value - state_value num_per_traj = _get_num_per_traj(done) td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True) @@ -246,6 +265,7 @@ def vec_generalized_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Vectorized Generalized advantage estimate of a trajectory. @@ -259,23 +279,33 @@ def vec_generalized_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) dtype = state_value.dtype - not_done = (~done).to(dtype) - *batch_size, time_steps, lastdim = not_done.shape + *batch_size, time_steps, lastdim = terminated.shape value = gamma * lmbda if isinstance(value, torch.Tensor) and value.numel() > 1: # create tensor while ensuring that gradients are passed + not_done = (~done).to(dtype) gammalmbdas = not_done * value else: # when gamma and lmbda are scalars, use fast_vec_gae implementation @@ -284,6 +314,7 @@ def vec_generalized_advantage_estimate( state_value=state_value, next_state_value=next_state_value, done=done, + terminated=terminated, gamma=gamma, lmbda=lmbda, ) @@ -299,7 +330,8 @@ def vec_generalized_advantage_estimate( first_below_thr = torch.where(first_below_thr)[0][0].item() gammalmbdas = gammalmbdas[..., :first_below_thr, :] - td0 = reward + not_done * gamma * next_state_value - state_value + not_terminated = (~terminated).to(dtype) + td0 = reward + not_terminated * gamma * next_state_value - state_value if len(batch_size) > 1: td0 = td0.flatten(0, len(batch_size) - 1) @@ -336,6 +368,7 @@ def td0_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """TD(0) advantage estimate of a trajectory. @@ -346,15 +379,25 @@ def td0_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) - returns = td0_return_estimate(gamma, next_state_value, reward, done) + returns = td0_return_estimate(gamma, next_state_value, reward, terminated) advantage = returns - state_value return advantage @@ -363,8 +406,11 @@ def td0_return_estimate( gamma: float, next_state_value: torch.Tensor, reward: torch.Tensor, - done: torch.Tensor, + terminated: torch.Tensor, + *, + done: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + # noqa: D417 """TD(0) discounted return estimate of a trajectory. Also known as bootstrapped Temporal Difference or one-step return. @@ -375,16 +421,24 @@ def td0_return_estimate( must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor - done (Tensor): boolean flag for end of episode. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. + + Keyword Args: + done (Tensor): Deprecated. Use ``terminated`` instead. All tensors (values, reward and done) must have shape ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if done is not None: + warnings.warn( + "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." + ) + if not (next_state_value.shape == reward.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) - not_done = (~done).int() - advantage = reward + gamma * not_done * next_state_value + not_terminated = (~terminated).int() + advantage = reward + gamma * not_terminated * next_state_value return advantage @@ -399,6 +453,7 @@ def td1_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -408,7 +463,9 @@ def td1_return_estimate( gamma (scalar): exponential mean discount. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -436,9 +493,12 @@ def td1_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() + not_terminated = (~terminated).int() returns = torch.empty_like(next_state_value) @@ -456,19 +516,29 @@ def td1_return_estimate( "rolling_gamma=False is expected only with time-sensitive gamma values" ) + done_but_not_terminated = (done & ~terminated).int() if rolling_gamma: - gamma = gamma * not_done + gamma = gamma * not_terminated g = next_state_value[..., -1, :] for i in reversed(range(T)): - g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * g + # if not done (and hence not terminated), get the bootstrapped value + # if done but not terminated, get nex_val + # if terminated, take nothing (gamma = 0) + dnt = done_but_not_terminated[..., i, :] + g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * ( + (1 - dnt) * g + dnt * next_state_value[..., i, :] + ) else: for k in range(T): - g = next_state_value[..., -1, :] + g = 0 _gamma = gamma[..., k, :] - nd = not_done + nd = not_terminated _gamma = _gamma.unsqueeze(-2) * nd for i in reversed(range(k, T)): - g = reward[..., i, :] + _gamma[..., i, :] * g + dnt = done_but_not_terminated[..., i, :] + g = reward[..., i, :] + _gamma[..., i, :] * ( + (1 - dnt) * g + dnt * next_state_value[..., i, :] + ) returns[..., k, :] = g return returns @@ -479,6 +549,7 @@ def td1_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -489,7 +560,9 @@ def td1_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -517,12 +590,26 @@ def td1_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) if not state_value.shape == next_state_value.shape: raise RuntimeError("shape of state_value and next_state_value must match") returns = td1_return_estimate( - gamma, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + next_state_value, + reward, + done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) advantage = returns - state_value return advantage @@ -533,7 +620,8 @@ def vec_td1_return_estimate( gamma, next_state_value, reward, - done, + done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: Optional[bool] = None, time_dim: int = -2, ): @@ -543,7 +631,9 @@ def vec_td1_return_estimate( gamma (scalar, Tensor): exponential mean discount. If tensor-valued, next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -576,6 +666,7 @@ def vec_td1_return_estimate( next_state_value=next_state_value, reward=reward, done=done, + terminated=terminated, rolling_gamma=rolling_gamma, lmbda=1, time_dim=time_dim, @@ -587,7 +678,8 @@ def vec_td1_advantage_estimate( state_value, next_state_value, reward, - done, + done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ): @@ -598,7 +690,9 @@ def vec_td1_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -626,11 +720,25 @@ def vec_td1_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) return ( vec_td1_return_estimate( - gamma, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) - state_value ) @@ -648,6 +756,7 @@ def td_lambda_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -658,7 +767,9 @@ def td_lambda_return_estimate( lmbda (scalar): trajectory discount. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -686,23 +797,26 @@ def td_lambda_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) - not_done = (~done).int() + not_terminated = (~terminated).int() returns = torch.empty_like(next_state_value) + next_state_value = next_state_value * not_terminated *batch, T, lastdim = returns.shape # if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True single_gamma = False - if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape): + if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape): single_gamma = True gamma = torch.full_like(next_state_value, gamma) single_lambda = False - if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == not_done.shape): + if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape): single_lambda = True lmbda = torch.full_like(next_state_value, lmbda) @@ -712,26 +826,28 @@ def td_lambda_return_estimate( raise RuntimeError( "rolling_gamma=False is expected only with time-sensitive gamma or lambda values" ) - if rolling_gamma: - gamma = gamma * not_done g = next_state_value[..., -1, :] for i in reversed(range(T)): + dn = done[..., i, :].int() + nv = next_state_value[..., i, :] + lmd = lmbda[..., i, :] + # if done, the bootstrapped gain is the next value, otherwise it's the + # value we computed during the previous iter + g = g * (1 - dn) + nv * dn g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * ( - (1 - lmbda[..., i, :]) * next_state_value[..., i, :] - + lmbda[..., i, :] * g + (1 - lmd) * nv + lmd * g ) else: for k in range(T): g = next_state_value[..., -1, :] _gamma = gamma[..., k, :] _lambda = lmbda[..., k, :] - nd = not_done - _gamma = _gamma.unsqueeze(-2) * nd for i in reversed(range(k, T)): - g = reward[..., i, :] + _gamma[..., i, :] * ( - (1 - _lambda) * next_state_value[..., i, :] + _lambda * g - ) + dn = done[..., i, :].int() + nv = next_state_value[..., i, :] + g = g * (1 - dn) + nv * dn + g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g) returns[..., k, :] = g return returns @@ -744,6 +860,7 @@ def td_lambda_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ) -> torch.Tensor: @@ -755,7 +872,9 @@ def td_lambda_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -783,12 +902,27 @@ def td_lambda_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) if not state_value.shape == next_state_value.shape: raise RuntimeError("shape of state_value and next_state_value must match") returns = td_lambda_return_estimate( - gamma, lmbda, next_state_value, reward, done, rolling_gamma, time_dim=time_dim + gamma, + lmbda, + next_state_value, + reward, + done, + terminated=terminated, + rolling_gamma=rolling_gamma, + time_dim=time_dim, ) advantage = returns - state_value return advantage @@ -800,6 +934,7 @@ def _fast_td_lambda_return_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor, thr: float = 1e-7, ): """Fast vectorized TD lambda return estimate. @@ -812,7 +947,8 @@ def _fast_td_lambda_return_estimate( lmbda (scalar): the lambda decay (exponential mean discount) next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function) reward (torch.Tensor): a [*B, T, F] tensor containing rewards - done (torch.Tensor): a [B, T] boolean tensor containing the done states + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for end of episode. thr (float): threshold for the filter. Below this limit, components will ignored. Defaults to 1e-7. @@ -822,23 +958,25 @@ def _fast_td_lambda_return_estimate( """ device = reward.device done = done.transpose(-2, -1) + terminated = terminated.transpose(-2, -1) reward = reward.transpose(-2, -1) next_state_value = next_state_value.transpose(-2, -1) + # the only valid next states are those where the trajectory does not terminate + next_state_value = (~terminated).int() * next_state_value + gamma_tensor = torch.tensor([gamma], device=device) gammalmbda = gamma_tensor * lmbda - not_done = (~done).int() num_per_traj = _get_num_per_traj(done) - nvalue_ndone = not_done * next_state_value - t = nvalue_ndone * gamma_tensor * (1 - lmbda) + reward - v3 = torch.zeros_like(t, device=device) - v3[..., -1] = nvalue_ndone[..., -1].clone() + done = done.clone() + done[..., -1] = 1 + not_done = (~done).int() - t_flat, mask = _split_and_pad_sequence( - t + v3 * gammalmbda, num_per_traj, return_mask=True - ) + t = reward + next_state_value * gamma_tensor * (1 - not_done * lmbda) + + t_flat, mask = _split_and_pad_sequence(t, num_per_traj, return_mask=True) gammalmbdas = _geom_series_like(t_flat[0], gammalmbda, thr=thr) @@ -855,6 +993,7 @@ def vec_td_lambda_return_estimate( next_state_value, reward, done, + terminated: torch.Tensor | None = None, rolling_gamma: Optional[bool] = None, time_dim: int = -2, ): @@ -868,7 +1007,9 @@ def vec_td_lambda_return_estimate( must be a [Batch x TimeSteps x 1] tensor reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -896,7 +1037,9 @@ def vec_td_lambda_return_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) gamma_thr = 1e-7 @@ -916,6 +1059,7 @@ def _is_scalar(tensor): next_state_value=next_state_value, reward=reward, done=done, + terminated=terminated, thr=gamma_thr, ) @@ -930,16 +1074,18 @@ def _is_scalar(tensor): """Vectorized version of td_lambda_advantage_estimate""" device = reward.device not_done = (~done).int() + not_terminated = (~terminated).int().transpose(-2, -1).unsqueeze(-2) + if len(batch): + not_terminated = not_terminated.flatten(0, len(batch)) + next_state_value = next_state_value * not_terminated if rolling_gamma is None: rolling_gamma = True - if rolling_gamma: - gamma = gamma * not_done - gammas = _make_gammas_tensor(gamma, T, rolling_gamma) - if not rolling_gamma: - done_follows_done = done[..., 1:, :][done[..., :-1, :]].all() - if not done_follows_done: + terminated_follows_terminated = terminated[..., 1:, :][ + terminated[..., :-1, :] + ].all() + if not terminated_follows_terminated: raise NotImplementedError( "When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, " "make sure that conseducitve trajectories are separated as different batch " @@ -948,46 +1094,47 @@ def _is_scalar(tensor): "consider using the non-vectorized version of the return computation or splitting " "your trajectories." ) - else: - gammas[..., 1:, :] = gammas[..., 1:, :] * not_done.view(-1, 1, T, 1) - gammas_cp = torch.cumprod(gammas, -2) - - lambdas = torch.ones(T + 1, 1, device=device) - lambdas[1:] = lmbda - lambdas_cp = torch.cumprod(lambdas, -2) - - gammas = gammas[..., 1:, :] - lambdas = lambdas[1:] - - dec = gammas_cp * lambdas_cp - if rolling_gamma in (None, True): + if rolling_gamma: + # Make the coefficient table + gammas = _make_gammas_tensor(gamma * not_done, T, rolling_gamma) + gammas_cp = torch.cumprod(gammas, -2) + lambdas = torch.ones(T + 1, 1, device=device) + lambdas[1:] = lmbda + lambdas_cp = torch.cumprod(lambdas, -2) + lambdas = lambdas[1:] + dec = gammas_cp * lambdas_cp + + gammas = _make_gammas_tensor(gamma, T, rolling_gamma) + gammas = gammas[..., 1:, :] if gammas.ndimension() == 4 and gammas.shape[1] > 1: gammas = gammas[:, :1] if lambdas.ndimension() == 4 and lambdas.shape[1] > 1: lambdas = lambdas[:, :1] - v3 = (gammas * lambdas).squeeze(-1) * next_state_value + + not_done = not_done.transpose(-2, -1).unsqueeze(-2) + if len(batch): + not_done = not_done.flatten(0, len(batch)) + # lambdas = lambdas * not_done + + v3 = (gammas * lambdas).squeeze(-1) * next_state_value * not_done v3[..., :-1] = 0 out = _custom_conv1d( - reward + (gammas * (1 - lambdas)).squeeze(-1) * next_state_value + v3, dec + reward + + gammas.squeeze(-1) + * next_state_value + * (1 - lambdas.squeeze(-1) * not_done) + + v3, + dec, ) + return out.view(*batch, lastdim, T).transpose(-2, -1) else: - v1 = _custom_conv1d(reward, dec) - - if gammas.ndimension() == 4 and gammas.shape[1] > 1: - gammas = gammas[:, :, :1].transpose(1, 2) - if lambdas.ndimension() == 4 and lambdas.shape[1] > 1: - lambdas = lambdas[:, :, :1].transpose(1, 2) - - v2 = _custom_conv1d( - next_state_value * not_done.view_as(next_state_value), - dec * (gammas * (1 - lambdas)).transpose(1, 2), + raise NotImplementedError( + "The vectorized version of TD(lambda) with rolling_gamma=False is currently not available. " + "To use this feature, use the non-vectorized version of TD(lambda). You can expect " + "good speed improvements by decorating the function with torch.compile!" ) - v3 = next_state_value * not_done.view_as(next_state_value) - v3[..., :-1] = 0 - v3 = _custom_conv1d(v3, dec * (gammas * lambdas).transpose(1, 2)) - return (v1 + v2 + v3).view(*batch, lastdim, T).transpose(-2, -1) def vec_td_lambda_advantage_estimate( @@ -997,6 +1144,7 @@ def vec_td_lambda_advantage_estimate( next_state_value, reward, done, + terminated: torch.Tensor | None = None, rolling_gamma: bool = None, time_dim: int = -2, ): @@ -1008,7 +1156,9 @@ def vec_td_lambda_advantage_estimate( state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. + done (Tensor): boolean flag for end of trajectory. + terminated (Tensor): boolean flag for the end of episode. Defaults to ``done`` + if not provided. rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma if a gamma tensor is tied to a single event: gamma = [g1, g2, g3, g4] @@ -1036,7 +1186,15 @@ def vec_td_lambda_advantage_estimate( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + if terminated is None: + terminated = done + if not ( + next_state_value.shape + == state_value.shape + == reward.shape + == done.shape + == terminated.shape + ): raise RuntimeError(SHAPE_ERR) return ( vec_td_lambda_return_estimate( @@ -1044,8 +1202,9 @@ def vec_td_lambda_advantage_estimate( lmbda, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) - state_value @@ -1069,7 +1228,8 @@ def reward2go( Args: reward (torch.Tensor): A tensor containing the rewards received at each time step over multiple trajectories. - done (torch.Tensor): A tensor with done (or truncated) states. + done (Tensor): boolean flag for end of episode. Differs from + truncated, where the episode did not end but was interrupted. gamma (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards. Defaults to 1.0. time_dim (int): dimension where the time is unrolled. Defaults to -2. diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index b5e9ce73319..e8e610af122 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -191,20 +191,20 @@ def _flatten_batch(tensor): return tensor.flatten(0, -1) -def _get_num_per_traj(dones_and_truncated): +def _get_num_per_traj(done): """Because we mark the end of each batch with a truncated signal, we can concatenate them. Args: - dones_and_truncated (torch.Tensor): A done or truncated mark of shape [*B, T] + done (torch.Tensor): A done or truncated mark of shape [*B, T] Returns: A list of integers representing the number of steps in each trajectory """ - dones_and_truncated = dones_and_truncated.clone() - dones_and_truncated[..., -1] = True + done = done.clone() + done[..., -1] = True # TODO: find a way of copying once only, eg not using reshape - num_per_traj = torch.where(dones_and_truncated.reshape(-1))[0] + 1 + num_per_traj = torch.where(done.reshape(-1))[0] + 1 num_per_traj[1:] = num_per_traj[1:] - num_per_traj[:-1] return num_per_traj