Skip to content

Commit

Permalink
[Feature] static_seed flag for envs, vectorized envs and collectors (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 1, 2022
1 parent a61c8a5 commit 86de26f
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 28 deletions.
8 changes: 6 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(self, seed: int = 100):
def maxstep(self):
return 100

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed=False) -> int:
self.seed = seed
self.counter = seed % 17 # make counter a small number
if static_seed:
return seed
return seed_generator(seed)

def custom_fun(self):
Expand All @@ -96,11 +98,13 @@ def __init__(self, device):
self.reward_spec = NdUnboundedContinuousTensorSpec((1,))
self.is_closed = False

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
assert seed >= 1
self.seed = seed
self.counter = seed % 17 # make counter a small number
self.max_val = max(self.counter + 100, self.counter * 2)
if static_seed:
return seed
return seed_generator(seed)

def _step(self, tensordict):
Expand Down
14 changes: 11 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ def make_frames_per_batch(frames_per_batch):


@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv")
def test_collector_vecnorm_envcreator():
@pytest.mark.parametrize("static_seed", [True, False])
def test_collector_vecnorm_envcreator(static_seed):
"""
High level test of the following pipeline:
(1) Design a function that creates an environment with VecNorm
Expand All @@ -470,12 +471,19 @@ def test_collector_vecnorm_envcreator():
)

init_seed = 0
new_seed = c.set_seed(init_seed)
new_seed = c.set_seed(init_seed, static_seed=static_seed)
if static_seed:
assert new_seed == init_seed
else:
assert new_seed != init_seed

seed = init_seed
for i in range(num_envs * num_data_collectors):
seed = seed_generator(seed)
assert new_seed == seed
if not static_seed:
assert new_seed == seed
else:
assert new_seed != seed

c_iter = iter(c)
next(c_iter)
Expand Down
11 changes: 8 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,17 @@ def test_parallel_env_with_policy(
@pytest.mark.parametrize("frame_skip", [4, 1])
@pytest.mark.parametrize("transformed_in", [False, True])
@pytest.mark.parametrize("transformed_out", [True, False])
@pytest.mark.parametrize("static_seed", [True, False])
def test_parallel_env_seed(
self, env_name, frame_skip, transformed_in, transformed_out
self, env_name, frame_skip, transformed_in, transformed_out, static_seed
):
env_parallel, env_serial, _ = _make_envs(
env_name, frame_skip, transformed_in, transformed_out, 5
)

out_seed_serial = env_serial.set_seed(0)
out_seed_serial = env_serial.set_seed(0, static_seed=static_seed)
if static_seed:
assert out_seed_serial == 0
td0_serial = env_serial.reset()
torch.manual_seed(0)

Expand All @@ -524,7 +527,9 @@ def test_parallel_env_seed(
td_serial[:, 0].get("next_" + key), td_serial[:, 1].get(key)
)

out_seed_parallel = env_parallel.set_seed(0)
out_seed_parallel = env_parallel.set_seed(0, static_seed=static_seed)
if static_seed:
assert out_seed_serial == 0
td0_parallel = env_parallel.reset()

torch.manual_seed(0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MockingOptim:
class MockingCollector:
called_update_policy_weights_ = False

def set_seed(self, seed):
def set_seed(self, seed, **kwargs):
return seed

def update_policy_weights_(self):
Expand Down
17 changes: 11 additions & 6 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
raise NotImplementedError

@abc.abstractmethod
def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -323,11 +323,13 @@ def __init__(
self._has_been_done = None
self._exclude_private_keys = True

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seeds of the environments stored in the DataCollector.
Args:
seed (int): integer representing the seed to be used for the environment.
static_seed(bool, optional): if True, the seed is not incremented.
Defaults to False
Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the
Expand All @@ -340,7 +342,7 @@ def set_seed(self, seed: int) -> int:
>>> out_seed = collector.set_seed(1) # out_seed = 6
"""
return self.env.set_seed(seed)
return self.env.set_seed(seed, static_seed=static_seed)

