Skip to content

Commit

Permalink
[Feature] Allow usage of a different device on main and sub-envs in P…
Browse files Browse the repository at this point in the history
…arallelEnv and SerialEnv (pytorch#1626)
  • Loading branch information
vmoens authored Nov 30, 2023
1 parent 2e7f574 commit 6c27bdb
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 114 deletions.
38 changes: 18 additions & 20 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend):
# regular parallel env
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

# env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, EnvCreator(make))
penv = ParallelEnv(num_workers, EnvCreator(make), device=device)
with torch.inference_mode():
# warmup
penv.rollout(2)
Expand All @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device):

for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

env_make = EnvCreator(make)
# penv = SerialEnv(num_workers, env_make)
penv = ParallelEnv(num_workers, env_make)
penv = ParallelEnv(num_workers, env_make, device=device)
collector = SyncDataCollector(
penv,
RandomPolicy(penv.action_spec),
Expand Down Expand Up @@ -164,14 +164,14 @@ def make_env(
for device in avail_devices:
# async collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
num_workers // num_collectors,
EnvCreator(make_env),
device=device,
)
collector = MultiaSyncDataCollector(
[penv] * num_collectors,
Expand Down Expand Up @@ -206,10 +206,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv

penv = EnvCreator(
Expand Down Expand Up @@ -247,14 +246,14 @@ def make_env(
for device in avail_devices:
# sync collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
num_workers // num_collectors,
EnvCreator(make_env),
device=device,
)
collector = MultiSyncDataCollector(
[penv] * num_collectors,
Expand Down Expand Up @@ -289,10 +288,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv

penv = EnvCreator(
Expand Down
25 changes: 15 additions & 10 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def transformed_env_constructor(
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
ignore_device: bool = False,
) -> Union[Callable, EnvCreator]:
"""
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
Expand Down Expand Up @@ -179,6 +180,7 @@ def transformed_env_constructor(
it should be set to 1 (or the number of dims of the batch).
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded
into the environment
ignore_device (bool, optional): if True, the device is ignored.
"""

def make_transformed_env(**kwargs) -> TransformedEnv:
Expand All @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
from_pixels = cfg.from_pixels

if custom_env is None and custom_env_maker is None:
if isinstance(cfg.collector_device, str):
device = cfg.collector_device
elif isinstance(cfg.collector_device, Sequence):
device = cfg.collector_device[0]
if not ignore_device:
if isinstance(cfg.collector_device, str):
device = cfg.collector_device
elif isinstance(cfg.collector_device, Sequence):
device = cfg.collector_device[0]
else:
raise ValueError(
"collector_device must be either a string or a sequence of strings"
)
else:
raise ValueError(
"collector_device must be either a string or a sequence of strings"
)
device = None
env_kwargs = {
"env_name": env_name,
"device": device,
Expand Down Expand Up @@ -252,19 +257,19 @@ def parallel_env_constructor(
kwargs: keyword arguments for the `transformed_env_constructor` method.
"""
batch_transform = cfg.batch_transform
kwargs.update({"cfg": cfg, "use_env_creator": True})
if cfg.env_per_collector == 1:
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
return make_transformed_env
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(
return_transformed_envs=not batch_transform, **kwargs
return_transformed_envs=not batch_transform, ignore_device=True, **kwargs
)
parallel_env = ParallelEnv(
num_workers=cfg.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
device=cfg.collector_device,
)
if batch_transform:
kwargs.update(
Expand Down
42 changes: 42 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0):


class TestParallel:
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="No cuda device detected."
)
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("hetero", [True, False])
@pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"])
@pytest.mark.parametrize("edevice", ["cpu", "cuda"])
@pytest.mark.parametrize("bwad", [True, False])
def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if not hetero:
env = cls(
2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice
)
else:
env1 = lambda: ContinuousActionVecMockEnv(device=edevice)
env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice))
env = cls(2, [env1, env2], device=pdevice)

r = env.rollout(2, break_when_any_done=bwad)
if pdevice is not None:
assert env.device.type == torch.device(pdevice).type
assert r.device.type == torch.device(pdevice).type
assert all(
item.device.type == torch.device(pdevice).type
for item in r.values(True, True)
)
else:
assert env.device.type == torch.device(edevice).type
assert r.device.type == torch.device(edevice).type
assert all(
item.device.type == torch.device(edevice).type
for item in r.values(True, True)
)
if parallel:
assert (
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
def test_env_with_batch_size(self, num_parallel_env, env_batch_size):
Expand Down
Loading

0 comments on commit 6c27bdb

Please sign in to comment.