forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
326 lines (285 loc) · 11.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import typing
from typing import Any, Callable, List, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torchrl.data.tensor_specs import (
Binary,
Categorical,
Composite,
MultiCategorical,
MultiOneHot,
OneHot,
Stacked,
StackedComposite,
TensorSpec,
)
numpy_to_torch_dtype_dict = {
np.dtype("bool"): torch.bool,
np.dtype("uint8"): torch.uint8,
np.dtype("int8"): torch.int8,
np.dtype("int16"): torch.int16,
np.dtype("int32"): torch.int32,
np.dtype("int64"): torch.int64,
np.dtype("float16"): torch.float16,
np.dtype("float32"): torch.float32,
np.dtype("float64"): torch.float64,
np.dtype("complex64"): torch.complex64,
np.dtype("complex128"): torch.complex128,
}
torch_to_numpy_dtype_dict = {
value: key for key, value in numpy_to_torch_dtype_dict.items()
}
DEVICE_TYPING = Union[torch.device, str, int]
if hasattr(typing, "get_args"):
DEVICE_TYPING_ARGS = typing.get_args(DEVICE_TYPING)
else:
DEVICE_TYPING_ARGS = (torch.device, str, int)
INDEX_TYPING = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
ACTION_SPACE_MAP = {
OneHot: "one_hot",
MultiOneHot: "mult_one_hot",
Binary: "binary",
Categorical: "categorical",
"one_hot": "one_hot",
"one-hot": "one_hot",
"mult_one_hot": "mult_one_hot",
"mult-one-hot": "mult_one_hot",
"multi_one_hot": "mult_one_hot",
"multi-one-hot": "mult_one_hot",
"binary": "binary",
"categorical": "categorical",
MultiCategorical: "multi_categorical",
"multi_categorical": "multi_categorical",
"multi-categorical": "multi_categorical",
"multi_discrete": "multi_categorical",
"multi-discrete": "multi_categorical",
}
def consolidate_spec(
spec: Composite,
recurse_through_entries: bool = True,
recurse_through_stack: bool = True,
):
"""Given a TensorSpec, removes exclusive keys by adding 0 shaped specs.
Args:
spec (Composite): the spec to be consolidated.
recurse_through_entries (bool): if True, call the function recursively on all entries of the spec.
Default is True.
recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively
on all specs in its list. Default is True.
"""
spec = spec.clone()
if not isinstance(spec, (Composite, StackedComposite)):
return spec
if isinstance(spec, StackedComposite):
keys = set(spec.keys()) # shared keys
exclusive_keys_per_spec = [
set() for _ in range(len(spec._specs))
] # list of exclusive keys per td
exclusive_keys_examples = (
{}
) # map of all exclusive keys to a list of their values
for spec_index in range(len(spec._specs)): # gather all exclusive keys
sub_spec = spec._specs[spec_index]
if recurse_through_stack:
sub_spec = consolidate_spec(
sub_spec, recurse_through_entries, recurse_through_stack
)
spec._specs[spec_index] = sub_spec
for sub_spec_key in sub_spec.keys():
if sub_spec_key not in keys: # exclusive key
exclusive_keys_per_spec[spec_index].add(sub_spec_key)
value = sub_spec[sub_spec_key]
if sub_spec_key in exclusive_keys_examples:
exclusive_keys_examples[sub_spec_key].append(value)
else:
exclusive_keys_examples.update({sub_spec_key: [value]})
for sub_spec, exclusive_keys in zip(
spec._specs, exclusive_keys_per_spec
): # add missing exclusive entries
for exclusive_key in set(exclusive_keys_examples.keys()).difference(
exclusive_keys
):
exclusive_keys_example_list = exclusive_keys_examples[exclusive_key]
sub_spec.set(
exclusive_key,
_empty_like_spec(exclusive_keys_example_list, sub_spec.shape),
)
if recurse_through_entries:
for key, value in spec.items():
if isinstance(value, (Composite, StackedComposite)):
spec.set(
key,
consolidate_spec(
value, recurse_through_entries, recurse_through_stack
),
)
return spec
def _empty_like_spec(specs: List[TensorSpec], shape):
for spec in specs[1:]:
if spec.__class__ != specs[0].__class__:
raise ValueError(
"Found same key in lazy specs corresponding to entries with different classes"
)
spec = specs[0]
if isinstance(spec, (Composite, StackedComposite)):
# the exclusive key has values which are CompositeSpecs ->
# we create an empty composite spec with same batch size
return spec.empty()
elif isinstance(spec, Stacked):
# the exclusive key has values which are LazyStackedTensorSpecs ->
# we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list.
# this will not add any new -1s when they are stacked
shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :])
return Stacked(
*[_empty_like_spec(spec._specs, shape) for _ in spec._specs],
dim=spec.stack_dim,
)
else:
# the exclusive key has values which are TensorSpecs ->
# if the shapes of the values are all the same, we create a TensorSpec with leading shape `shape` and following dims 0 (having the same ndims as the values)
# if the shapes of the values differ, we create a TensorSpec with 0 size in the differing dims
spec_shape = list(spec.shape)
for dim_index in range(len(spec_shape)):
hetero_dim = False
for sub_spec in specs:
if sub_spec.shape[dim_index] != spec.shape[dim_index]:
hetero_dim = True
break
if hetero_dim:
spec_shape[dim_index] = 0
if 0 not in spec_shape: # the values have all same shape
spec_shape = [
dim if i < len(shape) else 0 for i, dim in enumerate(spec_shape)
]
spec = spec[(0,) * len(spec.shape)]
spec = spec.expand(spec_shape)
return spec
def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True):
"""Given a TensorSpec, returns true if there are no exclusive keys.
Args:
spec (TensorSpec): the spec to check
recurse (bool): if True, check recursively in nested specs. Default is True.
"""
if isinstance(spec, StackedComposite):
keys = set(spec.keys())
for inner_td in spec._specs:
if recurse and not check_no_exclusive_keys(inner_td):
return False
if set(inner_td.keys()) != keys:
return False
elif isinstance(spec, Composite) and recurse:
for value in spec.values():
if not check_no_exclusive_keys(value):
return False
else:
return True
return True
def contains_lazy_spec(spec: TensorSpec) -> bool:
"""Returns true if a spec contains lazy stacked specs.
Args:
spec (TensorSpec): the spec to check
"""
if isinstance(spec, (Stacked, StackedComposite)):
return True
elif isinstance(spec, Composite):
for inner_spec in spec.values():
if contains_lazy_spec(inner_spec):
return True
return False
class CloudpickleWrapper(object):
"""A wrapper for functions that allow for serialization in multiprocessed settings."""
def __init__(self, fn: Callable, **kwargs):
if fn.__class__.__name__ == "EnvCreator":
raise RuntimeError(
"CloudpickleWrapper usage with EnvCreator class is "
"prohibited as it breaks the transmission of shared tensors."
)
self.fn = fn
self.kwargs = kwargs
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps((self.fn, self.kwargs))
def __setstate__(self, ob: bytes):
import pickle
self.fn, self.kwargs = pickle.loads(ob)
def __call__(self, *args, **kwargs) -> Any:
kwargs.update(self.kwargs)
return self.fn(*args, **kwargs)
def _process_action_space_spec(action_space, spec):
original_spec = spec
composite_spec = False
if isinstance(spec, Composite):
# this will break whenever our action is more complex than a single tensor
try:
if "action" in spec.keys():
_key = "action"
else:
# the first key is the action
for _key in spec.keys(True, True):
if isinstance(_key, tuple) and _key[-1] == "action":
break
else:
raise KeyError
spec = spec[_key]
composite_spec = True
except KeyError:
raise KeyError(
"action could not be found in the spec. Make sure "
"you pass a spec that is either a native action spec or a composite action spec "
"with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only."
)
if action_space is not None:
if isinstance(action_space, Composite):
raise ValueError("action_space cannot be of type Composite.")
if (
spec is not None
and isinstance(action_space, TensorSpec)
and action_space is not spec
):
raise ValueError(
"Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match."
)
if isinstance(action_space, TensorSpec):
spec = action_space
action_space = _find_action_space(action_space)
# check that the spec and action_space match
if spec is not None and _find_action_space(spec) != action_space:
raise ValueError(
f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}."
)
elif spec is not None:
action_space = _find_action_space(spec)
else:
raise ValueError(
"Neither action_space nor spec was defined. The action space cannot be inferred."
)
if composite_spec:
spec = original_spec
return action_space, spec
def _find_action_space(action_space):
if isinstance(action_space, TensorSpec):
if isinstance(action_space, Composite):
if "action" in action_space.keys():
_key = "action"
else:
# the first key is the action
for _key in action_space.keys(True, True):
if isinstance(_key, tuple) and _key[-1] == "action":
break
else:
raise KeyError
action_space = action_space[_key]
action_space = type(action_space)
try:
action_space = ACTION_SPACE_MAP[action_space]
except KeyError:
raise ValueError(
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
)
return action_space