Skip to content

Commit

Permalink
[BugFix] Fix update in serial / parallel env (pytorch#1866)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 5, 2024
1 parent 80fc87f commit 19a920e
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 178 deletions.
19 changes: 13 additions & 6 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
self.count += action.to(torch.int).to(self.device)
self.count += action.to(dtype=torch.int, device=self.device)
tensordict = TensorDict(
source={
"observation": self.count.clone(),
Expand Down Expand Up @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
3,
)
),
device=self.device,
)

self.unbatched_action_spec = CompositeSpec(
lazy=action_specs,
device=self.device,
)
self.unbatched_reward_spec = CompositeSpec(
{
Expand All @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)
self.unbatched_done_spec = CompositeSpec(
{
Expand All @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)

self.action_spec = self.unbatched_action_spec.expand(
Expand Down Expand Up @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_3d,
"tensor_0": tensor_0,
}
},
device=self.device,
)
elif i == 1:
return CompositeSpec(
Expand All @@ -1497,15 +1502,17 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_2d,
"tensor_1": tensor_1,
}
},
device=self.device,
)
elif i == 2:
return CompositeSpec(
{
"camera": camera,
"vector": vector_2d,
"tensor_2": tensor_2,
}
},
device=self.device,
)
else:
raise ValueError(f"Index {i} undefined for index 3")
Expand Down
18 changes: 13 additions & 5 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,8 +1675,12 @@ def test_maxframes_error():
@pytest.mark.parametrize("policy_device", [None, *get_available_devices()])
@pytest.mark.parametrize("env_device", [None, *get_available_devices()])
@pytest.mark.parametrize("storing_device", [None, *get_available_devices()])
@pytest.mark.parametrize("parallel", [False, True])
def test_reset_heterogeneous_envs(
policy_device: torch.device, env_device: torch.device, storing_device: torch.device
policy_device: torch.device,
env_device: torch.device,
storing_device: torch.device,
parallel,
):
if (
policy_device is not None
Expand All @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs(
env_device = torch.device("cpu") # explicit mapping
elif env_device is not None and env_device.type == "cuda" and policy_device is None:
policy_device = torch.device("cpu")
env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3))
env = SerialEnv(2, [env1, env2], device=env_device)
env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3))
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand All @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs(
assert (
data[0]["next", "truncated"].squeeze()
== torch.tensor([False, True], device=data_device).repeat(25)[:50]
).all(), data[0]["next", "truncated"][:10]
).all(), data[0]["next", "truncated"]
assert (
data[1]["next", "truncated"].squeeze()
== torch.tensor([False, False, True], device=data_device).repeat(17)[:50]
Expand Down
7 changes: 5 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):

@pytest.mark.parametrize("batch_size", [(1, 2)])
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
@pytest.mark.parametrize("break_when_any_done", [False, True])
def test_vec_env(
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
):
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
if env_type == "serial":
vec_env = SerialEnv(n_workers, env_fun)
Expand All @@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
rollout_steps,
policy=policy,
return_contiguous=False,
break_when_any_done=False,
break_when_any_done=break_when_any_done,
)
td = dense_stack_tds(td)
for i in range(env_fun().n_nested_dim):
Expand Down
35 changes: 31 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
Expand Down Expand Up @@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_lstm_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based):
device=device,
python_based=python_based,
)
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
Expand All @@ -1807,7 +1815,12 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down Expand Up @@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_gru_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
gru_module = GRUModule(
Expand All @@ -2134,7 +2151,17 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]

env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase:

if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=False)
self._shuttle.to(self.storing_device, non_blocking=True)
)
else:
tensordicts.append(self._shuttle)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
# to be deprecated in v0.4
def map_device(tensor):
if tensor.device != self.device:
return tensor.to(self.device, non_blocking=False)
return tensor.to(self.device, non_blocking=True)
return tensor

if is_tensor_collection(result):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def get_dataloader(
)
out = TensorDictReplayBuffer(
storage=TensorStorage(data),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=batch_size,
prefetch=prefetch,
Expand Down
Loading

0 comments on commit 19a920e

Please sign in to comment.