Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Jumanji envs #674

Merged
merged 49 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2f1fac8
init
vmoens Nov 2, 2022
df4760d
[jumanji] add support for seeding and batch_size
Nov 7, 2022
e6f4b96
init
vmoens Nov 7, 2022
294dbf0
[Feature] Nested composite spec (#654)
vmoens Nov 7, 2022
d8bdb05
data conversion between Jumanji and TorchRL
Nov 8, 2022
080fb47
Merge branch 'main' into refactor_next_bis
vmoens Nov 8, 2022
fce326b
[Feature] Move `transform.forward` to `transform.step` (#660)
vmoens Nov 10, 2022
d569512
feedback from vmoens
Nov 11, 2022
428c1d3
linter fix
Nov 11, 2022
b3b36cd
amend
vmoens Nov 11, 2022
16ef495
amend
vmoens Nov 11, 2022
47da10f
fix dtype
Nov 11, 2022
4617bb6
amend
vmoens Nov 11, 2022
a466faf
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 11, 2022
ed9ccbc
fixing key names
vmoens Nov 11, 2022
b6902ae
fixing key names
vmoens Nov 11, 2022
aa02748
fix tensor shape
Nov 11, 2022
9765655
Merge branch 'keyname_fix' into jumanji
vmoens Nov 11, 2022
8396a32
tests workflow
vmoens Nov 11, 2022
a9d3b2f
Merge branch 'keyname_fix' into refactor_next_bis
vmoens Nov 11, 2022
8431af2
Merge branch 'refactor_next_bis', remote-tracking branch 'origin/refa…
vmoens Nov 11, 2022
710cfc1
t checkout Merge branch 'main' of github.com:pytorch/rl
vmoens Nov 11, 2022
3d75f9d
Merge remote-tracking branch 'yingchenlin/jumanji' into jumanji
vmoens Nov 11, 2022
3837d3c
Merge branch 'main' into jumanji
vmoens Nov 11, 2022
f905f27
Merge remote-tracking branch 'origin/main' into jumanji
vmoens Nov 11, 2022
b31f588
bug fix
Nov 14, 2022
9afb313
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
485fc97
Merge branch 'main' into jumanji
vmoens Nov 15, 2022
3220b0e
amend
vmoens Nov 15, 2022
01683c9
Merge branch 'main' into jumanji
vmoens Nov 16, 2022
73bb98b
amend
vmoens Nov 16, 2022
3513800
amend
vmoens Nov 16, 2022
00d85e5
cleanup
vmoens Nov 16, 2022
0c9cd8a
cleanup
vmoens Nov 16, 2022
6dc0a94
bf
vmoens Nov 18, 2022
6249102
bf
vmoens Nov 18, 2022
e242a1f
bf
vmoens Nov 18, 2022
704fd80
bf
vmoens Nov 18, 2022
b887286
lint
vmoens Nov 18, 2022
037c4f4
Merge branch 'main' into jumanji
vmoens Nov 18, 2022
020404e
amend
vmoens Nov 18, 2022
51757ba
refactor
Nov 18, 2022
b14a029
Merge branch 'jumanji' of github.com:yingchenlin/rl into jumanji
Nov 18, 2022
a6bd1f5
refactor
Nov 18, 2022
b58ee1b
refactor
Nov 19, 2022
60f0bec
amend
vmoens Nov 19, 2022
d845598
amend
vmoens Nov 19, 2022
24c293f
amend
vmoens Nov 19, 2022
4a56a86
amend
vmoens Nov 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feedback from vmoens
  • Loading branch information
Ying-Chen Lin committed Nov 11, 2022
commit d569512198d5b98acbb1c0bcf9b61abd6b48c16f
18 changes: 15 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,9 @@ def test_habitat(self, envname):


@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed")
@pytest.mark.parametrize("envname", ["Snake-6x6-v0", "TSP50-v0"])
class TestJumanji:
@pytest.mark.parametrize("envname", ["Snake-6x6-v0"])

def test_jumanji_seeding(self, envname):
final_seed = []
tdreset = []
Expand All @@ -341,17 +342,28 @@ def test_jumanji_seeding(self, envname):
assert_allclose_td(*tdreset)
assert_allclose_td(*tdrollout)

@pytest.mark.parametrize("batch_size", [(), (2,), (2, 3)])
@pytest.mark.parametrize("envname", ["Snake-6x6-v0"])
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_batch_size(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
env.close()
del env
assert tdreset.batch_size == batch_size
assert tdrollout.batch_size[:-1] == batch_size

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_spec_rollout(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env.set_seed(0)
tdrollout = env.rollout(max_steps=50)
fake_td = (
env.fake_tensordict().unsqueeze(-1).expand(*tdrollout.shape).contiguous()
)
tdrollout.zero_()
assert (tdrollout == fake_td).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
70 changes: 36 additions & 34 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _jumanji_to_torchrl_spec_transform(
)
return CompositeSpec(**new_spec)
else:
raise NotImplementedError(type(spec))
raise TypeError(f"Unsupported spec type {type(spec)}")


def _jumanji_to_torchrl_obs_spec_transform(
Expand All @@ -91,7 +91,7 @@ def _jumanji_to_torchrl_obs_spec_transform(
elif isinstance(spec, jumanji.specs.Spec):
return CompositeSpec(**{f"next_{k}": v for k, v in new_spec.items()})
else:
raise NotImplementedError(type(spec))
raise TypeError(f"Unsupported spec type {type(spec)}")


def _jumanji_to_torchrl_state_spec_transform(
Expand Down Expand Up @@ -119,7 +119,7 @@ def _jumanji_to_torchrl_state_spec_transform(
}
)
else:
raise NotImplementedError(type(state))
raise TypeError(f"Unsupported state type {type(state)}")


def _jumanji_to_torchrl_input_spec_transform(
Expand All @@ -130,13 +130,14 @@ def _jumanji_to_torchrl_input_spec_transform(
categorical_action_encoding: bool = True,
) -> TensorSpec:
state_dict = _jumanji_to_torchrl_data_transform(state, device=device)
input_spec = _jumanji_to_torchrl_state_spec_transform(
state_dict, dtype, device, categorical_action_encoding
)
action_spec = _jumanji_to_torchrl_spec_transform(
action_spec, dtype, device, categorical_action_encoding
input_spec = CompositeSpec(
state=_jumanji_to_torchrl_state_spec_transform(
state_dict, dtype, device, categorical_action_encoding
),
action=_jumanji_to_torchrl_spec_transform(
action_spec, dtype, device, categorical_action_encoding
)
)
input_spec["action"] = action_spec
return input_spec


Expand All @@ -151,17 +152,18 @@ def _jumanji_to_torchrl_data_transform(val, device):
if val.dtype == np.uint64:
val = val.astype(np.int64)
return torch.tensor(val, device=device)
if isinstance(val, tuple) and hasattr(val, "_fields"): # named tuples
elif isinstance(val, tuple) and hasattr(val, "_fields"): # named tuples
return {
k: _jumanji_to_torchrl_data_transform(v, device=device)
for k, v in zip(val._fields, val)
}
if hasattr(val, "__dict__"):
elif hasattr(val, "__dict__"):
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return {
k: _jumanji_to_torchrl_data_transform(v, device=device)
for k, v in val.__dict__.items()
}
raise TypeError(f"Unsupported data type {type(val)}")
else:
raise TypeError(f"Unsupported data type {type(val)}")


def _torchrl_to_jumanji_state_transform(tensordict: TensorDict, env):
vmoens marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -287,9 +289,8 @@ def _build_env(

def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821
# generate a sample state object to build state spec from.
seed = int.from_bytes(np.random.bytes(7), byteorder="big", signed=False)
self.set_seed(seed)
state, _ = env.reset(self.key)
key = jax.random.PRNGKey(0)
state, _ = env.reset(key)

self._input_spec = _jumanji_to_torchrl_input_spec_transform(
env.action_spec(), state, device=self.device
Expand All @@ -308,14 +309,23 @@ def _check_kwargs(self, kwargs: Dict):
if not isinstance(env, (jumanji.env.Environment,)):
raise TypeError("env is not of type 'jumanji.env.Environment'.")

def _init_env(self) -> Optional[int]:
def _init_env(self):
pass

def _set_seed(self, seed):
if seed is None:
raise Exception("Jumanji requires an integer seed.")
self.key = jax.random.PRNGKey(seed)

def read_state(self, state):
state = _jumanji_to_torchrl_data_transform(state, device=self.device)
state = self.input_spec["state"].encode(state)
return state

def read_obs(self, obs):
obs = _jumanji_to_torchrl_data_transform(obs, device=self.device)
return super().read_obs(obs)

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

state = _torchrl_to_jumanji_state_transform(tensordict.get("state"), self._env)
Expand All @@ -328,16 +338,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
state = self._reshape(state)
timestep = self._reshape(timestep)

state_dict = _jumanji_to_torchrl_data_transform(state, device=self.device)
obs_dict = self.read_obs(
_jumanji_to_torchrl_data_transform(timestep.observation, device=self.device)
)
reward = self.read_reward(
reward,
_jumanji_to_torchrl_data_transform(timestep.reward, device=self.device),
)
done = _jumanji_to_torchrl_data_transform(
timestep.step_type == self.lib.types.StepType.LAST, device=self.device
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
reward = self.read_reward(reward, np.asarray(timestep.reward))
done = torch.tensor(
np.asarray(timestep.step_type == self.lib.types.StepType.LAST)
)

self._is_done = done
Expand All @@ -349,7 +354,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
)
tensordict_out.set("reward", reward)
tensordict_out.set("done", done)
tensordict_out.set("state", state_dict)
tensordict_out["state"] = state_dict

return tensordict_out

Expand All @@ -362,11 +367,9 @@ def _reset(
state = self._reshape(state)
timestep = self._reshape(timestep)

state_dict = _jumanji_to_torchrl_data_transform(state, device=self.device)
obs_dict = self.read_obs(
_jumanji_to_torchrl_data_transform(timestep.observation, device=self.device)
)
done = torch.zeros(self.batch_size, dtype=torch.bool, device=self.device)
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
done = torch.zeros(self.batch_size, dtype=torch.bool)

self._is_done = done

Expand All @@ -376,7 +379,7 @@ def _reset(
device=self.device,
)
tensordict_out.set("done", done)
tensordict_out.set("state", state_dict)
tensordict_out["state"] = state_dict

return tensordict_out

Expand All @@ -397,7 +400,6 @@ class JumanjiEnv(JumanjiWrapper):
>>> td = env.rand_step()
>>> print(td)
>>> print(env.available_envs)

"""

def __init__(self, env_name, **kwargs):
Expand Down