Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix SAC alpha optim #1192

Merged
merged 44 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
30a8763
init
vmoens May 24, 2023
e7f1dac
amend
vmoens May 24, 2023
79b5912
amend
vmoens May 24, 2023
be142ad
amend
vmoens May 24, 2023
95be960
amend
vmoens May 24, 2023
1c81ec3
amend
vmoens May 24, 2023
5e769d0
Merge remote-tracking branch 'origin/main' into fix_sac
vmoens May 24, 2023
3d8db39
amend
vmoens May 24, 2023
5080b09
amend
vmoens May 24, 2023
5418c6e
amend
vmoens May 25, 2023
f38a64c
amend
vmoens May 25, 2023
4a2fff4
amend
vmoens May 25, 2023
8efd0c4
amend
vmoens May 25, 2023
17b2808
amend
vmoens May 25, 2023
1cddbcc
amend
vmoens May 25, 2023
018b513
lint
vmoens May 25, 2023
17f4056
lint
vmoens May 25, 2023
38965d2
lint
vmoens May 25, 2023
a3c8b9e
fix
vmoens May 25, 2023
cb68bfd
fix
vmoens May 25, 2023
7ff6648
fix
vmoens May 25, 2023
86d9065
amend
vmoens May 25, 2023
4e10dc3
fix
vmoens May 25, 2023
c0a0e09
Merge remote-tracking branch 'origin/main' into fix_sac
vmoens May 26, 2023
2508cf6
Merge remote-tracking branch 'origin/main' into fix_sac
vmoens May 26, 2023
cccfea8
Merge remote-tracking branch 'origin/fix_sac' into fix_sac
vmoens May 26, 2023
c5e61f5
amend
vmoens May 26, 2023
1b50ca8
amend
vmoens May 26, 2023
7a14703
lint
vmoens May 26, 2023
c529a01
init
vmoens May 26, 2023
65b9241
Merge branch 'fix_brax' into fix_sac
vmoens May 26, 2023
68ea2c0
fix setup
vmoens May 26, 2023
a74d5c4
lint
vmoens May 26, 2023
8ff8c33
amend
vmoens May 26, 2023
364aba4
amend
vmoens May 26, 2023
cb1e975
amend
vmoens May 26, 2023
1e4b01a
amend
vmoens May 26, 2023
373dacd
Merge remote-tracking branch 'origin/main' into fix_sac
vmoens May 26, 2023
3afad13
Merge branch 'main' into fix_sac
vmoens May 26, 2023
8cf1d8a
fix
vmoens May 26, 2023
484de95
fix
vmoens May 26, 2023
c37d5e3
Merge remote-tracking branch 'origin/main' into fix_sac
vmoens May 26, 2023
556cea5
fix
vmoens May 26, 2023
c9b312a
fix
vmoens May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 25, 2023
commit 17b28084e05d7709d36056ecfa9d08adaa39783f
1 change: 1 addition & 0 deletions torchrl/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

torch::Tensor safetanh(torch::Tensor input) {
auto out = torch::tanh(input);
torch::NoGradGuard no_grad;
auto data = out.detach();
data.clamp_(-0.999999, 0.999999);
return out;
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(
min: Union[torch.Tensor, Number] = -1.0,
max: Union[torch.Tensor, Number] = 1.0,
event_dims: int = 1,
tanh_loc: bool = True,
tanh_loc: bool = False,
):
err_msg = "TanhNormal max values must be strictly greater than min values"
if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
device = self.device
td_device = tensordict_reshape.to(device)

loss_actor = self._loss_actor(td_device)
if self._version == 1:
loss_qvalue, priority = self._loss_qvalue_v1(td_device)
loss_value = self._loss_value(td_device)
else:
loss_qvalue, priority = self._loss_qvalue_v2(td_device)
loss_value = None
loss_actor = self._loss_actor(td_device)
loss_alpha = self._loss_alpha(td_device)
tensordict_reshape.set(self.priority_key, priority)
if (loss_actor.shape != loss_qvalue.shape) or (
Expand Down