Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Adaptable non-blocking for mps and non cuda device in batched-envs #1900

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def __init__(
reset_when_done: bool = True,
interruptor=None,
):
from torchrl.envs.batched_envs import _BatchedEnv
from torchrl.envs.batched_envs import BatchedEnvBase

self.closed = True

Expand All @@ -591,7 +591,7 @@ def __init__(
else:
env = create_env_fn
if create_env_kwargs:
if not isinstance(env, _BatchedEnv):
if not isinstance(env, BatchedEnvBase):
raise RuntimeError(
"kwargs were passed to SyncDataCollector but they can't be set "
f"on environment of type {type(create_env_fn)}."
Expand Down Expand Up @@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict:
`"env_state_dict"`.

"""
from torchrl.envs.batched_envs import _BatchedEnv
from torchrl.envs.batched_envs import BatchedEnvBase

if isinstance(self.env, TransformedEnv):
env_state_dict = self.env.transform.state_dict()
elif isinstance(self.env, _BatchedEnv):
elif isinstance(self.env, BatchedEnvBase):
env_state_dict = self.env.state_dict()
else:
env_state_dict = OrderedDict()
Expand Down
64 changes: 40 additions & 24 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


def _check_start(fun):
def decorated_fun(self: _BatchedEnv, *args, **kwargs):
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
if self.is_closed:
self._create_td()
self._start_workers()
Expand Down Expand Up @@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs):
return super().__call__(*args, **kwargs)


class _BatchedEnv(EnvBase):
class BatchedEnvBase(EnvBase):
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.

Those queries will return a list of length equal to the number of workers containing the
Expand Down Expand Up @@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase):
serial_for_single (bool, optional): if ``True``, creating a parallel environment
with a single worker will return a :class:`~SerialEnv` instead.
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
non_blocking (bool, optional): if ``True``, device moves will be done using the
``non_blocking=True`` option. Defaults to ``True`` for batched environments
on cuda devices, and ``False`` otherwise.

Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
Expand All @@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase):
>>> env = ParallelEnv(2, [
... lambda: DMControlEnv("humanoid", "stand"),
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
>>> r = env.rollout(10) # executes 10 random steps in the environment
>>> r[0] # data for Humanoid stand
>>> rollout = env.rollout(10) # executes 10 random steps in the environment
>>> rollout[0] # data for Humanoid stand
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
Expand Down Expand Up @@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase):
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> r[1] # data for Humanoid walk
>>> rollout[1] # data for Humanoid walk
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
Expand Down Expand Up @@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase):
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> # serial_for_single to avoid creating parallel envs if not necessary
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
"""
Expand Down Expand Up @@ -270,6 +274,7 @@ def __init__(
num_threads: int = None,
num_sub_threads: int = 1,
serial_for_single: bool = False,
non_blocking: bool = False,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -327,6 +332,15 @@ def __init__(
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
self._properties_set = False
self._get_metadata(create_env_fn, create_env_kwargs)
self._non_blocking = non_blocking

@property
def non_blocking(self):
nb = self._non_blocking
if nb is None:
nb = self.device is not None and self.device.type == "cuda"
self._non_blocking = nb
return nb

def _get_metadata(
self, create_env_fn: List[Callable], create_env_kwargs: List[Dict]
Expand Down Expand Up @@ -654,6 +668,7 @@ def start(self) -> None:
self._start_workers()

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
if device == self.device:
return self
Expand All @@ -675,10 +690,10 @@ def to(self, device: DEVICE_TYPING):
return self


class SerialEnv(_BatchedEnv):
class SerialEnv(BatchedEnvBase):
"""Creates a series of environments in the same process."""

__doc__ += _BatchedEnv.__doc__
__doc__ += BatchedEnvBase.__doc__

_share_memory = False

Expand Down Expand Up @@ -769,7 +784,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
else:
env_device = _env.device
if env_device != self.device and env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
tensordict_ = tensordict_.to(
env_device, non_blocking=self.non_blocking
)
else:
tensordict_ = tensordict_.clone(False)
else:
Expand Down Expand Up @@ -798,7 +815,7 @@ def select_and_clone(name, tensor):
if device is None:
out = out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

def _reset_proc_data(self, tensordict, tensordict_reset):
Expand All @@ -819,7 +836,9 @@ def _step(
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
env_device = self._envs[i].device
if env_device != self.device and env_device is not None:
data_in = tensordict_in[i].to(env_device, non_blocking=True)
data_in = tensordict_in[i].to(
env_device, non_blocking=self.non_blocking
)
else:
data_in = tensordict_in[i]
out_td = self._envs[i]._step(data_in)
Expand All @@ -839,7 +858,7 @@ def select_and_clone(name, tensor):
if device is None:
out = out.clear_device_()
elif out.device != device:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

def __getattr__(self, attr: str) -> Any:
Expand Down Expand Up @@ -885,14 +904,14 @@ def to(self, device: DEVICE_TYPING):
return self


class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
"""Creates one environment per process.

TensorDicts are passed via shared memory or memory map.

"""

__doc__ += _BatchedEnv.__doc__
__doc__ += BatchedEnvBase.__doc__
__doc__ += """

.. warning::
Expand Down Expand Up @@ -1167,14 +1186,14 @@ def step_and_maybe_reset(
tensordict_ = tensordict_.clone()
elif device is not None:
next_td = next_td._fast_apply(
lambda x: x.to(device, non_blocking=True)
lambda x: x.to(device, non_blocking=self.non_blocking)
if x.device != device
else x.clone(),
device=device,
filter_empty=True,
)
tensordict_ = tensordict_._fast_apply(
lambda x: x.to(device, non_blocking=True)
lambda x: x.to(device, non_blocking=self.non_blocking)
if x.device != device
else x.clone(),
device=device,
Expand Down Expand Up @@ -1239,7 +1258,7 @@ def select_and_clone(name, tensor):
if device is None:
out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

@_check_start
Expand Down Expand Up @@ -1325,7 +1344,7 @@ def select_and_clone(name, tensor):
if device is None:
out.clear_device_()
else:
out = out.to(device, non_blocking=True)
out = out.to(device, non_blocking=self.non_blocking)
return out

@_check_start
Expand Down Expand Up @@ -1644,12 +1663,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
child_pipe.send(("_".join([cmd, "done"]), None))


def _update_cuda(t_dest, t_source):
if t_source is None:
return
t_dest.copy_(t_source.pin_memory(), non_blocking=True)
return


def _filter_empty(tensordict):
return tensordict.select(*tensordict.keys(True, True))


# Create an alias for possible imports
_BatchedEnv = BatchedEnvBase
Loading