Skip to content

Commit

Permalink
[Doc] Fix advantage examples (#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 4, 2023
1 parent a43612a commit 9ccae47
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def forward(
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
Expand Down Expand Up @@ -743,7 +743,7 @@ def forward(
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
Expand Down Expand Up @@ -955,7 +955,7 @@ def forward(
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
Expand Down Expand Up @@ -1198,7 +1198,7 @@ def forward(
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
Expand Down

0 comments on commit 9ccae47

Please sign in to comment.