-
Notifications
You must be signed in to change notification settings - Fork 327
/
collectors.py
3143 lines (2839 loc) · 141 KB
/
collectors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import _pickle
import abc
import contextlib
import functools
import os
import queue
import sys
import time
import typing
import warnings
from collections import defaultdict, OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
from textwrap import indent
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from tensordict import (
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
TensorDictParams,
)
from tensordict.base import NO_DEFAULT
from tensordict.nn import CudaGraphModule, TensorDictModule
from tensordict.utils import Buffer
from torch import multiprocessing as mp
from torch.nn import Parameter
from torch.utils.data import IterableDataset
from torchrl._utils import (
_check_for_faulty_process,
_ends_with,
_make_ordinal_device,
_ProcessNoWarn,
_replace_last,
accept_remote_rref_udf_invocation,
compile_with_warmup,
logger as torchrl_logger,
prod,
RL_WARNINGS,
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
_make_compatible_policy,
ExplorationType,
RandomPolicy,
set_exploration_type,
)
try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
def cudagraph_mark_step_begin():
"""Placeholder for missing cudagraph_mark_step_begin method."""
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
_TIMEOUT = 1.0
INSTANTIATE_TIMEOUT = 20
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000))
DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
_is_osx = sys.platform.startswith("darwin")
class _Interruptor:
"""A class for managing the collection state of a process.
This class provides methods to start and stop collection, and to check
whether collection has been stopped. The collection state is protected
by a lock to ensure thread-safety.
"""
# interrupter vs interruptor: google trends seems to indicate that "or" is more
# widely used than "er" even if my IDE complains about that...
def __init__(self):
self._collect = True
self._lock = mp.Lock()
def start_collection(self):
with self._lock:
self._collect = True
def stop_collection(self):
with self._lock:
self._collect = False
def collection_stopped(self):
with self._lock:
return self._collect is False
class _InterruptorManager(SyncManager):
"""A custom SyncManager for managing the collection state of a process.
This class extends the SyncManager class and allows to share an Interruptor object
between processes.
"""
pass
_InterruptorManager.register("_Interruptor", _Interruptor)
def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
"""Maps the tensors to CPU through a nested dictionary."""
return OrderedDict(
**{
k: recursive_map_to_cpu(item)
if isinstance(item, OrderedDict)
else item.cpu()
if isinstance(item, torch.Tensor)
else item
for k, item in dictionary.items()
}
)
class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
"""Base class for data collectors."""
_iterator = None
total_frames: int
requested_frames_per_batch: int
frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
def _get_policy_and_device(
self,
policy: Callable[[Any], Any] | None = None,
observation_spec: TensorSpec = None,
policy_device: Any = NO_DEFAULT,
env_maker: Any | None = None,
env_maker_kwargs: dict | None = None,
) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]:
"""Util method to get a policy and its device given the collector __init__ inputs.
We want to copy the policy and then move the data there, not call policy.to(device).
Args:
policy (TensorDictModule, optional): a policy to be used
observation_spec (TensorSpec, optional): spec of the observations
policy_device (torch.device, optional): the device where the policy should be placed.
Defaults to self.policy_device
env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
env_maker_kwargs (a dict, optional): the env_maker function kwargs.
"""
if policy_device is NO_DEFAULT:
policy_device = self.policy_device
if not self.trust_policy:
env = getattr(self, "env", None)
policy = _make_compatible_policy(
policy,
observation_spec,
env=env,
env_maker=env_maker,
env_maker_kwargs=env_maker_kwargs,
)
if not policy_device:
return policy, None
if isinstance(policy, nn.Module):
param_and_buf = TensorDict.from_module(policy, as_module=True)
else:
# Because we want to reach the warning
param_and_buf = TensorDict()
i = -1
for p in param_and_buf.values(True, True):
i += 1
if p.device != policy_device:
# Then we need casting
break
else:
if i == -1 and not self.trust_policy:
# We trust that the policy policy device is adequate
warnings.warn(
"A policy device was provided but no parameter/buffer could be found in "
"the policy. Casting to policy_device is therefore impossible. "
"The collector will trust that the devices match. To suppress this "
"warning, set `trust_policy=True` when building the collector."
)
return policy, None
def map_weight(
weight,
policy_device=policy_device,
):
is_param = isinstance(weight, Parameter)
is_buffer = isinstance(weight, Buffer)
weight = weight.data
if weight.device != policy_device:
weight = weight.to(policy_device)
elif weight.device.type in ("cpu",):
weight = weight.share_memory_()
if is_param:
weight = Parameter(weight, requires_grad=False)
elif is_buffer:
weight = Buffer(weight)
return weight
# Create a stateless policy, then populate this copy with params on device
get_original_weights = functools.partial(TensorDict.from_module, policy)
with param_and_buf.to("meta").to_module(policy):
policy = deepcopy(policy)
param_and_buf.apply(
map_weight,
filter_empty=False,
).to_module(policy)
return policy, get_original_weights
def update_policy_weights_(
self, policy_weights: Optional[TensorDictBase] = None
) -> None:
"""Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
Args:
policy_weights (TensorDictBase, optional): if provided, a TensorDict containing
the weights of the policy to be used for the udpdate.
"""
if policy_weights is not None:
self.policy_weights.data.update_(policy_weights)
elif self.get_weights_fn is not None:
self.policy_weights.data.update_(self.get_weights_fn())
def __iter__(self) -> Iterator[TensorDictBase]:
yield from self.iterator()
def next(self):
try:
if self._iterator is None:
self._iterator = iter(self)
out = next(self._iterator)
# if any, we don't want the device ref to be passed in distributed settings
out.clear_device_()
return out
except StopIteration:
return None
@abc.abstractmethod
def shutdown(self):
raise NotImplementedError
@abc.abstractmethod
def iterator(self) -> Iterator[TensorDictBase]:
raise NotImplementedError
@abc.abstractmethod
def set_seed(self, seed: int, static_seed: bool = False) -> int:
raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> OrderedDict:
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(self, state_dict: OrderedDict) -> None:
raise NotImplementedError
def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
self.compiled_policy = compile_policy not in (False, None)
self.cudagraphed_policy = cudagraph_policy not in (False, None)
self.compiled_policy_kwargs = (
{} if not isinstance(compile_policy, typing.Mapping) else compile_policy
)
self.cudagraphed_policy_kwargs = (
{} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
)
def __repr__(self) -> str:
string = f"{self.__class__.__name__}()"
return string
def __class_getitem__(self, index):
raise NotImplementedError
def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.requested_frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")
@accept_remote_rref_udf_invocation
class SyncDataCollector(DataCollectorBase):
"""Generic data collector for RL problems. Requires an environment constructor and a policy.
Args:
create_env_fn (Callable): a callable that returns an instance of
:class:`~torchrl.envs.EnvBase` class.
policy (Callable): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
``action_spec``.
Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
This is the recommended usage of the collector.
Other callables are accepted too:
If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
total_frames (int): A keyword-only argument representing the total
number of frames returned by the collector
during its lifespan. If the ``total_frames`` is not divisible by
``frames_per_batch``, an exception is raised.
Endless collectors can be created by passing ``total_frames=-1``.
Defaults to ``-1`` (endless collector).
device (int, str or torch.device, optional): The generic device of the
collector. The ``device`` args fills any non-specified device: if
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
``env_device`` is not specified, its value will be set to ``device``.
Defaults to ``None`` (No default device).
storing_device (int, str or torch.device, optional): The device on which
the output :class:`~tensordict.TensorDict` will be stored.
If ``device`` is passed and ``storing_device`` is ``None``, it will
default to the value indicated by ``device``.
For long trajectories, it may be necessary to store the data on a different
device than the one where the policy and env are executed.
Defaults to ``None`` (the output tensordict isn't on a specific device,
leaf tensors sit on the device where they were created).
env_device (int, str or torch.device, optional): The device on which
the environment should be cast (or executed if that functionality is
supported). If not specified and the env has a non-``None`` device,
``env_device`` will default to that value. If ``device`` is passed
and ``env_device=None``, it will default to ``device``. If the value
as such specified of ``env_device`` differs from ``policy_device``
and one of them is not ``None``, the data will be cast to ``env_device``
before being passed to the env (i.e., passing different devices to
policy and env is supported). Defaults to ``None``.
policy_device (int, str or torch.device, optional): The device on which
the policy should be cast.
If ``device`` is passed and ``policy_device=None``, it will default
to ``device``. If the value as such specified of ``policy_device``
differs from ``env_device`` and one of them is not ``None``,
the data will be cast to ``policy_device`` before being passed to
the policy (i.e., passing different devices to policy and env is
supported). Defaults to ``None``.
create_env_kwargs (dict, optional): Dictionary of kwargs for
``create_env_fn``.
max_frames_per_traj (int, optional): Maximum steps per trajectory.
Note that a trajectory can span across multiple batches (unless
``reset_at_each_iter`` is set to ``True``, see below).
Once a trajectory reaches ``n_steps``, the environment is reset.
If the environment wraps multiple environments together, the number
of steps is tracked for each environment independently. Negative
values are allowed, in which case this argument is ignored.
Defaults to ``None`` (i.e., no maximum number of steps).
init_random_frames (int, optional): Number of frames for which the
policy is ignored before it is called. This feature is mainly
intended to be used in offline/model-based settings, where a
batch of random trajectories can be used to initialize training.
If provided, it will be rounded up to the closest multiple of frames_per_batch.
Defaults to ``None`` (i.e. no random frames).
reset_at_each_iter (bool, optional): Whether environments should be reset
at the beginning of a batch collection.
Defaults to ``False``.
postproc (Callable, optional): A post-processing transform, such as
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
instance.
Defaults to ``None``.
split_trajs (bool, optional): Boolean indicating whether the resulting
TensorDict should be split according to the trajectories.
See :func:`~torchrl.collectors.utils.split_trajectories` for more
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
return_same_td (bool, optional): if ``True``, the same TensorDict
will be returned at each iteration, with its values
updated. This feature should be used cautiously: if the same
tensordict is added to a replay buffer for instance,
the whole content of the buffer will be identical.
Default is ``False``.
interruptor (_Interruptor, optional):
An _Interruptor object that can be used from outside the class to control rollout collection.
The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
strategies such as preeptively stopping rollout collection.
Default is ``False``.
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
but populate the buffer instead. Defaults to ``None``.
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
and ``False`` otherwise.
compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
will be used to compile the policy.
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
If a dictionary of kwargs is passed, it will be used to wrap the policy.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(
... create_env_fn=env_maker,
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... device="cpu",
... storing_device="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
collector: TensorDict(
fields={
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False)
>>> del collector
The collector delivers batches of data that are marked with a ``"time"``
dimension.
Examples:
>>> assert data.names[-1] == "time"
"""
def __init__(
self,
create_env_fn: Union[
EnvBase, "EnvCreator", Sequence[Callable[[], EnvBase]] # noqa: F821
], # noqa: F821
policy: Optional[
Union[
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
] = None,
*,
frames_per_batch: int,
total_frames: int = -1,
device: DEVICE_TYPING = None,
storing_device: DEVICE_TYPING = None,
policy_device: DEVICE_TYPING = None,
env_device: DEVICE_TYPING = None,
create_env_kwargs: dict | None = None,
max_frames_per_traj: int | None = None,
init_random_frames: int | None = None,
reset_at_each_iter: bool = False,
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
split_trajs: bool | None = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
return_same_td: bool = False,
reset_when_done: bool = True,
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
trust_policy: bool = None,
compile_policy: bool | Dict[str, Any] | None = None,
cudagraph_policy: bool | Dict[str, Any] | None = None,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
self.closed = True
if create_env_kwargs is None:
create_env_kwargs = {}
if not isinstance(create_env_fn, EnvBase):
env = create_env_fn(**create_env_kwargs)
else:
env = create_env_fn
if create_env_kwargs:
if not isinstance(env, BatchedEnvBase):
raise RuntimeError(
"kwargs were passed to SyncDataCollector but they can't be set "
f"on environment of type {type(create_env_fn)}."
)
env.update_kwargs(create_env_kwargs)
if policy is None:
policy = RandomPolicy(env.full_action_spec)
if trust_policy is None:
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
self.trust_policy = trust_policy
self._read_compile_kwargs(compile_policy, cudagraph_policy)
##########################
# Trajectory pool
self._traj_pool_val = kwargs.pop("traj_pool", None)
if kwargs:
raise TypeError(
f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
)
##########################
# Setting devices:
# The rule is the following:
# - If no device is passed, all devices are assumed to work OOB.
# The tensordict used for output is not on any device (ie, actions and observations
# can be on a different device).
# - If the ``device`` is passed, it is used for all devices (storing, env and policy)
# unless overridden by another kwarg.
# - The rest of the kwargs control the respective device.
storing_device, policy_device, env_device = self._get_devices(
storing_device=storing_device,
policy_device=policy_device,
env_device=env_device,
device=device,
)
self.storing_device = storing_device
if self.storing_device is not None and self.storing_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_storage = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
# Will break for older PT versions which don't have torch.mps
self._sync_storage = torch.mps.synchronize
elif self.storing_device.type == "cpu":
self._sync_storage = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_storage = _do_nothing
self.env_device = env_device
if self.env_device is not None and self.env_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_env = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_env = torch.mps.synchronize
elif self.env_device.type == "cpu":
self._sync_env = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_env = _do_nothing
self.policy_device = policy_device
if self.policy_device is not None and self.policy_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_policy = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_policy = torch.mps.synchronize
elif self.policy_device.type == "cpu":
self._sync_policy = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_policy = _do_nothing
self.device = device
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
# and assume the user knows what she's doing).
# If the devices match we're happy too.
# Only if the values differ we need to cast
self._cast_to_policy_device = self.policy_device != self.env_device
self.env: EnvBase = env
del env
self.replay_buffer = replay_buffer
if self.replay_buffer is not None:
if postproc is not None:
raise TypeError("postproc must be None when a replay buffer is passed.")
if use_buffers:
raise TypeError("replay_buffer is exclusive with use_buffers.")
if use_buffers is None:
use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
self.closed = False
if not reset_when_done:
raise ValueError("reset_when_done is deprectated.")
self.reset_when_done = reset_when_done
self.n_env = self.env.batch_size.numel()
(self.policy, self.get_weights_fn,) = self._get_policy_and_device(
policy=policy,
observation_spec=self.env.observation_spec,
)
if isinstance(self.policy, nn.Module):
self.policy_weights = TensorDict.from_module(self.policy, as_module=True)
else:
self.policy_weights = TensorDict()
if self.compiled_policy:
self.policy = compile_with_warmup(
self.policy, **self.compiled_policy_kwargs
)
if self.cudagraphed_policy:
self.policy = CudaGraphModule(self.policy, **self.cudagraphed_policy_kwargs)
if self.env_device:
self.env: EnvBase = self.env.to(self.env_device)
elif self.env.device is not None:
# we did not receive an env device, we use the device of the env
self.env_device = self.env.device
# If the storing device is not the same as the policy device, we have
# no guarantee that the "next" entry from the policy will be on the
# same device as the collector metadata.
self._cast_to_env_device = self._cast_to_policy_device or (
self.env.device != self.storing_device
)
self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
# let's check that there is no StepCounter yet
for key in self.env.output_spec.keys(True, True):
if isinstance(key, str):
key = (key,)
if "step_count" in key:
raise ValueError(
"A 'step_count' key is already present in the environment "
"and the 'max_frames_per_traj' argument may conflict with "
"a 'StepCounter' that has already been set. "
"Possible solutions: Set max_frames_per_traj to 0 or "
"remove the StepCounter limit from the environment transforms."
)
self.env = TransformedEnv(
self.env, StepCounter(max_steps=self.max_frames_per_traj)
)
if total_frames is None or total_frames < 0:
total_frames = float("inf")
else:
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.total_frames = (
int(total_frames) if total_frames != float("inf") else total_frames
)
self.reset_at_each_iter = reset_at_each_iter
self.init_random_frames = (
int(init_random_frames) if init_random_frames not in (None, -1) else 0
)
if (
init_random_frames not in (-1, None, 0)
and init_random_frames % frames_per_batch != 0
and RL_WARNINGS
):
warnings.warn(
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
f" this results in more init_random_frames than requested"
f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.postproc = postproc
if (
self.postproc is not None
and hasattr(self.postproc, "to")
and self.storing_device
):
self.postproc.to(self.storing_device)
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
warnings.warn(
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
f" this results in more frames_per_batch per iteration that requested"
f" ({-(-frames_per_batch // self.n_env) * self.n_env}). "
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.frames_per_batch = -(-frames_per_batch // self.n_env)
self.requested_frames_per_batch = self.frames_per_batch * self.n_env
self.exploration_type = (
exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
)
self.return_same_td = return_same_td
self.set_truncated = set_truncated
self._make_shuttle()
if self._use_buffers:
self._make_final_rollout()
self._set_truncated_keys()
if split_trajs is None:
split_trajs = False
self.split_trajs = split_trajs
self._exclude_private_keys = True
self.interruptor = interruptor
self._frames = 0
self._iter = -1
@property
def _traj_pool(self):
pool = getattr(self, "_traj_pool_val", None)
if pool is None:
pool = self._traj_pool_val = _TrajectoryPool()
return pool
def _make_shuttle(self):
# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
with torch.no_grad():
self._shuttle = self.env.reset()
if self.policy_device != self.env_device or self.env_device is None:
self._shuttle_has_no_device = True
self._shuttle.clear_device_()
else:
self._shuttle_has_no_device = False
traj_ids = self._traj_pool.get_traj_and_increment(
self.n_env, device=self.storing_device
).view(self.env.batch_size)
self._shuttle.set(
("collector", "traj_ids"),
traj_ids,
)
def _make_final_rollout(self):
with torch.no_grad():
self._final_rollout = self.env.fake_tensordict()
# If storing device is not None, we use this to cast the storage.
# If it is None and the env and policy are on the same device,
# the storing device is already the same as those, so we don't need
# to consider this use case.
# In all other cases, we can't really put a device on the storage,
# since at least one data source has a device that is not clear.
if self.storing_device:
self._final_rollout = self._final_rollout.to(
self.storing_device, non_blocking=True
)
else:
# erase all devices
self._final_rollout.clear_device_()
# If the policy has a valid spec, we use it
self._policy_output_keys = set()
if (
hasattr(self.policy, "spec")
and self.policy.spec is not None
and all(v is not None for v in self.policy.spec.values(True, True))
):
if any(
key not in self._final_rollout.keys(isinstance(key, tuple))
for key in self.policy.spec.keys(True, True)
):
# if policy spec is non-empty, all the values are not None and the keys
# match the out_keys we assume the user has given all relevant information
# the policy could have more keys than the env:
policy_spec = self.policy.spec
if policy_spec.ndim < self._final_rollout.ndim:
policy_spec = policy_spec.expand(self._final_rollout.shape)
for key, spec in policy_spec.items(True, True):
self._policy_output_keys.add(key)
if key in self._final_rollout.keys(True):
continue
self._final_rollout.set(key, spec.zero())
else:
# otherwise, we perform a small number of steps with the policy to
# determine the relevant keys with which to pre-populate _final_rollout.
# This is the safest thing to do if the spec has None fields or if there is
# no spec at all.
# See #505 for additional context.
self._final_rollout.update(self._shuttle.copy())
with torch.no_grad():
policy_input = self._shuttle.copy()
if self.policy_device:
policy_input = policy_input.to(self.policy_device)
# we cast to policy device, we'll deal with the device later
policy_input_copy = policy_input.copy()
policy_input_clone = (
policy_input.clone()
) # to test if values have changed in-place
if self.compiled_policy:
cudagraph_mark_step_begin()
policy_output = self.policy(policy_input)
# check that we don't have exclusive keys, because they don't appear in keys
def check_exclusive(val):
if (
isinstance(val, LazyStackedTensorDict)
and val._has_exclusive_keys
):
raise RuntimeError(
"LazyStackedTensorDict with exclusive keys are not permitted in collectors. "
"Consider using a placeholder for missing keys."
)
policy_output._fast_apply(
check_exclusive, call_on_nested=True, filter_empty=True
)
# Use apply, because it works well with lazy stacks
# Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
# or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
def filter_policy(name, value_output, value_input, value_input_clone):
if (value_input is None) or (
(value_output is not value_input)
and (
value_output.device != value_input_clone.device
or ~torch.isclose(value_output, value_input_clone).any()
)
):
return value_output
filtered_policy_output = policy_output.apply(
filter_policy,
policy_input_copy,
policy_input_clone,
default=None,
filter_empty=True,
named=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
set(filtered_policy_output.keys(True, True))
)
)
self._final_rollout.update(
policy_output.select(*self._policy_output_keys)
)
del filtered_policy_output, policy_output, policy_input
_env_output_keys = []
for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
_env_output_keys += list(self.env.output_spec[spec].keys(True, True))
self._env_output_keys = _env_output_keys
self._final_rollout = (
self._final_rollout.unsqueeze(-1)
.expand(*self.env.batch_size, self.frames_per_batch)
.clone()
.zero_()
)
# in addition to outputs of the policy, we add traj_ids to
# _final_rollout which will be collected during rollout
self._final_rollout.set(
("collector", "traj_ids"),
torch.zeros(
*self._final_rollout.batch_size,
dtype=torch.int64,
device=self.storing_device,
),
)
self._final_rollout.refine_names(..., "time")
def _set_truncated_keys(self):
self._truncated_keys = []
if self.set_truncated:
if not any(_ends_with(key, "truncated") for key in self.env.done_keys):
raise RuntimeError(
"set_truncated was set to True but no truncated key could be found "
"in the environment. Make sure the truncated keys are properly set using "
"`env.add_truncated_keys()` before passing the env to the collector."
)
self._truncated_keys = [
key for key in self.env.done_keys if _ends_with(key, "truncated")
]
@classmethod
def _get_devices(
cls,
*,
storing_device: torch.device,
policy_device: torch.device,
env_device: torch.device,
device: torch.device,
):
device = _make_ordinal_device(torch.device(device) if device else device)
storing_device = _make_ordinal_device(
torch.device(storing_device) if storing_device else device
)
policy_device = _make_ordinal_device(
torch.device(policy_device) if policy_device else device
)
env_device = _make_ordinal_device(
torch.device(env_device) if env_device else device
)
if storing_device is None and (env_device == policy_device):
storing_device = env_device
return storing_device, policy_device, env_device
# for RPC
def next(self):
return super().next()
# for RPC
def update_policy_weights_(
self, policy_weights: Optional[TensorDictBase] = None
) -> None:
super().update_policy_weights_(policy_weights)
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seeds of the environments stored in the DataCollector.
Args:
seed (int): integer representing the seed to be used for the environment.
static_seed(bool, optional): if ``True``, the seed is not incremented.
Defaults to False
Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the
seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples:
>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = ParallelEnv(6, env_fn)
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
>>> out_seed = collector.set_seed(1) # out_seed = 6
"""
out = self.env.set_seed(seed, static_seed=static_seed)
return out
def _increment_frames(self, numel):
self._frames += numel