Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Jumanji envs #674

Merged
merged 49 commits into from
Nov 19, 2022
Merged
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2f1fac8
init
vmoens Nov 2, 2022
df4760d
[jumanji] add support for seeding and batch_size
Nov 7, 2022
e6f4b96
init
vmoens Nov 7, 2022
294dbf0
[Feature] Nested composite spec (#654)
vmoens Nov 7, 2022
d8bdb05
data conversion between Jumanji and TorchRL
Nov 8, 2022
080fb47
Merge branch 'main' into refactor_next_bis
vmoens Nov 8, 2022
fce326b
[Feature] Move `transform.forward` to `transform.step` (#660)
vmoens Nov 10, 2022
d569512
feedback from vmoens
Nov 11, 2022
428c1d3
linter fix
Nov 11, 2022
b3b36cd
amend
vmoens Nov 11, 2022
16ef495
amend
vmoens Nov 11, 2022
47da10f
fix dtype
Nov 11, 2022
4617bb6
amend
vmoens Nov 11, 2022
a466faf
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 11, 2022
ed9ccbc
fixing key names
vmoens Nov 11, 2022
b6902ae
fixing key names
vmoens Nov 11, 2022
aa02748
fix tensor shape
Nov 11, 2022
9765655
Merge branch 'keyname_fix' into jumanji
vmoens Nov 11, 2022
8396a32
tests workflow
vmoens Nov 11, 2022
a9d3b2f
Merge branch 'keyname_fix' into refactor_next_bis
vmoens Nov 11, 2022
8431af2
Merge branch 'refactor_next_bis', remote-tracking branch 'origin/refa…
vmoens Nov 11, 2022
710cfc1
t checkout Merge branch 'main' of github.com:pytorch/rl
vmoens Nov 11, 2022
3d75f9d
Merge remote-tracking branch 'yingchenlin/jumanji' into jumanji
vmoens Nov 11, 2022
3837d3c
Merge branch 'main' into jumanji
vmoens Nov 11, 2022
f905f27
Merge remote-tracking branch 'origin/main' into jumanji
vmoens Nov 11, 2022
b31f588
bug fix
Nov 14, 2022
9afb313
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
485fc97
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
3220b0e
amend
vmoens Nov 15, 2022
01683c9
Merge branch 'main' into jumanji
vmoens Nov 16, 2022
73bb98b
amend
vmoens Nov 16, 2022
3513800
amend
vmoens Nov 16, 2022
00d85e5
cleanup
vmoens Nov 16, 2022
0c9cd8a
cleanup
vmoens Nov 16, 2022
6dc0a94
bf
vmoens Nov 18, 2022
6249102
bf
vmoens Nov 18, 2022
e242a1f
bf
vmoens Nov 18, 2022
704fd80
bf
vmoens Nov 18, 2022
b887286
lint
vmoens Nov 18, 2022
037c4f4
Merge branch 'main' into jumanji
vmoens Nov 18, 2022
020404e
amend
vmoens Nov 18, 2022
51757ba
refactor
Nov 18, 2022
b14a029
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 18, 2022
a6bd1f5
refactor
Nov 18, 2022
b58ee1b
refactor
Nov 19, 2022
60f0bec
amend
vmoens Nov 19, 2022
d845598
amend
vmoens Nov 19, 2022
24c293f
amend
vmoens Nov 19, 2022
4a56a86
amend
vmoens Nov 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleanup
  • Loading branch information
vmoens committed Nov 16, 2022
commit 00d85e54324543e3efd9e287aa084fdb8cbd8038
54 changes: 36 additions & 18 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensordict.tensordict import TensorDictBase
from torchrl._utils import seed_generator
from torchrl.envs import EnvBase
from torchrl.data import CompositeSpec


# Specified for test_utils.py
Expand Down Expand Up @@ -63,26 +64,43 @@ def _test_fake_tensordict(env: EnvBase):


def _check_dtype(key, value, obs_spec, input_spec):
if isinstance(value, TensorDictBase) and key == "next":
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec, input_spec=None)
elif isinstance(value, TensorDictBase) and key in obs_spec.keys():
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None)
elif isinstance(value, TensorDictBase) and key in input_spec.keys():
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value)
return
else:
if obs_spec is not None and key in obs_spec.keys():
assert (
obs_spec[key].dtype is value.dtype
), f"{obs_spec[key].dtype} vs {value.dtype} for {key}"
elif input_spec is not None and key in input_spec.keys():
assert (
input_spec[key].dtype is value.dtype
), f"{input_spec[key].dtype} vs {value.dtype} for {key}"
else:
assert key in {"done", "reward"}, (key, obs_spec, input_spec)
raise KeyError(key)
#
# if isinstance(value, TensorDictBase) and key == "next":
# for _key, _value in value.items():
# _check_dtype(_key, _value, obs_spec, input_spec=input_spec)
# elif isinstance(value, TensorDictBase) and isinstance(obs_spec, CompositeSpec) and key in obs_spec.keys():
# for _key, _value in value.items():
# _check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None)
# elif isinstance(value, TensorDictBase) and isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
# for _key, _value in value.items():
# _check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
# else:
# if isinstance(obs_spec, CompositeSpec) and key in obs_spec.keys():
# assert (
# obs_spec[key].dtype is value.dtype
# ), f"{obs_spec[key].dtype} vs {value.dtype} for {key}"
# assert obs_spec[key].is_in(value)
# elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
# assert (
# input_spec[key].dtype is value.dtype
# ), f"{input_spec[key].dtype} vs {value.dtype} for {key}"
# assert input_spec[key].is_in(value)
# else:
# assert key in {"done", "reward"}, (key, value, obs_spec, input_spec)


# Decorator to retry upon certain Exceptions.
Expand Down