Skip to content

Commit

Permalink
[Refactor] Binary spec inherits from discrete spec (pytorch#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Mar 23, 2023
1 parent 929d07f commit 893cffc
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 121 deletions.
1 change: 1 addition & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,7 @@ def test_to_numpy(self, shape, stack_dim):
c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
c = torch.stack([c1, c2], stack_dim)
torch.manual_seed(0)

shape = list(shape)
shape.insert(stack_dim, 2)
Expand Down
207 changes: 86 additions & 121 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,127 +1357,6 @@ def expand(self, *shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)


@dataclass(repr=False)
class BinaryDiscreteTensorSpec(TensorSpec):
"""A binary discrete tensor spec.
Args:
n (int): length of the binary vector.
shape (torch.Size, optional): total shape of the sampled tensors.
If provided, the last dimension must match n.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.
Examples:
>>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
>>> print(spec.zero())
"""

shape: torch.Size
space: BinaryBox
device: torch.device = torch.device("cpu")
dtype: torch.dtype = torch.float
domain: str = ""

# SPEC_HANDLED_FUNCTIONS = {}

def __init__(
self,
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.long,
):
dtype, device = _default_dtype_and_device(dtype, device)
box = BinaryBox(n)
if shape is None or not len(shape):
shape = torch.Size((n,))
else:
shape = torch.Size(shape)
if shape[-1] != box.n:
raise ValueError(
f"The last value of the shape must match n for transform of type {self.__class__}. "
f"Got n={box.n} and shape={shape}."
)

super().__init__(shape, box, device, dtype, domain="discrete")

def rand(self, shape=None) -> torch.Tensor:
if shape is None:
shape = torch.Size([])
shape = [*shape, *self.shape]
return torch.zeros(shape, device=self.device, dtype=self.dtype).bernoulli_()

def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
if not isinstance(index, torch.Tensor):
raise ValueError(
f"Only tensors are allowed for indexing using"
f" {self.__class__.__name__}.index(...)"
)
index = index.nonzero().squeeze()
index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1]))
return tensor_to_index.gather(-1, index)

def is_in(self, val: torch.Tensor) -> bool:
return ((val == 0) | (val == 1)).all()

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(val < 0 for val in shape):
raise ValueError(
f"{self.__class__.__name__}.extend does not support negative shapes."
)
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
raise ValueError(
f"The last {self.ndim} of the extended shape must match the"
f"shape of the CompositeSpec in CompositeSpec.extend."
)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def squeeze(self, dim: int | None = None):
if self.shape[-1] == 1 and dim in (len(self.shape), -1, None):
raise ValueError(
"Final dimension of BinaryDiscreteTensorSpec must remain unchanged"
)
shape = _squeezed_shape(self.shape, dim)
if shape is None:
return self
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def unsqueeze(self, dim: int):
if dim in (len(self.shape), -1):
raise ValueError(
"Final dimension of BinaryDiscreteTensorSpec must remain unchanged"
)
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(
n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype
)

def clone(self) -> CompositeSpec:
return self.__class__(
n=self.space.n, shape=self.shape, device=self.device, dtype=self.dtype
)


@dataclass(repr=False)
class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
"""A concatenation of one-hot discrete tensor spec.
Expand Down Expand Up @@ -1844,6 +1723,92 @@ def clone(self) -> CompositeSpec:
)


@dataclass(repr=False)
class BinaryDiscreteTensorSpec(DiscreteTensorSpec):
"""A binary discrete tensor spec.
Args:
n (int): length of the binary vector.
shape (torch.Size, optional): total shape of the sampled tensors.
If provided, the last dimension must match n.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.
Examples:
>>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
>>> print(spec.zero())
"""

def __init__(
self,
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.long,
):
if shape is None or not len(shape):
shape = torch.Size((n,))
else:
shape = torch.Size(shape)
if shape[-1] != n:
raise ValueError(
f"The last value of the shape must match n for spec {self.__class__}. "
f"Got n={n} and shape={shape}."
)
super().__init__(n=2, shape=shape, device=device, dtype=dtype)

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(val < 0 for val in shape):
raise ValueError(
f"{self.__class__.__name__}.extend does not support negative shapes."
)
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
raise ValueError(
f"The last {self.ndim} of the extended shape must match the"
f"shape of the CompositeSpec in CompositeSpec.extend."
)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def squeeze(self, dim=None):
shape = _squeezed_shape(self.shape, dim)
if shape is None:
return self
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def unsqueeze(self, dim: int):
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(
n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype
)

def clone(self) -> CompositeSpec:
return self.__class__(
n=self.shape[-1],
shape=self.shape,
device=self.device,
dtype=self.dtype,
)


@dataclass(repr=False)
class MultiDiscreteTensorSpec(DiscreteTensorSpec):
"""A concatenation of discrete tensor spec.
Expand Down

0 comments on commit 893cffc

Please sign in to comment.