Skip to content

Commit

Permalink
[Refactor] Defaults split_trajs to False (pytorch#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 10, 2023
1 parent eb9a37d commit 2de55cb
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 206 deletions.
33 changes: 19 additions & 14 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,9 @@ def env_fn(seed):
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)

if num_env == 1:
# rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension
rollout1a = rollout1a.unsqueeze(0)
# if num_env == 1:
# # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension
# rollout1a = rollout1a.unsqueeze(0)
assert (
rollout1a.batch_size == b1.batch_size
), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}"
Expand Down Expand Up @@ -690,12 +690,12 @@ def make_frames_per_batch(frames_per_batch):
data1 = []
for d in collector1:
data1.append(d)
count += d.shape[1]
count += d.shape[-1]
if count > max_frames_per_traj:
break

data1 = torch.cat(data1, 1)
data1 = data1[:, :max_frames_per_traj]
data1 = torch.cat(data1, d.ndim - 1)
data1 = data1[..., :max_frames_per_traj]

collector1.shutdown()
del collector1
Expand All @@ -715,12 +715,12 @@ def make_frames_per_batch(frames_per_batch):
data10 = []
for d in collector10:
data10.append(d)
count += d.shape[1]
count += d.shape[-1]
if count > max_frames_per_traj:
break

data10 = torch.cat(data10, 1)
data10 = data10[:, :max_frames_per_traj]
data10 = torch.cat(data10, data1.ndim - 1)
data10 = data10[..., :max_frames_per_traj]

collector10.shutdown()
del collector10
Expand All @@ -740,14 +740,14 @@ def make_frames_per_batch(frames_per_batch):
data20 = []
for d in collector20:
data20.append(d)
count += d.shape[1]
count += d.shape[-1]
if count > max_frames_per_traj:
break

collector20.shutdown()
del collector20
data20 = torch.cat(data20, 1)
data20 = data20[:, :max_frames_per_traj]
data20 = torch.cat(data20, data1.ndim - 1)
data20 = data20[..., :max_frames_per_traj]

assert_allclose_td(data1, data20)
assert_allclose_td(data10, data20)
Expand Down Expand Up @@ -932,7 +932,10 @@ def make_env():
)
@pytest.mark.parametrize("init_random_frames", [0, 50])
@pytest.mark.parametrize("explicit_spec", [True, False])
def test_collector_output_keys(collector_class, init_random_frames, explicit_spec):
@pytest.mark.parametrize("split_trajs", [True, False])
def test_collector_output_keys(
collector_class, init_random_frames, explicit_spec, split_trajs
):
from torchrl.envs.libs.gym import GymEnv

out_features = 1
Expand Down Expand Up @@ -979,6 +982,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
"total_frames": total_frames,
"frames_per_batch": frames_per_batch,
"init_random_frames": init_random_frames,
"split_trajs": split_trajs,
}

if collector_class is not SyncDataCollector:
Expand All @@ -995,7 +999,6 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
"collector",
"hidden1",
"hidden2",
("collector", "mask"),
("next", "hidden1"),
("next", "hidden2"),
("next", "observation"),
Expand All @@ -1005,6 +1008,8 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
"observation",
("collector", "traj_ids"),
}
if split_trajs:
keys.add(("collector", "mask"))
b = next(iter(collector))

assert set(b.keys(True)) == keys
Expand Down
12 changes: 6 additions & 6 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
actor, gamma=gamma, loss_function="l2", delay_value=delay_value
)

ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
ms_td = ms(td.clone())

with _check_td_steady(ms_td):
Expand All @@ -351,7 +351,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
with torch.no_grad():
loss = loss_fn(td)
if n == 0:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum([item for _, item in loss.items()])
_loss_ms = sum([item for _, item in loss_ms.items()])
assert (
Expand Down Expand Up @@ -635,7 +635,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):
delay_value=delay_value,
)

ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
ms = MultiStep(gamma=gamma, n_steps=n).to(device)
ms_td = ms(td.clone())
with _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
Expand Down Expand Up @@ -853,7 +853,7 @@ def test_td3_batcher(
delay_actor=delay_actor,
)

ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
ms = MultiStep(gamma=gamma, n_steps=n).to(device)

td_clone = td.clone()
ms_td = ms(td_clone)
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def test_sac_batcher(
**kwargs,
)

ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
ms = MultiStep(gamma=gamma, n_steps=n).to(device)

td_clone = td.clone()
ms_td = ms(td_clone)
Expand Down Expand Up @@ -1717,7 +1717,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
delay_qvalue=delay_qvalue,
)

ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
ms = MultiStep(gamma=gamma, n_steps=n).to(device)

td_clone = td.clone()
ms_td = ms(td_clone)
Expand Down
168 changes: 153 additions & 15 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,24 @@ def test_multistep(n, key, device, T=11):
)

