Skip to content

Commit

Permalink
[Doc] Add documentation for masks in tensor specs (#2289)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
kurtamohler and vmoens authored Jul 11, 2024
1 parent 8e43ac8 commit f764c02
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,8 @@ class OneHotDiscreteTensorSpec(TensorSpec):
discrete outcomes are sampled from an arbitrary set, whose
elements will be mapped in a register to a series of unique
one-hot binary vectors).
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
"""

Expand Down Expand Up @@ -1368,6 +1370,25 @@ def n(self):
return self.space.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
The mask can also be set during initialization of the spec.
Args:
mask (torch.Tensor or None): boolean mask. If None, the mask is
disabled. Otherwise, the shape of the mask must be expandable to
the shape of the spec. ``False`` masks an outcome and ``True``
leaves the outcome unmasked. If all of the possible outcomes are
masked, then an error is raised when a sample is taken.
Examples:
>>> mask = torch.tensor([True, False, False])
>>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes are masked
>>> ts.rand()
tensor([[1, 0, 0],
[1, 0, 0]])
"""
if mask is not None:
try:
mask = mask.expand(self._safe_shape)
Expand Down Expand Up @@ -2516,6 +2537,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
device (str, int or torch.device, optional): device of
the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
Examples:
>>> ts = MultiOneHotDiscreteTensorSpec((3,2,3))
Expand Down Expand Up @@ -2564,6 +2587,28 @@ def __init__(
self.update_mask(mask)

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
The mask can also be set during initialization of the spec.
Args:
mask (torch.Tensor or None): boolean mask. If None, the mask is
disabled. Otherwise, the shape of the mask must be expandable to
the shape of the spec. ``False`` masks an outcome and ``True``
leaves the outcome unmasked. If all of the possible outcomes are
masked, then an error is raised when a sample is taken.
Examples:
>>> mask = torch.tensor([True, False, False,
... True, True])
>>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes for the first
>>> # one-hot group are masked, but neither of the two possible
>>> # outcomes for the second one-hot group are masked.
>>> ts.rand()
tensor([[1, 0, 0, 0, 1],
[1, 0, 0, 1, 0]])
"""
if mask is not None:
try:
mask = mask.expand(*self._safe_shape)
Expand Down Expand Up @@ -2900,6 +2945,8 @@ class DiscreteTensorSpec(TensorSpec):
shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
"""

Expand Down Expand Up @@ -2933,6 +2980,25 @@ def n(self):
return self.space.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
The mask can also be set during initialization of the spec.
Args:
mask (torch.Tensor or None): boolean mask. If None, the mask is
disabled. Otherwise, the shape of the mask must be expandable to
the shape of the equivalent one-hot spec. ``False`` masks an
outcome and ``True`` leaves the outcome unmasked. If all of the
possible outcomes are masked, then an error is raised when a
sample is taken.
Examples:
>>> mask = torch.tensor([True, False, True])
>>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask)
>>> # One of the three possible outcomes is masked
>>> ts.rand()
tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2])
"""
if mask is not None:
try:
mask = mask.expand(_remove_neg_shapes(*self.shape, self.space.n))
Expand Down Expand Up @@ -3315,6 +3381,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec):
dtype (str or torch.dtype, optional): dtype of the tensors.
remove_singleton (bool, optional): if ``True``, singleton samples (of size [1])
will be squeezed. Defaults to ``True``.
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
Examples:
>>> ts = MultiDiscreteTensorSpec((3, 2, 3))
Expand Down Expand Up @@ -3361,6 +3429,32 @@ def __init__(
self.remove_singleton = remove_singleton

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
The mask can also be set during initialization of the spec.
Args:
mask (torch.Tensor or None): boolean mask. If None, the mask is
disabled. Otherwise, the shape of the mask must be expandable to
the shape of the equivalent one-hot spec. ``False`` masks an
outcome and ``True`` leaves the outcome unmasked. If all of the
possible outcomes are masked, then an error is raised when a
sample is taken.
Examples:
>>> mask = torch.tensor([False, False, True,
... True, True])
>>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes for the first
>>> # group are masked, but neither of the two possible
>>> # outcomes for the second group are masked.
>>> ts.rand()
tensor([[2, 1],
[2, 0],
[2, 1],
[2, 1],
[2, 0]])
"""
if mask is not None:
try:
mask = mask.expand(_remove_neg_shapes(*self.shape[:-1], mask.shape[-1]))
Expand Down

0 comments on commit f764c02

Please sign in to comment.