Skip to content

Commit

Permalink
[Feature] Port test_fake_tensordict to torchrl (pytorch#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 8, 2022
1 parent 36d2018 commit bb3c271
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 46 deletions.
12 changes: 11 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,20 @@ It is also possible to reset some but not all of the environments:
is_shared=True)
A note on performance: launching a :obj:`ParallelEnv` can take quite some time
*A note on performance*: launching a :obj:`ParallelEnv` can take quite some time
as it requires to launch as many python instances as there are processes. Due to
the time that it takes to run :obj:`import torch` (and other imports), starting the
parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow.
Once the environment is launched, a great speedup should be observed.

Another thing to take in consideration is that :obj:`ParallelEnv`s (as well as data collectors)
will create data buffers based on the environment specs to pass data from one process
to another. This means that a misspecified spec (input, observation or reward) will
cause a breakage at runtime as the data can't be written on the preallocated buffer.
In general, an environment should be tested using the :obj:`check_env_specs`
test function before being used in a :obj:`ParallelEnv`. This function will raise
an assertion error whenever the preallocated buffer and the collected data mismatch.

We also offer the :obj:`SerialEnv` class that enjoys the exact same API but is executed
serially. This is mostly useful for testing purposes, when one wants to assess the
behaviour of a :obj:`ParallelEnv` without launching the subprocesses.
Expand Down Expand Up @@ -210,6 +218,7 @@ in the environment. The keys to be included in this inverse transform are passed
TensorDictPrimer
R3MTransform
VIPTransform
VIPRewardTransform

Helpers
-------
Expand All @@ -223,6 +232,7 @@ Helpers
get_available_libraries
set_exploration_mode
exploration_mode
check_env_specs

Domain-specific
---------------
Expand Down
40 changes: 0 additions & 40 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytest
import torch.cuda
from torchrl._utils import implement_for, seed_generator
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _has_gym

# Specified for test_utils.py
Expand Down Expand Up @@ -70,45 +69,6 @@ def generate_seeds(seed, repeat):
return seeds


def _test_fake_tensordict(env: EnvBase):
fake_tensordict = env.fake_tensordict().flatten_keys(".")
real_tensordict = env.rollout(3).flatten_keys(".")

keys1 = set(fake_tensordict.keys())
keys2 = set(real_tensordict.keys())
assert keys1 == keys2
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()
assert (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
== real_tensordict.apply(lambda x: torch.zeros_like(x))
).all()
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

# test dtypes
for key, value in real_tensordict.unflatten_keys(".").items():
_check_dtype(key, value, env.observation_spec, env.input_spec)


def _check_dtype(key, value, obs_spec, input_spec):
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value), (input_spec[key], value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value), (input_spec[key], value)
return
else:
raise KeyError(key)


# Decorator to retry upon certain Exceptions.
def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
def deco_retry(f):
Expand Down
10 changes: 5 additions & 5 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
from _utils_internal import (
_test_fake_tensordict,
get_available_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand All @@ -25,6 +24,7 @@
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.utils import check_env_specs

if _has_gym:
import gym
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
from_pixels=from_pixels,
pixels_only=pixels_only,
)
_test_fake_tensordict(env)
check_env_specs(env)


@implement_for("gym", None, "0.26")
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only):
from_pixels=from_pixels,
pixels_only=pixels_only,
)
_test_fake_tensordict(env)
check_env_specs(env)


@pytest.mark.skipif(
Expand Down Expand Up @@ -337,7 +337,7 @@ class TestHabitat:
def test_habitat(self, envname):
env = HabitatEnv(envname)
rollout = env.rollout(3)
_test_fake_tensordict(env)
check_env_specs(env)


@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed")
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_jumanji_batch_size(self, envname, batch_size):
def test_jumanji_spec_rollout(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env.set_seed(0)
_test_fake_tensordict(env)
check_env_specs(env)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_consistency(self, envname, batch_size):
Expand Down
50 changes: 50 additions & 0 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import pkg_resources
import torch
from tensordict.nn.probabilistic import ( # noqa
interaction_mode as exploration_mode,
set_interaction_mode as set_exploration_mode,
Expand Down Expand Up @@ -151,3 +152,52 @@ def _check_dmlab():
# "screeps": None, # https://github.com/screeps/screeps
# "ml-agents": None,
}


def check_env_specs(env):
"""Tests an environment specs against the results of short rollout.
This test function should be used as a sanity check for an env wrapped with
torchrl's EnvBase subclasses: any discrepency between the expected data and
the data collected should raise an assertion error.
A broken environment spec will likely make it impossible to use parallel
environments.
"""
fake_tensordict = env.fake_tensordict().flatten_keys(".")
real_tensordict = env.rollout(3).flatten_keys(".")

keys1 = set(fake_tensordict.keys())
keys2 = set(real_tensordict.keys())
assert keys1 == keys2
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()
assert (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
== real_tensordict.apply(lambda x: torch.zeros_like(x))
).all()
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

# test dtypes
for key, value in real_tensordict.unflatten_keys(".").items():
_check_dtype(key, value, env.observation_spec, env.input_spec)


def _check_dtype(key, value, obs_spec, input_spec):
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_dtype(_key, _value, obs_spec, input_spec)
return
elif key in input_spec.keys(yield_nesting_keys=True):
assert input_spec[key].is_in(value), (input_spec[key], value)
return
elif key in obs_spec.keys(yield_nesting_keys=True):
assert obs_spec[key].is_in(value), (input_spec[key], value)
return
else:
raise KeyError(key)

0 comments on commit bb3c271

Please sign in to comment.