Skip to content

Commit

Permalink
[BugFix] Fix examples (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 21, 2023
1 parent 98cafa5 commit 3c8197b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DQNLoss(LossModule):
"""

default_value_estimator = ValueEstimators.TDLambda
default_value_estimator = ValueEstimators.TD0

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def transpose_tensor(tensor):
or tensor.numel() <= 1
):
return tensor, False
if time_dim < 0:
timedim = tensor.ndim + time_dim
if time_dim >= 0:
timedim = time_dim - tensor.ndim
else:
timedim = time_dim
if timedim < 0 or timedim >= tensor.ndim:
if timedim < -tensor.ndim or timedim >= 0:
raise RuntimeError(ERROR.format(tensor.shape, timedim))
if tensor.ndim >= 2:
single_dim = False
tensor = tensor.transpose(timedim, -2)
elif tensor.ndim == 1 and timedim == 0:
elif tensor.ndim == 1 and timedim == -1:
single_dim = True
tensor = tensor.unsqueeze(-1)
else:
Expand Down

0 comments on commit 3c8197b

Please sign in to comment.