Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Aug 6, 2024
1 parent a00fc64 commit c4ea7aa
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
11 changes: 7 additions & 4 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -3681,7 +3682,7 @@ def test_unbounded(self):
assert isinstance(unbounded_continuous, UnboundedContinuousTensorSpec)
assert not isinstance(unbounded_continuous, UnboundedDiscreteTensorSpec)

with pytest.warns(None):
with warnings.catch_warnings():
unbounded_continuous = UnboundedContinuous()

with pytest.warns(
Expand All @@ -3694,18 +3695,20 @@ def test_unbounded(self):
assert isinstance(unbounded_discrete, UnboundedDiscreteTensorSpec)
assert not isinstance(unbounded_discrete, UnboundedContinuousTensorSpec)

with pytest.warns(None):
with warnings.catch_warnings():
unbounded_discrete = UnboundedDiscrete()

# What if we mess with dtypes?
unbounded_continuous_fake = UnboundedContinuousTensorSpec(dtype=torch.int32)
with pytest.warns(DeprecationWarning):
unbounded_continuous_fake = UnboundedContinuousTensorSpec(dtype=torch.int32)
assert isinstance(unbounded_continuous_fake, Unbounded)
assert not isinstance(unbounded_continuous_fake, UnboundedContinuous)
assert not isinstance(unbounded_continuous_fake, UnboundedContinuousTensorSpec)
assert isinstance(unbounded_continuous_fake, UnboundedDiscrete)
assert isinstance(unbounded_continuous_fake, UnboundedDiscreteTensorSpec)

unbounded_discrete_fake = UnboundedDiscreteTensorSpec(dtype=torch.float32)
with pytest.warns(DeprecationWarning):
unbounded_discrete_fake = UnboundedDiscreteTensorSpec(dtype=torch.float32)
assert isinstance(unbounded_discrete_fake, Unbounded)
assert isinstance(unbounded_discrete_fake, UnboundedContinuous)
assert isinstance(unbounded_discrete_fake, UnboundedContinuousTensorSpec)
Expand Down
10 changes: 8 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8463,8 +8463,14 @@ def test_independent_reward_specs_from_shared_env(self):
assert t2_reward_spec.space.low == -2
assert t2_reward_spec.space.high == 2

assert base_env.reward_spec.space.low == -np.inf
assert base_env.reward_spec.space.high == np.inf
assert (
base_env.reward_spec.space.low
== torch.finfo(base_env.reward_spec.dtype).low
)
assert (
base_env.reward_spec.space.high
== torch.finfo(base_env.reward_spec.dtype).high
)

def test_allow_done_after_reset(self):
base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True)
Expand Down
33 changes: 25 additions & 8 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ def __call__(cls, *args, **kwargs):
instance.__class__ = BoundedDiscrete
return instance


@dataclass(repr=False)
class Bounded(TensorSpec, metaclass=_BoundedMeta):
"""A bounded continuous tensor spec.
Expand Down Expand Up @@ -2095,6 +2096,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
dtype=self.dtype,
)


class BoundedContinuous(Bounded):
"""A specialized version of :class:`torchrl.data.Bounded` with continuous space."""

Expand All @@ -2105,9 +2107,12 @@ def __init__(
shape: Optional[Union[torch.Size, int]] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[torch.dtype, str]] = None,
domain: str="continuous",
domain: str = "continuous",
):
super().__init__(low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain)
super().__init__(
low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain
)


class BoundedDiscrete(Bounded):
"""A specialized version of :class:`torchrl.data.Bounded` with discrete space."""
Expand All @@ -2119,9 +2124,16 @@ def __init__(
shape: Optional[Union[torch.Size, int]] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[torch.dtype, str]] = None,
domain: str="discrete",
domain: str = "discrete",
):
super().__init__(low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain, )
super().__init__(
low=low,
high=high,
shape=shape,
device=device,
dtype=dtype,
domain=domain,
)


def _is_nested_list(index, notuple=False):
Expand Down Expand Up @@ -2252,6 +2264,7 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)


class _UnboundedMeta(abc.ABCMeta):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
Expand All @@ -2261,6 +2274,7 @@ def __call__(cls, *args, **kwargs):
instance.__class__ = UnboundedDiscrete
return instance


@dataclass(repr=False)
class Unbounded(TensorSpec, metaclass=_UnboundedMeta):
"""An unbounded tensor spec.
Expand Down Expand Up @@ -2419,12 +2433,12 @@ def __eq__(self, other):
return super().__eq__(other)



class UnboundedContinuous(Unbounded):
"""A specialized version of :class:`torchrl.data.Unbounded` with continuous space."""

...


class UnboundedDiscrete(Unbounded):
"""A specialized version of :class:`torchrl.data.Unbounded` with discrete space."""

Expand All @@ -2437,6 +2451,7 @@ def __init__(
):
super().__init__(shape=shape, device=device, dtype=dtype, **kwargs)


@dataclass(repr=False)
class MultiOneHot(OneHot):
"""A concatenation of one-hot discrete tensor spec.
Expand Down Expand Up @@ -5141,8 +5156,10 @@ class BinaryDiscreteTensorSpec(Binary, metaclass=_LegacySpecMeta):

...


_BoundedLegacyMeta = type("_BoundedLegacyMeta", (_LegacySpecMeta, _BoundedMeta), {})


class BoundedTensorSpec(Bounded, metaclass=_BoundedLegacyMeta):
"""Deprecated version of :class:`torchrl.data.Bounded`."""

Expand All @@ -5161,8 +5178,9 @@ def __instancecheck__(cls, instance):
)



class UnboundedContinuousTensorSpec(Unbounded, metaclass=_LegacyUnboundedContinuousMetaclass):
class UnboundedContinuousTensorSpec(
Unbounded, metaclass=_LegacyUnboundedContinuousMetaclass
):
"""Deprecated version of :class:`torchrl.data.Unbounded` with continuous space."""

...
Expand All @@ -5173,7 +5191,6 @@ def __instancecheck__(cls, instance):
return isinstance(instance, Unbounded) and instance.domain == "discrete"



_LegacyUnboundedDiscreteMetaclass = type(
"_LegacyUnboundedDiscreteMetaclass",
(_UnboundedDiscreteMetaclass, _LegacySpecMeta),
Expand Down

0 comments on commit c4ea7aa

Please sign in to comment.