def iterator(self) -> Iterator[TensorDictBase]:
"""Iterates through the DataCollector.
Expand Down Expand Up @@ -828,11 +830,13 @@ def _shutdown_main(self) -> None:
for pipe in self.pipes:
pipe.close()

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seeds of the environments stored in the DataCollector.
Args:
seed: integer representing the seed to be used for the environment.
static_seed (bool, optional): if True, the seed is not incremented.
Defaults to False
Returns:
Output seed. This is useful when more than one environment is
Expand All @@ -849,7 +853,7 @@ def set_seed(self, seed: int) -> int:
"""
_check_for_faulty_process(self.procs)
for idx in range(self.num_workers):
self.pipes[idx].send((seed, "seed"))
self.pipes[idx].send(((seed, static_seed), "seed"))
new_seed, msg = self.pipes[idx].recv()
if msg != "seeded":
raise RuntimeError(f"Expected msg='seeded', got {msg}")
Expand Down Expand Up @@ -1355,7 +1359,8 @@ def _main_async_collector(
continue

elif msg == "seed":
new_seed = dc.set_seed(data_in)
data_in, static_seed = data_in
new_seed = dc.set_seed(data_in, static_seed=static_seed)
torch.manual_seed(data_in)
np.random.seed(data_in)
pipe_child.send((new_seed, "seeded"))
Expand Down
14 changes: 9 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,14 @@ def reset(
def numel(self) -> int:
return prod(self.batch_size)

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seed of the environment and returns the next seed to be used (
which is the input seed if a single environment is present)
Args:
seed: integer
seed (int): seed to be set
static_seed (bool, optional): if True, the seed is not incremented.
Defaults to False
Returns:
integer representing the "next seed": i.e. the seed that should be
Expand All @@ -394,7 +396,7 @@ def set_seed(self, seed: int) -> int:
if seed is not None:
torch.manual_seed(seed)
self._set_seed(seed)
if seed is not None:
if seed is not None and not static_seed:
new_seed = seed_generator(seed)
seed = new_seed
return seed
Expand Down Expand Up @@ -738,11 +740,13 @@ def close(self) -> None:
except AttributeError:
pass

def set_seed(self, seed: Optional[int] = None) -> Optional[int]:
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
if seed is not None:
torch.manual_seed(seed)
self._set_seed(seed)
if seed is not None:
if seed is not None and not static_seed:
new_seed = seed_generator(seed)
seed = new_seed
return seed
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_out = self.transform(tensordict_out)
return tensordict_out

def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Set the seeds of the environment"""
return self.base_env.set_seed(seed)
return self.base_env.set_seed(seed, static_seed=static_seed)

def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
out_tensordict = self.base_env.reset(execute_step=False, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def _shutdown_workers(self) -> None:
del self._envs

@_check_start
def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
for i, env in enumerate(self._envs):
new_seed = env.set_seed(seed)
new_seed = env.set_seed(seed, static_seed=static_seed)
seed = new_seed
return seed

Expand Down Expand Up @@ -789,10 +789,10 @@ def _shutdown_workers(self) -> None:
del self.parent_channels

@_check_start
def set_seed(self, seed: int) -> int:
def set_seed(self, seed: int, static_seed: bool = False) -> int:
self._seeds = []
for channel in self.parent_channels:
channel.send(("seed", seed))
channel.send(("seed", (seed, static_seed)))
self._seeds.append(seed)
msg, new_seed = channel.recv()
if msg != "seeded":
Expand Down Expand Up @@ -940,7 +940,7 @@ def _run_worker_pipe_shared_mem(
raise RuntimeError("call 'init' before closing")
# torch.manual_seed(data)
# np.random.seed(data)
new_seed = env.set_seed(data)
new_seed = env.set_seed(data[0], static_seed=data[1])
child_pipe.send(("seeded", new_seed))

elif cmd == "init":
Expand Down
2 changes: 1 addition & 1 deletion torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def load_from_file(self, file: Union[str, pathlib.Path]) -> Trainer:
return self

def set_seed(self):
seed = self.collector.set_seed(self.seed)
seed = self.collector.set_seed(self.seed, static_seed=False)
torch.manual_seed(seed)
np.random.seed(seed)

Expand Down

0 comments on commit 86de26f

Please sign in to comment.