Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 29, 2024
commit 577512397366bd4d07099486b13b2c3554e265c8
125 changes: 125 additions & 0 deletions examples/envs/gym_conversion_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This script gives some examples of gym environment conversion with Dict, Tuple and Sequence spaces.
"""

import gymnasium as gym
from gymnasium import spaces

from torchrl.envs import GymWrapper

action_space = spaces.Discrete(2)


class BaseEnv(gym.Env):
def step(self, action):
return self.observation_space.sample(), 1, False, False, {}

def reset(self, **kwargs):
return self.observation_space.sample(), {}


class SimpleEnv(BaseEnv):
def __init__(self):
self.observation_space = spaces.Box(-1, 1, (2,))
self.action_space = action_space


gym.register("SimpleEnv-v0", entry_point=SimpleEnv)


class SimpleEnvWithDict(BaseEnv):
def __init__(self):
self.observation_space = spaces.Dict(
obs0=spaces.Box(-1, 1, (2,)), obs1=spaces.Box(-1, 1, (3,))
)
self.action_space = action_space


gym.register("SimpleEnvWithDict-v0", entry_point=SimpleEnvWithDict)


class SimpleEnvWithTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(spaces.Box(-1, 1, (2,)), spaces.Box(-1, 1, (3,)))
)
self.action_space = action_space


gym.register("SimpleEnvWithTuple-v0", entry_point=SimpleEnvWithTuple)


class SimpleEnvWithSequence(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space


gym.register("SimpleEnvWithSequence-v0", entry_point=SimpleEnvWithSequence)


class SimpleEnvWithSequenceOfTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Tuple(
(
spaces.Box(-1, 1, (2,)),
spaces.Box(-1, 1, (3,)),
)
),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space


gym.register(
"SimpleEnvWithSequenceOfTuple-v0", entry_point=SimpleEnvWithSequenceOfTuple
)


class SimpleEnvWithTupleOfSequences(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(
spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
),
spaces.Sequence(
spaces.Box(-1, 1, (3,)),
# Only stack=True is currently allowed
stack=True,
),
)
)
self.action_space = action_space


gym.register(
"SimpleEnvWithTupleOfSequences-v0", entry_point=SimpleEnvWithTupleOfSequences
)

if __name__ == "__main__":
for envname in [
"SimpleEnv",
"SimpleEnvWithDict",
"SimpleEnvWithTuple",
"SimpleEnvWithSequence",
"SimpleEnvWithSequenceOfTuple",
"SimpleEnvWithTupleOfSequences",
]:
print("\n\nEnv =", envname)
env = gym.make(envname + "-v0")
env_torchrl = GymWrapper(env)
print(env_torchrl.rollout(10, return_contiguous=False))
62 changes: 61 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_gym_pixel_wrapper():

_has_pytree = True
try:
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_map
except ImportError:
_has_pytree = False
IS_OSX = platform == "darwin"
Expand Down Expand Up @@ -313,6 +313,66 @@ def test_gym_spec_cast(self, categorical):
assert spec == recon
assert recon.shape == spec.shape

# @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
@pytest.mark.parametrize("order", ["tuple_seq"])
def test_gym_spec_cast_tuple_sequential(self, order):
if order == "seq_tuple":
space = gym_backend("spaces").Dict(
feature=gym_backend("spaces").Sequence(
gym_backend("spaces").Tuple(
(
gym_backend("spaces").Box(-1, 1, shape=(2, 2)),
gym_backend("spaces").Box(-1, 1, shape=(1, 2)),
)
),
)
)
elif order == "tuple_seq":
space = gym_backend("spaces").Dict(
feature=gym_backend("spaces").Tuple(
(
gym_backend("spaces").Sequence(
gym_backend("spaces").Box(-1, 1, shape=(2, 2)), stack=True
),
gym_backend("spaces").Sequence(
gym_backend("spaces").Box(-1, 1, shape=(1, 2)), stack=True
),
),
)
)
else:
raise NotImplementedError
sample = space.sample()
partial_tree_map = functools.partial(
tree_map, is_leaf=lambda x: isinstance(x, (tuple, torch.Tensor))
)

def stack_tuples(item):
if isinstance(item, tuple):
try:
return torch.stack(
[partial_tree_map(stack_tuples, x) for x in item]
)
except RuntimeError:
item = [partial_tree_map(stack_tuples, x) for x in item]
try:
return torch.nested.nested_tensor(item)
except RuntimeError:
return tuple(item)
return torch.as_tensor(item)

sample_pt = partial_tree_map(stack_tuples, sample)
# sample_pt = torch.utils._pytree.tree_map(lambda x: torch.stack(list(x)), sample_pt, is_leaf=lambda x: isinstance(x, tuple))
spec = _gym_to_torchrl_spec_transform(space)
rand = spec.rand()
assert spec.is_in(rand)
assert rand in spec
assert sample_pt in spec, (rand, sample_pt)
space_recon = _torchrl_to_gym_spec_transform(spec)
assert space_recon == space, (space_recon, space)
rand_numpy = rand.numpy()
assert space.contains(rand_numpy)

_BACKENDS = [None]
if _has_gymnasium:
_BACKENDS += ["gymnasium"]
Expand Down
12 changes: 8 additions & 4 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,10 @@ def test_discrete_conversion(n, device, shape):
assert categorical.to_one_hot_spec() == one_hot
assert one_hot.to_categorical_spec() == categorical

assert categorical.is_in(one_hot.to_categorical(one_hot.rand(shape)))
assert one_hot.is_in(categorical.to_one_hot(categorical.rand(shape)))
categorical_recon = one_hot.to_categorical(one_hot.rand(shape))
assert categorical.is_in(categorical_recon), (categorical, categorical_recon)
one_hot_recon = categorical.to_one_hot(categorical.rand(shape))
assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon)


@pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]])
Expand All @@ -338,8 +340,10 @@ def test_multi_discrete_conversion(ns, shape, device):
assert categorical.to_one_hot_spec() == one_hot
assert one_hot.to_categorical_spec() == categorical

assert categorical.is_in(one_hot.to_categorical(one_hot.rand(shape)))
assert one_hot.is_in(categorical.to_one_hot(categorical.rand(shape)))
categorical_recon = one_hot.to_categorical(one_hot.rand(shape))
assert categorical.is_in(categorical_recon), (categorical, categorical_recon)
one_hot_recon = categorical.to_one_hot(categorical.rand(shape))
assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon)


@pytest.mark.parametrize("is_complete", [True, False])
Expand Down
Loading
Loading