Skip to content

Commit

Permalink
[Feature] serial_for_single arg in batched envs (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 29, 2024
1 parent 9da61f2 commit 156a668
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 17 deletions.
4 changes: 3 additions & 1 deletion examples/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def make_base_env(

def make_parallel_env(env_name, num_envs, device, is_test=False):
env = ParallelEnv(
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
num_envs,
EnvCreator(lambda: make_base_env(env_name, device=device)),
serial_for_single=True,
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
Expand Down
2 changes: 2 additions & 0 deletions examples/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -89,6 +90,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
ParallelEnv(
eval_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 2 additions & 0 deletions examples/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def make_environment(cfg):
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -87,6 +88,7 @@ def make_environment(cfg):
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def make_env():
return make_base_env(env_cfg)

env = make_transformed_env(
ParallelEnv(num_envs, EnvCreator(make_env)),
ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True),
env_cfg,
obs_loc,
obs_std,
Expand Down
2 changes: 2 additions & 0 deletions examples/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def make_environment(cfg):
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -88,6 +89,7 @@ def make_environment(cfg):
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
6 changes: 5 additions & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def gym_make():
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(
num_workers,
make_env,
serial_for_single=True,
)
action_spec = make_env.action_spec

if args.worker_parallelism == "collector" and num_workers > 1:
Expand Down
6 changes: 5 additions & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def gym_make():
if num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(
num_workers,
make_env,
serial_for_single=True,
)
action_spec = make_env.action_spec

collector = RPCDataCollector(
Expand Down
6 changes: 5 additions & 1 deletion examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def gym_make():
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(
num_workers,
make_env,
serial_for_single=True,
)
action_spec = make_env.action_spec

if args.worker_parallelism == "collector" and num_workers > 1:
Expand Down
1 change: 1 addition & 0 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def parallel_env_constructor(
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
device=cfg.collector_device,
serial_for_single=True,
)
if batch_transform:
kwargs.update(
Expand Down
2 changes: 2 additions & 0 deletions examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -93,6 +94,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
ParallelEnv(
eval_num_envs,
EnvCreator(lambda: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
4 changes: 3 additions & 1 deletion examples/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def make_base_env(

def make_parallel_env(env_name, num_envs, device, is_test=False):
env = ParallelEnv(
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
num_envs,
EnvCreator(lambda: make_base_env(env_name, device=device)),
serial_for_single=True,
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
Expand Down
1 change: 1 addition & 0 deletions examples/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def parallel_env_constructor(
num_workers=cfg.collector.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
serial_for_single=True,
pin_memory=False,
)
if batch_transform:
Expand Down
2 changes: 2 additions & 0 deletions examples/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def make_environment(cfg):
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -77,6 +78,7 @@ def make_environment(cfg):
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 2 additions & 0 deletions examples/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def make_environment(cfg):
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -89,6 +90,7 @@ def make_environment(cfg):
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
8 changes: 8 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

def test_serial_for_single(self):
env = ParallelEnv(1, ContinuousActionVecMockEnv, serial_for_single=True)
assert isinstance(env, SerialEnv)
env = ParallelEnv(1, ContinuousActionVecMockEnv)
assert isinstance(env, ParallelEnv)
env = ParallelEnv(2, ContinuousActionVecMockEnv, serial_for_single=True)
assert isinstance(env, ParallelEnv)

@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
def test_env_with_batch_size(self, num_parallel_env, env_batch_size):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def ctx_factory():

if inspect.isclass(func):
raise RuntimeError(
"Cannot decorate classes; it is ambiguous whether or not only the "
"Cannot decorate classes; it is ambiguous whether only the "
"constructor or all methods should have the context manager applied; "
"additionally, decorating a class at definition-site will prevent "
"use of the identifier as a conventional type. "
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
This feature is mainly intended to be used in offline/model-based settings, where a batch of random
trajectories can be used to initialize training.
Defaults to ``None`` (i.e. no random frames)
reset_at_each_iter (bool): Whether or not environments should be reset for each batch.
reset_at_each_iter (bool): whether environments should be reset for each batch.
default=False.
postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a
useful format for training.
Expand Down
105 changes: 100 additions & 5 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.common import _EnvPostInit, EnvBase
from torchrl.envs.env_creator import get_env_metadata

# legacy
Expand Down Expand Up @@ -104,6 +104,19 @@ def new_fun(self, *args, **kwargs):
return new_fun


class _PEnvMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.pop("serial_for_single", False)
if serial_for_single:
num_workers = kwargs.get("num_workers", None)
if num_workers is None:
num_workers = args[0]
if num_workers == 1:
# We still use a serial to keep the shape unchanged
return SerialEnv(*args, **kwargs)
return super().__call__(*args, **kwargs)


class _BatchedEnv(EnvBase):
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
Expand All @@ -120,12 +133,14 @@ class _BatchedEnv(EnvBase):
If a single task is used, a callable should be used and not a list of identical callables:
if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were
needed, which comes with a slight compute overhead;
Keyword Args:
create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created;
share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy
stack is returned.
default = None (False if single task);
shared_memory (bool): whether or not the returned tensordict will be placed in shared memory;
memmap (bool): whether or not the returned tensordict will be placed in memory map.
shared_memory (bool): whether the returned tensordict will be placed in shared memory;
memmap (bool): whether the returned tensordict will be placed in memory map.
policy_proof (callable, optional): if provided, it'll be used to get the list of
tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc.
device (str, int, torch.device): The device of the batched environment can be passed.
Expand All @@ -147,7 +162,84 @@ class _BatchedEnv(EnvBase):
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
This parameter has no effect for the :class:`~SerialEnv` class.
serial_for_single (bool, optional): if ``True``, creating a parallel environment
with a single worker will return a :class:`~SerialEnv` instead.
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
>>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases.
>>> env = SerialEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on the same process serially
>>> env = ParallelEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on dedicated processes
>>> from torchrl.envs import DMControlEnv
>>> env = ParallelEnv(2, [
... lambda: DMControlEnv("humanoid", "stand"),
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
>>> r = env.rollout(10) # executes 10 random steps in the environment
>>> r[0] # data for Humanoid stand
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
next: TensorDict(
fields={
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> r[1] # data for Humanoid walk
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
next: TensorDict(
fields={
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
"""

_verbose: bool = VERBOSE
Expand All @@ -162,6 +254,7 @@ def __init__(
self,
num_workers: int,
create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]],
*,
create_env_kwargs: Union[dict, Sequence[dict]] = None,
pin_memory: bool = False,
share_individual_td: Optional[bool] = None,
Expand All @@ -172,8 +265,10 @@ def __init__(
allow_step_when_done: bool = False,
num_threads: int = None,
num_sub_threads: int = 1,
serial_for_single: bool = False,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
self.is_closed = True
if num_threads is None:
num_threads = num_workers + 1 # 1 more thread for this proc
Expand Down Expand Up @@ -760,7 +855,7 @@ def to(self, device: DEVICE_TYPING):
return self


class ParallelEnv(_BatchedEnv):
class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
"""Creates one environment per process.
TensorDicts are passed via shared memory or memory map.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def auto_register_info_dict(self):
within the tensordict.
This method returns a (possibly transformed) environment where we make sure that
the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether or not
the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether
the info is filled at reset time.
This method requires running a few iterations in the environment to
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class DQNLoss(LossModule):
delay_value (bool, optional): whether to duplicate the value network
into a new target value network to
create a DQN with a target network. Default is ``False``.
double_dqn (bool, optional): whether or not to use Double DQN, as described in
double_dqn (bool, optional): whether to use Double DQN, as described in
https://arxiv.org/abs/1509.06461. Defaults to ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
Expand Down
Loading

0 comments on commit 156a668

Please sign in to comment.