Skip to content

Commit

Permalink
[Test] Remove import of test class (pytorch#1549)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Sep 20, 2023
1 parent bf02209 commit f13cd77
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 100 deletions.
88 changes: 87 additions & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.cuda

from tensordict import tensorclass
from tensordict import tensorclass, TensorDict
from torchrl._utils import implement_for, seed_generator

from torchrl.envs import MultiThreadedEnv, ObservationNorm
Expand Down Expand Up @@ -347,3 +347,89 @@ def rollout_consistency_assertion(
assert (
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1)


def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
index_batch_size = (0,) * (len(td.batch_size) - 1)

# Check done and reset for root
observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1
next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "done"][observation_is_max]).all()
assert (~td["next", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all()
# Obs after not done is previous obs
assert (
td["observation"][index_batch_size][1:][~next_is_done]
== td["next", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for root
action_is_count = td["action"].long().argmax(-1).to(torch.bool)
assert (
td["next", "observation"][action_is_count]
== td["observation"][action_is_count] + 1
).all()
assert (td["next", "reward"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for root
assert (
td["next", "observation"][~action_is_count]
== td["observation"][~action_is_count]
).all()
assert (td["next", "reward"][~action_is_count] == 0).all()

# Check done and reset for nested_1
observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1
next_is_done = td["next", "nested_1", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "nested_1", "done"][observation_is_max]).all()
assert (~td["next", "nested_1", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (
td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0
).all()
# Obs after not done is previous obs
assert (
td["nested_1", "observation"][index_batch_size][1:][~next_is_done]
== td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for nested_1
action_is_count = td["nested_1"]["action"].to(torch.bool)
assert (
td["next", "nested_1", "observation"][action_is_count]
== td["nested_1", "observation"][action_is_count] + 1
).all()
assert (td["next", "nested_1", "gift"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for nested_1
assert (
td["next", "nested_1", "observation"][~action_is_count]
== td["nested_1", "observation"][~action_is_count]
).all()
assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all()

# Check done and reset for nested_2
observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1
next_is_done = td["next", "nested_2", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "nested_2", "done"][observation_is_max]).all()
assert (~td["next", "nested_2", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (
td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0
).all()
# Obs after not done is previous obs
assert (
td["nested_2", "observation"][index_batch_size][1:][~next_is_done]
== td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for nested_2
action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool)
assert (
td["next", "nested_2", "observation"][action_is_count]
== td["nested_2", "observation"][action_is_count] + 1
).all()
assert (td["next", "nested_2", "reward"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for nested_2
assert (
td["next", "nested_2", "observation"][~action_is_count]
== td["nested_2", "observation"][~action_is_count]
).all()
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()
10 changes: 7 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import numpy as np
import pytest
import torch
from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED
from _utils_internal import (
check_rollout_consistency_multikey_env,
generate_seeds,
PENDULUM_VERSIONED,
PONG_VERSIONED,
)
from mocking_classes import (
ContinuousActionVecMockEnv,
CountingBatchedEnv,
Expand All @@ -29,7 +34,6 @@
from tensordict.nn import TensorDictModule
from tensordict.tensordict import assert_allclose_td, TensorDict

from test_env import TestMultiKeyEnvs
from torch import nn
from torchrl._utils import prod, seed_generator
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
Expand Down Expand Up @@ -1582,7 +1586,7 @@ def test_collector(self, batch_size, frames_per_batch, max_steps, seed=1):
ccollector.shutdown()
for done_key in env.done_keys:
assert _replace_last(done_key, "_reset") not in _td.keys(True, True)
TestMultiKeyEnvs.check_rollout_consistency(_td, max_steps=max_steps)
check_rollout_consistency_multikey_env(_td, max_steps=max_steps)

def test_multi_collector_consistency(
self, seed=1, frames_per_batch=20, batch_dim=10
Expand Down
99 changes: 3 additions & 96 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from _utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand Down Expand Up @@ -2019,100 +2020,6 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):

@pytest.mark.parametrize("seed", [0])
class TestMultiKeyEnvs:
@staticmethod
def check_rollout_consistency(td: TensorDict, max_steps: int):
index_batch_size = (0,) * (len(td.batch_size) - 1)

# Check done and reset for root
observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1
next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "done"][observation_is_max]).all()
assert (~td["next", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all()
# Obs after not done is previous obs
assert (
td["observation"][index_batch_size][1:][~next_is_done]
== td["next", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for root
action_is_count = td["action"].long().argmax(-1).to(torch.bool)
assert (
td["next", "observation"][action_is_count]
== td["observation"][action_is_count] + 1
).all()
assert (td["next", "reward"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for root
assert (
td["next", "observation"][~action_is_count]
== td["observation"][~action_is_count]
).all()
assert (td["next", "reward"][~action_is_count] == 0).all()

# Check done and reset for nested_1
observation_is_max = (
td["next", "nested_1", "observation"][..., 0] == max_steps + 1
)
next_is_done = td["next", "nested_1", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "nested_1", "done"][observation_is_max]).all()
assert (~td["next", "nested_1", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (
td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0
).all()
# Obs after not done is previous obs
assert (
td["nested_1", "observation"][index_batch_size][1:][~next_is_done]
== td["next", "nested_1", "observation"][index_batch_size][:-1][
~next_is_done
]
).all()
# Check observation and reward update with count action for nested_1
action_is_count = td["nested_1"]["action"].to(torch.bool)
assert (
td["next", "nested_1", "observation"][action_is_count]
== td["nested_1", "observation"][action_is_count] + 1
).all()
assert (td["next", "nested_1", "gift"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for nested_1
assert (
td["next", "nested_1", "observation"][~action_is_count]
== td["nested_1", "observation"][~action_is_count]
).all()
assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all()

# Check done and reset for nested_2
observation_is_max = (
td["next", "nested_2", "observation"][..., 0] == max_steps + 1
)
next_is_done = td["next", "nested_2", "done"][index_batch_size][:-1].squeeze(-1)
assert (td["next", "nested_2", "done"][observation_is_max]).all()
assert (~td["next", "nested_2", "done"][~observation_is_max]).all()
# Obs after done is 0
assert (
td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0
).all()
# Obs after not done is previous obs
assert (
td["nested_2", "observation"][index_batch_size][1:][~next_is_done]
== td["next", "nested_2", "observation"][index_batch_size][:-1][
~next_is_done
]
).all()
# Check observation and reward update with count action for nested_2
action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool)
assert (
td["next", "nested_2", "observation"][action_is_count]
== td["nested_2", "observation"][action_is_count] + 1
).all()
assert (td["next", "nested_2", "reward"][action_is_count] == 1).all()
# Check observation and reward do not update with no-count action for nested_2
assert (
td["next", "nested_2", "observation"][~action_is_count]
== td["nested_2", "observation"][~action_is_count]
).all()
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()

@pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)])
@pytest.mark.parametrize("rollout_steps", [1, 5])
@pytest.mark.parametrize("max_steps", [2, 5])
Expand All @@ -2121,7 +2028,7 @@ def test_rollout(self, batch_size, rollout_steps, max_steps, seed):
policy = MultiKeyCountingEnvPolicy(full_action_spec=env.action_spec)
td = env.rollout(rollout_steps, policy=policy)
torch.manual_seed(seed)
self.check_rollout_consistency(td, max_steps=max_steps)
check_rollout_consistency_multikey_env(td, max_steps=max_steps)

@pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)])
@pytest.mark.parametrize("rollout_steps", [5])
Expand All @@ -2148,7 +2055,7 @@ def test_parallel(
rollout_steps,
policy=policy,
)
self.check_rollout_consistency(td, max_steps=max_steps)
check_rollout_consistency_multikey_env(td, max_steps=max_steps)


@pytest.mark.parametrize(
Expand Down

0 comments on commit f13cd77

Please sign in to comment.