diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0006213cd27..ae5b58a06a0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1326,6 +1326,8 @@ class OneHotDiscreteTensorSpec(TensorSpec): discrete outcomes are sampled from an arbitrary set, whose elements will be mapped in a register to a series of unique one-hot binary vectors). + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. """ @@ -1368,6 +1370,25 @@ def n(self): return self.space.n def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the spec. ``False`` masks an outcome and ``True`` + leaves the outcome unmasked. If all of the possible outcomes are + masked, then an error is raised when a sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, False]) + >>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes are masked + >>> ts.rand() + tensor([[1, 0, 0], + [1, 0, 0]]) + """ if mask is not None: try: mask = mask.expand(self._safe_shape) @@ -2516,6 +2537,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. Examples: >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) @@ -2564,6 +2587,28 @@ def __init__( self.update_mask(mask) def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the spec. ``False`` masks an outcome and ``True`` + leaves the outcome unmasked. If all of the possible outcomes are + masked, then an error is raised when a sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, False, + ... True, True]) + >>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes for the first + >>> # one-hot group are masked, but neither of the two possible + >>> # outcomes for the second one-hot group are masked. + >>> ts.rand() + tensor([[1, 0, 0, 0, 1], + [1, 0, 0, 1, 0]]) + """ if mask is not None: try: mask = mask.expand(*self._safe_shape) @@ -2900,6 +2945,8 @@ class DiscreteTensorSpec(TensorSpec): shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. """ @@ -2933,6 +2980,25 @@ def n(self): return self.space.n def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the equivalent one-hot spec. ``False`` masks an + outcome and ``True`` leaves the outcome unmasked. If all of the + possible outcomes are masked, then an error is raised when a + sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, True]) + >>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask) + >>> # One of the three possible outcomes is masked + >>> ts.rand() + tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2]) + """ if mask is not None: try: mask = mask.expand(_remove_neg_shapes(*self.shape, self.space.n)) @@ -3315,6 +3381,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): dtype (str or torch.dtype, optional): dtype of the tensors. remove_singleton (bool, optional): if ``True``, singleton samples (of size [1]) will be squeezed. Defaults to ``True``. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. Examples: >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) @@ -3361,6 +3429,32 @@ def __init__( self.remove_singleton = remove_singleton def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the equivalent one-hot spec. ``False`` masks an + outcome and ``True`` leaves the outcome unmasked. If all of the + possible outcomes are masked, then an error is raised when a + sample is taken. + + Examples: + >>> mask = torch.tensor([False, False, True, + ... True, True]) + >>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes for the first + >>> # group are masked, but neither of the two possible + >>> # outcomes for the second group are masked. + >>> ts.rand() + tensor([[2, 1], + [2, 0], + [2, 1], + [2, 1], + [2, 0]]) + """ if mask is not None: try: mask = mask.expand(_remove_neg_shapes(*self.shape[:-1], mask.shape[-1]))