Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Jumanji envs #674

Merged
merged 49 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2f1fac8
init
vmoens Nov 2, 2022
df4760d
[jumanji] add support for seeding and batch_size
Nov 7, 2022
e6f4b96
init
vmoens Nov 7, 2022
294dbf0
[Feature] Nested composite spec (#654)
vmoens Nov 7, 2022
d8bdb05
data conversion between Jumanji and TorchRL
Nov 8, 2022
080fb47
Merge branch 'main' into refactor_next_bis
vmoens Nov 8, 2022
fce326b
[Feature] Move `transform.forward` to `transform.step` (#660)
vmoens Nov 10, 2022
d569512
feedback from vmoens
Nov 11, 2022
428c1d3
linter fix
Nov 11, 2022
b3b36cd
amend
vmoens Nov 11, 2022
16ef495
amend
vmoens Nov 11, 2022
47da10f
fix dtype
Nov 11, 2022
4617bb6
amend
vmoens Nov 11, 2022
a466faf
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 11, 2022
ed9ccbc
fixing key names
vmoens Nov 11, 2022
b6902ae
fixing key names
vmoens Nov 11, 2022
aa02748
fix tensor shape
Nov 11, 2022
9765655
Merge branch 'keyname_fix' into jumanji
vmoens Nov 11, 2022
8396a32
tests workflow
vmoens Nov 11, 2022
a9d3b2f
Merge branch 'keyname_fix' into refactor_next_bis
vmoens Nov 11, 2022
8431af2
Merge branch 'refactor_next_bis', remote-tracking branch 'origin/refa…
vmoens Nov 11, 2022
710cfc1
t checkout Merge branch 'main' of github.com:pytorch/rl
vmoens Nov 11, 2022
3d75f9d
Merge remote-tracking branch 'yingchenlin/jumanji' into jumanji
vmoens Nov 11, 2022
3837d3c
Merge branch 'main' into jumanji
vmoens Nov 11, 2022
f905f27
Merge remote-tracking branch 'origin/main' into jumanji
vmoens Nov 11, 2022
b31f588
bug fix
Nov 14, 2022
9afb313
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
485fc97
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
3220b0e
amend
vmoens Nov 15, 2022
01683c9
Merge branch 'main' into jumanji
vmoens Nov 16, 2022
73bb98b
amend
vmoens Nov 16, 2022
3513800
amend
vmoens Nov 16, 2022
00d85e5
cleanup
vmoens Nov 16, 2022
0c9cd8a
cleanup
vmoens Nov 16, 2022
6dc0a94
bf
vmoens Nov 18, 2022
6249102
bf
vmoens Nov 18, 2022
e242a1f
bf
vmoens Nov 18, 2022
704fd80
bf
vmoens Nov 18, 2022
b887286
lint
vmoens Nov 18, 2022
037c4f4
Merge branch 'main' into jumanji
vmoens Nov 18, 2022
020404e
amend
vmoens Nov 18, 2022
51757ba
refactor
Nov 18, 2022
b14a029
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 18, 2022
a6bd1f5
refactor
Nov 18, 2022
b58ee1b
refactor
Nov 19, 2022
60f0bec
amend
vmoens Nov 19, 2022
d845598
amend
vmoens Nov 19, 2022
24c293f
amend
vmoens Nov 19, 2022
4a56a86
amend
vmoens Nov 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bf
  • Loading branch information
vmoens committed Nov 18, 2022
commit 6dc0a943eaf018b3f07929925562d43040ac3419
11 changes: 6 additions & 5 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def _test_fake_tensordict(env: EnvBase):
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()
fake_tensordict.zero_()
real_tensordict.zero_()
assert (fake_tensordict == real_tensordict).all()
assert (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
== real_tensordict.apply(lambda x: torch.zeros_like(x))
).all()
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

Expand All @@ -69,10 +70,10 @@ def _check_dtype(key, value, obs_spec, input_spec):
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value)
assert input_spec[key].is_in(value), (input_spec[key], value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value)
assert obs_spec[key].is_in(value), (input_spec[key], value)
return
else:
raise KeyError(key)
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,11 +646,11 @@ def to(self, device: DEVICE_TYPING) -> EnvBase:
def fake_tensordict(self) -> TensorDictBase:
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
input_spec = self.input_spec
fake_input = input_spec.zero(self.batch_size)
fake_input = input_spec.rand(self.batch_size)
observation_spec = self.observation_spec
fake_obs = observation_spec.zero(self.batch_size)
fake_obs = observation_spec.rand(self.batch_size)
reward_spec = self.reward_spec
fake_reward = reward_spec.zero(self.batch_size)
fake_reward = reward_spec.rand(self.batch_size)
fake_td = TensorDict(
{
**fake_obs,
Expand Down