Skip to content

Commit

Permalink
[BugFix] Adaptable non-blocking for mps and non cuda device in batche…
Browse files Browse the repository at this point in the history
…d-envs (pytorch#1900)
  • Loading branch information
vmoens authored Feb 12, 2024
1 parent 1647fa4 commit 6f6c896
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
8 changes: 4 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def __init__(
reset_when_done: bool = True,
interruptor=None,
):
from torchrl.envs.batched_envs import _BatchedEnv
from torchrl.envs.batched_envs import BatchedEnvBase

self.closed = True

Expand All @@ -591,7 +591,7 @@ def __init__(
else:
env = create_env_fn
if create_env_kwargs:
if not isinstance(env, _BatchedEnv):
if not isinstance(env, BatchedEnvBase):
raise RuntimeError(
"kwargs were passed to SyncDataCollector but they can't be set "
f"on environment of type {type(create_env_fn)}."
Expand Down Expand Up @@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict:
`"env_state_dict"`.
"""
from torchrl.envs.batched_envs import _BatchedEnv
from torchrl.envs.batched_envs import BatchedEnvBase

if isinstance(self.env, TransformedEnv):
env_state_dict = self.env.transform.state_dict()
elif isinstance(self.env, _BatchedEnv):
elif isinstance(self.env, BatchedEnvBase):
env_state_dict = self.env.state_dict()
else:
env_state_dict = OrderedDict()
Expand Down
64 changes: 40 additions & 24 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


def _check_start(fun):
def decorated_fun(self: _BatchedEnv, *args, **kwargs):
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
if self.is_closed:
self._create_td()
self._start_workers()
Expand Down Expand Up @@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs):
return super().__call__(*args, **kwargs)


class _BatchedEnv(EnvBase):
class BatchedEnvBase(EnvBase):
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
Those queries will return a list of length equal to the number of workers containing the
Expand Down Expand Up @@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase):
serial_for_single (bool, optional): if ``True``, creating a parallel environment
with a single worker will return a :class:`~SerialEnv` instead.
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
non_blocking (bool, optional): if ``True``, device moves will be done using the
``non_blocking=True`` option. Defaults to ``True`` for batched environments
on cuda devices, and ``False`` otherwise.
Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
Expand All @@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase):
>>> env = ParallelEnv(2, [
... lambda: DMControlEnv("humanoid", "stand"),
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
>>> r = env.rollout(10) # executes 10 random steps in the environment
>>> r[0] # data for Humanoid stand
>>> rollout = env.rollout(10) # executes 10 random steps in the environment
>>> rollout[0] # data for Humanoid stand
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
Expand Down Expand Up @@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase):
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> r[1] # data for Humanoid walk
>>> rollout[1] # data for Humanoid walk
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
Expand Down Expand Up @@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase):
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> # serial_for_single to avoid creating parallel envs if not necessary
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
"""
Expand Down Expand Up @@ -270,6 +274,7 @@ def __init__(
num_threads: int = None,
num_sub_threads: int = 1,
serial_for_single: bool = False,
non_blocking: bool = False,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -327,6 +332,15 @@ def __init__(
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
self._properties_set = False
self._get_metadata(create_env_fn, create_env_kwargs)
self._non_blocking = non_blocking

@property
def non_blocking(self):
nb = self._non_blocking
if nb is None:
nb = self.device is not None and self.device.type == "cuda"
self._non_blocking = nb
return nb

def _get_metadata(
self, create_env_fn: List[Callable], create_env_kwargs: List[Dict]
Expand Down Expand Up @@ -654,6 +668,7 @@ def start(self) -> None:
self._start_workers()

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
if device == self.device:
return self
Expand All @@ -675,10 +690,10 @@ def to(self, device: DEVICE_TYPING):
return self


class SerialEnv(_BatchedEnv):
class SerialEnv(BatchedEnvBase):
"""Creates a series of environments in the same process."""

__doc__ += _BatchedEnv.__doc__
__doc__ += BatchedEnvBase.__doc__

_share_memory = False

Expand Down Expand Up @@ -769,7 +784,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
else:
env_device = _env.device
if env_device != self.device and env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
tensordict_ = tensordict_.to(
env_device, non_blocking=self.non_blocking
)
else:
tensordict_ = tensordict_.clone(False)
else:
Expand Down Expand Up @@ -798,7 +815,7 @@ def select_and_clone(name, tensor):
if device is None:
out = out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

def _reset_proc_data(self, tensordict, tensordict_reset):
Expand All @@ -819,7 +836,9 @@ def _step(
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
env_device = self._envs[i].device
if env_device != self.device and env_device is not None:
data_in = tensordict_in[i].to(env_device, non_blocking=True)
data_in = tensordict_in[i].to(
env_device, non_blocking=self.non_blocking
)
else:
data_in = tensordict_in[i]
out_td = self._envs[i]._step(data_in)
Expand All @@ -839,7 +858,7 @@ def select_and_clone(name, tensor):
if device is None:
out = out.clear_device_()
elif out.device != device:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

def __getattr__(self, attr: str) -> Any:
Expand Down Expand Up @@ -885,14 +904,14 @@ def to(self, device: DEVICE_TYPING):
return self


class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
"""Creates one environment per process.
TensorDicts are passed via shared memory or memory map.
"""

__doc__ += _BatchedEnv.__doc__
__doc__ += BatchedEnvBase.__doc__
__doc__ += """
.. warning::
Expand Down Expand Up @@ -1167,14 +1186,14 @@ def step_and_maybe_reset(
tensordict_ = tensordict_.clone()
elif device is not None:
next_td = next_td._fast_apply(
lambda x: x.to(device, non_blocking=True)
lambda x: x.to(device, non_blocking=self.non_blocking)
if x.device != device
else x.clone(),
device=device,
filter_empty=True,
)
tensordict_ = tensordict_._fast_apply(
lambda x: x.to(device, non_blocking=True)
lambda x: x.to(device, non_blocking=self.non_blocking)
if x.device != device
else x.clone(),
device=device,
Expand Down Expand Up @@ -1239,7 +1258,7 @@ def select_and_clone(name, tensor):
if device is None:
out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

@_check_start
Expand Down Expand Up @@ -1325,7 +1344,7 @@ def select_and_clone(name, tensor):
if device is None:
out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

@_check_start
Expand Down Expand Up @@ -1644,12 +1663,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
child_pipe.send(("_".join([cmd, "done"]), None))


def _update_cuda(t_dest, t_source):
if t_source is None:
return
t_dest.copy_(t_source.pin_memory(), non_blocking=True)
return


def _filter_empty(tensordict):
return tensordict.select(*tensordict.keys(True, True))


# Create an alias for possible imports
_BatchedEnv = BatchedEnvBase

0 comments on commit 6f6c896

Please sign in to comment.