-
Notifications
You must be signed in to change notification settings - Fork 326
/
advantages.py
1832 lines (1643 loc) · 76.4 KB
/
advantages.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import abc
import functools
import warnings
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import wraps
from typing import Callable, List, Union
import torch
from tensordict import TensorDictBase
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictModule,
set_skip_existing,
TensorDictModule,
TensorDictModuleBase,
)
from tensordict.nn.probabilistic import interaction_type
from tensordict.utils import NestedKey
from torch import Tensor
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp
from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td0_return_estimate,
td_lambda_return_estimate,
vec_generalized_advantage_estimate,
vec_td1_return_estimate,
vec_td_lambda_return_estimate,
vtrace_advantage_estimate,
)
try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling
try:
from torch import vmap
except ImportError as err:
try:
from functorch import vmap
except ImportError:
raise ImportError(
"vmap couldn't be found. Make sure you have torch>2.0 installed."
) from err
def _self_set_grad_enabled(fun):
@wraps(fun)
def new_fun(self, *args, **kwargs):
with torch.set_grad_enabled(self.differentiable):
return fun(self, *args, **kwargs)
return new_fun
def _self_set_skip_existing(fun):
@functools.wraps(fun)
def new_func(self, *args, **kwargs):
if self.skip_existing is not None:
with set_skip_existing(self.skip_existing):
return fun(self, *args, **kwargs)
return fun(self, *args, **kwargs)
return new_func
def _call_actor_net(
actor_net: ProbabilisticTensorDictModule,
data: TensorDictBase,
params: TensorDictBase,
log_prob_key: NestedKey,
):
dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
if isinstance(dist, CompositeDistribution):
kwargs = {
"aggregate_probabilities": True,
"inplace": False,
"include_sum": False,
}
else:
kwargs = {}
s = actor_net._dist_sample(dist, interaction_type=interaction_type())
return dist.log_prob(s, **kwargs)
class ValueEstimatorBase(TensorDictModuleBase):
"""An abstract parent class for value function modules.
Its :meth:`ValueFunctionBase.forward` method will compute the value (given
by the value network) and the value estimate (given by the value estimator)
as well as the advantage and write these values in the output tensordict.
If only the value estimate is needed, the :meth:`ValueFunctionBase.value_estimate`
should be used instead.
"""
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is written to.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is written to.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
reward (NestedKey): The input tensordict key where the reward is written to.
Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Defaults to ``"terminated"``.
steps_to_next_obs (NestedKey): The key in the input tensordict
that indicates the number of steps to the next observation.
Defaults to ``"steps_to_next_obs"``.
sample_log_prob (NestedKey): The key in the input tensordict that
indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``.
"""
advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
steps_to_next_obs: NestedKey = "steps_to_next_obs"
sample_log_prob: NestedKey = "sample_log_prob"
default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]
_vmap_randomness = None
@property
def advantage_key(self):
return self.tensor_keys.advantage
@property
def value_key(self):
return self.tensor_keys.value
@property
def value_target_key(self):
return self.tensor_keys.value_target
@property
def reward_key(self):
return self.tensor_keys.reward
@property
def done_key(self):
return self.tensor_keys.done
@property
def terminated_key(self):
return self.tensor_keys.terminated
@property
def steps_to_next_obs_key(self):
return self.tensor_keys.steps_to_next_obs
@property
def sample_log_prob_key(self):
return self.tensor_keys.sample_log_prob
@abc.abstractmethod
def forward(
self,
tensordict: TensorDictBase,
*,
params: TensorDictBase | None = None,
target_params: TensorDictBase | None = None,
) -> TensorDictBase:
"""Computes the advantage estimate given the data in tensordict.
If a functional module is provided, a nested TensorDict containing the parameters
(and if relevant the target parameters) can be passed to the module.
Args:
tensordict (TensorDictBase): A TensorDict containing the data
(an observation key, ``"action"``, ``("next", "reward")``,
``("next", "done")``, ``("next", "terminated")``,
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the TDEstimate.
The data passed to this module should be structured as
:obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the
feature dimension(s). The tensordict must have shape ``[*B, T]``.
Keyword Args:
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
target params to be passed to the functional value network module.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
Returns:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
"""
...
def __init__(
self,
*,
value_network: TensorDictModule,
shifted: bool = False,
differentiable: bool = False,
skip_existing: bool | None = None,
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
device: torch.device | None = None,
):
super().__init__()
if device is None:
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
# this is saved for tracking only and should not be used to cast anything else than buffers during
# init.
self._device = device
self._tensor_keys = None
self.differentiable = differentiable
self.skip_existing = skip_existing
self.__dict__["value_network"] = value_network
self.dep_keys = {}
self.shifted = shifted
if advantage_key is not None:
raise RuntimeError(
"Setting 'advantage_key' via constructor is deprecated, use .set_keys(advantage_key='some_key') instead.",
)
if value_target_key is not None:
raise RuntimeError(
"Setting 'value_target_key' via constructor is deprecated, use .set_keys(value_target_key='some_key') instead.",
)
if value_key is not None:
raise RuntimeError(
"Setting 'value_key' via constructor is deprecated, use .set_keys(value_key='some_key') instead.",
)
@property
def tensor_keys(self) -> _AcceptedKeys:
if self._tensor_keys is None:
self.set_keys()
return self._tensor_keys
@tensor_keys.setter
def tensor_keys(self, value):
if not isinstance(value, type(self._AcceptedKeys)):
raise ValueError("value must be an instance of _AcceptedKeys")
self._keys = value
@property
def in_keys(self):
try:
in_keys = (
self.value_network.in_keys
+ [
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
]
+ [("next", in_key) for in_key in self.value_network.in_keys]
)
except AttributeError:
# value network does not have an `in_keys` attribute
in_keys = []
pass
return in_keys
@property
def out_keys(self):
return [
self.tensor_keys.advantage,
self.tensor_keys.value_target,
]
def set_keys(self, **kwargs) -> None:
"""Set tensordict key names."""
for key, value in kwargs.items():
if not isinstance(value, (str, tuple)):
raise ValueError(
f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}"
)
if value is None:
raise ValueError("tensordict keys cannot be None")
if key not in self._AcceptedKeys.__dict__:
raise KeyError(
f"{key} is not an accepted tensordict key for advantages"
)
if (
key == "value"
and hasattr(self.value_network, "out_keys")
and (value not in self.value_network.out_keys)
):
raise KeyError(
f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}"
)
if self._tensor_keys is None:
conf = asdict(self.default_keys)
conf.update(self.dep_keys)
else:
conf = asdict(self._tensor_keys)
conf.update(kwargs)
self._tensor_keys = self._AcceptedKeys(**conf)
def value_estimate(
self,
tensordict,
target_params: TensorDictBase | None = None,
next_value: torch.Tensor | None = None,
**kwargs,
):
"""Gets a value estimate, usually used as a target value for the value network.
If the state value key is present under ``tensordict.get(("next", self.tensor_keys.value))``
then this value will be used without recurring to the value network.
Args:
tensordict (TensorDictBase): the tensordict containing the data to
read.
target_params (TensorDictBase, optional): A nested TensorDict containing the
target params to be passed to the functional value network module.
next_value (torch.Tensor, optional): the value of the next state
or state-action pair. Exclusive with ``target_params``.
**kwargs: the keyword arguments to be passed to the value network.
Returns: a tensor corresponding to the state value.
"""
raise NotImplementedError
@property
def is_functional(self):
# legacy
return False
@property
def is_stateless(self):
# legacy
return False
def _next_value(self, tensordict, target_params, kwargs):
step_td = step_mdp(tensordict, keep_other=False)
if self.value_network is not None:
with hold_out_net(
self.value_network
) if target_params is None else target_params.to_module(self.value_network):
self.value_network(step_td)
next_value = step_td.get(self.tensor_keys.value)
return next_value
@property
def vmap_randomness(self):
if self._vmap_randomness is None:
if is_dynamo_compiling():
self._vmap_randomness = "different"
return "different"
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"
return self._vmap_randomness
def set_vmap_randomness(self, value):
self._vmap_randomness = value
def _get_time_dim(self, time_dim: int | None, data: TensorDictBase):
if time_dim is not None:
if time_dim < 0:
time_dim = data.ndim + time_dim
return time_dim
time_dim_attr = getattr(self, "time_dim", None)
if time_dim_attr is not None:
if time_dim_attr < 0:
time_dim_attr = data.ndim + time_dim_attr
return time_dim_attr
if data._has_names():
for i, name in enumerate(data.names):
if name == "time":
return i
return data.ndim - 1
def _call_value_nets(
self,
data: TensorDictBase,
params: TensorDictBase,
next_params: TensorDictBase,
single_call: bool,
value_key: NestedKey,
detach_next: bool,
vmap_randomness: str = "error",
*,
value_net: TensorDictModuleBase | None = None,
):
if value_net is None:
value_net = self.value_network
in_keys = value_net.in_keys
if single_call:
for i, name in enumerate(data.names):
if name == "time":
ndim = i + 1
break
else:
ndim = None
if ndim is not None:
# get data at t and last of t+1
idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),)
idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
data_in = torch.cat(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False)[
idx0
],
],
ndim - 1,
)
else:
if RL_WARNINGS:
warnings.warn(
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
)
ndim = data.ndim
idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),)
idx_ = (slice(None),) * (ndim - 1) + (
slice(data.shape[ndim - 1], None),
)
data_in = torch.cat(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False),
],
ndim - 1,
)
# next_params should be None or be identical to params
if next_params is not None and next_params is not params:
raise ValueError(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if params is not None:
with params.to_module(value_net):
value_est = value_net(data_in).get(value_key)
else:
value_est = value_net(data_in).get(value_key)
value, value_ = value_est[idx], value_est[idx_]
else:
data_in = torch.stack(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False),
],
0,
)
if (params is not None) ^ (next_params is not None):
raise ValueError(
"params and next_params must be either both provided or not."
)
elif params is not None:
params_stack = torch.stack([params, next_params], 0).contiguous()
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
data_in, params_stack
)
else:
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
value_est = data_out.get(value_key)
value, value_ = value_est[0], value_est[1]
data.set(value_key, value)
data.set(("next", value_key), value_)
if detach_next:
value_ = value_.detach()
return value, value_
class TD0Estimator(ValueEstimatorBase):
"""Temporal Difference (TD(0)) estimate of advantage function.
AKA bootstrapped temporal difference or 1-step return.
Keyword Args:
gamma (scalar): exponential mean discount.
value_network (TensorDictModule): value operator used to retrieve
the value estimates.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
average_rewards (bool, optional): if ``True``, rewards will be standardized
before the TD is computed.
differentiable (bool, optional): if ``True``, gradients are propagated through
the computation of the value function. Default is ``False``.
.. note::
The proper way to make the function call non-differentiable is to
decorate it in a `torch.no_grad()` context manager/decorator or
pass detached parameters for functional modules.
skip_existing (bool, optional): if ``True``, the value network will skip
modules which outputs are already present in the tensordict.
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
is not affected.
advantage_key (str or tuple of str, optional): [Deprecated] the key of
the advantage entry. Defaults to ``"advantage"``.
value_target_key (str or tuple of str, optional): [Deprecated] the key
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
"""
def __init__(
self,
*,
gamma: float | torch.Tensor,
value_network: TensorDictModule,
shifted: bool = False,
average_rewards: bool = False,
differentiable: bool = False,
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
skip_existing: bool | None = None,
device: torch.device | None = None,
):
super().__init__(
value_network=value_network,
differentiable=differentiable,
shifted=shifted,
advantage_key=advantage_key,
value_target_key=value_target_key,
value_key=value_key,
skip_existing=skip_existing,
device=device,
)
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
self.average_rewards = average_rewards
@_self_set_skip_existing
@_self_set_grad_enabled
@dispatch
def forward(
self,
tensordict: TensorDictBase,
*,
params: TensorDictBase | None = None,
target_params: TensorDictBase | None = None,
) -> TensorDictBase:
"""Computes the TD(0) advantage given the data in tensordict.
If a functional module is provided, a nested TensorDict containing the parameters
(and if relevant the target parameters) can be passed to the module.
Args:
tensordict (TensorDictBase): A TensorDict containing the data
(an observation key, ``"action"``, ``("next", "reward")``,
``("next", "done")``, ``("next", "terminated")``, and ``"next"``
tensordict state as returned by the environment) necessary to
compute the value estimates and the TDEstimate.
The data passed to this module should be structured as
:obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the
feature dimension(s). The tensordict must have shape ``[*B, T]``.
Keyword Args:
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
target params to be passed to the functional value network module.
Returns:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
Examples:
>>> from tensordict import TensorDict
>>> value_net = TensorDictModule(
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
... gamma=0.98,
... value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
Examples:
>>> value_net = TensorDictModule(
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
... gamma=0.98,
... value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
raise RuntimeError(
"Expected input tensordict to have at least one dimensions, got"
f"tensordict.batch_size = {tensordict.batch_size}"
)
if self.is_stateless and params is None:
raise RuntimeError(
"Expected params to be passed to advantage module but got none."
)
if self.value_network is not None:
if params is not None:
params = params.detach()
if target_params is None:
target_params = params.clone(False)
with hold_out_net(self.value_network) if (
params is None and target_params is None
) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = self._call_value_nets(
data=tensordict,
params=params,
next_params=target_params,
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
next_value = tensordict.get(("next", self.tensor_keys.value))
value_target = self.value_estimate(tensordict, next_value=next_value)
tensordict.set(self.tensor_keys.advantage, value_target - value)
tensordict.set(self.tensor_keys.value_target, value_target)
return tensordict
def value_estimate(
self,
tensordict,
target_params: TensorDictBase | None = None,
next_value: torch.Tensor | None = None,
**kwargs,
):
reward = tensordict.get(("next", self.tensor_keys.reward))
device = reward.device
if self.gamma.device != device:
self.gamma = self.gamma.to(device)
gamma = self.gamma
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
if steps_to_next_obs is not None:
gamma = gamma ** steps_to_next_obs.view_as(reward)
if self.average_rewards:
reward = reward - reward.mean()
reward = reward / reward.std().clamp_min(1e-5)
tensordict.set(
("next", self.tensor_keys.reward), reward
) # we must update the rewards if they are used later in the code
if next_value is None:
next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
done = tensordict.get(("next", self.tensor_keys.done))
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
value_target = td0_return_estimate(
gamma=gamma,
next_state_value=next_value,
reward=reward,
done=done,
terminated=terminated,
)
return value_target
class TD1Estimator(ValueEstimatorBase):
r""":math:`\infty`-Temporal Difference (TD(1)) estimate of advantage function.
Keyword Args:
gamma (scalar): exponential mean discount.
value_network (TensorDictModule): value operator used to retrieve the value estimates.
average_rewards (bool, optional): if ``True``, rewards will be standardized
before the TD is computed.
differentiable (bool, optional): if ``True``, gradients are propagated through
the computation of the value function. Default is ``False``.
.. note::
The proper way to make the function call non-differentiable is to
decorate it in a `torch.no_grad()` context manager/decorator or
pass detached parameters for functional modules.
skip_existing (bool, optional): if ``True``, the value network will skip
modules which outputs are already present in the tensordict.
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
is not affected.
advantage_key (str or tuple of str, optional): [Deprecated] the key of
the advantage entry. Defaults to ``"advantage"``.
value_target_key (str or tuple of str, optional): [Deprecated] the key
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
markes with the ``"time"`` name if any, and to the last dimension
otherwise. Can be overridden during a call to
:meth:`~.value_estimate`.
Negative dimensions are considered with respect to the input
tensordict.
"""
def __init__(
self,
*,
gamma: float | torch.Tensor,
value_network: TensorDictModule,
average_rewards: bool = False,
differentiable: bool = False,
skip_existing: bool | None = None,
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
device: torch.device | None = None,
time_dim: int | None = None,
):
super().__init__(
value_network=value_network,
differentiable=differentiable,
advantage_key=advantage_key,
value_target_key=value_target_key,
value_key=value_key,
shifted=shifted,
skip_existing=skip_existing,
device=device,
)
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
self.average_rewards = average_rewards
self.time_dim = time_dim
@_self_set_skip_existing
@_self_set_grad_enabled
@dispatch
def forward(
self,
tensordict: TensorDictBase,
*,
params: TensorDictBase | None = None,
target_params: TensorDictBase | None = None,
) -> TensorDictBase:
"""Computes the TD(1) advantage given the data in tensordict.
If a functional module is provided, a nested TensorDict containing the parameters
(and if relevant the target parameters) can be passed to the module.
Args:
tensordict (TensorDictBase): A TensorDict containing the data
(an observation key, ``"action"``, ``("next", "reward")``,
``("next", "done")``, ``("next", "terminated")``,
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the TDEstimate.
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
The tensordict must have shape ``[*B, T]``.
Keyword Args:
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
target params to be passed to the functional value network module.
Returns:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
Examples:
>>> from tensordict import TensorDict
>>> value_net = TensorDictModule(
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
... gamma=0.98,
... value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
Examples:
>>> value_net = TensorDictModule(
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = TDEstimate(
... gamma=0.98,
... value_network=value_net,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
"""
if tensordict.batch_dims < 1:
raise RuntimeError(
"Expected input tensordict to have at least one dimensions, got"
f"tensordict.batch_size = {tensordict.batch_size}"
)
if self.is_stateless and params is None:
raise RuntimeError(
"Expected params to be passed to advantage module but got none."
)
if self.value_network is not None:
if params is not None:
params = params.detach()
if target_params is None:
target_params = params.clone(False)
with hold_out_net(self.value_network) if (
params is None and target_params is None
) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = self._call_value_nets(
data=tensordict,
params=params,
next_params=target_params,
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
next_value = tensordict.get(("next", self.tensor_keys.value))
value_target = self.value_estimate(tensordict, next_value=next_value)
tensordict.set(self.tensor_keys.advantage, value_target - value)
tensordict.set(self.tensor_keys.value_target, value_target)
return tensordict
def value_estimate(
self,
tensordict,
target_params: TensorDictBase | None = None,
next_value: torch.Tensor | None = None,
time_dim: int | None = None,
**kwargs,
):
reward = tensordict.get(("next", self.tensor_keys.reward))
device = reward.device
if self.gamma.device != device:
self.gamma = self.gamma.to(device)
gamma = self.gamma
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
if steps_to_next_obs is not None:
gamma = gamma ** steps_to_next_obs.view_as(reward)
if self.average_rewards:
reward = reward - reward.mean()
reward = reward / reward.std().clamp_min(1e-5)
tensordict.set(
("next", self.tensor_keys.reward), reward
) # we must update the rewards if they are used later in the code
if next_value is None:
next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
done = tensordict.get(("next", self.tensor_keys.done))
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
time_dim = self._get_time_dim(time_dim, tensordict)
value_target = vec_td1_return_estimate(
gamma,
next_value,
reward,
done=done,
terminated=terminated,
time_dim=time_dim,
)
return value_target
class TDLambdaEstimator(ValueEstimatorBase):
r"""TD(:math:`\lambda`) estimate of advantage function.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
value_network (TensorDictModule): value operator used to retrieve the value estimates.
average_rewards (bool, optional): if ``True``, rewards will be standardized
before the TD is computed.
differentiable (bool, optional): if ``True``, gradients are propagated through
the computation of the value function. Default is ``False``.
.. note::
The proper way to make the function call non-differentiable is to
decorate it in a `torch.no_grad()` context manager/decorator or
pass detached parameters for functional modules.
vectorized (bool, optional): whether to use the vectorized version of the
lambda return. Default is `True`.
skip_existing (bool, optional): if ``True``, the value network will skip
modules which outputs are already present in the tensordict.
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
is not affected.
advantage_key (str or tuple of str, optional): [Deprecated] the key of
the advantage entry. Defaults to ``"advantage"``.
value_target_key (str or tuple of str, optional): [Deprecated] the key
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): the device where the buffers will be instantiated.
Defaults to ``torch.get_default_device()``.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
markes with the ``"time"`` name if any, and to the last dimension
otherwise. Can be overridden during a call to
:meth:`~.value_estimate`.
Negative dimensions are considered with respect to the input
tensordict.
"""
def __init__(
self,
*,
gamma: float | torch.Tensor,
lmbda: float | torch.Tensor,
value_network: TensorDictModule,
average_rewards: bool = False,
differentiable: bool = False,
vectorized: bool = True,
skip_existing: bool | None = None,
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
device: torch.device | None = None,