Skip to content

Commit

Permalink
[Refactor] Box device (pytorch#881)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 30, 2023
1 parent c5493ec commit 336dc98
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 10 deletions.
11 changes: 11 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import os
import time
from functools import wraps
Expand Down Expand Up @@ -104,3 +105,13 @@ def dtype_fixture():
torch.set_default_dtype(torch.double)
yield dtype
torch.set_default_dtype(dtype)


@contextlib.contextmanager
def set_global_var(module, var_name, value):
old_value = getattr(module, var_name)
setattr(module, var_name, value)
try:
yield
finally:
setattr(module, var_name, old_value)
17 changes: 13 additions & 4 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices
import torchrl.data.tensor_specs
from _utils_internal import get_available_devices, set_global_var
from scipy.stats import chisquare
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -57,8 +58,11 @@ def test_discrete(cls):
ts.encode(torch.tensor([5]))
ts.encode(torch.tensor(5).numpy())
ts.encode(9)
with pytest.raises(AssertionError):
with pytest.raises(AssertionError), set_global_var(
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
):
ts.encode(torch.tensor([11])) # out of bounds
assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE
assert ts.is_in(r)
assert (ts.encode(ts.to_numpy(r)) == r).all()

Expand Down Expand Up @@ -114,10 +118,15 @@ def test_ndbounded(dtype, shape):
ts.encode(lb + torch.rand(10) * (ub - lb))
ts.encode((lb + torch.rand(10) * (ub - lb)).numpy())
assert (ts.encode(ts.to_numpy(r)) == r).all()
with pytest.raises(AssertionError):
with pytest.raises(AssertionError), set_global_var(
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
):
ts.encode(torch.rand(10) + 3) # out of bounds
with pytest.raises(AssertionError):
with pytest.raises(AssertionError), set_global_var(
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
):
ts.to_numpy(torch.rand(10) + 3) # out of bounds
assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down
30 changes: 26 additions & 4 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]

_NO_CHECK_SPEC_ENCODE = get_binary_env_var("NO_CHECK_SPEC_ENCODE")
# By default, we do not check that an obs is in the domain. THis should be done when validating the env beforehand
_CHECK_SPEC_ENCODE = get_binary_env_var("CHECK_SPEC_ENCODE")


_DEFAULT_SHAPE = torch.Size((1,))

Expand Down Expand Up @@ -108,8 +110,28 @@ def clone(self) -> DiscreteBox:
class ContinuousBox(Box):
"""A continuous box of values, in between a minimum and a maximum."""

minimum: torch.Tensor
maximum: torch.Tensor
_minimum: torch.Tensor
_maximum: torch.Tensor
device: torch.device = None

# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
@property
def minimum(self):
return self._minimum.to(self.device)

@property
def maximum(self):
return self._maximum.to(self.device)

@minimum.setter
def minimum(self, value):
self.device = value.device
self._minimum = value.cpu()

@maximum.setter
def maximum(self, value):
self.device = value.device
self._maximum = value.cpu()

def __post_init__(self):
self.minimum = self.minimum.clone()
Expand Down Expand Up @@ -257,7 +279,7 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
f"Shape mismatch: the value has shape {val.shape} which "
f"is incompatible with the spec shape {self.shape}."
)
if not _NO_CHECK_SPEC_ENCODE:
if _CHECK_SPEC_ENCODE:
self.assert_is_in(val)
return val

Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta):
- run_type_checks (bool): if True, the observation and reward dtypes
will be compared against their respective spec and an exception
will be raised if they don't match.
Defaults to False.
Methods:
step (TensorDictBase -> TensorDictBase): step in the environment
Expand All @@ -226,7 +227,7 @@ def __init__(
device: DEVICE_TYPING = "cpu",
dtype: Optional[Union[torch.dtype, np.dtype]] = None,
batch_size: Optional[torch.Size] = None,
run_type_checks: bool = True,
run_type_checks: bool = False,
):
super().__init__()
if device is not None:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,7 +2738,9 @@ def transform_observation_spec(
dtype=torch.int64,
device=observation_spec.device,
)
observation_spec["step_count"].space.minimum = 0
observation_spec["step_count"].space.minimum = (
observation_spec["step_count"].space.minimum * 0
)
return observation_spec


Expand Down

0 comments on commit 336dc98

Please sign in to comment.