Skip to content

Commit

Permalink
[BugFix] Avoid reshape(-1) for inputs to objectives modules (#2494)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
kurtamohler and vmoens authored Oct 15, 2024
1 parent 4860674 commit 4736fac
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 87 deletions.
56 changes: 22 additions & 34 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,32 +514,21 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

q_loss, metadata = self.q_loss(tensordict_reshape)
cql_loss, cql_metadata = self.cql_loss(tensordict_reshape)
q_loss, metadata = self.q_loss(tensordict)
cql_loss, cql_metadata = self.cql_loss(tensordict)
if self.with_lagrange:
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(
tensordict_reshape
)
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
metadata.update(alpha_prime_metadata)
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape)
loss_actor, actor_metadata = self.actor_loss(tensordict_reshape)
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
loss_actor, actor_metadata = self.actor_loss(tensordict)
loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
metadata.update(bc_metadata)
metadata.update(cql_metadata)
metadata.update(actor_metadata)
metadata.update(alpha_metadata)
tensordict_reshape.set(
tensordict.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
out = {
"loss_actor": loss_actor,
"loss_actor_bc": loss_actor_bc,
Expand Down Expand Up @@ -682,7 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
)
# take max over actions
state_action_value = state_action_value.reshape(
self.num_qvalue_nets, tensordict.shape[0], self.num_random, -1
torch.Size(
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]
Expand Down Expand Up @@ -739,14 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
"This could be caused by calling cql_loss method before q_loss method."
)

random_actions_tensor = (
torch.FloatTensor(
tensordict.shape[0] * self.num_random,
random_actions_tensor = pred_q1.new_empty(
(
*tensordict.shape[:-1],
tensordict.shape[-1] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
)
.uniform_(-1, 1)
.to(tensordict.device)
)
).uniform_(-1, 1)
curr_actions_td, curr_log_pis = self._get_policy_actions(
tensordict.copy(),
self.actor_network_params,
Expand Down Expand Up @@ -833,31 +823,31 @@ def filter_and_repeat(name, x):
q_new[0] - new_log_pis.detach().unsqueeze(-1),
q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
],
1,
-1,
)
cat_q2 = torch.cat(
[
q_random[1] - random_density,
q_new[1] - new_log_pis.detach().unsqueeze(-1),
q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
],
1,
-1,
)

min_qf1_loss = (
torch.logsumexp(cat_q1 / self.temperature, dim=1)
torch.logsumexp(cat_q1 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)
min_qf2_loss = (
torch.logsumexp(cat_q2 / self.temperature, dim=1)
torch.logsumexp(cat_q2 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)

# Subtract the log likelihood of data
cql_q1_loss = min_qf1_loss - pred_q1 * self.min_q_weight
cql_q2_loss = min_qf2_loss - pred_q2 * self.min_q_weight
cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight

# write cql losses in tensordict for alpha prime loss
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
Expand Down Expand Up @@ -1080,9 +1070,9 @@ def __init__(
self.loss_function = loss_function
if action_space is None:
# infer from value net
try:
if hasattr(value_network, "action_space"):
action_space = value_network.spec
except AttributeError:
else:
# let's try with action_space then
try:
action_space = value_network.action_space
Expand Down Expand Up @@ -1205,8 +1195,6 @@ def value_loss(
with torch.no_grad():
td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)

tensordict.set(
self.tensor_keys.priority,
Expand Down
15 changes: 3 additions & 12 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,23 +495,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape)
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
loss_qvalue, value_metadata = self.qvalue_loss(tensordict)
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
if loss_actor.shape != loss_qvalue.shape:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down
20 changes: 5 additions & 15 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_actor, metadata = self.actor_loss(tensordict_reshape)
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape)
loss_value, metadata_value = self.value_loss(tensordict_reshape)
loss_actor, metadata = self.actor_loss(tensordict)
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict)
loss_value, metadata_value = self.value_loss(tensordict)
metadata.update(metadata_qvalue)
metadata.update(metadata_value)

Expand All @@ -392,13 +385,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
)
tensordict_reshape.set(
tensordict.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
if shape:
tensordict.update(tensordict_reshape.view(shape))

entropy = -tensordict_reshape.get(self.tensor_keys.log_prob).detach()
entropy = -tensordict.get(self.tensor_keys.log_prob).detach()
out = {
"loss_actor": loss_actor,
"loss_qvalue": loss_qvalue,
Expand Down
34 changes: 8 additions & 26 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,30 +577,21 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

if self._version == 1:
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape)
loss_value, _ = self._value_loss(tensordict_reshape)
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict)
loss_value, _ = self._value_loss(tensordict)
else:
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict)
loss_value = None
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_actor, metadata_actor = self._actor_loss(tensordict)
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
if (loss_actor.shape != loss_qvalue.shape) or (
loss_value is not None and loss_actor.shape != loss_value.shape
):
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down Expand Up @@ -1158,26 +1149,17 @@ def in_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_value, metadata_value = self._value_loss(tensordict_reshape)
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_value, metadata_value = self._value_loss(tensordict)
loss_actor, metadata_actor = self._actor_loss(tensordict)
loss_alpha = self._alpha_loss(
log_prob=metadata_actor["log_prob"],
)

tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"])
tensordict.set(self.tensor_keys.priority, metadata_value["td_error"])
if loss_actor.shape != loss_value.shape:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down

1 comment on commit 4736fac

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 4736fac Previous: 4860674 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 33.51988179966991 iter/sec (stddev: 0.18139682304628116) 213.54621168616438 iter/sec (stddev: 0.0008217465351374432) 6.37

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.