Skip to content

Commit

Permalink
[BugFix] Set strict=False in tensordict.select() calls for objective …
Browse files Browse the repository at this point in the history
…classes (#2004)
  • Loading branch information
albertbou92 authored Mar 12, 2024
1 parent 2b8450c commit c371266
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 41 deletions.
8 changes: 6 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def _log_probs(
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor_network.in_keys).clone()
tensordict_clone = tensordict.select(
*self.actor_network.in_keys, strict=False
).clone()
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
Expand All @@ -425,7 +427,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic_network.in_keys)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
Expand Down
14 changes: 10 additions & 4 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)

td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q.set(self.tensor_keys.action, a_reparm)
td_q = self._vmap_qvalue_networkN0(
td_q,
Expand Down Expand Up @@ -612,7 +612,9 @@ def filter_and_repeat(name, x):
# tensordict.del_("scale")

return (
tensordict.select(*self.actor_network.in_keys, self.tensor_keys.action),
tensordict.select(
*self.actor_network.in_keys, self.tensor_keys.action, strict=False
),
sample_log_prob,
)

Expand Down Expand Up @@ -680,7 +682,9 @@ def q_loss(self, tensordict: TensorDictBase) -> Tensor:
self.target_qvalue_network_params,
)

tensordict_pred_q = tensordict.select(*self.qvalue_network.in_keys)
tensordict_pred_q = tensordict.select(
*self.qvalue_network.in_keys, strict=False
)
q_pred = self._vmap_qvalue_networkN0(
tensordict_pred_q, self.qvalue_network_params
).get(self.tensor_keys.state_action_value)
Expand Down Expand Up @@ -746,7 +750,9 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor:
)
# select and stack input params
# q value random action
tensordict_q_random = tensordict.select(*self.actor_network.in_keys)
tensordict_q_random = tensordict.select(
*self.actor_network.in_keys, strict=False
)

batch_size = tensordict_q_random.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def loss_actor(
tensordict: TensorDictBase,
) -> [torch.Tensor, dict]:
td_copy = tensordict.select(
*self.actor_in_keys, *self.value_exclusive_keys
*self.actor_in_keys, *self.value_exclusive_keys, strict=False
).detach()
with self.actor_network_params.to_module(self.actor_network):
td_copy = self.actor_network(td_copy)
Expand All @@ -318,7 +318,7 @@ def loss_value(
tensordict: TensorDictBase,
) -> Tuple[torch.Tensor, dict]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys).detach()
td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
Expand Down
10 changes: 5 additions & 5 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ def _cached_detach_qvalue_network_params(self):

def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
obs_keys = self.actor_network.in_keys
tensordict_clone = tensordict.select(*obs_keys)
tensordict_clone = tensordict.select(*obs_keys, strict=False)
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
self.actor_network(tensordict_clone)

tensordict_expand = self._vmap_qvalue_networkN0(
tensordict_clone.select(*self.qvalue_network.in_keys),
tensordict_clone.select(*self.qvalue_network.in_keys, strict=False),
self._cached_detach_qvalue_network_params,
)
state_action_value = tensordict_expand.get("state_action_value").squeeze(-1)
Expand All @@ -352,7 +352,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:

obs_keys = self.actor_network.in_keys
tensordict = tensordict.select(
"next", *obs_keys, self.tensor_keys.action
"next", *obs_keys, self.tensor_keys.action, strict=False
).clone(False)

selected_models_idx = torch.randperm(self.num_qvalue_nets)[
Expand All @@ -362,7 +362,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
selected_q_params = self.target_qvalue_network_params[selected_models_idx]

next_td = step_mdp(tensordict).select(
*self.actor_network.in_keys
*self.actor_network.in_keys, strict=False
) # next_observation ->
# observation
# select pseudo-action
Expand Down Expand Up @@ -390,7 +390,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
tensordict.set(("next", "state_value"), next_state_value)
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
tensordict_expand = self._vmap_qvalue_networkN0(
tensordict.select(*self.qvalue_network.in_keys),
tensordict.select(*self.qvalue_network.in_keys, strict=False),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get("state_action_value").squeeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

def forward(self, fake_data) -> torch.Tensor:
lambda_target = fake_data.get("lambda_target")
tensordict_select = fake_data.select(*self.value_model.in_keys)
tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False)
self.value_model(tensordict_select)
if self.discount_loss:
discount = self.gamma * torch.ones_like(
Expand Down
32 changes: 20 additions & 12 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])

# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)

Expand All @@ -405,7 +405,9 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
)
# state value
with torch.no_grad():
td_copy = tensordict.select(*self.value_network.in_keys).detach()
td_copy = tensordict.select(
*self.value_network.in_keys, strict=False
).detach()
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(
Expand All @@ -423,11 +425,11 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:

def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
# state value
td_copy = tensordict.select(*self.value_network.in_keys)
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
Expand All @@ -437,13 +439,15 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:

def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
obs_keys = self.actor_network.in_keys
tensordict = tensordict.select("next", *obs_keys, self.tensor_keys.action)
tensordict = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
)

target_value = self.value_estimator.value_estimate(
tensordict, target_params=self.target_value_network_params
).squeeze(-1)
tensordict_expand = self._vmap_qvalue_networkN0(
tensordict.select(*self.qvalue_network.in_keys),
tensordict.select(*self.qvalue_network.in_keys, strict=False),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
Expand Down Expand Up @@ -752,7 +756,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])

# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
Expand All @@ -773,7 +777,9 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
)
with torch.no_grad():
# state value
td_copy = tensordict.select(*self.value_network.in_keys).detach()
td_copy = tensordict.select(
*self.value_network.in_keys, strict=False
).detach()
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(
Expand All @@ -793,7 +799,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
# Min Q value
with torch.no_grad():
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
Expand All @@ -809,7 +815,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
chosen_state_action_value = (state_action_value * action).sum(-1)
min_Q, _ = torch.min(chosen_state_action_value, dim=0)
# state value
td_copy = tensordict.select(*self.value_network.in_keys)
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
Expand All @@ -819,14 +825,16 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:

def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
obs_keys = self.actor_network.in_keys
next_td = tensordict.select("next", *obs_keys, self.tensor_keys.action)
next_td = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
)
with torch.no_grad():
target_value = self.value_estimator.value_estimate(
next_td, target_params=self.target_value_network_params
).squeeze(-1)

# predict current Q value
td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.qvalue_network_params)
state_action_value = td_q.get(self.tensor_keys.state_action_value)
action = tensordict.get(self.tensor_keys.action)
Expand Down
6 changes: 3 additions & 3 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _qvalue_params_cat(self, selected_q_params):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
tensordict_select = tensordict.clone(False).select(
"next", *obs_keys, self.tensor_keys.action
"next", *obs_keys, self.tensor_keys.action, strict=False
)
selected_models_idx = torch.randperm(self.num_qvalue_nets)[
: self.sub_sample_len
Expand All @@ -444,10 +444,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)

tensordict_actor_grad = tensordict_select.select(
*obs_keys
*obs_keys, strict=False
) # to avoid overwriting keys
next_td_actor = step_mdp(tensordict_select).select(
*self.actor_network.in_keys
*self.actor_network.in_keys, strict=False
) # next_observation ->
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
# tensordict_actor = tensordict_actor.contiguous()
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic_network.in_keys)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
Expand Down
10 changes: 5 additions & 5 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def _actor_loss(
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)

td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q.set(self.tensor_keys.action, a_reparm)
td_q = self._vmap_qnetworkN0(
td_q,
Expand Down Expand Up @@ -719,7 +719,7 @@ def _qvalue_v2_loss(
target_value = self._compute_target_v2(tensordict)

tensordict_expand = self._vmap_qnetworkN0(
tensordict.select(*self.qvalue_network.in_keys),
tensordict.select(*self.qvalue_network.in_keys, strict=False),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
Expand All @@ -738,7 +738,7 @@ def _value_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys).detach()
td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1)
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def _value_loss(
) -> Tuple[Tensor, Dict[str, Tensor]]:
target_value = self._compute_target(tensordict)
tensordict_expand = self._vmap_qnetworkN0(
tensordict.select(*self.qvalue_network.in_keys),
tensordict.select(*self.qvalue_network.in_keys, strict=False),
self.qvalue_network_params,
)

Expand Down Expand Up @@ -1236,7 +1236,7 @@ def _actor_loss(
prob = dist.probs
log_prob = dist.logits

td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)

td_q = self._vmap_qnetworkN0(
td_q, self._cached_detached_qvalue_params # should we clone?
Expand Down
14 changes: 9 additions & 5 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,13 @@ def _cached_stack_actor_params(self):
)

def actor_loss(self, tensordict):
tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys)
tensordict_actor_grad = tensordict.select(
*self.actor_network.in_keys, strict=False
)
with self.actor_network_params.to_module(self.actor_network):
tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
actor_loss_td = tensordict_actor_grad.select(
*self.qvalue_network.in_keys
*self.qvalue_network.in_keys, strict=False
).expand(
self.num_qvalue_nets, *tensordict_actor_grad.batch_size
) # for actor loss
Expand Down Expand Up @@ -389,7 +391,7 @@ def value_loss(self, tensordict):

with torch.no_grad():
next_td_actor = step_mdp(tensordict).select(
*self.actor_network.in_keys
*self.actor_network.in_keys, strict=False
) # next_observation ->
with self.target_actor_network_params.to_module(self.actor_network):
next_td_actor = self.actor_network(next_td_actor)
Expand All @@ -400,7 +402,9 @@ def value_loss(self, tensordict):
self.tensor_keys.action,
next_action,
)
next_val_td = next_td_actor.select(*self.qvalue_network.in_keys).expand(
next_val_td = next_td_actor.select(
*self.qvalue_network.in_keys, strict=False
).expand(
self.num_qvalue_nets, *next_td_actor.batch_size
) # for next value estimation
next_target_q1q2 = (
Expand All @@ -420,7 +424,7 @@ def value_loss(self, tensordict):
next_target_qvalue.unsqueeze(-1),
)

qval_td = tensordict.select(*self.qvalue_network.in_keys).expand(
qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
self.num_qvalue_nets,
*tensordict.batch_size,
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _call_actor_net(
log_prob_key: NestedKey,
):
# TODO: extend to handle time dimension (and vmap?)
log_pi = actor_net(data.select(*actor_net.in_keys)).get(log_prob_key)
log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key)
return log_pi


Expand Down

0 comments on commit c371266

Please sign in to comment.