Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow usage of a different device on main and sub-envs in ParallelEnv and SerialEnv #1626

Merged
merged 71 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
1f539dd
init
vmoens Oct 6, 2023
d4c16e1
amend
vmoens Oct 6, 2023
d2321aa
amend
vmoens Oct 6, 2023
565115a
amend
vmoens Oct 6, 2023
3c46136
amend
vmoens Oct 6, 2023
78cfa41
amend
vmoens Oct 6, 2023
a6bd8eb
amend
vmoens Oct 6, 2023
04d4ae7
amend
vmoens Oct 6, 2023
3e31963
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 10, 2023
f1b0ea4
tensordict_
vmoens Oct 10, 2023
16b3538
amend rollout logic
vmoens Oct 10, 2023
7ad0864
amend
vmoens Oct 10, 2023
bcac398
amend
vmoens Oct 10, 2023
428f8ee
inference
vmoens Oct 10, 2023
6fbd0bd
cpu -> cuda
vmoens Oct 10, 2023
02db623
checks
vmoens Oct 10, 2023
08a8f47
using pipe instead of event
vmoens Oct 10, 2023
45e64f7
amend
vmoens Oct 10, 2023
7dd4821
amend
vmoens Oct 10, 2023
e1a2206
rm cuda event
vmoens Oct 10, 2023
dc2caab
amend
vmoens Oct 10, 2023
01ffbf9
amend
vmoens Oct 10, 2023
ac76ec3
amend
vmoens Oct 10, 2023
ceab010
amend
vmoens Oct 10, 2023
f0327c9
amend
vmoens Oct 10, 2023
518b3d1
amend
vmoens Oct 10, 2023
47dd93b
amend
vmoens Oct 10, 2023
354fb6f
amend
vmoens Oct 10, 2023
53d5f9a
amend
vmoens Oct 10, 2023
78c00e8
amend
vmoens Oct 10, 2023
f63480e
amend
vmoens Oct 10, 2023
9a3631f
amend
vmoens Oct 10, 2023
2ceb438
amend
vmoens Oct 10, 2023
6ecebda
amend
vmoens Oct 10, 2023
5c613c3
amend
vmoens Oct 10, 2023
9f97e58
amend
vmoens Oct 10, 2023
9cbcbb0
amend
vmoens Oct 10, 2023
72c4163
amend
vmoens Oct 10, 2023
6f4c374
amend
vmoens Oct 10, 2023
3657b41
empty
vmoens Oct 11, 2023
9206b93
amend
vmoens Oct 11, 2023
897123f
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 10, 2023
2bac78c
amend
vmoens Nov 10, 2023
51a8856
amend
vmoens Nov 10, 2023
7a300f5
amend
vmoens Nov 10, 2023
b25921d
amend
vmoens Nov 10, 2023
d42348d
amend
vmoens Nov 10, 2023
f3421aa
amend
vmoens Nov 10, 2023
939ece4
amend
vmoens Nov 10, 2023
082ba9a
amend
vmoens Nov 10, 2023
4fd670f
amend
vmoens Nov 10, 2023
2a773f3
amend
vmoens Nov 10, 2023
e7cb5dd
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
ea88bb4
Merge remote-tracking branch 'origin/main' into parallel_cuda_refactor
vmoens Nov 27, 2023
7c04f62
amend
vmoens Nov 27, 2023
b422775
amend
vmoens Nov 27, 2023
ed0287d
amend
vmoens Nov 27, 2023
05b0c08
amend
vmoens Nov 27, 2023
33de0fc
amend
vmoens Nov 28, 2023
492a884
amend
vmoens Nov 28, 2023
ff4799d
amend
vmoens Nov 28, 2023
f25b957
amend
vmoens Nov 28, 2023
8899dbd
amend
vmoens Nov 28, 2023
65c9deb
amend
vmoens Nov 28, 2023
77c2d6b
amend
vmoens Nov 28, 2023
7928744
amend
vmoens Nov 28, 2023
e7fda36
amend
vmoens Nov 29, 2023
fb9a03a
amend
vmoens Nov 29, 2023
d73ca22
amend
vmoens Nov 29, 2023
c4d4c6b
amend
vmoens Nov 29, 2023
b2840b0
doc
vmoens Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Nov 28, 2023
commit 8899dbd31701463d67521a677fa73042018acfb1
10 changes: 7 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,20 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
assert env.device.type == torch.device(pdevice).type
assert r.device.type == torch.device(pdevice).type
assert all(
item.device.type == torch.device(pdevice).type for item in r.values(True, True)
item.device.type == torch.device(pdevice).type
for item in r.values(True, True)
)
else:
assert env.device.type == torch.device(edevice).type
assert r.device.type == torch.device(edevice).type
assert all(
item.device.type == torch.device(edevice).type for item in r.values(True, True)
item.device.type == torch.device(edevice).type
for item in r.values(True, True)
)
if parallel:
assert env.shared_tensordict_parent.device.type == torch.device(edevice).type
assert (
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
Expand Down
210 changes: 17 additions & 193 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,20 +744,11 @@ def _start_workers(self) -> None:

self.parent_channels = []
self._workers = []
if self.shared_tensordict_parent.device.type == "cuda":
func = _run_worker_pipe_cuda
self._cuda_stream = torch.cuda.Stream(self.shared_tensordict_parent.device)
self._cuda_events = [
torch.cuda.Event(interprocess=True) for _ in range(_num_workers)
]
self._events = None
kwargs = [{"cuda_event": self._cuda_events[i]} for i in range(_num_workers)]
else:
func = _run_worker_pipe_shared_mem
self._cuda_stream = None
self._cuda_events = None
self._events = [ctx.Event() for _ in range(_num_workers)]
kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)]
func = _run_worker_pipe_shared_mem
self._cuda_stream = None
self._cuda_events = None
self._events = [ctx.Event() for _ in range(_num_workers)]
kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)]
with clear_mpi_env_vars():
for idx in range(_num_workers):
if self._verbose:
Expand Down Expand Up @@ -860,17 +851,10 @@ def step_and_maybe_reset(
for i in range(self.num_workers):
self.parent_channels[i].send(("step_and_maybe_reset", None))

if self._events is not None:
# CPU case
for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()
else:
# CUDA case
for i in range(self.num_workers):
event = self._cuda_events[i]
self._cuda_stream.wait_event(event)
for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand Down Expand Up @@ -915,17 +899,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for i in range(self.num_workers):
self.parent_channels[i].send(("step", None))

if self._events is not None:
# CPU case
for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()
else:
# CUDA case
for i in range(self.num_workers):
event = self._cuda_events[i]
self._cuda_stream.wait_event(event)
for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand Down Expand Up @@ -993,17 +970,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
channel.send(out)
workers.append(i)

if self._events is not None:
# CPU case
for i in workers:
event = self._events[i]
event.wait()
event.clear()
else:
# CUDA case
for i in workers:
event = self._cuda_events[i]
self._cuda_stream.wait_event(event)
for i in workers:
event = self._events[i]
event.wait()
event.clear()

selected_output_keys = self._selected_reset_keys_filt
device = self.device
Expand Down Expand Up @@ -1153,7 +1123,6 @@ def _run_worker_pipe_shared_mem(
"env_fun_kwargs must be empty if an environment is passed to a process."
)
env = env_fun
env_device = env.device

i = -1
initialized = False
Expand Down Expand Up @@ -1275,151 +1244,6 @@ def _run_worker_pipe_shared_mem(
child_pipe.send(("_".join([cmd, "done"]), None))


def _run_worker_pipe_cuda(
parent_pipe: connection.Connection,
child_pipe: connection.Connection,
env_fun: Union[EnvBase, Callable],
env_fun_kwargs: Dict[str, Any],
cuda_event: torch.cuda.Event = None,
shared_tensordict: TensorDictBase = None,
_selected_input_keys=None,
_selected_reset_keys=None,
_selected_step_keys=None,
has_lazy_inputs: bool = False,
verbose: bool = False,
) -> None:
parent_pipe.close()
pid = os.getpid()
if not isinstance(env_fun, EnvBase):
env = env_fun(**env_fun_kwargs)
else:
if env_fun_kwargs:
raise RuntimeError(
"env_fun_kwargs must be empty if an environment is passed to a process."
)
env = env_fun
del env_fun
env_device = env.device

stream = torch.cuda.Stream(env_device)
with torch.cuda.StreamContext(stream):
# we check if the devices mismatch. This tells us that the data need
# to be cast onto the right device before any op
env_device_cpu = env_device.type == "cpu"
i = -1
initialized = False

child_pipe.send("started")

while True:
try:
cmd, data = child_pipe.recv()
except EOFError as err:
raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err
if cmd == "seed":
if not initialized:
raise RuntimeError("call 'init' before closing")
# torch.manual_seed(data)
# np.random.seed(data)
new_seed = env.set_seed(data[0], static_seed=data[1])
child_pipe.send(("seeded", new_seed))

elif cmd == "init":
if verbose:
print(f"initializing {pid}")
if initialized:
raise RuntimeError("worker already initialized")
i = 0
next_shared_tensordict = shared_tensordict.get("next")
shared_tensordict = shared_tensordict.clone(False)
del shared_tensordict["next"]

if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
raise RuntimeError(
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
)
initialized = True

elif cmd == "reset":
if verbose:
print(f"resetting worker {pid}")
if not initialized:
raise RuntimeError("call 'init' before resetting")
cur_td = env._reset(tensordict=data)
shared_tensordict.update_(cur_td)
stream.record_event(cuda_event)
stream.synchronize()

elif cmd == "step":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
env_input = shared_tensordict
next_td = env._step(env_input)
next_shared_tensordict.update_(next_td)
stream.record_event(cuda_event)
stream.synchronize()

elif cmd == "step_and_maybe_reset":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
env_input = shared_tensordict
td, root_next_td = env.step_and_maybe_reset(env_input)
next_shared_tensordict.update_(td.get("next"))
shared_tensordict.update_(root_next_td)
stream.record_event(cuda_event)
stream.synchronize()

elif cmd == "close":
del shared_tensordict, data
if not initialized:
raise RuntimeError("call 'init' before closing")
env.close()
del env
stream.record_event(cuda_event)
stream.synchronize()
child_pipe.close()
if verbose:
print(f"{pid} closed")
break

elif cmd == "load_state_dict":
env.load_state_dict(data)
stream.record_event(cuda_event)
stream.synchronize()

elif cmd == "state_dict":
state_dict = _recursively_strip_locks_from_state_dict(env.state_dict())
msg = "state_dict"
child_pipe.send((msg, state_dict))

else:
err_msg = f"{cmd} from env"
try:
attr = getattr(env, cmd)
if callable(attr):
args, kwargs = data
args_replace = []
for _arg in args:
if isinstance(_arg, str) and _arg == "_self":
continue
else:
args_replace.append(_arg)
result = attr(*args_replace, **kwargs)
else:
result = attr
except Exception as err:
raise AttributeError(
f"querying {err_msg} resulted in an error."
) from err
if cmd not in ("to"):
child_pipe.send(("_".join([cmd, "done"]), result))
else:
# don't send env through pipe
child_pipe.send(("_".join([cmd, "done"]), None))


def _update_cuda(t_dest, t_source):
if t_source is None:
return
Expand Down
Loading