Skip to content

Commit

Permalink
Writing tests for transforms (pytorch#79)
Browse files Browse the repository at this point in the history
Merging even though one test fails as it is likely due to functorch changes in the main branch
  • Loading branch information
vmoens authored Apr 20, 2022
1 parent 8042232 commit fa7bce2
Show file tree
Hide file tree
Showing 18 changed files with 734 additions and 217 deletions.
40 changes: 29 additions & 11 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def custom_prop(self):

class DiscreteActionVecMockEnv(_MockEnv):
size = 7
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([size]))
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size]))
)
action_spec = OneHotDiscreteTensorSpec(7)
reward_spec = UnboundedContinuousTensorSpec()
from_pixels = False
Expand All @@ -94,7 +96,9 @@ def _get_out_obs(self, obs):
def _reset(self, tensordict: _TensorDict) -> _TensorDict:
self.counter += 1
state = torch.zeros(self.size) + self.counter
tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.select().set(
"next_" + self.out_key, self._get_out_obs(state)
)
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
return tensordict

Expand Down Expand Up @@ -123,7 +127,9 @@ def _step(

class ContinuousActionVecMockEnv(_MockEnv):
size = 7
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([size]))
observation_spec = CompositeSpec(
next_observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size]))
)
action_spec = NdBoundedTensorSpec(-1, 1, (7,))
reward_spec = UnboundedContinuousTensorSpec()
from_pixels = False
Expand All @@ -138,15 +144,19 @@ def _get_out_obs(self, obs):

def _reset(self, tensordict: _TensorDict) -> _TensorDict:
self.counter += 1
self.step_count = 0
state = torch.zeros(self.size) + self.counter
tensordict = tensordict.select().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.select().set(
"next_" + self.out_key, self._get_out_obs(state)
)
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
return tensordict

def _step(
self,
tensordict: _TensorDict,
) -> _TensorDict:
self.step_count += 1
tensordict = tensordict.to(self.device)
a = tensordict.get("action")
assert not self.is_done, "trying to execute step in done env"
Expand Down Expand Up @@ -186,12 +196,14 @@ def __call__(self, tensordict):


class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv):
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7]))
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7]))
)
action_spec = OneHotDiscreteTensorSpec(7)
reward_spec = UnboundedContinuousTensorSpec()
from_pixels = True

out_key = "observation_pixels"
out_key = "pixels"

def _get_out_obs(self, obs):
obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0)
Expand All @@ -202,7 +214,9 @@ def _get_in_obs(self, obs):


class DiscreteActionConvMockEnvNumpy(DiscreteActionConvMockEnv):
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3]))
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3]))
)
from_pixels = True

def _get_out_obs(self, obs):
Expand All @@ -218,12 +232,14 @@ def _obs_step(self, obs, a):


class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv):
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7]))
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7]))
)
action_spec = NdBoundedTensorSpec(-1, 1, (7,))
reward_spec = UnboundedContinuousTensorSpec()
from_pixels = True

out_key = "observation_pixels"
out_key = "pixels"

def _get_out_obs(self, obs):
obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0)
Expand All @@ -234,7 +250,9 @@ def _get_in_obs(self, obs):


class ContinuousActionConvMockEnvNumpy(ContinuousActionConvMockEnv):
observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3]))
observation_spec = CompositeSpec(
next_pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3]))
)
from_pixels = True

def _get_out_obs(self, obs):
Expand All @@ -250,7 +268,7 @@ def _obs_step(self, obs, a):


class DiscreteActionConvPolicy(DiscreteActionVecPolicy):
in_keys = ["observation_pixels"]
in_keys = ["pixels"]
out_keys = ["action"]

def _get_in_obs(self, tensordict):
Expand Down
5 changes: 2 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,8 @@ def test_excluded_keys(collector_class, exclude):
pytest.skip("defining _exclude_private_keys is not possible")
make_env = lambda: ContinuousActionVecMockEnv()
dummy_env = make_env()
policy_module = nn.Linear(
dummy_env.observation_spec.shape[-1], dummy_env.action_spec.shape[-1]
)
obs_spec = dummy_env.observation_spec["next_observation"]
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
policy = Actor(policy_module, spec=dummy_env.action_spec)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)

