Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for GRU #1586

Merged
merged 37 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cd70b16
init
vmoens Jul 28, 2023
4126035
init
vmoens Jul 28, 2023
888fee5
amend
vmoens Aug 10, 2023
5fcb42c
init
vmoens Aug 11, 2023
871589e
amend
vmoens Aug 11, 2023
858da29
amend
vmoens Aug 11, 2023
473b200
Merge branch 'main' into event_mp
vmoens Aug 31, 2023
871bdc7
amend
vmoens Aug 31, 2023
bb352e5
amend
vmoens Aug 31, 2023
28cb428
amend
vmoens Aug 31, 2023
26b53bc
amend
vmoens Aug 31, 2023
77872a6
amend
vmoens Aug 31, 2023
c7ed82e
amend
vmoens Aug 31, 2023
fd089d1
lint
vmoens Aug 31, 2023
d6f304a
fixes
vmoens Aug 31, 2023
df0210e
amend
vmoens Aug 31, 2023
78a06d5
amend
vmoens Aug 31, 2023
45615c3
amend
vmoens Aug 31, 2023
2138931
Merge branch 'event_mp' into fix_lstm_penv
vmoens Aug 31, 2023
bc3abd2
amend
vmoens Aug 31, 2023
e0d81ef
tmp
vmoens Aug 31, 2023
b839208
Merge branch 'main' into fix_lstm_penv
vmoens Sep 1, 2023
9fc95fc
amend
vmoens Sep 1, 2023
8f3ed5e
amend
vmoens Sep 1, 2023
ff4bb70
Merge remote-tracking branch 'origin/main' into fix_lstm_penv
vmoens Sep 1, 2023
e6fa755
amend
vmoens Sep 1, 2023
38f74a2
amend
vmoens Sep 1, 2023
b568c3b
amend
vmoens Sep 1, 2023
71d1076
amend
vmoens Sep 1, 2023
5f7885e
init
vmoens Sep 1, 2023
6b46000
Merge branch 'main' into gru
vmoens Oct 1, 2023
b0434d9
Merge remote-tracking branch 'origin/main' into gru
vmoens Oct 4, 2023
b053bee
amend
vmoens Oct 4, 2023
274daef
Merge remote-tracking branch 'origin/main' into gru
vmoens Oct 5, 2023
c40e3db
Update test/mocking_classes.py
vmoens Oct 5, 2023
cec7987
Merge branch 'gru' of github.com:pytorch/rl into gru
vmoens Oct 5, 2023
da7e173
amend
vmoens Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Aug 11, 2023
commit 858da29bfab4bbbf0ca219d122a248da8e209565
4 changes: 2 additions & 2 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
)
for key in self._selected_reset_keys:
if key != "_reset":
if key != ("_reset",):
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
Expand Down Expand Up @@ -879,7 +879,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
)
for key in self._selected_reset_keys:
if key != "_reset":
if key != ("_reset",):
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
else:
Expand Down
13 changes: 5 additions & 8 deletions tutorials/sphinx-tutorials/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,13 @@ def _step(tensordict):
new_th = th + new_thdot * dt
reward = -costs.view(*tensordict.shape, 1)
done = torch.zeros_like(reward, dtype=torch.bool)
# The output must be written in a ``"next"`` entry
out = TensorDict(
{
"next": {
"th": new_th,
"thdot": new_thdot,
"params": tensordict["params"],
"reward": reward,
"done": done,
}
"th": new_th,
"thdot": new_thdot,
"params": tensordict["params"],
"reward": reward,
"done": done,
},
tensordict.shape,
)
Expand Down