Skip to content

Commit

Permalink
[BugFix] Fix missing ("next", "observation") key in dispatch of losses (
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 6, 2023
1 parent 61f6915 commit 9467036
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 38 deletions.
4 changes: 3 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,12 +979,12 @@ def test_ddpg_notensordict(self):
value = self._create_mock_value()
td = self._create_mock_data_ddpg()
loss = DDPGLoss(actor, value)
loss.make_value_estimator(ValueEstimators.TD1)

kwargs = {
"observation": td.get("observation"),
"next_reward": td.get(("next", "reward")),
"next_done": td.get(("next", "done")),
"next_observation": td.get(("next", "observation")),
"action": td.get("action"),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
Expand Down Expand Up @@ -3623,6 +3623,7 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke

kwargs = {
observation_key: td.get(observation_key),
f"next_{observation_key}": td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
action_key: td.get(action_key),
Expand Down Expand Up @@ -4675,6 +4676,7 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
f"next_{observation_key}": td.get(("next", observation_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

Expand Down
54 changes: 28 additions & 26 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class A2CLoss(LossModule):
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
Expand Down Expand Up @@ -142,13 +143,13 @@ class A2CLoss(LossModule):
... in_keys=["observation"])
>>> loss = A2CLoss(actor, value, loss_critic_type="l2")
>>> batch = [2, ]
>>> loss_val = loss(
>>> loss_obj, loss_critic, entropy, loss_entropy = loss(
... observation = torch.randn(*batch, n_obs),
... action = spec.rand(batch),
... next_done = torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward = torch.randn(*batch, 1))
>>> loss_val
(tensor(1.7593, grad_fn=<MeanBackward0>), tensor(0.2344, grad_fn=<MeanBackward0>), tensor(1.5480), tensor(-0.0155, grad_fn=<MulBackward0>))
... next_reward = torch.randn(*batch, 1),
... next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()
"""

@dataclass
Expand Down Expand Up @@ -227,6 +228,29 @@ def __init__(
self.gamma = gamma
self.loss_critic_type = loss_critic_type

@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor.in_keys,
*[("next", key) for key in self.actor.in_keys],
]
if self.critic_coef:
keys.extend(self.critic.in_keys)
return list(set(keys))

@property
def out_keys(self):
outs = ["loss_objective"]
if self.critic_coef:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
outs.append("loss_entropy")
return outs

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
Expand Down Expand Up @@ -289,28 +313,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
)
return self.critic_coef * loss_value

@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
]
keys.extend(self.actor.in_keys)
if self.critic_coef:
keys.extend(self.critic.in_keys)
return list(set(keys))

@property
def out_keys(self):
outs = ["loss_objective"]
if self.critic_coef:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
outs.append("loss_entropy")
return outs

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
Expand Down
12 changes: 7 additions & 5 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class DDPGLoss(LossModule):
... "action": spec.rand(batch),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
Expand Down Expand Up @@ -108,13 +109,13 @@ class DDPGLoss(LossModule):
... module=module,
... in_keys=["observation", "action"])
>>> loss = DDPGLoss(actor, value)
>>> loss_val = loss(
>>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
... next_done=torch.zeros(1, dtype=torch.bool),
... next_observation=torch.randn(n_obs),
... next_reward=torch.randn(1))
>>> loss_val
(tensor(-0.8247, grad_fn=<MeanBackward0>), tensor(1.3344, grad_fn=<MeanBackward0>), tensor(0.6193), tensor(1.7744), tensor(0.6193), tensor(1.7744))
>>> loss_actor.backward()
"""

Expand Down Expand Up @@ -202,9 +203,10 @@ def in_keys(self):
keys = [
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor_in_keys,
*[("next", key) for key in self.actor_in_keys],
*self.value_network.in_keys,
]
keys += self.value_network.in_keys
keys += self.actor_in_keys
keys = list(set(keys))
return keys

Expand Down
14 changes: 8 additions & 6 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class IQLLoss(LossModule):
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
Expand Down Expand Up @@ -152,13 +153,13 @@ class IQLLoss(LossModule):
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_val = loss(
>>> loss_actor, loss_qvlaue, loss_value, entropy = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_val
(tensor(1.4535, grad_fn=<MeanBackward0>), tensor(0.8389, grad_fn=<MeanBackward0>), tensor(0.3406, grad_fn=<MeanBackward0>), tensor(3.3441))
>>> loss_actor.backward()
"""

Expand Down Expand Up @@ -269,10 +270,11 @@ def in_keys(self):
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.qvalue_network.in_keys,
*self.value_network.in_keys,
]
keys.extend(self.actor_network.in_keys)
keys.extend(self.qvalue_network.in_keys)
keys.extend(self.value_network.in_keys)

return list(set(keys))

Expand Down

1 comment on commit 9467036

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 9467036 Previous: 61f6915 Ratio
benchmarks/test_objectives_benchmarks.py::test_values[td0_return_estimate-False-False] 2726.7636652941505 iter/sec (stddev: 0.00011459219585227928) 5668.670730799731 iter/sec (stddev: 0.000024439708272042735) 2.08
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 206.0166184936221 iter/sec (stddev: 0.0005706787796717586) 415.13363476293216 iter/sec (stddev: 0.00012745330443058114) 2.02

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.