Expand Down
8 changes: 7 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_env_seed(env_name, frame_skip, seed=0):
assert_allclose_td(td0a, td0c.select(*td0a.keys()))
with pytest.raises(AssertionError):
assert_allclose_td(td1a, td1c)
env.close()


@pytest.mark.skipif(not _has_gym, reason="no gym")
Expand Down Expand Up @@ -144,6 +145,7 @@ def test_rollout(env_name, frame_skip, seed=0):
rollout3 = env.rollout(n_steps=100)
with pytest.raises(AssertionError):
assert_allclose_td(rollout1, rollout3)
env.close()


def _make_envs(env_name, frame_skip, transformed, N):
Expand Down Expand Up @@ -225,7 +227,7 @@ def test_parallel_env_seed(env_name, frame_skip, transformed):
torch.manual_seed(0)

td_serial = env_serial.rollout(n_steps=10, auto_reset=False).contiguous()
key = "observation_pixels" if "observation_pixels" in td_serial else "observation"
key = "pixels" if "pixels" in td_serial else "observation"
torch.testing.assert_allclose(
td_serial[:, 0].get("next_" + key), td_serial[:, 1].get(key)
)
Expand All @@ -245,6 +247,9 @@ def test_parallel_env_seed(env_name, frame_skip, transformed):
assert_allclose_td(td_serial[:, 0], td_parallel[:, 0]) # first step
assert_allclose_td(td_serial[:, 1], td_parallel[:, 1]) # second step
assert_allclose_td(td_serial, td_parallel)
env_parallel.close()
env_serial.close()
env0.close()


@pytest.mark.skipif(not _has_gym, reason="no gym")
Expand All @@ -259,6 +264,7 @@ def test_parallel_env_shutdown():
assert env.is_closed
env.reset()
assert not env.is_closed
env.close()


@pytest.mark.parametrize("parallel", [True, False])
Expand Down
10 changes: 5 additions & 5 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_dqn_maker(device, noisy, distributional, from_pixels):

expected_keys = ["done", "action", "action_value"]
if from_pixels:
expected_keys += ["observation_pixels"]
expected_keys += ["pixels"]
else:
expected_keys += ["observation_vector"]

Expand Down Expand Up @@ -104,7 +104,7 @@ def test_ddpg_maker(device, from_pixels):
actor(td)
expected_keys = ["done", "action"]
if from_pixels:
expected_keys += ["observation_pixels"]
expected_keys += ["pixels"]
else:
expected_keys += ["observation_vector"]

Expand Down Expand Up @@ -156,7 +156,7 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde):
actor = actor_value.get_policy_operator()
expected_keys = [
"done",
"observation_pixels" if len(from_pixels) else "observation_vector",
"pixels" if len(from_pixels) else "observation_vector",
"action_dist_param_0",
"action_dist_param_1",
"action",
Expand All @@ -178,7 +178,7 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde):
value = actor_value.get_value_operator()
expected_keys = [
"done",
"observation_pixels" if len(from_pixels) else "observation_vector",
"pixels" if len(from_pixels) else "observation_vector",
"state_value",
]
if shared_mapping:
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels):
actor(td_clone)
expected_keys = [
"done",
"observation_pixels" if len(from_pixels) else "observation_vector",
"pixels" if len(from_pixels) else "observation_vector",
"action",
]
if len(gsde):
Expand Down
4 changes: 1 addition & 3 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@


@pytest.mark.parametrize("n", range(13))
@pytest.mark.parametrize(
"key", ["observation", "observation_pixels", "observation_whatever"]
)
@pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"])
def test_multistep(n, key, T=11):
torch.manual_seed(0)

Expand Down
Loading

0 comments on commit fa7bce2

Please sign in to comment.