diff --git a/test/test_collector.py b/test/test_collector.py index 9b0117e7486..4f12e445bf3 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2585,8 +2585,15 @@ def test_unique_traj_sync(self, cat_results): buffer.extend(d) assert c._use_buffers traj_ids = buffer[:].get(("collector", "traj_ids")) - # check that we have as many trajs as expected (no skip) - assert traj_ids.unique().numel() == traj_ids.max() + 1 + # Ideally, we'd like that (sorted_traj.values == sorted_traj.indices).all() + # but in practice, one env can reach the end of the rollout and do a reset + # (which we don't want to prevent) and increment the global traj count, + # when the others have not finished yet. In that case, this traj number will never + # appear. + # sorted_traj = traj_ids.unique().sort() + # assert (sorted_traj.values == sorted_traj.indices).all() + # assert traj_ids.unique().numel() == traj_ids.max() + 1 + # check that trajs are not overlapping if stack_results: sets = [ @@ -2751,6 +2758,79 @@ def test_async(self, use_buffers): del collector +class TestCollectorRB: + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + def test_collector_rb_sync(self): + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + + for c in collector: + assert c is None + rb.sample() + rbdata0 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16 + ) + torch.manual_seed(0) + + for i, c in enumerate(collector): + rb.extend(c) + torch.testing.assert_close( + rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"] + ) + assert c is not None + rb.sample() + + rbdata1 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + assert assert_allclose_td(rbdata0, rbdata1) + + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + def test_collector_rb_multisync(self): + env = GymEnv(CARTPOLE_VERSIONED()) + env.set_seed(0) + + rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) + rb.add(env.rand_step(env.reset())) + rb.empty() + + collector = MultiSyncDataCollector( + [lambda: env, lambda: env], + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 16 + assert c is None + assert len(rb) == pred_len + collector.shutdown() + assert len(rb) == 256 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_rb.py b/test/test_rb.py index 4243917c627..359b245fd9f 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2064,13 +2064,16 @@ def exec_multiproc_rb( init=True, writer_type=TensorDictRoundRobinWriter, sampler_type=RandomSampler, + device=None, ): rb = TensorDictReplayBuffer( storage=storage_type(21), writer=writer_type(), sampler=sampler_type() ) if init: td = TensorDict( - {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10] + {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, + [10], + device=device, ) rb.extend(td) q0 = mp.Queue(1) @@ -2098,13 +2101,6 @@ def test_error_list(self): with pytest.raises(RuntimeError, match="Cannot share a storage of type"): self.exec_multiproc_rb(storage_type=ListStorage) - def test_error_nonshared(self): - # non shared tensor storage cannot be shared - with pytest.raises( - RuntimeError, match="The storage must be place in shared memory" - ): - self.exec_multiproc_rb(storage_type=LazyTensorStorage) - def test_error_maxwriter(self): # TensorDictMaxValueWriter cannot be shared with pytest.raises(RuntimeError, match="cannot be shared between processes"): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index be24a06e39c..3a8686f2893 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -50,6 +50,7 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories +from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase @@ -357,6 +358,8 @@ class SyncDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + but populate the buffer instead. Defaults to ``None``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -446,6 +449,8 @@ def __init__( interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -472,6 +477,14 @@ def __init__( policy = RandomPolicy(env.full_action_spec) + ########################## + # Trajectory pool + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + ########################## # Setting devices: # The rule is the following: @@ -538,10 +551,19 @@ def __init__( self.env: EnvBase = env del env + self.replay_buffer = replay_buffer + if self.replay_buffer is not None: + if postproc is not None: + raise TypeError("postproc must be None when a replay buffer is passed.") + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self.closed = False + if not reset_when_done: raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done @@ -655,6 +677,13 @@ def __init__( self._frames = 0 self._iter = -1 + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + def _make_shuttle(self): # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env with torch.no_grad(): @@ -665,9 +694,9 @@ def _make_shuttle(self): else: self._shuttle_has_no_device = False - traj_ids = torch.arange(self.n_env, device=self.storing_device).view( - self.env.batch_size - ) + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) self._shuttle.set( ("collector", "traj_ids"), traj_ids, @@ -871,7 +900,15 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: >>> out_seed = collector.set_seed(1) # out_seed = 6 """ - return self.env.set_seed(seed, static_seed=static_seed) + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed def iterator(self) -> Iterator[TensorDictBase]: """Iterates through the DataCollector. @@ -917,14 +954,15 @@ def cuda_check(tensor: torch.Tensor): for stream in streams: stack.enter_context(torch.cuda.stream(stream)) - total_frames = self.total_frames - while self._frames < self.total_frames: self._iter += 1 tensordict_out = self.rollout() - self._frames += tensordict_out.numel() - if self._frames >= total_frames: - self.env.close() + if tensordict_out is None: + # if a replay buffer is passed, there is no tensordict_out + # frames are updated within the rollout function + yield + continue + self._increment_frames(tensordict_out.numel()) if self.split_trajs: tensordict_out = split_trajectories( @@ -976,14 +1014,20 @@ def _update_traj_ids(self, env_output) -> None: env_output.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): + device = self.storing_device + traj_ids = self._shuttle.get(("collector", "traj_ids")) - traj_sop = traj_sop.to(self.storing_device) - traj_ids = traj_ids.clone().to(self.storing_device) - traj_ids[traj_sop] = traj_ids.max() + torch.arange( - 1, - traj_sop.sum() + 1, - device=self.storing_device, + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) self._shuttle.set(("collector", "traj_ids"), traj_ids) @torch.no_grad() @@ -1053,13 +1097,18 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.storing_device is not None: - tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=True) - ) - self._sync_storage() + if self.replay_buffer is not None: + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return else: - tensordicts.append(self._shuttle) + if self.storing_device is not None: + tensordicts.append( + self._shuttle.to(self.storing_device, non_blocking=True) + ) + self._sync_storage() + else: + tensordicts.append(self._shuttle) # carry over collector data without messing up devices collector_data = self._shuttle.get("collector").copy() @@ -1067,13 +1116,14 @@ def rollout(self) -> TensorDictBase: if self._shuttle_has_no_device: self._shuttle.clear_device_() self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) if ( self.interruptor is not None and self.interruptor.collection_stopped() ): + if self.replay_buffer is not None: + return result = self._final_rollout if self._use_buffers: try: @@ -1109,6 +1159,8 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout, ) + elif self.replay_buffer is not None: + return else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") @@ -1380,6 +1432,8 @@ class _MultiDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + but populate the buffer instead. Defaults to ``None``. """ @@ -1415,6 +1469,7 @@ def __init__( cat_results: str | int | None = None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, ): exploration_type = _convert_exploration_type( exploration_mode=exploration_mode, exploration_type=exploration_type @@ -1458,6 +1513,13 @@ def __init__( del storing_device, env_device, policy_device, device self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + replay_buffer.share() _policy_weights_dict = {} _get_weights_fn_dict = {} @@ -1694,6 +1756,8 @@ def _run_processes(self) -> None: queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] self.pipes = [] + self._traj_pool = _TrajectoryPool(lock=True) + for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): @@ -1730,6 +1794,8 @@ def _run_processes(self) -> None: "interruptor": self.interruptor, "set_truncated": self.set_truncated, "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "traj_pool": self._traj_pool, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -2088,10 +2154,6 @@ def iterator(self) -> Iterator[TensorDictBase]: workers_frames = [0 for _ in range(self.num_workers)] same_device = None self.out_buffer = None - last_traj_ids = [-10 for _ in range(self.num_workers)] - last_traj_ids_subs = [None for _ in range(self.num_workers)] - traj_max = -1 - traj_ids_list = [None for _ in range(self.num_workers)] preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 while not all(dones) and self._frames < self.total_frames: @@ -2125,7 +2187,13 @@ def iterator(self) -> Iterator[TensorDictBase]: for _ in range(self.num_workers): new_data, j = self.queue_out.get() use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = ( + workers_frames[idx] + self.frames_per_batch_worker + ) + continue + elif j == 0 or not use_buffers: try: data, idx = new_data self.buffers[idx] = data @@ -2167,51 +2235,25 @@ def iterator(self) -> Iterator[TensorDictBase]: if workers_frames[idx] >= self.total_frames: dones[idx] = True + if self.replay_buffer is not None: + yield + self._frames += self.frames_per_batch_worker * self.num_workers + continue + # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 for idx in range(self.num_workers): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) - is_last = traj_ids == last_traj_ids[idx] - # If we `cat` interrupted data, we have already filtered out - # non-valid steps. If we stack, we haven't. - if preempt and cat_results == "stack": - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - last_traj_ids[idx] = traj_ids[..., valid][..., -1:].clone() - else: - last_traj_ids[idx] = traj_ids[..., -1:].clone() - if not is_last.all(): - traj_to_correct = traj_ids[~is_last] - traj_to_correct = ( - traj_to_correct + (traj_max + 1) - traj_to_correct.min() - ) - traj_ids = traj_ids.masked_scatter(~is_last, traj_to_correct) - # is_last can only be true if we're after the first iteration - if is_last.any(): - traj_ids = torch.where( - is_last, last_traj_ids_subs[idx].expand_as(traj_ids), traj_ids - ) - if preempt: if cat_results == "stack": mask_frames = buffer.get(("collector", "traj_ids")) != -1 - traj_ids = torch.where(mask_frames, traj_ids, -1) n_collected += mask_frames.sum().cpu() - last_traj_ids_subs[idx] = traj_ids[..., valid][..., -1:].clone() else: - last_traj_ids_subs[idx] = traj_ids[..., -1:].clone() n_collected += traj_ids.numel() else: - last_traj_ids_subs[idx] = traj_ids[..., -1:].clone() n_collected += traj_ids.numel() - traj_ids_list[idx] = traj_ids - - traj_max = max(traj_max, traj_ids.max()) if same_device is None: prev_device = None @@ -2232,9 +2274,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = stack( [item.cpu() for item in buffers.values()], 0 ) - self.out_buffer.set( - ("collector", "traj_ids"), torch.stack(traj_ids_list), inplace=True - ) else: if self._use_buffers is None: torchrl_logger.warning( @@ -2251,9 +2290,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = torch.cat( [item.cpu() for item in buffers.values()], cat_results ) - self.out_buffer.set_( - ("collector", "traj_ids"), torch.cat(traj_ids_list, cat_results) - ) except RuntimeError as err: if ( preempt @@ -2762,6 +2798,8 @@ def _main_async_collector( interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + traj_pool: _TrajectoryPool = None, ) -> None: pipe_parent.close() # init variables that will be cleared when closing @@ -2786,6 +2824,8 @@ def _main_async_collector( interruptor=interruptor, set_truncated=set_truncated, use_buffers=use_buffers, + replay_buffer=replay_buffer, + traj_pool=traj_pool, ) use_buffers = inner_collector._use_buffers if verbose: @@ -2848,6 +2888,21 @@ def _main_async_collector( # In that case, we skip the collected trajectory and get the message from main. This is faster than # sending the trajectory in the queue until timeout when it's never going to be received. continue + + if replay_buffer is not None: + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + if j == 0 or not use_buffers: collected_tensordict = next_data if ( @@ -2956,3 +3011,20 @@ def _make_meta_params(param): if is_param: pd = nn.Parameter(pd, requires_grad=False) return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index fafc120fe94..5ad1bb170cb 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -7,6 +7,7 @@ import collections import contextlib import json +import multiprocessing import textwrap import threading import warnings @@ -128,6 +129,8 @@ class ReplayBuffer: Defaults to ``None`` (global default generator). .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -212,6 +215,7 @@ def __init__( dim_extend: int | None = None, checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 generator: torch.Generator | None = None, + shared: bool = False, ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -228,6 +232,9 @@ def __init__( if self._prefetch_cap: self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) + self.shared = shared + self.share(self.shared) + self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() from torchrl.envs.transforms.transforms import ( @@ -272,6 +279,13 @@ def __init__( self._storage.checkpointer = checkpointer self.set_rng(generator=generator) + def share(self, shared: bool = True): + self.shared = shared + if self.shared: + self._write_lock = multiprocessing.Lock() + else: + self._write_lock = contextlib.nullcontext() + def set_rng(self, generator): self._rng = generator self._storage._rng = generator @@ -576,13 +590,13 @@ def add(self, data: Any) -> int: return self._add(data) def _add(self, data): - with self._replay_lock: + with self._replay_lock, self._write_lock: index = self._writer.add(data) self._sampler.add(index) return index def _extend(self, data: Sequence) -> torch.Tensor: - with self._replay_lock: + with self._replay_lock, self._write_lock: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -630,7 +644,7 @@ def update_priority( if self.dim_extend > 0 and priority.ndim > 1: priority = self._transpose(priority).flatten() # priority = priority.flatten() - with self._replay_lock: + with self._replay_lock, self._write_lock: self._sampler.update_priority(index, priority, storage=self.storage) @pin_memory_output @@ -1053,6 +1067,8 @@ class TensorDictReplayBuffer(ReplayBuffer): Defaults to ``None`` (global default generator). .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -1392,6 +1408,8 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): Defaults to ``None`` (global default generator). .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -1466,6 +1484,7 @@ def __init__( batch_size: int | None = None, dim_extend: int | None = None, generator: torch.Generator | None = None, + shared: bool = False, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1483,6 +1502,7 @@ def __init__( batch_size=batch_size, dim_extend=dim_extend, generator=generator, + shared=shared, ) @@ -1635,6 +1655,8 @@ class ReplayBufferEnsemble(ReplayBuffer): Defaults to ``None`` (global default generator). .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform @@ -1725,6 +1747,7 @@ def __init__( sample_from_all: bool = False, num_buffer_sampled: int | None = None, generator: torch.Generator | None = None, + shared: bool = False, **kwargs, ): @@ -1762,6 +1785,7 @@ def __init__( batch_size=batch_size, collate_fn=collate_fn, generator=generator, + shared=shared, **kwargs, ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 04cc63e231d..d1bd6fbf599 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -547,15 +547,24 @@ def __getstate__(self): # check that the content is shared, otherwise tell the user we can't help storage = self._storage STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes." + + # If the content is on cpu, it will be placed in shared memory. + # If it's on cuda it's already shared. + # If it's memmaped no worry in this case either. + # Only if the device is not "cpu" or "cuda" we may have a problem. + def assert_is_sharable(tensor): + if tensor.device is None or tensor.device.type in ( + "cuda", + "cpu", + "meta", + ): + return + raise RuntimeError(STORAGE_ERR) + if is_tensor_collection(storage): - if not storage.is_memmap() and not storage.is_shared(): - raise RuntimeError(STORAGE_ERR) + storage.apply(assert_is_sharable) else: - if ( - not isinstance(storage, MemoryMappedTensor) - and not storage.is_shared() - ): - raise RuntimeError(STORAGE_ERR) + tree_map(storage, assert_is_sharable) return state diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index f785d1cedd9..e2007227127 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -216,6 +216,7 @@ class PendulumEnv(EnvBase): "render_fps": 30, } batch_locked = False + rng = None def __init__(self, td_params=None, seed=None, device=None): if td_params is None: @@ -224,7 +225,7 @@ def __init__(self, td_params=None, seed=None, device=None): super().__init__(device=device) self._make_spec(td_params) if seed is None: - seed = torch.empty((), dtype=torch.int64).random_().item() + seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item() self.set_seed(seed) @classmethod @@ -354,7 +355,8 @@ def make_composite_from_td(td): return composite def _set_seed(self, seed: int): - rng = torch.manual_seed(seed) + rng = torch.Generator() + rng.manual_seed(seed) self.rng = rng @staticmethod