Skip to content

Commit

Permalink
[BugFix] Fix CompositeSpec.to_numpy method (pytorch#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
riiswa authored Feb 21, 2023
1 parent 5bd4b4f commit 726dc42
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ def test_is_in(self, is_complete, device, dtype):
r = ts.rand()
assert ts.is_in(r)

def test_to_numpy(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
for _ in range(100):
r = ts.rand()
for key, value in ts.to_numpy(r).items():
spec = ts[key]
assert (spec.to_numpy(r[key]) == value).all()

@pytest.mark.parametrize("shape", [[], [3]])
def test_project(self, is_complete, device, dtype, shape):
ts = self._composite_spec(is_complete, device, dtype)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,7 +1947,7 @@ def clone(self) -> CompositeSpec:
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
return {key: self[key]._to_numpy(val) for key, val in val.items()}
return {key: self[key].to_numpy(val) for key, val in val.items()}

def zero(self, shape=None) -> TensorDictBase:
if shape is None:
Expand Down

0 comments on commit 726dc42

Please sign in to comment.