forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
70 lines (54 loc) · 1.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
def check_finite(tensor: torch.Tensor):
"""Raise an error if a tensor has non-finite elements."""
if not tensor.isfinite().all():
raise ValueError("Encountered a non-finite tensor.")
def _init_first(fun):
def new_fun(self, *args, **kwargs):
if not self.initialized:
self._init()
return fun(self, *args, **kwargs)
return new_fun
class _set_missing_tolerance:
"""Context manager to change the transform tolerance to missing values."""
def __init__(self, transform, mode):
self.transform = transform
self.mode = mode
def __enter__(self):
self.exit_mode = self.transform.missing_tolerance
if self.mode != self.exit_mode:
self.transform.set_missing_tolerance(self.mode)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.mode != self.exit_mode:
self.transform.set_missing_tolerance(self.exit_mode)
def _get_reset(reset_key, tensordict):
_reset = tensordict.get(reset_key, None)
# reset key must be unraveled already
parent_td = (
tensordict.get(reset_key[:-1], None)
if isinstance(reset_key, tuple)
else tensordict
)
if parent_td is None:
# we do this just in case the nested td wasn't found
parent_td = tensordict
if _reset is None:
_reset = torch.ones(
(),
dtype=torch.bool,
device=parent_td.device,
).expand(parent_td.batch_size)
if _reset.ndim > parent_td.ndim:
_reset = _reset.flatten(parent_td.ndim, -1).any(-1)
return _reset
def _stateless_param(param):
is_param = isinstance(param, nn.Parameter)
param = param.data.to("meta")
if is_param:
return nn.Parameter(param, requires_grad=False)
return param