Skip to content

Commit

Permalink
[Test] Check dtypes of envs (pytorch#666)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 11, 2022
1 parent ecedcf1 commit 278e9be
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 249 deletions.
30 changes: 30 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pytest
import torch.cuda
from torchrl._utils import seed_generator
from torchrl.data import CompositeSpec
from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import EnvBase


Expand Down Expand Up @@ -54,6 +56,34 @@ def _test_fake_tensordict(env: EnvBase):
for key in keys2:
assert fake_tensordict[key].shape == real_tensordict[key].shape

# test dtypes
for key, value in real_tensordict.unflatten_keys(".").items():
_check_dtype(key, value, env.observation_spec, env.input_spec)


def _check_dtype(key, value, obs_spec, input_spec):
if key.startswith("next_"):
return
if isinstance(value, TensorDictBase):
for _key, _value in value.items():
if isinstance(obs_spec, CompositeSpec) and "next_" + key in obs_spec.keys():
_check_dtype(_key, _value, obs_spec["next_" + key], input_spec=None)
elif isinstance(input_spec, CompositeSpec) and key in input_spec.keys():
_check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key])
else:
raise KeyError(f"key '{_key}' is unknown.")
else:
if obs_spec is not None and "next_" + key in obs_spec.keys():
assert (
obs_spec["next_" + key].dtype is value.dtype
), f"{obs_spec['next_' + key].dtype} vs {value.dtype} for {key}"
elif input_spec is not None and key in input_spec.keys():
assert (
input_spec[key].dtype is value.dtype
), f"{input_spec[key].dtype} vs {value.dtype} for {key}"
else:
assert key in {"done", "reward"}, (key, obs_spec, input_spec)


# Decorator to retry upon certain Exceptions.
def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
Expand Down
131 changes: 0 additions & 131 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@
MockSerialEnv,
)
from packaging import version
from scipy.stats import chisquare
from torch import nn
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
DiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
NdBoundedTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
Expand Down Expand Up @@ -917,132 +912,6 @@ def env_fn2(seed):
env2.close()


class TestSpec:
@pytest.mark.parametrize(
"action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]
)
def test_discrete_action_spec_reconstruct(self, action_spec_cls):
torch.manual_seed(0)
action_spec = action_spec_cls(10)

actions_tensors = [action_spec.rand() for _ in range(10)]
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
assert all(
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
)

actions_numpy = [int(np.random.randint(0, 10, (1,))) for a in actions_tensors]
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
assert all([(a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2)])

def test_mult_discrete_action_spec_reconstruct(self):
torch.manual_seed(0)
action_spec = MultOneHotDiscreteTensorSpec((10, 5))

actions_tensors = [action_spec.rand() for _ in range(10)]
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
assert all(
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
)

actions_numpy = [
np.concatenate(
[np.random.randint(0, 10, (1,)), np.random.randint(0, 5, (1,))], 0
)
for a in actions_tensors
]
actions_tensors = [action_spec.encode(a) for a in actions_numpy]
actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors]
assert all([(a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2)])

def test_one_hot_discrete_action_spec_rand(self):
torch.manual_seed(0)
action_spec = OneHotDiscreteTensorSpec(10)

sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)

sample_list = sample.argmax(-1)
sample_list = list([sum(sample_list == i).item() for i in range(10)])
assert chisquare(sample_list).pvalue > 0.1

sample = action_spec.to_numpy(sample)
sample = [sum(sample == i) for i in range(10)]
assert chisquare(sample).pvalue > 0.1

def test_categorical_action_spec_rand(self):
torch.manual_seed(0)
action_spec = DiscreteTensorSpec(10)

sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)

sample_list = sample[:, 0]
sample_list = list([sum(sample_list == i).item() for i in range(10)])
assert chisquare(sample_list).pvalue > 0.1

sample = action_spec.to_numpy(sample)
sample = [sum(sample == i) for i in range(10)]
assert chisquare(sample).pvalue > 0.1

def test_mult_discrete_action_spec_rand(self):
torch.manual_seed(0)
ns = (10, 5)
N = 100000
action_spec = MultOneHotDiscreteTensorSpec((10, 5))

actions_tensors = [action_spec.rand() for _ in range(10)]
actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors]
actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy]
assert all(
[(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)]
)

sample = np.stack(
[action_spec.to_numpy(action_spec.rand()) for _ in range(N)], 0
)
assert sample.shape[0] == N
assert sample.shape[1] == 2
assert sample.ndim == 2, f"found shape: {sample.shape}"

sample0 = sample[:, 0]
sample_list = list([sum(sample0 == i) for i in range(ns[0])])
assert chisquare(sample_list).pvalue > 0.1

sample1 = sample[:, 1]
sample_list = list([sum(sample1 == i) for i in range(ns[1])])
assert chisquare(sample_list).pvalue > 0.1

def test_categorical_action_spec_encode(self):
action_spec = DiscreteTensorSpec(10)

projected = action_spec.project(
torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long)
)
assert (
projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long)
).all()

projected = action_spec.project(
torch.tensor([-100.0, -1.0, 0.0, 1.0, 9.0, 10.0, 100.0], dtype=torch.float)
)
assert (
projected == torch.tensor([0, 0, 0, 1, 9, 9, 9], dtype=torch.long)
).all()

def test_bounded_rand(self):
spec = BoundedTensorSpec(-3, 3)
sample = torch.stack([spec.rand() for _ in range(100)])
assert (-3 <= sample).all() and (3 >= sample).all()

def test_ndbounded_shape(self):
spec = NdBoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5])
sample = torch.stack([spec.rand() for _ in range(100)], 0)
assert (-3 <= sample).all() and (3 >= sample).all()
assert sample.shape == torch.Size([100, 10, 5])


@pytest.mark.skipif(not _has_gym, reason="no gym")
def test_seed():
torch.manual_seed(0)
Expand Down
Loading

0 comments on commit 278e9be

Please sign in to comment.