Skip to content

Commit

Permalink
SubTensorDict set method (pytorch#143)
Browse files Browse the repository at this point in the history
* init

* Update test_shared.py
  • Loading branch information
vmoens authored May 17, 2022
1 parent c06a538 commit 49c2ea9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
2 changes: 1 addition & 1 deletion test/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_memmap(idx, dtype, large_scale=False):

td_to_copy = td[idx].contiguous()
for k in td_to_copy.keys():
td_to_copy.set(k, torch.ones_like(td_to_copy.get(k)))
td_to_copy.set_(k, torch.ones_like(td_to_copy.get(k)))

print("\nTesting writing to TD")
for i in range(2):
Expand Down
50 changes: 41 additions & 9 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,36 @@ def test_tensordict_indexing(device):
def test_subtensordict_construction(device):
torch.manual_seed(1)
td = TensorDict({}, batch_size=(4, 5))
td.set("key1", torch.randn(4, 5, 1, device=device))
td.set("key2", torch.randn(4, 5, 6, dtype=torch.double, device=device))
val1 = torch.randn(4, 5, 1, device=device)
val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device)
val1_copy = val1.clone()
val2_copy = val2.clone()
td.set("key1", val1)
td.set("key2", val2)
std1 = td.get_sub_tensordict(2)
std2 = std1.get_sub_tensordict(2)
std_control = td.get_sub_tensordict((2, 2))
idx = (2, 2)
std_control = td.get_sub_tensordict(idx)
assert (std_control.get("key1") == std2.get("key1")).all()
assert (std_control.get("key2") == std2.get("key2")).all()

# write values
std_control.set("key1", torch.randn(1, device=device))
std_control.set("key2", torch.randn(6, device=device, dtype=torch.double))
with pytest.raises(RuntimeError, match="is prohibited for existing tensors"):
std_control.set("key1", torch.randn(1, device=device))
with pytest.raises(RuntimeError, match="is prohibited for existing tensors"):
std_control.set("key2", torch.randn(6, device=device, dtype=torch.double))

subval1 = torch.randn(1, device=device)
subval2 = torch.randn(6, device=device, dtype=torch.double)
std_control.set_("key1", subval1)
std_control.set_("key2", subval2)
assert (val1_copy[idx] != subval1).all()
assert (td.get("key1")[idx] == subval1).all()
assert (td.get("key1")[1, 1] == val1_copy[1, 1]).all()

assert (val2_copy[idx] != subval2).all()
assert (td.get("key2")[idx] == subval2).all()
assert (td.get("key2")[1, 1] == val2_copy[1, 1]).all()

assert (std_control.get("key1") == std2.get("key1")).all()
assert (std_control.get("key2") == std2.get("key2")).all()
Expand Down Expand Up @@ -851,7 +870,10 @@ def test_unsqueeze(self, td_name, squeeze_dim):
td = getattr(self, td_name)
td_unsqueeze = torch.unsqueeze(td, dim=squeeze_dim)
tensor = torch.ones_like(td.get("a").unsqueeze(squeeze_dim))
td_unsqueeze.set("a", tensor)
if td_name == "sub_td":
td_unsqueeze.set_("a", tensor)
else:
td_unsqueeze.set("a", tensor)
assert (td_unsqueeze.get("a") == tensor).all()
assert (td.get("a") == tensor.squeeze(squeeze_dim)).all()
assert td_unsqueeze.squeeze(squeeze_dim) is td
Expand All @@ -876,7 +898,10 @@ def test_squeeze(self, td_name, squeeze_dim=-1):
td_squeeze = torch.squeeze(td, dim=-1)
tensor_squeeze_dim = td.batch_dims + squeeze_dim
tensor = torch.ones_like(td.get("a").squeeze(tensor_squeeze_dim))
td_squeeze.set("a", tensor)
if td_name == "sub_td":
td_squeeze.set_("a", tensor)
else:
td_squeeze.set("a", tensor)
assert (td_squeeze.get("a") == tensor).all()
assert (td.get("a") == tensor.unsqueeze(tensor_squeeze_dim)).all()
assert td_squeeze.unsqueeze(squeeze_dim) is td
Expand All @@ -902,7 +927,10 @@ def test_view(self, td_name):
tensor = td.get("a")
tensor = tensor.view(-1, tensor.numel() // np.prod(td.batch_size))
tensor = torch.ones_like(tensor)
td_view.set("a", tensor)
if td_name == "sub_td":
td_view.set_("a", tensor)
else:
td_view.set("a", tensor)
assert (td_view.get("a") == tensor).all()
assert (td.get("a") == tensor.view(td.get("a").shape)).all()
assert td_view.view(td.shape) is td
Expand Down Expand Up @@ -960,7 +988,11 @@ def test_rename_key(self, td_name) -> None:
torch.testing.assert_allclose(a, z)

new_z = torch.randn_like(z)
td.set("z", new_z)
if td_name == "sub_td":
td.set_("z", new_z)
else:
td.set("z", new_z)

torch.testing.assert_allclose(new_z, td.get("z"))

new_z = torch.randn_like(z)
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,11 @@ def set(
raise RuntimeError("Cannot modify immutable TensorDict")
if inplace and key in self.keys():
return self.set_(key, tensor)
elif key in self.keys():
raise RuntimeError(
"Calling `SubTensorDict.set(key, value, inplace=False)` is prohibited for existing tensors. "
"Consider calling `SubTensorDict.set_(...)` or cloning your tensordict first."
)

tensor = self._process_tensor(
tensor, check_device=False, check_tensor_shape=False
Expand Down

0 comments on commit 49c2ea9

Please sign in to comment.