Skip to content

Commit

Permalink
[BugFix] Clone memmap tensors on regular tensors and other replay buf…
Browse files Browse the repository at this point in the history
…fer improvements (pytorch#340)
  • Loading branch information
vmoens authored Aug 8, 2022
1 parent 2f57154 commit c61ae7b
Show file tree
Hide file tree
Showing 21 changed files with 215 additions and 166 deletions.
6 changes: 3 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_update_weights(use_async):
policy_state_dict = policy.state_dict()
for worker in range(3):
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
torch.testing.assert_allclose(
torch.testing.assert_close(
state_dict[f"worker{worker}"]["policy_state_dict"][k],
policy_state_dict[k].cpu(),
)
Expand All @@ -534,7 +534,7 @@ def test_update_weights(use_async):
for worker in range(3):
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
with pytest.raises(AssertionError):
torch.testing.assert_allclose(
torch.testing.assert_close(
state_dict[f"worker{worker}"]["policy_state_dict"][k],
policy_state_dict[k].cpu(),
)
Expand All @@ -547,7 +547,7 @@ def test_update_weights(use_async):
policy_state_dict = policy.state_dict()
for worker in range(3):
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
torch.testing.assert_allclose(
torch.testing.assert_close(
state_dict[f"worker{worker}"]["policy_state_dict"][k],
policy_state_dict[k].cpu(),
)
Expand Down
10 changes: 5 additions & 5 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def test_parallel_env_seed(
max_steps=10, auto_reset=False, tensordict=td0_serial
).contiguous()
key = "pixels" if "pixels" in td_serial else "observation"
torch.testing.assert_allclose(
torch.testing.assert_close(
td_serial[:, 0].get("next_" + key), td_serial[:, 1].get(key)
)

Expand All @@ -443,7 +443,7 @@ def test_parallel_env_seed(
td_parallel = env_parallel.rollout(
max_steps=10, auto_reset=False, tensordict=td0_parallel
).contiguous()
torch.testing.assert_allclose(
torch.testing.assert_close(
td_parallel[:, :-1].get("next_" + key), td_parallel[:, 1:].get(key)
)

Expand Down Expand Up @@ -809,13 +809,13 @@ def test_seed():
torch.manual_seed(0)
rollout2 = env2.rollout(max_steps=30)

torch.testing.assert_allclose(
torch.testing.assert_close(
rollout1["observation"][1:], rollout1["next_observation"][:-1]
)
torch.testing.assert_allclose(
torch.testing.assert_close(
rollout2["observation"][1:], rollout2["next_observation"][:-1]
)
torch.testing.assert_allclose(rollout1["observation"], rollout2["observation"])
torch.testing.assert_close(rollout1["observation"], rollout2["observation"])


@pytest.mark.parametrize("keep_other", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def test_gsde(
action1 = module(td).get("action")
action2 = actor(td).get("action")
if gSDE or exploration_mode == "mode":
torch.testing.assert_allclose(action1, action2)
torch.testing.assert_close(action1, action2)
else:
with pytest.raises(AssertionError):
torch.testing.assert_allclose(action1, action2)
torch.testing.assert_close(action1, action2)


@pytest.mark.parametrize(
Expand Down
16 changes: 8 additions & 8 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):
tsf_loc = actor.module[-1].module.transform(td.get("loc"))
if exploration == "random":
with pytest.raises(AssertionError):
torch.testing.assert_allclose(td.get("action"), tsf_loc)
torch.testing.assert_close(td.get("action"), tsf_loc)
else:
torch.testing.assert_allclose(td.get("action"), tsf_loc)
torch.testing.assert_close(td.get("action"), tsf_loc)

value(td)
expected_keys += ["state_action_value"]
Expand Down Expand Up @@ -260,9 +260,9 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration):

if exploration == "random":
with pytest.raises(AssertionError):
torch.testing.assert_allclose(td_clone.get("action"), tsf_loc)
torch.testing.assert_close(td_clone.get("action"), tsf_loc)
else:
torch.testing.assert_allclose(td_clone.get("action"), tsf_loc)
torch.testing.assert_close(td_clone.get("action"), tsf_loc)

value = actor_value.get_value_operator()
expected_keys = [
Expand Down Expand Up @@ -354,9 +354,9 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
tsf_loc = actor.module[-1].module.transform(td_clone.get("loc"))
if exploration == "random":
with pytest.raises(AssertionError):
torch.testing.assert_allclose(td_clone.get("action"), tsf_loc)
torch.testing.assert_close(td_clone.get("action"), tsf_loc)
else:
torch.testing.assert_allclose(td_clone.get("action"), tsf_loc)
torch.testing.assert_close(td_clone.get("action"), tsf_loc)

try:
_assert_keys_match(td_clone, expected_keys)
Expand Down Expand Up @@ -472,9 +472,9 @@ def test_redq_make(device, from_pixels, gsde, exploration):
tsf_loc = actor.module[-1].module.transform(td.get("loc"))
if exploration == "random":
with pytest.raises(AssertionError):
torch.testing.assert_allclose(td.get("action"), tsf_loc)
torch.testing.assert_close(td.get("action"), tsf_loc)
else:
torch.testing.assert_allclose(td.get("action"), tsf_loc)
torch.testing.assert_close(td.get("action"), tsf_loc)

qvalue(td)
expected_keys = [
Expand Down
4 changes: 2 additions & 2 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_memmap_ownership_2pass(value):
assert m1._has_ownership + m2._has_ownership + m3._has_ownership == 1


def test_memmap_clone():
def test_memmap_new():
t = torch.tensor([1])
m1 = MemmapTensor(t)
m2 = m1.clone()
m2 = MemmapTensor(m1)
assert isinstance(m2, MemmapTensor)
assert m2.filename != m1.filename
assert m2.filename == m2.file.name
Expand Down
12 changes: 5 additions & 7 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_noisy(layer_class, device, seed=0):
layer.reset_noise()
y2 = layer(x)
y3 = layer(x)
torch.testing.assert_allclose(y2, y3)
torch.testing.assert_close(y2, y3)
with pytest.raises(AssertionError):
torch.testing.assert_allclose(y1, y2)
torch.testing.assert_close(y1, y2)


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -152,13 +152,11 @@ def test_actorcritic(device):
td_policy = policy_op(td.clone())
value_op = op.get_value_operator()
td_value = value_op(td)
torch.testing.assert_allclose(td_total.get("action"), td_policy.get("action"))
torch.testing.assert_allclose(
torch.testing.assert_close(td_total.get("action"), td_policy.get("action"))
torch.testing.assert_close(
td_total.get("sample_log_prob"), td_policy.get("sample_log_prob")
)
torch.testing.assert_allclose(
td_total.get("state_value"), td_value.get("state_value")
)
torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value"))

value_params = set(
list(op.get_value_operator().parameters()) + list(op.module[0].parameters())
Expand Down
2 changes: 1 addition & 1 deletion test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_multistep(n, key, device, T=11):
assert ((next_obs == true_next_obs) | terminated[:, (1 + ms.n_steps_max) :]).all()

# test gamma computation
torch.testing.assert_allclose(
torch.testing.assert_close(
ms_tensordict.get("gamma"), ms.gamma ** ms_tensordict.get("steps_to_next_obs")
)

Expand Down
8 changes: 2 additions & 6 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,12 @@ def test_prb(priority_key, contiguous, device):
rb.update_priority(s)
s = rb.sample(5)
assert (val == s.get("a")).sum() >= 1
torch.testing.assert_allclose(
td2[idx0].get("a").view(1), s.get("a").unique().view(1)
)
torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1))

# test updating values of original td
td2.set_("a", torch.ones_like(td2.get("a")))
s = rb.sample(5)
torch.testing.assert_allclose(
td2[idx0].get("a").view(1), s.get("a").unique().view(1)
)
torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1))


@pytest.mark.parametrize("stack", [False, True])
Expand Down
15 changes: 9 additions & 6 deletions test/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,25 @@ def test_memmap(idx, dtype, large_scale=False):
print("\nTesting writing to TD")
for i in range(2):
t0 = time.time()
td_sm[idx].update_(td_to_copy)
sub_td_sm = td_sm.get_sub_tensordict(idx)
sub_td_sm.update_(td_to_copy)
if i == 1:
print(f"sm td: {time.time() - t0:4.4f} sec")
torch.testing.assert_allclose(td_sm[idx].get("a"), td_to_copy.get("a"))
torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a"))

t0 = time.time()
td_memmap[idx].update_(td_to_copy)
sub_td_sm = td_memmap.get_sub_tensordict(idx)
sub_td_sm.update_(td_to_copy)
if i == 1:
print(f"memmap td: {time.time() - t0:4.4f} sec")
torch.testing.assert_allclose(td_memmap[idx].get("a"), td_to_copy.get("a"))
torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a"))

t0 = time.time()
td_saved[idx].update_(td_to_copy)
sub_td_sm = td_saved.get_sub_tensordict(idx)
sub_td_sm.update_(td_to_copy)
if i == 1:
print(f"saved td: {time.time() - t0:4.4f} sec")
torch.testing.assert_allclose(td_saved[idx].get("a"), td_to_copy.get("a"))
torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a"))


if __name__ == "__main__":
Expand Down
47 changes: 26 additions & 21 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_tensordict_set(device):
# test set_at_ with dtype casting
x = torch.randn(6, dtype=torch.double, device=device)
td.set_at_("key2", x, (2, 2)) # robust to dtype casting
torch.testing.assert_allclose(td.get("key2")[2, 2], x.to(torch.float))
torch.testing.assert_close(td.get("key2")[2, 2], x.to(torch.float))

td.set("key1", torch.zeros(4, 5, dtype=torch.double, device=device), inplace=True)
assert (td.get("key1") == 0).all()
Expand Down Expand Up @@ -136,13 +136,13 @@ def test_tensordict_indexing(device):
batch_size=[3, 4],
)
td[0].set_("key1", x)
torch.testing.assert_allclose(td.get("key1")[0], x)
torch.testing.assert_allclose(td.get("key1")[0], td[0].get("key1"))
torch.testing.assert_close(td.get("key1")[0], x)
torch.testing.assert_close(td.get("key1")[0], td[0].get("key1"))

y = torch.randn(3, 5, device=device)
td[:, 0].set_("key1", y)
torch.testing.assert_allclose(td.get("key1")[:, 0], y)
torch.testing.assert_allclose(td.get("key1")[:, 0], td[:, 0].get("key1"))
torch.testing.assert_close(td.get("key1")[:, 0], y)
torch.testing.assert_close(td.get("key1")[:, 0], td[:, 0].get("key1"))


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -209,8 +209,8 @@ def test_mask_td(device):
mask_list = [False, True, False, True]

td_masked2 = td[mask_list, 0]
torch.testing.assert_allclose(td.get("key1")[mask_list, 0], td_masked2.get("key1"))
torch.testing.assert_allclose(td.get("key2")[mask_list, 0], td_masked2.get("key2"))
torch.testing.assert_close(td.get("key1")[mask_list, 0], td_masked2.get("key1"))
torch.testing.assert_close(td.get("key2")[mask_list, 0], td_masked2.get("key2"))


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -482,9 +482,9 @@ def test_savedtensordict(device):
]
ss = stack_td(ss_list, 0)
assert ss_list[1] is ss[1]
torch.testing.assert_allclose(ss_list[1].get("a"), vals[1])
torch.testing.assert_allclose(ss_list[1].get("a"), ss[1].get("a"))
torch.testing.assert_allclose(ss[1].get("a"), ss.get("a")[1])
torch.testing.assert_close(ss_list[1].get("a"), vals[1])
torch.testing.assert_close(ss_list[1].get("a"), ss[1].get("a"))
torch.testing.assert_close(ss[1].get("a"), ss.get("a")[1])
assert ss.get("a").device == device


Expand Down Expand Up @@ -1087,26 +1087,26 @@ def test_rename_key(self, td_name, device) -> None:
a = a._tensor
if isinstance(z, MemmapTensor):
z = z._tensor
torch.testing.assert_allclose(a, z)
torch.testing.assert_close(a, z)

new_z = torch.randn_like(z)
if td_name in ("sub_td", "sub_td2"):
td.set_("z", new_z)
else:
td.set("z", new_z)

torch.testing.assert_allclose(new_z, td.get("z"))
torch.testing.assert_close(new_z, td.get("z"))

new_z = torch.randn_like(z)
td.set_("z", new_z)
torch.testing.assert_allclose(new_z, td.get("z"))
torch.testing.assert_close(new_z, td.get("z"))

def test_set_nontensor(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
r = torch.randn_like(td.get("a"))
td.set("numpy", r.cpu().numpy())
torch.testing.assert_allclose(td.get("numpy"), r)
torch.testing.assert_close(td.get("numpy"), r)

@pytest.mark.parametrize(
"actual_index,expected_index",
Expand Down Expand Up @@ -1152,7 +1152,7 @@ def test_setitem(self, td_name, device, idx):
pytest.mark.skip("cannot index tensor with desired index")
return

td_clone = td[idx].clone().zero_()
td_clone = td[idx].to_tensordict().zero_()
td[idx] = td_clone
assert (td[idx].get("a") == 0).all()

Expand Down Expand Up @@ -1555,7 +1555,6 @@ def test_batchsize_reset():

# test that lazy tds return an exception
td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)])
td_stack.to_tensordict().batch_size = [2]
with pytest.raises(
RuntimeError,
match=re.escape(
Expand All @@ -1564,9 +1563,10 @@ def test_batchsize_reset():
),
):
td_stack.batch_size = [2]
td_stack.to_tensordict().batch_size = [2]

td = TensorDict({"a": torch.randn(3, 4)}, [3, 4])
subtd = td[:, torch.tensor([1, 2])]
subtd = td.get_sub_tensordict((slice(None), torch.tensor([1, 2])))
with pytest.raises(
RuntimeError,
match=re.escape(
Expand Down Expand Up @@ -1793,23 +1793,23 @@ def _driver_func(tensordict, tensordict_unbind):
is_done = parents[i].recv()
assert is_done == "done"
new_a = tensordict.get("a").clone().contiguous()
torch.testing.assert_allclose(a_prev - 1, new_a)
torch.testing.assert_close(a_prev - 1, new_a)

a_prev = tensordict.get("a").clone().contiguous()
for i in range(2):
parents[i].send(("update", i))
is_done = parents[i].recv()
assert is_done == "done"
new_a = tensordict.get("a").clone().contiguous()
torch.testing.assert_allclose(a_prev + 1, new_a)
torch.testing.assert_close(a_prev + 1, new_a)

for i in range(2):
parents[i].send(("close", None))
procs[i].join()


@pytest.mark.parametrize(
"td_type", ["contiguous", "stack", "saved", "memmap", "memmap_stack"]
"td_type", ["memmap", "memmap_stack", "contiguous", "stack", "saved"]
)
def test_mp(td_type):
tensordict = TensorDict(
Expand Down Expand Up @@ -1840,7 +1840,12 @@ def test_mp(td_type):
)
else:
raise NotImplementedError
_driver_func(tensordict, tensordict.unbind(0))
_driver_func(
tensordict,
(tensordict.get_sub_tensordict(0), tensordict.get_sub_tensordict(1))
# tensordict,
# tensordict.unbind(0),
)


def test_saved_delete():
Expand Down
2 changes: 1 addition & 1 deletion torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def seed_generator(seed):


class KeyDependentDefaultDict(collections.defaultdict):
def __init__(self, fun=lambda x: x):
def __init__(self, fun):
self.fun = fun
super().__init__()

Expand Down
Loading

0 comments on commit c61ae7b

Please sign in to comment.