Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] serial_for_single arg in batched envs #1846

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jan 29, 2024
commit 19c379fd61353bdb5fe00c1d0fe37484c823a8e3
2 changes: 1 addition & 1 deletion examples/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def make_base_env(

def make_parallel_env(env_name, num_envs, device, is_test=False):
env = ParallelEnv(
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)), serial_for_single=True,
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
Expand Down
4 changes: 2 additions & 2 deletions examples/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -88,7 +88,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
eval_env = TransformedEnv(
ParallelEnv(
eval_num_envs,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -86,7 +86,7 @@ def make_environment(cfg):
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def make_env():
return make_base_env(env_cfg)

env = make_transformed_env(
ParallelEnv(num_envs, EnvCreator(make_env)),
ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True),
env_cfg,
obs_loc,
obs_std,
Expand Down
4 changes: 2 additions & 2 deletions examples/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -87,7 +87,7 @@ def make_environment(cfg):
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def gym_make():
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(num_workers, make_env, serial_for_single=True,)
action_spec = make_env.action_spec

if args.worker_parallelism == "collector" and num_workers > 1:
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def gym_make():
if num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(num_workers, make_env, serial_for_single=True,)
action_spec = make_env.action_spec

collector = RPCDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def gym_make():
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
make_env = ParallelEnv(num_workers, make_env)
make_env = ParallelEnv(num_workers, make_env, serial_for_single=True,)
action_spec = make_env.action_spec

if args.worker_parallelism == "collector" and num_workers > 1:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def parallel_env_constructor(
create_env_fn=make_transformed_env,
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
device=cfg.collector_device,
device=cfg.collector_device, serial_for_single=True,
)
if batch_transform:
kwargs.update(
Expand Down
4 changes: 2 additions & 2 deletions examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
train_num_envs,
EnvCreator(lambda: env_maker(cfg)),
EnvCreator(lambda: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -92,7 +92,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1):
eval_env = TransformedEnv(
ParallelEnv(
eval_num_envs,
EnvCreator(lambda: env_maker(cfg)),
EnvCreator(lambda: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def make_base_env(

def make_parallel_env(env_name, num_envs, device, is_test=False):
env = ParallelEnv(
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)), serial_for_single=True,
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
Expand Down
2 changes: 1 addition & 1 deletion examples/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def parallel_env_constructor(
parallel_env = ParallelEnv(
num_workers=cfg.collector.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
create_env_kwargs=None, serial_for_single=True,
pin_memory=False,
)
if batch_transform:
Expand Down
4 changes: 2 additions & 2 deletions examples/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -76,7 +76,7 @@ def make_environment(cfg):
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def make_environment(cfg):
"""Make environments for training and evaluation."""
parallel_env = ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
)
parallel_env.set_seed(cfg.env.seed)

Expand All @@ -88,7 +88,7 @@ def make_environment(cfg):
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
EnvCreator(lambda cfg=cfg: env_maker(cfg)),
EnvCreator(lambda cfg=cfg: env_maker(cfg)), serial_for_single=True,
),
train_env.transform.clone(),
)
Expand Down
11 changes: 11 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,17 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

def test_serial_for_single(self):
env = ParallelEnv(1, lambda: ContinuousActionVecMockEnv, serial_for_single=True)
assert isinstance(env, SerialEnv)
env.close()
env = ParallelEnv(1, lambda: ContinuousActionVecMockEnv)
assert isinstance(env, ParallelEnv)
env.close()
env = ParallelEnv(2, lambda: ContinuousActionVecMockEnv, serial_for_single=True)
assert isinstance(env, ParallelEnv)
env.close()

@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
def test_env_with_batch_size(self, num_parallel_env, env_batch_size):
Expand Down
25 changes: 22 additions & 3 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.common import EnvBase, _EnvPostInit
from torchrl.envs.env_creator import get_env_metadata

# legacy
Expand Down Expand Up @@ -103,6 +103,18 @@ def new_fun(self, *args, **kwargs):

return new_fun

class _PEnvMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.get("serial_for_single", False)
if serial_for_single:
num_workers = kwargs.get("num_workers", None)
if num_workers is None:
num_workers = args[0]
if num_workers == 1:
# We still use a serial to keep the shape unchanged
return SerialEnv(*args, **kwargs)
instance: EnvBase = super().__call__(*args, **kwargs)


class _BatchedEnv(EnvBase):
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
Expand All @@ -120,6 +132,8 @@ class _BatchedEnv(EnvBase):
If a single task is used, a callable should be used and not a list of identical callables:
if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were
needed, which comes with a slight compute overhead;

Keyword Args:
create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created;
share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy
stack is returned.
Expand Down Expand Up @@ -147,7 +161,9 @@ class _BatchedEnv(EnvBase):
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
This parameter has no effect for the :class:`~SerialEnv` class.

serial_for_single (bool, optional): if ``True``, creating a parallel environment
with a single worker will return a :class:`~SerialEnv` instead.
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
"""

_verbose: bool = VERBOSE
Expand All @@ -162,6 +178,7 @@ def __init__(
self,
num_workers: int,
create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]],
*,
create_env_kwargs: Union[dict, Sequence[dict]] = None,
pin_memory: bool = False,
share_individual_td: Optional[bool] = None,
Expand All @@ -172,8 +189,10 @@ def __init__(
allow_step_when_done: bool = False,
num_threads: int = None,
num_sub_threads: int = 1,
serial_for_single: bool = False,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
self.is_closed = True
if num_threads is None:
num_threads = num_workers + 1 # 1 more thread for this proc
Expand Down Expand Up @@ -760,7 +779,7 @@ def to(self, device: DEVICE_TYPING):
return self


class ParallelEnv(_BatchedEnv):
class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
"""Creates one environment per process.

TensorDicts are passed via shared memory or memory map.
Expand Down
Loading