Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Heterogeneous Environments compatibility #1411

Merged
merged 106 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
07b4cb0
init
vmoens Jul 6, 2023
26b6c31
amend
vmoens Jul 6, 2023
bf514b9
amend
vmoens Jul 6, 2023
5de5d0c
Merge branch 'main' into fix_compositespec
matteobettini Jul 17, 2023
a1c947e
amend
matteobettini Jul 18, 2023
51a2721
amend
matteobettini Jul 18, 2023
58b5e2d
amend
matteobettini Jul 18, 2023
d598659
amend
matteobettini Jul 18, 2023
9b6ae84
Merge branch 'main' into fix_compositespec
matteobettini Jul 18, 2023
42dafc8
amend
matteobettini Jul 18, 2023
7693fb6
fix shape in lazy stacked specs
matteobettini Jul 18, 2023
d5bd9d7
contraints in lazy stacked specs
matteobettini Jul 18, 2023
dd2ea3e
contraints in lazy stacked specs
matteobettini Jul 18, 2023
64cbe1b
len
matteobettini Jul 18, 2023
8f901fe
amend
matteobettini Jul 18, 2023
f0a3823
amend
matteobettini Jul 18, 2023
5a040c6
amend
matteobettini Jul 19, 2023
d1c839c
amend
matteobettini Jul 19, 2023
11ef80c
amend
matteobettini Jul 19, 2023
0fb3f41
amend
matteobettini Jul 19, 2023
632cff8
amend
matteobettini Jul 19, 2023
518432a
amend
matteobettini Jul 19, 2023
58a38f8
amend
matteobettini Jul 19, 2023
a533cbc
print
matteobettini Jul 19, 2023
1ff0fce
print
matteobettini Jul 19, 2023
038fd96
amend
matteobettini Jul 20, 2023
18d8aa1
amend
matteobettini Jul 20, 2023
eed9b92
Merge branch 'fix_compositespec' into het_env_test
matteobettini Jul 20, 2023
753bd64
amend
matteobettini Jul 20, 2023
ca81957
amend
matteobettini Jul 20, 2023
be898ac
Merge branch 'fix_compositespec' into het_env_test
matteobettini Jul 20, 2023
15a8b8c
amend
matteobettini Jul 21, 2023
03ad1ba
Merge branch 'het_env_test' into hetero_step_mdp
matteobettini Jul 21, 2023
caffc44
amend
matteobettini Jul 21, 2023
f26ca9f
amend
matteobettini Jul 21, 2023
e9c7299
unlazyfy
matteobettini Jul 23, 2023
e6ce009
Merge branch 'hetero_step_mdp' into fix_compositespec
matteobettini Jul 24, 2023
56c6199
unlazyfy
matteobettini Jul 24, 2023
035c452
unlazyfy
matteobettini Jul 24, 2023
4dcd16f
unlazyfy
matteobettini Jul 24, 2023
dcde4f3
typo
matteobettini Jul 24, 2023
0719729
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 24, 2023
16b3a39
amend
matteobettini Jul 24, 2023
8bb00cb
amend
matteobettini Jul 24, 2023
63fb017
amend
matteobettini Jul 24, 2023
02ae8e2
amend
matteobettini Jul 24, 2023
7aa1024
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 24, 2023
4f53955
amend
matteobettini Jul 24, 2023
03b64b0
temp
matteobettini Jul 24, 2023
984d812
fix
matteobettini Jul 24, 2023
3fae396
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 25, 2023
e8ffe83
amend
matteobettini Jul 25, 2023
4d83710
rename
matteobettini Jul 25, 2023
d31aeee
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 25, 2023
ad9bf12
fix
matteobettini Jul 25, 2023
a535bfa
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 25, 2023
1717594
amend
matteobettini Jul 25, 2023
01dad7a
amend
matteobettini Jul 25, 2023
bc4757f
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 25, 2023
710497f
refactor names
matteobettini Jul 25, 2023
99ece39
refactor names
matteobettini Jul 25, 2023
c57393c
Merge branch 'unlazyfy' into fix_compositespec
matteobettini Jul 25, 2023
961467e
amend
matteobettini Jul 25, 2023
d451c78
amend
matteobettini Jul 25, 2023
8720aed
amend
matteobettini Jul 25, 2023
325199c
amend
matteobettini Jul 25, 2023
1086a9d
typos
matteobettini Jul 25, 2023
356c48e
tests
matteobettini Jul 25, 2023
09bb27d
amend
matteobettini Jul 25, 2023
85ae904
remove unrelated
matteobettini Jul 25, 2023
c945632
remove unrelated
matteobettini Jul 25, 2023
e22fcf9
remove unrelated
matteobettini Jul 25, 2023
b9ed2ae
het envs
matteobettini Jul 25, 2023
55fb67d
remove exclusive keys test
matteobettini Jul 25, 2023
b377efa
test env
matteobettini Jul 25, 2023
6842598
step mdp test temp
matteobettini Jul 25, 2023
61fb590
amend
matteobettini Jul 26, 2023
612d552
amend
matteobettini Jul 26, 2023
4f39403
amend
matteobettini Jul 26, 2023
a50a26f
amend
matteobettini Jul 27, 2023
feb93b5
Merge branch 'fix_compositespec' into het_env
matteobettini Jul 27, 2023
a3ec6e7
amend
matteobettini Jul 27, 2023
aaa55a4
amend
matteobettini Jul 27, 2023
6ade0e3
amend
matteobettini Jul 27, 2023
4125023
amend
matteobettini Jul 27, 2023
8cbffbc
amend
matteobettini Jul 27, 2023
2c0d29b
fix
matteobettini Jul 28, 2023
ce18062
Apply suggestions from code review
matteobettini Jul 28, 2023
38c4747
Apply suggestions from code review
matteobettini Jul 28, 2023
0c37c7e
docs
matteobettini Jul 28, 2023
74b0068
docs
matteobettini Jul 28, 2023
5a958a7
Merge branch 'main' into fix_compositespec
matteobettini Jul 28, 2023
230456b
docs
matteobettini Jul 28, 2023
2f5b8a3
Merge branch 'fix_compositespec' into het_env
matteobettini Jul 28, 2023
b5c3d0b
Merge branch 'main' into het_env
matteobettini Aug 1, 2023
7e31559
update
matteobettini Aug 1, 2023
2d975ad
fix
matteobettini Aug 1, 2023
f215a04
amend
matteobettini Aug 1, 2023
1ac8a22
amend
matteobettini Aug 1, 2023
4c7a1e9
test
matteobettini Aug 1, 2023
6c873df
amend
matteobettini Aug 2, 2023
bd18dd1
Merge branch 'main' into het_env
matteobettini Aug 4, 2023
64d0ed8
empty
matteobettini Aug 4, 2023
a7e9738
Merge branch 'main' into het_env
matteobettini Aug 4, 2023
e26a74b
try block for het keys
matteobettini Aug 4, 2023
9d43ca1
comments
matteobettini Aug 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jul 6, 2023
commit 26b6c31f196f14ac12cb63f40e04b14f6eca852d
63 changes: 63 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,69 @@ def test_to_numpy(self):
with pytest.raises(AssertionError):
c.to_numpy(td_fail, safe=True)

