Skip to content

Commit

Permalink
[BugFix] Update cql docstring example (#1951)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
BY571 and vmoens authored Feb 22, 2024
1 parent b28bbfe commit 49032ca
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class CQLLoss(LossModule):
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.
Keyword args:
loss_function (str, optional): loss function to be used with
the value function loss. Default is `"smooth_l1"`.
Expand Down Expand Up @@ -127,8 +128,9 @@ class CQLLoss(LossModule):
alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha_prime: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
Expand Down Expand Up @@ -169,10 +171,10 @@ class CQLLoss(LossModule):
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = CQLLoss(actor, qvalue, value)
>>> loss = CQLLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue, _, _, _, _ = loss(
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
Expand All @@ -185,7 +187,7 @@ class CQLLoss(LossModule):
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
Expand Down Expand Up @@ -471,10 +473,11 @@ def out_keys(self):
"loss_qvalue",
"loss_cql",
"loss_alpha",
"loss_alpha_prime",
"alpha",
"entropy",
]
if self.with_lagrange:
keys.append("loss_alpha_prime")
self._out_keys = keys
return self._out_keys

Expand Down Expand Up @@ -886,8 +889,9 @@ class DiscreteCQLLoss(LossModule):
Examples:
>>> from torchrl.modules import MLP
>>> from torchrl.modules import MLP, QValueActor
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.objectives import DiscreteCQLLoss
>>> n_obs, n_act = 4, 3
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
>>> spec = OneHotDiscreteTensorSpec(n_act)
Expand All @@ -905,8 +909,11 @@ class DiscreteCQLLoss(LossModule):
>>> loss(data)
TensorDict(
fields={
loss: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True),
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Expand Down

0 comments on commit 49032ca

Please sign in to comment.