Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Apr 30, 2024
commit 381d3545ea22ae94014a3070e10d6a8af3330f24
79 changes: 79 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3404,6 +3404,85 @@ def test_project(self, shape, device, spectype, rand_shape, n=5):
assert (sp != s).all()


class TestDynamicSpec:
def test_all(self):
spec = UnboundedContinuousTensorSpec((-1, 1, 2))
unb = spec
assert spec.shape == (-1, 1, 2)
x = torch.randn(3, 1, 2)
xunb = x
assert spec.is_in(x)

spec = UnboundedDiscreteTensorSpec((-1, 1, 2))
unbd = spec
assert spec.shape == (-1, 1, 2)
x = torch.randint(10, (3, 1, 2))
xunbd = x
assert spec.is_in(x)

spec = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1)
bound = spec
assert spec.shape == (-1, 1, 2)
x = torch.rand((3, 1, 2))
xbound = x
assert spec.is_in(x)

spec = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4)
oneh = spec
assert spec.shape == (-1, 1, 2, 4)
x = torch.zeros((3, 1, 2, 4), dtype=torch.bool)
x[..., 0] = 1
xoneh = x
assert spec.is_in(x)

spec = DiscreteTensorSpec(shape=(-1, 1, 2), n=4)
disc = spec
assert spec.shape == (-1, 1, 2)
x = torch.randint(4, (3, 1, 2))
xdisc = x
assert spec.is_in(x)

spec = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4])
moneh = spec
assert spec.shape == (-1, 1, 2, 7)
x = torch.zeros((3, 1, 2, 7), dtype=torch.bool)
x[..., 0] = 1
x[..., -1] = 1
xmoneh = x
assert spec.is_in(x)

spec = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4])
mdisc = spec
assert spec.mask is None
assert spec.shape == (-1, 1, 2, 2)
x = torch.randint(3, (3, 1, 2, 2))
xmdisc = x
assert spec.is_in(x)

spec = CompositeSpec(
unb=unb,
unbd=unbd,
bound=bound,
oneh=oneh,
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2)
)
assert spec.shape == (-1, 1, 2)

data = TensorDict({
"unb": xunb,
"unbd": xunbd,
"bound": xbound,
"oneh": xoneh,
"disc": xdisc,
"moneh": xmoneh,
"mdisc": xmdisc,
}, [3, 1, 2])
assert spec.is_in(data)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
110 changes: 71 additions & 39 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,13 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:

