Skip to content

Commit

Permalink
[BugFix] Update iql docstring example (#1950)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
BY571 and vmoens authored Feb 22, 2024
1 parent 49032ca commit 15876b8
Showing 1 changed file with 45 additions and 49 deletions.
94 changes: 45 additions & 49 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,22 @@ class IQLLoss(LossModule):
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
>>> class QValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> qvalue = SafeModule(
... QValueClass(),
... in_keys=["observation", "action"],
... out_keys=["state_action_value"],
... )
>>> value = SafeModule(
... nn.Linear(n_obs, 1),
... in_keys=["observation"],
... out_keys=["state_value"],
... )
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
Expand Down Expand Up @@ -134,20 +136,22 @@ class IQLLoss(LossModule):
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
>>> class QValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> qvalue = SafeModule(
... QValueClass(),
... in_keys=["observation", "action"],
... out_keys=["state_action_value"],
... )
>>> value = SafeModule(
... nn.Linear(n_obs, 1),
... in_keys=["observation"],
... out_keys=["state_value"],
... )
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
Expand All @@ -165,7 +169,7 @@ class IQLLoss(LossModule):
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
Expand Down Expand Up @@ -495,7 +499,7 @@ class DiscreteIQLLoss(IQLLoss):
Args:
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule): Q(s) parametric model
qvalue_network (TensorDictModule): Q(s, a) parametric model.
value_network (TensorDictModule, optional): V(s) parametric model.
Keyword Args:
Expand Down Expand Up @@ -526,34 +530,33 @@ class DiscreteIQLLoss(IQLLoss):
>>> import torch
>>> from torch import nn
>>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import DiscreteIQLLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = OneHotDiscreteTensorSpec(n_act)
>>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["logits"],
... out_keys=["action"],
... spec=spec,
... distribution_class=OneHotCategorical)
>>> qvalue = TensorDictModule(
... nn.Linear(n_obs),
>>> qvalue = SafeModule(
... nn.Linear(n_obs, n_act),
... in_keys=["observation"],
... out_keys=["state_action_value"],
... )
>>> value = TensorDictModule(
... nn.Linear(n_obs),
>>> value = SafeModule(
... nn.Linear(n_obs, 1),
... in_keys=["observation"],
... out_keys=["state_value"],
... )
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> action = spec.rand(batch).long()
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
Expand Down Expand Up @@ -585,40 +588,33 @@ class DiscreteIQLLoss(IQLLoss):
>>> import torch
>>> from torch import nn
>>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import DiscreteIQLLoss
>>> from tensordict import TensorDict
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = OneHotDiscreteTensorSpec(n_act)
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"])
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["logits"],
... out_keys=["action"],
... spec=spec,
... distribution_class=OneHotCategorical)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs, n_act)
... def forward(self, obs):
... return self.linear(obs)
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation'])
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> qvalue = SafeModule(
... nn.Linear(n_obs, n_act),
... in_keys=["observation"],
... out_keys=["state_action_value"],
... )
>>> value = SafeModule(
... nn.Linear(n_obs, 1),
... in_keys=["observation"],
... out_keys=["state_value"],
... )
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> action = spec.rand(batch).long()
>>> loss_actor, loss_qvalue, loss_value, entropy = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
Expand All @@ -633,7 +629,7 @@ class DiscreteIQLLoss(IQLLoss):
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
>>> loss_actor, loss_qvalue, loss_value = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
Expand Down

0 comments on commit 15876b8

Please sign in to comment.