# assert that done at last step is similar to unterminated traj
assert (ms_tensordict.get("gamma")[4] == ms_tensordict.get("gamma")[0]).all()
assert (
ms_tensordict.get(("next", key))[4] == ms_tensordict.get(("next", key))[0]
).all()
assert (
ms_tensordict.get("steps_to_next_obs")[4]
== ms_tensordict.get("steps_to_next_obs")[0]
).all()
torch.testing.assert_close(
ms_tensordict.get("gamma")[4], ms_tensordict.get("gamma")[0]
)
torch.testing.assert_close(
ms_tensordict.get(("next", key))[4], ms_tensordict.get(("next", key))[0]
)
torch.testing.assert_close(
ms_tensordict.get("steps_to_next_obs")[4],
ms_tensordict.get("steps_to_next_obs")[0],
)

# check that next obs is properly replaced, or that it is terminated
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps_max) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps_max)]
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps)]
terminated = ~ms_tensordict.get("nonterminal")
assert ((next_obs == true_next_obs) | terminated[:, (1 + ms.n_steps_max) :]).all()
assert (
(next_obs == true_next_obs).all(-1) | terminated[:, (1 + ms.n_steps) :]
).all()

# test gamma computation
torch.testing.assert_close(
Expand All @@ -90,10 +94,144 @@ def test_multistep(n, key, device, T=11):
!= ms_tensordict.get(("next", "original_reward"))
).any()
else:
assert (
ms_tensordict.get(("next", "reward"))
== ms_tensordict.get(("next", "original_reward"))
).all()
torch.testing.assert_close(
ms_tensordict.get(("next", "reward")),
ms_tensordict.get(("next", "original_reward")),
)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"batch_size",
[
[
4,
],
[],
[
1,
],
[2, 3],
],
)
@pytest.mark.parametrize(
"T",
[
10,
1,
2,
],
)
@pytest.mark.parametrize(
"obs_dim",
[
[
1,
],
[],
],
)
@pytest.mark.parametrize("unsq_reward", [True, False])
@pytest.mark.parametrize("last_done", [True, False])
@pytest.mark.parametrize("n_steps", [3, 1, 0])
def test_mutistep_cattrajs(
batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps
):
# tests multi-step in the presence of consecutive trajectories.
obs = torch.randn(*batch_size, T + 1, *obs_dim)
reward = torch.rand(*batch_size, T)
action = torch.rand(*batch_size, T)
done = torch.zeros(*batch_size, T + 1, dtype=torch.bool)
done[..., T // 2] = 1
if last_done:
done[..., -1] = 1
if unsq_reward:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)

td = TensorDict(
{
"obs": obs[..., :-1] if not obs_dim else obs[..., :-1, :],
"action": action,
"done": done[..., :-1] if not unsq_reward else done[..., :-1, :],
"next": {
"obs": obs[..., 1:] if not obs_dim else obs[..., 1:, :],
"done": done[..., 1:] if not unsq_reward else done[..., 1:, :],
"reward": reward,
},
},
batch_size=[*batch_size, T],
device=device,
)
ms = MultiStep(0.98, n_steps)
tdm = ms(td)
if n_steps == 0:
# n_steps = 0 has no effect
for k in td["next"].keys():
assert (tdm["next", k] == td["next", k]).all()
else:
next_obs = []
obs = td["next", "obs"]
done = td["next", "done"]
if obs_dim:
obs = obs.squeeze(-1)
if unsq_reward:
done = done.squeeze(-1)
for t in range(T):
idx = t + n_steps
while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1:
idx = idx - 1
next_obs.append(obs[..., idx])
true_next_obs = tdm.get(("next", "obs"))
if obs_dim:
true_next_obs = true_next_obs.squeeze(-1)
next_obs = torch.stack(next_obs, -1)
assert (next_obs == true_next_obs).all()


@pytest.mark.parametrize("unsq_reward", [True, False])
def test_unusual_done(unsq_reward):
batch_size = [10, 3]
T = 10
obs_dim = [
1,
]
last_done = True
device = torch.device("cpu")
n_steps = 3

obs = torch.randn(*batch_size, T + 1, 5, *obs_dim)
reward = torch.rand(*batch_size, T, 5)
action = torch.rand(*batch_size, T, 5)
done = torch.zeros(*batch_size, T + 1, 5, dtype=torch.bool)
done[..., T // 2, :] = 1
if last_done:
done[..., -1, :] = 1
if unsq_reward:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)

td = TensorDict(
{
"obs": obs[..., :-1, :] if not obs_dim else obs[..., :-1, :, :],
"action": action,
"done": done[..., :-1, :] if not unsq_reward else done[..., :-1, :, :],
"next": {
"obs": obs[..., 1:, :] if not obs_dim else obs[..., 1:, :, :],
"done": done[..., 1:, :] if not unsq_reward else done[..., 1:, :, :],
"reward": reward,
},
},
batch_size=[*batch_size, T],
device=device,
)
ms = MultiStep(0.98, n_steps)
if unsq_reward:
with pytest.raises(RuntimeError, match="tensordict shape must be compatible"):
_ = ms(td)
else:
# we just check that it runs
_ = ms(td)


class TestSplits:
Expand Down
Loading

0 comments on commit 2de55cb

Please sign in to comment.