def is_in(self, val: torch.Tensor) -> bool:
if self.mask is None:
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
shape_match = val.shape == shape
if not shape_match:
return False
dtype_match = val.dtype == self.dtype
if not dtype_match:
return False
return (val.sum(-1) == 1).all()
shape = self.mask.shape
shape = torch.broadcast_shapes(shape, val.shape)
Expand Down Expand Up @@ -1641,39 +1648,44 @@ def __init__(
shape = torch.Size([shape])
else:
shape = torch.Size(list(shape))

if shape is not None:
shape_corr = _remove_neg_shapes(shape)
else:
shape_corr = None
if high.ndimension():
if shape is not None and shape != high.shape:
if shape_corr is not None and shape_corr != high.shape_corr:
raise RuntimeError(err_msg)
shape = high.shape
low = low.expand(shape).clone()
shape = high.shape_corr
low = low.expand(shape_corr).clone()
elif low.ndimension():
if shape is not None and shape != low.shape:
if shape_corr is not None and shape_corr != low.shape_corr:
raise RuntimeError(err_msg)
shape = low.shape
high = high.expand(shape).clone()
elif shape is None:
shape = low.shape_corr
high = high.expand(shape_corr).clone()
elif shape_corr is None:
raise RuntimeError(err_msg)
else:
low = low.expand(shape).clone()
high = high.expand(shape).clone()
low = low.expand(shape_corr).clone()
high = high.expand(shape_corr).clone()

if low.numel() > high.numel():
high = high.expand_as(low).clone()
elif high.numel() > low.numel():
low = low.expand_as(high).clone()
if shape is None:
shape = low.shape
if shape_corr is None:
shape = low.shape_corr
else:
if isinstance(shape, float):
shape = torch.Size([shape])
elif not isinstance(shape, torch.Size):
shape = torch.Size(shape)
shape_err_msg = f"low and shape mismatch, got {low.shape} and {shape}"
if len(low.shape) != len(shape):
raise RuntimeError(shape_err_msg)
if not all(_s == _sa for _s, _sa in zip(shape, low.shape)):
raise RuntimeError(shape_err_msg)
if isinstance(shape_corr, float):
shape_corr = torch.Size([shape_corr])
elif not isinstance(shape_corr, torch.Size):
shape_corr = torch.Size(shape_corr)
shape_corr_err_msg = (
f"low and shape_corr mismatch, got {low.shape} and {shape_corr}"
)
if len(low.shape) != len(shape_corr):
raise RuntimeError(shape_corr_err_msg)
if not all(_s == _sa for _s, _sa in zip(shape_corr, low.shape)):
raise RuntimeError(shape_corr_err_msg)
self.shape = shape

super().__init__(
Expand Down Expand Up @@ -1844,10 +1856,18 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
shape_match = val.shape == shape
if not shape_match:
return False
dtype_match = val.dtype == self.dtype
if not dtype_match:
return False
try:
return (val >= self.space.low.to(val.device)).all() and (
within_bounds = (val >= self.space.low.to(val.device)).all() and (
val <= self.space.high.to(val.device)
).all()
return within_bounds
except RuntimeError as err:
if "The size of tensor a" in str(err):
warnings.warn(f"Got a shape mismatch: {str(err)}")
Expand Down Expand Up @@ -1959,7 +1979,7 @@ def one(self, batch_size):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
return (
isinstance(val, NonTensorData)
and val.shape == shape
Expand Down Expand Up @@ -2079,7 +2099,7 @@ def rand(self, shape=None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
return val.shape == shape and val.dtype == self.dtype

def _project(self, val: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -2178,7 +2198,7 @@ def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.int64,
):
if isinstance(shape, int):
shape = torch.Size([shape])
Expand All @@ -2195,8 +2215,8 @@ def __init__(
min_value = torch.iinfo(dtype).min
max_value = torch.iinfo(dtype).max
space = ContinuousBox(
torch.full(shape, min_value, device=device),
torch.full(shape, max_value, device=device),
torch.full(_remove_neg_shapes(shape), min_value, device=device),
torch.full(_remove_neg_shapes(shape), max_value, device=device),
)

super().__init__(
Expand Down Expand Up @@ -2234,7 +2254,7 @@ def rand(self, shape=None) -> torch.Tensor:
return r.to(self.device)

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
return val.shape == shape and val.dtype == self.dtype

def expand(self, *shape):
Expand Down Expand Up @@ -2819,6 +2839,13 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:

def is_in(self, val: torch.Tensor) -> bool:
if self.mask is None:
shape = torch.broadcast_shapes(_remove_neg_shapes(self.shape), val.shape)
shape_match = val.shape == shape
if not shape_match:
return False
dtype_match = val.dtype == self.dtype
if not dtype_match:
return False
return (0 <= val).all() and (val < self.space.n).all()
shape = self.mask.shape
shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
Expand Down Expand Up @@ -3166,7 +3193,7 @@ def __init__(
nvec: Union[Sequence[int], torch.Tensor, int],
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.long,
dtype: Optional[Union[str, torch.dtype]] = torch.int64,
mask: torch.Tensor | None = None,
):
if not isinstance(nvec, torch.Tensor):
Expand All @@ -3185,7 +3212,7 @@ def __init__(
f"Got nvec.shape[-1]={sum(nvec)} and shape={shape}."
)

self.nvec = self.nvec.expand(shape)
self.nvec = self.nvec.expand(_remove_neg_shapes(shape))

space = BoxList.from_nvec(self.nvec)
super(DiscreteTensorSpec, self).__init__(
Expand Down Expand Up @@ -3333,18 +3360,19 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:

def is_in(self, val: torch.Tensor) -> bool:
if self.mask is not None:
return all(
spec.is_in(_val)
for (_val, spec) in zip(val.unbind(-1), self._split_self())
)
vals = val.unbind(-1)
splits = self._split_self()
if not len(vals) == len(splits):
return False
return all(spec.is_in(val) for (val, spec) in zip(vals, splits))

if val.ndim < 1:
val = val.unsqueeze(0)
val_have_wrong_dim = (
self.shape != torch.Size([1])
and val.shape[-len(self.shape) :] != self.shape
)
if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim:
shape = _remove_neg_shapes(self.shape)
shape = torch.broadcast_shapes(shape, val.shape)
if shape != val.shape:
return False
if self.dtype != val.dtype:
return False
val_device = val.device
return (
Expand Down Expand Up @@ -4732,3 +4760,7 @@ def _minmax_dtype(dtype):
else:
info = torch.iinfo(dtype)
return info.min, info.max


def _remove_neg_shapes(shape):
return torch.Size([d if d >= 0 else 1 for d in shape])
Loading