Skip to content

Commit

Permalink
[BugFix] Fix parallel env data passing on cuda (pytorch#1024)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 18, 2023
1 parent 55976e4 commit cbb0c2f
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 19 deletions.
18 changes: 13 additions & 5 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,14 +907,17 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
1,
),
dtype=torch.int32,
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
(
*self.batch_size,
1,
)
),
device=self.device,
)
self.done_spec = DiscreteTensorSpec(
2,
Expand All @@ -923,14 +926,19 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
*self.batch_size,
1,
),
device=self.device,
)
self.input_spec = CompositeSpec(
action=BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]),
action=BinaryDiscreteTensorSpec(
n=1, shape=[*self.batch_size, 1], device=self.device
),
shape=self.batch_size,
device=self.device,
)

self.count = torch.zeros(
(*self.batch_size, 1), device=self.device, dtype=torch.int
self.register_buffer(
"count",
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
)

def _set_seed(self, seed: Optional[int]):
Expand All @@ -956,7 +964,7 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get("action")
self.count += action.to(torch.int)
self.count += action.to(torch.int).to(self.device)
tensordict = TensorDict(
source={
"observation": self.count.clone(),
Expand Down
151 changes: 149 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import torch
import yaml

from _utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
Expand All @@ -32,17 +33,19 @@
from packaging import version
from tensordict.tensordict import assert_allclose_td, TensorDict
from torch import nn

from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.data.tensor_specs import (
CompositeSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.utils import check_env_specs, make_composite_from_td, step_mdp
from torchrl.envs.vec_env import ParallelEnv, SerialEnv
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
from torchrl.modules.tensordict_module import WorldModelWrapper

Expand Down Expand Up @@ -1133,6 +1136,150 @@ def test_make_spec_from_td():
assert val.dtype is spec[key].dtype


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No cuda device")
class TestConcurrentEnvs:
"""Concurrent parallel envs on multiple procs can interfere."""

class Policy(nn.Module):
in_keys = []
out_keys = ["action"]

def __init__(self, spec):
super().__init__()
self.spec = spec

def forward(self, tensordict):
tensordict.set("action", self.spec["action"].zero() + 1)
return tensordict

@staticmethod
def main_penv(j, q=None):
device = "cpu" if not torch.cuda.device_count() else "cuda:0"
n_workers = 1
env_p = ParallelEnv(
n_workers,
[
lambda i=i: CountingEnv(i, device=device)
for i in range(j, j + n_workers)
],
)
env_s = SerialEnv(
n_workers,
[
lambda i=i: CountingEnv(i, device=device)
for i in range(j, j + n_workers)
],
)
spec = env_p.action_spec
policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec.to(device)))
N = 10
r_p = []
r_s = []
for _ in range(N):
with torch.no_grad():
r_p.append(env_s.rollout(100, break_when_any_done=False, policy=policy))
r_s.append(env_p.rollout(100, break_when_any_done=False, policy=policy))

if (torch.stack(r_p).contiguous() == torch.stack(r_s).contiguous()).all():
if q is not None:
q.put("passed")
else:
pass
else:
if q is not None:
q.put("failed")
else:
raise RuntimeError()

@staticmethod
def main_collector(j, q=None):
device = "cpu" if not torch.cuda.device_count() else "cuda:0"
N = 10
n_workers = 1
make_envs = [
lambda i=i: CountingEnv(i, device=device) for i in range(j, j + n_workers)
]
spec = make_envs[0]().action_spec
policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec))
collector = MultiSyncDataCollector(
make_envs,
policy,
frames_per_batch=n_workers * 100,
total_frames=N * n_workers * 100,
)
single_collectors = [
SyncDataCollector(
make_envs[i](),
policy,
frames_per_batch=n_workers * 100,
total_frames=N * n_workers * 100,
)
for i in range(n_workers)
]
collector = iter(collector)
single_collectors = [iter(sc) for sc in single_collectors]

r_p = []
r_s = []
for _ in range(N):
with torch.no_grad():
r_p.append(next(collector))
r_s.append(torch.cat([next(sc) for sc in single_collectors]))

if (torch.stack(r_p).contiguous() == torch.stack(r_s).contiguous()).all():
if q is not None:
q.put("passed")
else:
pass
else:
if q is not None:
q.put("failed")
else:
raise RuntimeError()