def test_unsqueeze(self):
c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
c2 = CompositeSpec(
a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], 1)
for unsq in range(-2, 3):
cu = c.unsqueeze(unsq)
shape = list(c.shape)
new_unsq = unsq if unsq >= 0 else c.ndim + unsq + 1
shape.insert(new_unsq, 1)
assert cu.shape == torch.Size(shape)
cus = cu.squeeze(unsq)
assert c.shape == cus.shape, unsq
assert cus == c

assert c.squeeze().shape == torch.Size([2, 3])

specs = [
CompositeSpec(
{
"observation_0": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_1": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_2": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_3": UnboundedContinuousTensorSpec(
shape=torch.Size([4]), device="cpu", dtype=torch.float32
)
}
),
]

c = torch.stack(specs, dim=0)
cu = c.unsqueeze(0)
assert cu.shape == torch.Size([1, 4])
cus = cu.squeeze(0)
assert cus == c


# MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
@pytest.mark.parametrize(
Expand Down
72 changes: 70 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3125,7 +3125,14 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N
pass

def __eq__(self, other):
pass
if not isinstance(other, LazyStackedCompositeSpec):
return False
if len(self._specs) != len(other._specs):
return False
for _spec1, _spec2 in zip(self._specs, other._specs):
if _spec1 != _spec2:
return False
return True

def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
if safe is None:
Expand Down Expand Up @@ -3164,9 +3171,13 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
) -> KeysView:
return self._specs[0].keys(
keys = self._specs[0].keys(
include_nested=include_nested, leaves_only=leaves_only
)
keys = set(keys)
for spec in self._specs[1:]:
keys = keys.intersection(spec.keys(include_nested, leaves_only))
return sorted(keys, key=str)

def project(self, val: TensorDictBase) -> TensorDictBase:
raise NotImplementedError
Expand Down Expand Up @@ -3231,6 +3242,63 @@ def set(self, name, spec):
)
self._specs[name] = spec

def unsqueeze(self, dim: int):
if dim < 0:
new_dim = dim + len(self.shape) + 1
else:
new_dim = dim
if new_dim > len(self.shape) or new_dim < 0:
raise ValueError(f"Cannot unsqueeze along dim {dim}.")
new_stack_dim = self.dim if self.dim < new_dim else self.dim + 1
if new_dim > self.dim:
# unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1
new_stack_dim = self.dim
new_dim = new_dim - 1
else:
# unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1
new_stack_dim = self.dim + 1
return LazyStackedCompositeSpec(
*[spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim
)

def squeeze(self, dim: int=None):
if dim is None:
size = self.shape
if len(size) == 1 or size.count(1) == 0:
return self
first_singleton_dim = size.index(1)

squeezed_dict = self.squeeze(first_singleton_dim)
return squeezed_dict.squeeze(dim=None)

if dim < 0:
new_dim = self.ndim + dim
else:
new_dim = dim

if self.shape and (new_dim >= self.ndim or new_dim < 0):
raise RuntimeError(
f"squeezing is allowed for dims comprised between 0 and "
f"spec.ndim only. Got dim={dim} and shape"
f"={self.shape}."
)

if new_dim >= self.ndim or self.shape[new_dim] != 1:
return self

if new_dim == self.dim:
return self._specs[0]
if new_dim > self.dim:
# squeeze 2, stack is on 1 => squeeze 1, stack along 1
new_stack_dim = self.dim
new_dim = new_dim - 1
else:
# squeeze 0, stack is on 1 => squeeze 0, stack on 1
new_stack_dim = self.dim - 1
return LazyStackedCompositeSpec(
*[spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim
)


# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
@TensorSpec.implements_for_spec(torch.stack)
Expand Down