forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
batched_envs.py
2528 lines (2262 loc) · 105 KB
/
batched_envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import gc
import os
import weakref
from collections import OrderedDict
from copy import copy, deepcopy
from functools import wraps
from multiprocessing import connection
from multiprocessing.synchronize import Lock as MpLock
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from warnings import warn
import torch
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
unravel_key,
)
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_make_ordinal_device,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
)
from torchrl.data.tensor_specs import Composite, NonTensor
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata
# legacy
from torchrl.envs.libs.envpool import ( # noqa: F401
MultiThreadedEnv,
MultiThreadedEnvWrapper,
)
from torchrl.envs.utils import (
_aggregate_end_of_traj,
_sort_keys,
_update_during_reset,
clear_mpi_env_vars,
)
def _check_start(fun):
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
if self.is_closed:
self._create_td()
self._start_workers()
else:
if isinstance(self, ParallelEnv):
_check_for_faulty_process(self._workers)
return fun(self, *args, **kwargs)
return decorated_fun
class _dispatch_caller_parallel:
def __init__(self, attr, parallel_env):
self.attr = attr
self.parallel_env = parallel_env
def __call__(self, *args, **kwargs):
# remove self from args
args = [_arg if _arg is not self.parallel_env else "_self" for _arg in args]
for channel in self.parallel_env.parent_channels:
channel.send((self.attr, (args, kwargs)))
results = []
for channel in self.parallel_env.parent_channels:
msg, result = channel.recv()
results.append(result)
return results
def __iter__(self):
# if the object returned is not a callable
return iter(self.__call__())
class _dispatch_caller_serial:
def __init__(self, list_callable: List[Callable, Any]):
self.list_callable = list_callable
def __call__(self, *args, **kwargs):
return [_callable(*args, **kwargs) for _callable in self.list_callable]
def lazy_property(prop: property):
"""Converts a property in a lazy property, that will call _set_properties when queried the first time."""
return property(fget=lazy(prop.fget), fset=prop.fset)
def lazy(fun):
"""Converts a fun in a lazy fun, that will call _set_properties when queried the first time."""
@wraps(fun)
def new_fun(self, *args, **kwargs):
if not self._properties_set:
self._set_properties()
return fun(self, *args, **kwargs)
return new_fun
class _PEnvMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.pop("serial_for_single", False)
if serial_for_single:
num_workers = kwargs.get("num_workers")
# Remove start method from kwargs
kwargs.pop("mp_start_method", None)
if num_workers is None:
num_workers = args[0]
if num_workers == 1:
# We still use a serial to keep the shape unchanged
return SerialEnv(*args, **kwargs)
return super().__call__(*args, **kwargs)
class BatchedEnvBase(EnvBase):
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
Those queries will return a list of length equal to the number of workers containing the
values resulting from those queries.
>>> env = ParallelEnv(3, my_env_fun)
>>> custom_attribute_list = env.custom_attribute
>>> custom_method_list = env.custom_method(*args)
Args:
num_workers: number of workers (i.e. env instances) to be deployed simultaneously;
create_env_fn (callable or list of callables): function (or list of functions) to be used for the environment
creation.
If a single task is used, a callable should be used and not a list of identical callables:
if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were
needed, which comes with a slight compute overhead;
Keyword Args:
create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created;
share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy
stack is returned.
default = None (False if single task);
shared_memory (bool): whether the returned tensordict will be placed in shared memory;
memmap (bool): whether the returned tensordict will be placed in memory map.
policy_proof (callable, optional): if provided, it'll be used to get the list of
tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc.
device (str, int, torch.device): The device of the batched environment can be passed.
If not, it is inferred from the env. In this case, it is assumed that
the device of all environments match. If it is provided, it can differ
from the sub-environment device(s). In that case, the data will be
automatically cast to the appropriate device during collection.
This can be used to speed up collection in case casting to device
introduces an overhead (eg, numpy-based environents etc.): by using
a ``"cuda"`` device for the batched environment but a ``"cpu"``
device for the nested environments, one can keep the overhead to a
minimum.
num_threads (int, optional): number of threads for this process.
Should be equal to one plus the number of processes launched within
each subprocess (or one if a single process is launched).
Defaults to the number of workers + 1.
This parameter has no effect for the :class:`~SerialEnv` class.
num_sub_threads (int, optional): number of threads of the subprocesses.
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
This parameter has no effect for the :class:`~SerialEnv` class.
serial_for_single (bool, optional): if ``True``, creating a parallel environment
with a single worker will return a :class:`~SerialEnv` instead.
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
non_blocking (bool, optional): if ``True``, device moves will be done using the
``non_blocking=True`` option. Defaults to ``True``.
mp_start_method (str, optional): the multiprocessing start method.
Uses the default start method if not indicated ('spawn' by default in
TorchRL if not initiated differently before first import).
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
use_buffers (bool, optional): whether communication between workers should
occur via circular preallocated memory buffers. Defaults to ``True`` unless
one of the environment has dynamic specs.
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
.. note::
One can pass keyword arguments to each sub-environments using the following
technique: every keyword argument in :meth:`~.reset` will be passed to each
environment except for the ``list_of_kwargs`` argument which, if present,
should contain a list of the same length as the number of workers with the
worker-specific keyword arguments stored in a dictionary.
If a partial reset is queried, the element of ``list_of_kwargs`` corresponding
to sub-environments that are not reset will be ignored.
Examples:
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
>>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases.
>>> env = SerialEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on the same process serially
>>> env = ParallelEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on dedicated processes
>>> from torchrl.envs import DMControlEnv
>>> env = ParallelEnv(2, [
... lambda: DMControlEnv("humanoid", "stand"),
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
>>> rollout = env.rollout(10) # executes 10 random steps in the environment
>>> rollout[0] # data for Humanoid stand
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
next: TensorDict(
fields={
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> rollout[1] # data for Humanoid walk
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
next: TensorDict(
fields={
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> # serial_for_single to avoid creating parallel envs if not necessary
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
"""
_verbose: bool = VERBOSE
_excluded_wrapped_keys = [
"is_closed",
"parent_channels",
"batch_size",
"_dummy_env_str",
]
def __init__(
self,
num_workers: int,
create_env_fn: Union[Callable[[], EnvBase], Sequence[Callable[[], EnvBase]]],
*,
create_env_kwargs: Union[dict, Sequence[dict]] = None,
pin_memory: bool = False,
share_individual_td: Optional[bool] = None,
shared_memory: bool = True,
memmap: bool = False,
policy_proof: Optional[Callable] = None,
device: Optional[DEVICE_TYPING] = None,
allow_step_when_done: bool = False,
num_threads: int = None,
num_sub_threads: int = 1,
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
self.is_closed = True
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self._cache_in_keys = None
self._use_buffers = use_buffers
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
if callable(create_env_fn):
create_env_fn = [create_env_fn for _ in range(num_workers)]
elif len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
)
create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
if isinstance(create_env_kwargs, dict):
create_env_kwargs = [
deepcopy(create_env_kwargs) for _ in range(num_workers)
]
self.policy_proof = policy_proof
self.num_workers = num_workers
self.create_env_fn = create_env_fn
self.create_env_kwargs = create_env_kwargs
self.pin_memory = pin_memory
if pin_memory:
raise ValueError("pin_memory for batched envs is deprecated")
# if share_individual_td is None, we will assess later if the output can be stacked
self.share_individual_td = share_individual_td
self._share_memory = shared_memory
self._memmap = memmap
self.allow_step_when_done = allow_step_when_done
if allow_step_when_done:
raise ValueError("allow_step_when_done is deprecated")
if self._share_memory and self._memmap:
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
self._device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
self.__dict__["_output_spec"] = None
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
self._properties_set = False
self._get_metadata(create_env_fn, create_env_kwargs)
self._non_blocking = non_blocking
if mp_start_method is not None and not isinstance(self, ParallelEnv):
raise TypeError(
f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
)
self._mp_start_method = mp_start_method
@property
def non_blocking(self):
nb = self._non_blocking
if nb is None:
nb = True
self._non_blocking = nb
return nb
@property
def _sync_m2w(self) -> Callable:
sync_func = self.__dict__.get("_sync_m2w_value")
if sync_func is None:
sync_m2w, sync_w2m = self._find_sync_values()
self.__dict__["_sync_m2w_value"] = sync_m2w
self.__dict__["_sync_w2m_value"] = sync_w2m
return sync_m2w
return sync_func
@property
def _sync_w2m(self) -> Callable:
sync_func = self.__dict__.get("_sync_w2m_value")
if sync_func is None:
sync_m2w, sync_w2m = self._find_sync_values()
self.__dict__["_sync_m2w_value"] = sync_m2w
self.__dict__["_sync_w2m_value"] = sync_w2m
return sync_w2m
return sync_func
def _find_sync_values(self):
"""Returns the m2w and w2m sync values, in that order."""
if not self._use_buffers:
return _do_nothing, _do_nothing
# Simplest case: everything is on the same device
worker_device = self.shared_tensordict_parent.device
self_device = self.device
if not self.non_blocking or (
worker_device == self_device or self_device is None
):
# even if they're both None, there is no device-to-device movement
return _do_nothing, _do_nothing
if worker_device is None:
worker_not_main = False
def find_all_worker_devices(item):
nonlocal worker_not_main
if hasattr(item, "device"):
worker_not_main = worker_not_main or (item.device != self_device)
for td in self.shared_tensordicts:
td.apply(find_all_worker_devices, filter_empty=True)
if worker_not_main:
if torch.cuda.is_available():
worker_device = (
torch.device("cuda")
if self_device.type != "cuda"
else torch.device("cpu")
)
elif torch.backends.mps.is_available():
worker_device = (
torch.device("mps")
if self_device.type != "mps"
else torch.device("cpu")
)
else:
raise RuntimeError("Did not find a valid worker device")
else:
worker_device = self_device
if (
worker_device is not None
and worker_device.type == "cuda"
and self_device is not None
and self_device.type == "cpu"
):
return _do_nothing, _cuda_sync(worker_device)
if (
worker_device is not None
and worker_device.type == "mps"
and self_device is not None
and self_device.type == "cpu"
):
return _mps_sync(worker_device), _mps_sync(worker_device)
if (
worker_device is not None
and worker_device.type == "cpu"
and self_device is not None
and self_device.type == "cuda"
):
return _cuda_sync(self_device), _do_nothing
if (
worker_device is not None
and worker_device.type == "cpu"
and self_device is not None
and self_device.type == "mps"
):
return _mps_sync(self_device), _mps_sync(self_device)
return _do_nothing, _do_nothing
def __getstate__(self):
out = copy(self.__dict__)
out["_sync_m2w_value"] = None
out["_sync_w2m_value"] = None
return out
@property
def _has_dynamic_specs(self):
return not self._use_buffers
def _get_metadata(
self, create_env_fn: List[Callable], create_env_kwargs: List[Dict]
):
if self._single_task:
# if EnvCreator, the metadata are already there
meta_data: EnvMetaData = get_env_metadata(
create_env_fn[0], create_env_kwargs[0]
)
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
)
if self._use_buffers is not False:
_use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers and not _use_buffers:
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers
if self.share_individual_td is None:
self.share_individual_td = False
else:
n_tasks = len(create_env_fn)
self.meta_data: List[EnvMetaData] = []
for i in range(n_tasks):
self.meta_data.append(
get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
)
if self.share_individual_td is not True:
share_individual_td = not _stackable(
*[meta_data.tensordict for meta_data in self.meta_data]
)
if share_individual_td and self.share_individual_td is False:
raise ValueError(
"share_individual_td=False was provided but share_individual_td must "
"be True to accommodate non-stackable tensors."
)
self.share_individual_td = share_individual_td
_use_buffers = all(
not metadata.has_dynamic_specs for metadata in self.meta_data
)
if self._use_buffers and not _use_buffers:
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers
self._set_properties()
def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
"""Updates the kwargs of each environment given a dictionary or a list of dictionaries.
Args:
kwargs (dict or list of dict): new kwargs to use with the environments
"""
if isinstance(kwargs, dict):
for _kwargs in self.create_env_kwargs:
_kwargs.update(kwargs)
else:
for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs):
_kwargs.update(_new_kwargs)
def _get_in_keys_to_exclude(self, tensordict):
if self._cache_in_keys is None:
self._cache_in_keys = list(
set(self.input_spec.keys(True)).intersection(
tensordict.keys(True, True)
)
)
return self._cache_in_keys
def _set_properties(self):
cls = type(self)
def _check_for_empty_spec(specs: Composite):
for subspec in (
"full_state_spec",
"full_action_spec",
"full_done_spec",
"full_reward_spec",
"full_observation_spec",
):
for key, spec in reversed(
list(specs.get(subspec, default=Composite()).items(True))
):
if isinstance(spec, Composite) and spec.is_empty():
raise RuntimeError(
f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using "
f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs."
)
return specs
meta_data = self.meta_data
self._properties_set = True
if self._single_task:
self._batch_size = meta_data.batch_size
device = meta_data.device
if self._device is None:
self._device = device
input_spec = _check_for_empty_spec(meta_data.specs["input_spec"].to(device))
output_spec = _check_for_empty_spec(
meta_data.specs["output_spec"].to(device)
)
self.action_spec = input_spec["full_action_spec"]
self.state_spec = input_spec["full_state_spec"]
self.observation_spec = output_spec["full_observation_spec"]
self.reward_spec = output_spec["full_reward_spec"]
self.done_spec = output_spec["full_done_spec"]
self._dummy_env_str = meta_data.env_str
self._env_tensordict = meta_data.tensordict
if device is None: # In other cases, the device will be mapped later
self._env_tensordict.clear_device_()
device_map = meta_data.device_map
def map_device(key, value, device_map=device_map):
return value.to(device_map[key])
self._env_tensordict.named_apply(
map_device, nested_keys=True, filter_empty=True
)
self._batch_locked = meta_data.batch_locked
else:
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
devices = set()
for _meta_data in meta_data:
device = _meta_data.device
devices.add(device)
if self._device is None:
if len(devices) > 1:
raise ValueError(
f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. "
f"Please indicate a device to be used for collection."
)
device = list(devices)[0]
self._device = device
input_spec = []
for md in meta_data:
input_spec.append(_check_for_empty_spec(md.specs["input_spec"]))
input_spec = torch.stack(input_spec, 0)
output_spec = []
for md in meta_data:
output_spec.append(_check_for_empty_spec(md.specs["output_spec"]))
output_spec = torch.stack(output_spec, 0)
self.action_spec = input_spec["full_action_spec"]
self.state_spec = input_spec["full_state_spec"]
self.observation_spec = output_spec["full_observation_spec"]
self.reward_spec = output_spec["full_reward_spec"]
self.done_spec = output_spec["full_done_spec"]
self._dummy_env_str = str(meta_data[0])
if self.share_individual_td:
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
else:
self._env_tensordict = torch.stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
self._batch_locked = meta_data[0].batch_locked
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
def state_dict(self) -> OrderedDict:
raise NotImplementedError
def load_state_dict(self, state_dict: OrderedDict) -> None:
raise NotImplementedError
batch_size = lazy_property(EnvBase.batch_size)
device = lazy_property(EnvBase.device)
input_spec = lazy_property(EnvBase.input_spec)
output_spec = lazy_property(EnvBase.output_spec)
def _create_td(self) -> None:
"""Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations."""
if not self._use_buffers:
return
shared_tensordict_parent = self._env_tensordict.clone()
if self._env_tensordict.shape[0] != self.num_workers:
raise RuntimeError(
"batched environment base tensordict has the wrong shape"
)
# Non-tensor keys
non_tensor_keys = []
for spec in (
self.full_action_spec,
self.full_state_spec,
self.full_observation_spec,
self.full_reward_spec,
self.full_done_spec,
):
for key, _spec in spec.items(True, True):
if isinstance(_spec, NonTensor):
non_tensor_keys.append(key)
self._non_tensor_keys = non_tensor_keys
if self._single_task:
self._env_input_keys = sorted(
list(self.input_spec["full_action_spec"].keys(True, True))
+ list(self.state_spec.keys(True, True)),
key=_sort_keys,
)
self._env_output_keys = []
self._env_obs_keys = []
for key in self.output_spec["full_observation_spec"].keys(True, True):
self._env_output_keys.append(key)
self._env_obs_keys.append(key)
self._env_output_keys += self.reward_keys + self.done_keys
else:
# this is only possible if _single_task=False
env_input_keys = set()
for meta_data in self.meta_data:
if meta_data.specs["input_spec", "full_state_spec"] is not None:
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "full_state_spec"].keys(
True, True
)
)
env_input_keys = env_input_keys.union(
meta_data.specs["input_spec", "full_action_spec"].keys(True, True)
)
env_output_keys = set()
env_obs_keys = set()
for meta_data in self.meta_data:
env_obs_keys = env_obs_keys.union(
key
for key in meta_data.specs["output_spec"][
"full_observation_spec"
].keys(True, True)
)
env_output_keys = env_output_keys.union(
meta_data.specs["output_spec"]["full_observation_spec"].keys(
True, True
)
)
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
env_obs_keys = [
key for key in env_obs_keys if key not in self._non_tensor_keys
]
env_input_keys = [
key for key in env_input_keys if key not in self._non_tensor_keys
]
env_output_keys = [
key for key in env_output_keys if key not in self._non_tensor_keys
]
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
self._env_output_keys = sorted(env_output_keys, key=_sort_keys)
reset_keys = self.reset_keys
self._selected_keys = (
set(self._env_output_keys)
.union(self._env_input_keys)
.union(self._env_obs_keys)
.union(set(self.done_keys))
)
self._selected_keys = self._selected_keys.union(reset_keys)
# input keys
self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys}
# output keys after reset
self._selected_reset_keys = {
unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys
}
# output keys after reset, filtered
self._selected_reset_keys_filt = {
unravel_key(key) for key in self._env_obs_keys + self.done_keys
}
# output keys after step
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}
if not self.share_individual_td:
shared_tensordict_parent = shared_tensordict_parent.filter_non_tensor_data()
shared_tensordict_parent = shared_tensordict_parent.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
)
self.shared_tensordict_parent = shared_tensordict_parent
else:
# Multi-task: we share tensordict that *may* have different keys
shared_tensordict_parent = [
tensordict.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
).filter_non_tensor_data()
for tensordict in shared_tensordict_parent
]
shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
shared_tensordict_parent,
0,
)
self.shared_tensordict_parent = shared_tensordict_parent
if self.share_individual_td:
if not isinstance(self.shared_tensordict_parent, LazyStackedTensorDict):
self.shared_tensordicts = [
td.clone() for td in self.shared_tensordict_parent.unbind(0)
]
self.shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
self.shared_tensordicts, 0
)
else:
# Multi-task: we share tensordict that *may* have different keys
# LazyStacked already stores this so we don't need to do anything
self.shared_tensordicts = self.shared_tensordict_parent
if self._share_memory:
self.shared_tensordict_parent.share_memory_()
elif self._memmap:
self.shared_tensordict_parent.memmap_()
else:
if self._share_memory:
self.shared_tensordict_parent.share_memory_()
if not self.shared_tensordict_parent.is_shared():
raise RuntimeError("share_memory_() failed")
elif self._memmap:
self.shared_tensordict_parent.memmap_()
if not self.shared_tensordict_parent.is_memmap():
raise RuntimeError("memmap_() failed")
self.shared_tensordicts = self.shared_tensordict_parent.unbind(0)
for td in self.shared_tensordicts:
td.lock_()
# we cache all the keys of the shared parent td for future use. This is
# safe since the td is locked.
self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True))
self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next")
self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude(
"next", *self.reset_keys
)
def _start_workers(self) -> None:
"""Starts the various envs."""
raise NotImplementedError
def __repr__(self) -> str:
if self._dummy_env_str is None:
self._dummy_env_str = self._set_properties()
return (
f"{self.__class__.__name__}("
f"\n\tenv={self._dummy_env_str}, "
f"\n\tbatch_size={self.batch_size})"
)
def close(self) -> None:
if self.is_closed:
raise RuntimeError("trying to close a closed environment")
if self._verbose:
torchrl_logger.info(f"closing {self.__class__.__name__}")
self.__dict__["_input_spec"] = None
self.__dict__["_output_spec"] = None
self._properties_set = False
self._shutdown_workers()
self.is_closed = True
import torchrl
num_threads = min(
torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
)
torch.set_num_threads(num_threads)
def _shutdown_workers(self) -> None:
raise NotImplementedError
def _set_seed(self, seed: Optional[int]):
"""This method is not used in batched envs."""
pass
@lazy
def start(self) -> None:
if not self.is_closed:
raise RuntimeError("trying to start a environment that is not closed.")
self._create_td()
self._start_workers()
def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self._device = device
self.__dict__["_sync_m2w_value"] = None
self.__dict__["_sync_w2m_value"] = None
if self.__dict__["_input_spec"] is not None:
self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(device)
if self.__dict__["_output_spec"] is not None:
self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(device)
return self
def _reset_proc_data(self, tensordict, tensordict_reset):
# since we call `reset` directly, all the postproc has been completed
if tensordict is not None:
if isinstance(tensordict_reset, LazyStackedTensorDict) and not isinstance(
tensordict, LazyStackedTensorDict
):
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
return tensordict_reset
def add_truncated_keys(self):
raise RuntimeError(
"Cannot add truncated keys to a batched environment. Please add these entries to "
"the nested environments by calling sub_env.add_truncated_keys()"
)
class SerialEnv(BatchedEnvBase):
"""Creates a series of environments in the same process."""
__doc__ += BatchedEnvBase.__doc__
_share_memory = False
def _start_workers(self) -> None:
_num_workers = self.num_workers
self._envs = []
weakref_set = set()
for idx in range(_num_workers):
env = self.create_env_fn[idx](**self.create_env_kwargs[idx])
# We want to avoid having the same env multiple times
# so we try to deepcopy it if needed. If we can't, we make
# the user aware that this isn't a very good idea
wr = weakref.ref(env)
if wr in weakref_set:
try:
env = deepcopy(env)
except Exception:
warn(
"Deepcopying the env failed within SerialEnv "
"but more than one copy of the same env was found. "
"This is a dangerous situation if your env keeps track "
"of some variables (e.g., state) in-place. "
"We'll use the same copy of the environment be beaware that "
"this may have important, unwanted issues for stateful "
"environments!"
)
weakref_set.add(wr)
self._envs.append(env)
self.is_closed = False
@_check_start
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
for idx, env in enumerate(self._envs):
state_dict[f"worker{idx}"] = env.state_dict()
return state_dict
@_check_start
def load_state_dict(self, state_dict: OrderedDict) -> None:
if "worker0" not in state_dict:
state_dict = OrderedDict(
**{f"worker{idx}": state_dict for idx in range(self.num_workers)}
)
for idx, env in enumerate(self._envs):
env.load_state_dict(state_dict[f"worker{idx}"])
def _shutdown_workers(self) -> None:
if not self.is_closed:
for env in self._envs:
env.close()
del self._envs
@_check_start
def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
for env in self._envs:
new_seed = env.set_seed(seed, static_seed=static_seed)
seed = new_seed
return seed
@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
if kwargs is not list_of_kwargs[0] and kwargs:
# this means that kwargs had more than one element and that a list was provided
for elt in list_of_kwargs:
elt.update(kwargs)
if tensordict is not None:
needs_resetting = _aggregate_end_of_traj(
tensordict, reset_keys=self.reset_keys
)
if needs_resetting.ndim > 2:
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
if needs_resetting.ndim > 1:
needs_resetting = needs_resetting.any(-1)
elif not needs_resetting.ndim:
needs_resetting = needs_resetting.expand((self.num_workers,))
tensordict = tensordict.unbind(0)
else:
needs_resetting = torch.ones(
(self.num_workers,), device=self.device, dtype=torch.bool
)
out_tds = None
if not self._use_buffers or self._non_tensor_keys:
out_tds = [None] * self.num_workers
tds = []
for i, _env in enumerate(self._envs):
if not needs_resetting[i]:
if out_tds is not None and tensordict is not None:
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
continue
if tensordict is not None:
tensordict_ = tensordict[i]
if tensordict_.is_empty():
tensordict_ = None
else: