Skip to content

Commit

Permalink
[Feature] Dynamic specs (#2143)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 31, 2024
1 parent 765952a commit 8934ed0
Show file tree
Hide file tree
Showing 13 changed files with 1,782 additions and 398 deletions.
125 changes: 125 additions & 0 deletions examples/envs/gym_conversion_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This script gives some examples of gym environment conversion with Dict, Tuple and Sequence spaces.
"""

import gymnasium as gym
from gymnasium import spaces

from torchrl.envs import GymWrapper

action_space = spaces.Discrete(2)


class BaseEnv(gym.Env):
def step(self, action):
return self.observation_space.sample(), 1, False, False, {}

def reset(self, **kwargs):
return self.observation_space.sample(), {}


class SimpleEnv(BaseEnv):
def __init__(self):
self.observation_space = spaces.Box(-1, 1, (2,))
self.action_space = action_space


gym.register("SimpleEnv-v0", entry_point=SimpleEnv)


class SimpleEnvWithDict(BaseEnv):
def __init__(self):
self.observation_space = spaces.Dict(
obs0=spaces.Box(-1, 1, (2,)), obs1=spaces.Box(-1, 1, (3,))
)
self.action_space = action_space


gym.register("SimpleEnvWithDict-v0", entry_point=SimpleEnvWithDict)


class SimpleEnvWithTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(spaces.Box(-1, 1, (2,)), spaces.Box(-1, 1, (3,)))
)
self.action_space = action_space


gym.register("SimpleEnvWithTuple-v0", entry_point=SimpleEnvWithTuple)


class SimpleEnvWithSequence(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space


gym.register("SimpleEnvWithSequence-v0", entry_point=SimpleEnvWithSequence)


class SimpleEnvWithSequenceOfTuple(BaseEnv):
def __init__(self):
self.observation_space = spaces.Sequence(
spaces.Tuple(
(
spaces.Box(-1, 1, (2,)),
spaces.Box(-1, 1, (3,)),
)
),
# Only stack=True is currently allowed
stack=True,
)
self.action_space = action_space


gym.register(
"SimpleEnvWithSequenceOfTuple-v0", entry_point=SimpleEnvWithSequenceOfTuple
)


class SimpleEnvWithTupleOfSequences(BaseEnv):
def __init__(self):
self.observation_space = spaces.Tuple(
(
spaces.Sequence(
spaces.Box(-1, 1, (2,)),
# Only stack=True is currently allowed
stack=True,
),
spaces.Sequence(
spaces.Box(-1, 1, (3,)),
# Only stack=True is currently allowed
stack=True,
),
)
)
self.action_space = action_space


gym.register(
"SimpleEnvWithTupleOfSequences-v0", entry_point=SimpleEnvWithTupleOfSequences
)

if __name__ == "__main__":
for envname in [
"SimpleEnv",
"SimpleEnvWithDict",
"SimpleEnvWithTuple",
"SimpleEnvWithSequence",
"SimpleEnvWithSequenceOfTuple",
"SimpleEnvWithTupleOfSequences",
]:
print("\n\nEnv =", envname)
env = gym.make(envname + "-v0")
env_torchrl = GymWrapper(env)
print(env_torchrl.rollout(10, return_contiguous=False))
58 changes: 56 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,8 +1881,8 @@ def _step(self, tensordict):
lazy[obskey[1:]] += expand_right(
self.count[..., i, 0], lazy[obskey[1:]].shape
).clone()
td.update(self.output_spec["full_done_spec"].zero())
td.update(self.output_spec["full_reward_spec"].zero())
td.update(self.full_done_spec.zero())
td.update(self.full_reward_spec.zero())

assert td.batch_size == self.batch_size
return td
Expand All @@ -1896,3 +1896,57 @@ def _reset(self, tensordict=None):
reset_td.update(self.full_done_spec.zero())
assert reset_td.batch_size == self.batch_size
return reset_td


class EnvWithDynamicSpec(EnvBase):
def __init__(self, max_count=5):
super().__init__(batch_size=())
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
)
self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
self.full_done_spec = CompositeSpec(
done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
)
self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
self.count = 0
self.max_count = max_count

def _reset(self, tensordict=None):
self.count = 0
data = TensorDict(
{
"observation": torch.full(
(3, self.count + 1, 2),
self.count,
dtype=self.observation_spec["observation"].dtype,
)
}
)
data.update(self.done_spec.zero())
return data

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
done = self.count >= self.max_count
observation = TensorDict(
{
"observation": torch.full(
(3, self.count + 1, 2),
self.count,
dtype=self.observation_spec["observation"].dtype,
)
}
)
done = self.full_done_spec.zero() | done
reward = self.full_reward_spec.zero()
return observation.update(done).update(reward)

def _set_seed(self, seed: Optional[int]):
self.manual_seed = seed
return seed
122 changes: 88 additions & 34 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DiscreteActionConvPolicy,
DiscreteActionVecMockEnv,
DiscreteActionVecPolicy,
EnvWithDynamicSpec,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockSerialEnv,
Expand Down Expand Up @@ -559,17 +560,20 @@ def env_fn(seed):
total_frames=20000,
device="cpu",
)
for i, d in enumerate(collector):
if i == 0:
b1 = d
elif i == 1:
b2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)
collector.shutdown()
try:
assert collector._use_buffers
for i, d in enumerate(collector):
if i == 0:
b1 = d
elif i == 1:
b2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)
finally:
collector.shutdown()

ccollector = aSyncDataCollector(
create_env_fn=env_fn,
Expand All @@ -586,14 +590,19 @@ def env_fn(seed):
b2c = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(b1c, b2c)

assert_allclose_td(b1c, b1)
assert_allclose_td(b2c, b2)
try:
assert ccollector._use_buffers
assert d.names[-1] == "time"

with pytest.raises(AssertionError):
assert_allclose_td(b1c, b2c)

ccollector.shutdown()
assert_allclose_td(b1c, b1)
assert_allclose_td(b2c, b2)
finally:
ccollector.shutdown()
del ccollector


@pytest.mark.skipif(not _has_gym, reason="gym library is not installed")
Expand Down Expand Up @@ -789,12 +798,8 @@ def make_env(seed):


@pytest.mark.parametrize("num_env", [1, 2])
@pytest.mark.parametrize(
"env_name",
[
"vec",
],
) # 1226: for efficiency, we just test vec, not "conv"
# 1226: for efficiency, we just test vec, not "conv"
@pytest.mark.parametrize("env_name", ["vec"])
def test_collector_batch_size(
num_env, env_name, seed=100, num_workers=2, frames_per_batch=20
):
Expand Down Expand Up @@ -943,10 +948,12 @@ def env_fn(seed):
env.set_seed(seed)
rollout1b = env.rollout(policy=policy, max_steps=50, auto_reset=True)
rollout2 = env.rollout(policy=policy, max_steps=50, auto_reset=True)
assert_allclose_td(rollout1a, rollout1b)
with pytest.raises(AssertionError):
assert_allclose_td(rollout1a, rollout2)
env.close()
try:
assert_allclose_td(rollout1a, rollout1b)
with pytest.raises(AssertionError):
assert_allclose_td(rollout1a, rollout2)
finally:
env.close()

collector = SyncDataCollector(
create_env_fn=env_fn,
Expand All @@ -960,17 +967,19 @@ def env_fn(seed):
collector_iter = iter(collector)
b1 = next(collector_iter)
b2 = next(collector_iter)
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)

# if num_env == 1:
# # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension
# rollout1a = rollout1a.unsqueeze(0)
assert (
rollout1a.batch_size == b1.batch_size
), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}"
assert_allclose_td(rollout1a, b1.select(*rollout1a.keys(True, True)))
collector.shutdown()
try:
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)
assert (
rollout1a.batch_size == b1.batch_size
), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}"
assert_allclose_td(rollout1a, b1.select(*rollout1a.keys(True, True)))
finally:
collector.shutdown()


@pytest.mark.parametrize("num_env", [1, 2])
Expand Down Expand Up @@ -1072,6 +1081,7 @@ def env_fn(seed):

collector20.shutdown()
del collector20

data20 = torch.cat(data20, data1.ndim - 1)
data20 = data20[..., :max_frames_per_traj]

Expand Down Expand Up @@ -1414,6 +1424,7 @@ def env_fn(seed):
device=device,
storing_device=storing_device,
)
assert collector._use_buffers
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
collector.shutdown()
Expand Down Expand Up @@ -1826,6 +1837,7 @@ def test_set_truncated(collector_cls):
break
finally:
collector.shutdown()
del collector


class TestNestedEnvsCollector:
Expand Down Expand Up @@ -2566,6 +2578,7 @@ def test_unique_traj_sync(self, cat_results):
try:
for d in c:
buffer.extend(d)
assert c._use_buffers
traj_ids = buffer[:].get(("collector", "traj_ids"))
# check that we have as many trajs as expected (no skip)
assert traj_ids.unique().numel() == traj_ids.max() + 1
Expand All @@ -2587,6 +2600,47 @@ def test_unique_traj_sync(self, cat_results):
del c


class TestDynamicEnvs:
def test_dynamic_sync_collector(self):
env = EnvWithDynamicSpec()
policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(
env, policy, frames_per_batch=20, total_frames=100
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"

def test_dynamic_multisync_collector(self):
env = EnvWithDynamicSpec
policy = RandomPolicy(env().action_spec)
collector = MultiSyncDataCollector(
[env],
policy,
frames_per_batch=20,
total_frames=100,
use_buffers=False,
cat_results="stack",
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"

def test_dynamic_multiasync_collector(self):
env = EnvWithDynamicSpec
policy = RandomPolicy(env().action_spec)
collector = MultiaSyncDataCollector(
[env],
policy,
frames_per_batch=20,
total_frames=100,
# use_buffers=False,
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert data.names[-1] == "time"


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit 8934ed0

Please sign in to comment.