Skip to content

Commit

Permalink
[BugFix] Fix mp_start_method for ParallelEnv with single_for_serial (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 11, 2024
1 parent 358475a commit 2b8450c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 14 additions & 5 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,22 @@ def test_parallel_devices(
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

def test_serial_for_single(self, maybe_fork_ParallelEnv):
env = ParallelEnv(1, ContinuousActionVecMockEnv, serial_for_single=True)
@pytest.mark.parametrize("start_method", [None, "fork"])
def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method):
env = ParallelEnv(
1,
ContinuousActionVecMockEnv,
serial_for_single=True,
mp_start_method=start_method,
)
assert isinstance(env, SerialEnv)
env = maybe_fork_ParallelEnv(1, ContinuousActionVecMockEnv)
env = ParallelEnv(1, ContinuousActionVecMockEnv, mp_start_method=start_method)
assert isinstance(env, ParallelEnv)
env = maybe_fork_ParallelEnv(
2, ContinuousActionVecMockEnv, serial_for_single=True
env = ParallelEnv(
2,
ContinuousActionVecMockEnv,
serial_for_single=True,
mp_start_method=start_method,
)
assert isinstance(env, ParallelEnv)

Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.pop("serial_for_single", False)
if serial_for_single:
num_workers = kwargs.get("num_workers", None)
# Remove start method from kwargs
kwargs.pop("mp_start_method", None)
if num_workers is None:
num_workers = args[0]
if num_workers == 1:
Expand Down

0 comments on commit 2b8450c

Please sign in to comment.