@pytest.mark.parametrize("nproc", [1, 3])
def test_mp_concurrent(self, nproc):
if nproc == 1:
self.main_penv(3)
else:
from torch import multiprocessing as mp

q = mp.Queue(3)
ps = []
try:
for k in range(3, 10, 3):
p = mp.Process(target=type(self).main_penv, args=(k, q))
ps.append(p)
p.start()
for _ in range(3):
msg = q.get(timeout=100)
assert msg == "passed"
finally:
for p in ps:
p.join()

@pytest.mark.parametrize("nproc", [1, 3])
def test_mp_collector(self, nproc):
if nproc == 1:
self.main_collector(3)
else:
from torch import multiprocessing as mp

q = mp.Queue(3)
ps = []
try:
for k in range(3, 10, 3):
p = mp.Process(target=type(self).main_collector, args=(k, q))
ps.append(p)
p.start()
for _ in range(3):
msg = q.get(timeout=100)
assert msg == "passed"
finally:
for p in ps:
p.join(timeout=2)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,8 @@ def rollout(
"""
try:
policy_device = next(policy.parameters()).device
except AttributeError:
policy_device = "cpu"
except (StopIteration, AttributeError):
policy_device = self.device

env_device = self.device

Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from torchrl.envs.utils import _selective_unsqueeze

IMPORT_ERR = None
try:
import vmas

_has_vmas = True

except ImportError as err:
_has_vmas = False
IMPORT_ERR = str(err)
IMPORT_ERR = err

__all__ = ["VmasWrapper", "VmasEnv"]

Expand Down
37 changes: 28 additions & 9 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

# keys = set()
for i in range(self.num_workers):
msg, _ = self.parent_channels[i].recv()
msg, data = self.parent_channels[i].recv()
if msg != "step_result":
raise RuntimeError(
f"Expected 'step_result' but received {msg} from worker {i}"
)
if data is not None:
self.shared_tensordicts[i].update_(data)
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
return self.shared_tensordict_parent.select(*self._selected_step_keys).clone()
Expand Down Expand Up @@ -824,9 +826,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
for i, channel in enumerate(self.parent_channels):
if not _reset[i].any():
continue
cmd_in, _ = channel.recv()
cmd_in, data = channel.recv()
if cmd_in != "reset_obs":
raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs")
if data is not None:
self.shared_tensordicts[i].update_(data)

return self.shared_tensordict_parent.select(*self._selected_reset_keys).clone()

def __reduce__(self):
Expand Down Expand Up @@ -913,16 +918,16 @@ def _run_worker_pipe_shared_mem(
"env_fun_kwargs must be empty if an environment is passed to a process."
)
env = env_fun
is_cuda = torch.device(device).type == "cuda"
env = env.to(device)

i = -1
initialized = False

# make sure that process can be closed
tensordict = None
_td = None

reset_keys = None

while True:
try:
cmd, data = child_pipe.recv()
Expand Down Expand Up @@ -962,8 +967,18 @@ def _run_worker_pipe_shared_mem(
_td.del_("_reset")
if pin_memory:
_td.pin_memory()
tensordict.update_(_td.select(*tensordict.keys(True, True), strict=False))
child_pipe.send(("reset_obs", reset_keys))
if not is_cuda:
tensordict.update_(
_td.select(*tensordict.keys(True, True), strict=False)
)
child_pipe.send(("reset_obs", None))
else:
child_pipe.send(
(
"reset_obs",
_td.select(*tensordict.keys(True, True), strict=False),
)
)

elif cmd == "step":
if not initialized:
Expand All @@ -978,10 +993,14 @@ def _run_worker_pipe_shared_mem(
_td = env._step(_td)
if pin_memory:
_td.pin_memory()
tensordict.update_(_td.select("next"))
msg = "step_result"
data = (msg, None)
child_pipe.send(data)
if not is_cuda:
tensordict.update_(_td.select("next"))
data = (msg, None)
child_pipe.send(data)
else:
data = (msg, _td.select("next"))
child_pipe.send(data)

elif cmd == "close":
del tensordict, _td, data
Expand Down

0 comments on commit cbb0c2f

Please sign in to comment.