Skip to content

Commit

Permalink
[BugFix] Fix size-match unsqueeze deprecation (pytorch#750)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 23, 2022
1 parent c2b9f9c commit 427c89d
Show file tree
Hide file tree
Showing 38 changed files with 410 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ python -c "import functorch"
pip install git+https://github.com/pytorch/torchsnapshot

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

printf "* Installing torchrl\n"
python setup.py develop
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_examples/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ python -c "import functorch"
pip install git+https://github.com/pytorch/torchsnapshot

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

printf "* Installing torchrl\n"
python setup.py develop
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_brax/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import functorch;import tensordict"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_gym/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import tensordict"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_habitat/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import functorch;import tensordict"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_jumanji/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import functorch;import tensordict"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import tensordict"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_optdeps/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import functorch"
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_stable/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ else
fi

# install tensordict
pip install git+https://github.com/pytorch-labs/tensordict
pip install git+https://github.com/pytorch-labs/tensordict.git

# smoke test
python -c "import torch;import functorch"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ It is also possible to reset some but not all of the environments:
fields={
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
reset_workers: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
reset_workers: Tensor(torch.Size([4]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=None,
is_shared=True)
Expand Down
60 changes: 58 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def make_env():
)
for _data in collector:
continue
steps = _data["step_count"][..., 1:, :]
done = _data["done"][..., :-1, :]
steps = _data["step_count"][..., 1:]
done = _data["done"][..., :-1, :].squeeze(-1)
# we don't want just one done
assert done.sum() > 3
# check that after a done, the next step count is always 1
Expand Down Expand Up @@ -370,6 +370,62 @@ def make_env(seed):
del collector


@pytest.mark.parametrize("frames_per_batch", [200, 10])
@pytest.mark.parametrize("num_env", [1, 3])
@pytest.mark.parametrize("env_name", ["vec"])
def test_split_trajs(num_env, env_name, frames_per_batch, seed=5):
if num_env == 1:

def env_fn(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

else:

def env_fn(seed):
def make_env(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

env = SerialEnv(
num_workers=num_env,
create_env_fn=make_env,
create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)],
allow_step_when_done=True,
)
env.set_seed(seed)
return env

policy = make_policy(env_name)

collector = SyncDataCollector(
create_env_fn=env_fn,
create_env_kwargs={"seed": seed},
policy=policy,
frames_per_batch=frames_per_batch * num_env,
max_frames_per_traj=2000,
total_frames=20000,
device="cpu",
pin_memory=False,
reset_when_done=True,
split_trajs=True,
)
for _, d in enumerate(collector): # noqa
break

assert d.ndimension() == 2
assert d["mask"].shape == d.shape
assert d["step_count"].shape == d.shape
assert d["traj_ids"].shape == d.shape
for traj in d.unbind(0):
assert traj["traj_ids"].unique().numel() == 1
assert (traj["step_count"][1:] - traj["step_count"][:-1] == 1).all()

del collector


# TODO: design a test that ensures that collectors are interrupted even if __del__ is not called
# @pytest.mark.parametrize("should_shutdown", [True, False])
# def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40):
Expand Down
100 changes: 56 additions & 44 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict.tensordict import assert_allclose_td, TensorDict
from tensordict.utils import expand_as_right
from torch import autograd, nn
from torchrl.data import (
CompositeSpec,
Expand Down Expand Up @@ -253,20 +252,22 @@ def _create_seq_mock_data_dqn(
if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
# action_value = action_value.unsqueeze(-1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"action_value": action_value
* expand_as_right(mask.to(obs.dtype).squeeze(-1), action_value),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
)
return td
Expand Down Expand Up @@ -488,16 +489,18 @@ def _create_seq_mock_data_ddpg(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -726,16 +729,18 @@ def _create_seq_mock_data_sac(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1129,16 +1134,18 @@ def _create_seq_mock_data_redq(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1543,7 +1550,7 @@ def _create_mock_data_ppo(
"done": done,
"reward": reward,
"action": action,
"sample_log_prob": torch.randn_like(action[..., :1]) / 10,
"sample_log_prob": torch.randn_like(action[..., 1]) / 10,
},
device=device,
)
Expand All @@ -1564,23 +1571,25 @@ def _create_seq_mock_data_ppo(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"sample_log_prob": torch.randn_like(action[..., :1])
/ 10
* mask.to(obs.dtype),
"loc": params_mean * mask.to(obs.dtype),
"scale": params_scale * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"sample_log_prob": (torch.randn_like(action[..., 1]) / 10).masked_fill_(
~mask, 0.0
),
"loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
"scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down Expand Up @@ -1835,23 +1844,26 @@ def _create_seq_mock_data_a2c(
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs * mask.to(obs.dtype),
"next": {"observation": next_obs * mask.to(obs.dtype)},
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0)
},
"done": done,
"mask": mask,
"reward": reward * mask.to(obs.dtype),
"action": action * mask.to(obs.dtype),
"sample_log_prob": torch.randn_like(action[..., :1])
/ 10
* mask.to(obs.dtype),
"loc": params_mean * mask.to(obs.dtype),
"scale": params_scale * mask.to(obs.dtype),
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_(
~mask, 0.0
)
/ 10,
"loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
"scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
)
Expand Down
8 changes: 6 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ def test_multitask(self):
env2 = DMControlEnv("humanoid", "walk")
env2_obs_keys = list(env2.observation_spec.keys())

assert len(env1_obs_keys)
assert len(env2_obs_keys)

def env1_maker():
return TransformedEnv(
DMControlEnv("humanoid", "stand"),
Expand Down Expand Up @@ -449,6 +452,7 @@ def env2_maker():
)

env = ParallelEnv(2, [env1_maker, env2_maker])
# env = SerialEnv(2, [env1_maker, env2_maker])
assert not env._single_task

td = env.rollout(10, return_contiguous=False)
Expand Down Expand Up @@ -497,7 +501,7 @@ def test_parallel_env(
td1 = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()},
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down Expand Up @@ -581,7 +585,7 @@ def test_parallel_env_with_policy(
td1 = env_parallel.step(td)

td_reset = TensorDict(
source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()},
source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()},
batch_size=[
N,
],
Expand Down
Loading

0 comments on commit 427c89d

Please sign in to comment.