forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gym_like.py
593 lines (508 loc) · 25.1 KB
/
gym_like.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import abc
import re
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from tensordict import NonTensorData, TensorDict, TensorDictBase
from torchrl._utils import logger as torchrl_logger
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
from torchrl.envs.common import _EnvWrapper, EnvBase
class BaseInfoDictReader(metaclass=abc.ABCMeta):
"""Base class for info-readers."""
@abc.abstractmethod
def __call__(
self, info_dict: Dict[str, Any], tensordict: TensorDictBase
) -> TensorDictBase:
raise NotImplementedError
@property
@abc.abstractmethod
def info_spec(self) -> Dict[str, TensorSpec]:
raise NotImplementedError
class default_info_dict_reader(BaseInfoDictReader):
"""Default info-key reader.
Args:
keys (list of keys, optional): If provided, the list of keys to get from
the info dictionary. Defaults to all keys.
spec (List[TensorSpec], Dict[str, TensorSpec] or Composite, optional):
If a list of specs is provided, each spec will be matched to its
correspondent key to form a :class:`torchrl.data.Composite`.
If not provided, a composite spec with :class:`~torchrl.data.Unbounded`
specs will lazyly be created.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.
In cases where keys can be directly written to a tensordict (mostly if they abide to the
tensordict shape), one simply needs to indicate the keys to be registered during
instantiation.
Examples:
>>> from torchrl.envs.libs.gym import GymWrapper
>>> from torchrl.envs import default_info_dict_reader
>>> reader = default_info_dict_reader(["my_info_key"])
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
>>> env = GymWrapper(gym.make("some_env-v0"))
>>> env.set_info_dict_reader(info_dict_reader=reader)
>>> tensordict = env.reset()
>>> tensordict = env.rand_step(tensordict)
>>> assert "my_info_key" in tensordict.keys()
"""
def __init__(
self,
keys: List[str] | None = None,
spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None,
ignore_private: bool = True,
):
self.ignore_private = ignore_private
self._lazy = False
if keys is None:
self._lazy = True
self.keys = keys
if spec is None and keys is None:
_info_spec = None
elif spec is None:
_info_spec = Composite({key: Unbounded(()) for key in keys}, shape=[])
elif not isinstance(spec, Composite):
if self.keys is not None and len(spec) != len(self.keys):
raise ValueError(
"If specifying specs for info keys with a sequence, the "
"length of the sequence must match the number of keys"
)
if isinstance(spec, dict):
_info_spec = Composite(spec, shape=[])
else:
_info_spec = Composite(
{key: spec for key, spec in zip(keys, spec)}, shape=[]
)
else:
_info_spec = spec.clone()
self._info_spec = _info_spec
def __call__(
self, info_dict: Dict[str, Any], tensordict: TensorDictBase
) -> TensorDictBase:
if not isinstance(info_dict, (dict, TensorDictBase)) and len(self.keys):
warnings.warn(
f"Found an info_dict of type {type(info_dict)} "
f"but expected type or subtype `dict`."
)
keys = self.keys
if keys is None:
keys = info_dict.keys()
if self.ignore_private:
keys = [key for key in keys if not key.startswith("_")]
self.keys = keys
# create an info_spec only if there is none
info_spec = None if self.info_spec is not None else Composite()
for key in keys:
if key in info_dict:
val = info_dict[key]
if val.dtype == np.dtype("O"):
val = np.stack(val)
tensordict.set(key, val)
if info_spec is not None:
val = tensordict.get(key)
info_spec[key] = Unbounded(
val.shape, device=val.device, dtype=val.dtype
)
elif self.info_spec is not None:
if key in self.info_spec:
# Fill missing with 0s
tensordict.set(key, self.info_spec[key].zero())
else:
raise KeyError(f"The key {key} could not be found or inferred.")
# set the info spec if there wasn't any - this should occur only once in this class
if info_spec is not None:
if tensordict.device is not None:
info_spec = info_spec.to(tensordict.device)
self._info_spec = info_spec
return tensordict
def reset(self):
self.keys = None
self._info_spec = None
@property
def info_spec(self) -> Dict[str, TensorSpec]:
return self._info_spec
class GymLikeEnv(_EnvWrapper):
"""A gym-like env is an environment.
Its behavior is similar to gym environments in what common methods (specifically reset and step) are expected to do.
A :obj:`GymLikeEnv` has a :obj:`.step()` method with the following signature:
``env.step(action: np.ndarray) -> Tuple[Union[np.ndarray, dict], double, bool, *info]``
where the outputs are the observation, reward and done state respectively.
In this implementation, the info output is discarded (but specific keys can be read
by updating info_dict_reader, see :meth:`~.set_info_dict_reader` method).
By default, the first output is written at the "observation" key-value pair in the output tensordict, unless
the first output is a dictionary. In that case, each observation output will be put at the corresponding
:obj:`f"{key}"` location for each :obj:`f"{key}"` of the dictionary.
It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
"""
_info_dict_reader: List[BaseInfoDictReader]
@classmethod
def __new__(cls, *args, **kwargs):
self = super().__new__(cls, *args, _batch_locked=True, **kwargs)
self._info_dict_reader = []
return self
def read_action(self, action):
"""Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment.
Args:
action (Tensor or TensorDict): an action to be taken in the environment
Returns: an action in a format compatible with the contained environment.
"""
return self.action_spec.to_numpy(action, safe=False)
def read_done(
self,
terminated: bool | None = None,
truncated: bool | None = None,
done: bool | None = None,
) -> Tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]:
"""Done state reader.
In torchrl, a `"done"` signal means that a trajectory has reach its end,
either because it has been interrupted or because it is terminated.
Truncated means the episode has been interrupted early.
Terminated means the task is finished, the episode is completed.
Args:
terminated (np.ndarray, boolean or other format): completion state
obtained from the environment.
``"terminated"`` equates to ``"termination"`` in gymnasium:
the signal that the environment has reached the end of the
episode, any data coming after this should be considered as nonsensical.
Defaults to ``None``.
truncated (bool or None): early truncation signal.
Defaults to ``None``.
done (bool or None): end-of-trajectory signal.
This should be the fallback value of envs which do not specify
if the ``"done"`` entry points to a ``"terminated"`` or
``"truncated"``.
Defaults to ``None``.
Returns: a tuple with 4 boolean / tensor values,
- a terminated state,
- a truncated state,
- a done state,
- a boolean value indicating whether the frame_skip loop should be broken.
"""
if truncated is not None and done is None:
done = truncated | terminated
elif truncated is None and done is None:
done = terminated
do_break = done.any() if not isinstance(done, bool) else done
if isinstance(done, bool):
done = [done]
if terminated is not None:
terminated = [terminated]
if truncated is not None:
truncated = [truncated]
return (
torch.as_tensor(terminated),
torch.as_tensor(truncated),
torch.as_tensor(done),
do_break.any() if not isinstance(do_break, bool) else do_break,
)
def read_reward(self, reward):
"""Reads the reward and maps it to the reward space.
Args:
reward (torch.Tensor or TensorDict): reward to be mapped.
"""
if isinstance(reward, int) and reward == 0:
return self.reward_spec.zero()
reward = self.reward_spec.encode(reward, ignore_device=True)
if reward is None:
reward = torch.tensor(np.nan).expand(self.reward_spec.shape)
return reward
def read_obs(
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
) -> Dict[str, Any]:
"""Reads an observation from the environment and returns an observation compatible with the output TensorDict.
Args:
observations (observation under a format dictated by the inner env): observation to be read.
"""
if isinstance(observations, dict):
if "state" in observations and "observation" not in observations:
# we rename "state" in "observation" as "observation" is the conventional name
# for single observation in torchrl.
# naming it 'state' will result in envs that have a different name for the state vector
# when queried with and without pixels
observations["observation"] = observations.pop("state")
if not isinstance(observations, Mapping):
for key, spec in self.observation_spec.items(True, True):
observations_dict = {}
observations_dict[key] = spec.encode(observations, ignore_device=True)
# we don't check that there is only one spec because obs spec also
# contains the data spec of the info dict.
break
else:
raise RuntimeError("Could not find any element in observation_spec.")
observations = observations_dict
else:
for key, val in observations.items():
if isinstance(self.observation_spec[key], NonTensor):
observations[key] = NonTensorData(val)
else:
observations[key] = self.observation_spec[key].encode(
val, ignore_device=True
)
return observations
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
action = tensordict.get(self.action_key)
if self._convert_actions_to_numpy:
action = self.read_action(action)
reward = 0
for _ in range(self.wrapper_frame_skip):
(
obs,
_reward,
terminated,
truncated,
done,
info_dict,
) = self._output_transform(self._env.step(action))
if _reward is not None:
reward = reward + _reward
terminated, truncated, done, do_break = self.read_done(
terminated=terminated, truncated=truncated, done=done
)
if do_break:
break
reward = self.read_reward(reward)
obs_dict = self.read_obs(obs)
obs_dict[self.reward_key] = reward
# if truncated/terminated is not in the keys, we just don't pass it even if it
# is defined.
if terminated is None:
terminated = done
if truncated is not None:
obs_dict["truncated"] = truncated
obs_dict["done"] = done
obs_dict["terminated"] = terminated
validated = self.validated
if not validated:
tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size)
if validated is None:
# check if any value has to be recast to something else. If not, we can safely
# build the tensordict without running checks
self.validated = all(
val is tensordict_out.get(key)
for key, val in TensorDict(obs_dict, []).items(True, True)
)
else:
tensordict_out = TensorDict._new_unsafe(
obs_dict,
batch_size=tensordict.batch_size,
)
if self.device is not None:
tensordict_out = tensordict_out.to(self.device)
if self.info_dict_reader and (info_dict is not None):
if not isinstance(info_dict, dict):
warnings.warn(
f"Expected info to be a dictionary but got a {type(info_dict)} with values {str(info_dict)[:100]}."
)
else:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info_dict, tensordict_out)
if out is not None:
tensordict_out = out
return tensordict_out
@property
def validated(self):
return self.__dict__.get("_validated", None)
@validated.setter
def validated(self, value):
self.__dict__["_validated"] = value
def _reset(
self, tensordict: Optional[TensorDictBase] = None, **kwargs
) -> TensorDictBase:
obs, info = self._reset_output_transform(self._env.reset(**kwargs))
source = self.read_obs(obs)
tensordict_out = TensorDict._new_unsafe(
source=source,
batch_size=self.batch_size,
)
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info, tensordict_out)
if out is not None:
tensordict_out = out
elif info is None and self.info_dict_reader:
# populate the reset with the items we have not seen from info
for key, item in self.observation_spec.items(True, True):
if key not in tensordict_out.keys(True, True):
tensordict_out[key] = item.zero()
if self.device is not None:
tensordict_out = tensordict_out.to(self.device)
return tensordict_out
@abc.abstractmethod
def _output_transform(
self, step_outputs_tuple: Tuple
) -> Tuple[
Any,
float | np.ndarray,
bool | np.ndarray | None,
bool | np.ndarray | None,
bool | np.ndarray | None,
dict,
]:
"""A method to read the output of the env step.
Must return a tuple: (obs, reward, terminated, truncated, done, info).
If only one end-of-trajectory is passed, it is interpreted as ``"truncated"``.
An attempt to retrieve ``"truncated"`` from the info dict is also undertaken.
If 2 are passed (like in gymnasium), we interpret them as ``"terminated",
"truncated"`` (``"truncated"`` meaning that the trajectory has been
interrupted early), and ``"done"`` is the union of the two,
ie. the unspecified end-of-trajectory signal.
These three concepts have different usage:
- ``"terminated"`` indicated the final stage of a Markov Decision
Process. It means that one should not pay attention to the
upcoming observations (eg., in value functions) as they should be
regarded as not valid.
- ``"truncated"`` means that the environment has reached a stage where
we decided to stop the collection for some reason but the next
observation should not be discarded. If it were not for this
arbitrary decision, the collection could have proceeded further.
- ``"done"`` is either one or the other. It is to be interpreted as
"a reset should be called before the next step is undertaken".
"""
...
@abc.abstractmethod
def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple:
...
def set_info_dict_reader(
self,
info_dict_reader: BaseInfoDictReader | None = None,
ignore_private: bool = True,
) -> GymLikeEnv:
"""Sets an info_dict_reader function.
This function should take as input an
info_dict dictionary and the tensordict returned by the step function, and
write values in an ad-hoc manner from one to the other.
Args:
info_dict_reader (Callable[[Dict], TensorDict], optional): a callable
taking a input dictionary and output tensordict as arguments.
This function should modify the tensordict in-place. If none is
provided, :class:`~torchrl.envs.gym_like.default_info_dict_reader`
will be used.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.
Returns: the same environment with the dict_reader registered.
.. note::
Automatically registering an info_dict reader should be done via
:meth:`~.auto_register_info_dict`, which will ensure that the env
specs are properly constructed.
Examples:
>>> from torchrl.envs import default_info_dict_reader
>>> from torchrl.envs.libs.gym import GymWrapper
>>> reader = default_info_dict_reader(["my_info_key"])
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
>>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
>>> tensordict = env.reset()
>>> tensordict = env.rand_step(tensordict)
>>> assert "my_info_key" in tensordict.keys()
"""
if info_dict_reader is None:
info_dict_reader = default_info_dict_reader(ignore_private=ignore_private)
self.info_dict_reader.append(info_dict_reader)
if isinstance(info_dict_reader, BaseInfoDictReader):
# if we have a BaseInfoDictReader, we know what the specs will be
# In other cases (eg, RoboHive) we will need to figure it out empirically.
if (
isinstance(info_dict_reader, default_info_dict_reader)
and info_dict_reader.info_spec is None
):
torchrl_logger.info(
"The info_dict_reader does not have specs. The only way to palliate to this issue automatically "
"is to run a dummy rollout and gather the specs automatically. "
"To silence this message, provide the specs directly to your spec reader."
)
# Gym does not guarantee that reset passes all info
self.reset()
info_dict_reader.reset()
self.rand_step()
self.reset()
self.observation_spec.update(info_dict_reader.info_spec)
return self
def auto_register_info_dict(
self,
ignore_private: bool = True,
*,
info_dict_reader: BaseInfoDictReader = None,
) -> EnvBase:
"""Automatically registers the info dict and appends :class:`~torch.envs.transforms.TensorDictPrimer` instances if needed.
If no info_dict_reader is provided, it is assumed that all the information contained in the info dict can
be registered as numerical values within the tensordict.
This method returns a (possibly transformed) environment where we make sure that
the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether
the info is filled at reset time.
.. note:: This method requires running a few iterations in the environment to
manually check that the behavior matches expectations.
Args:
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.
Keyword Args:
info_dict_reader (BaseInfoDictReader, optional): the info_dict_reader, if it is known in advance.
Unlike :meth:`~.set_info_dict_reader`, this method will create the primers necessary to get
:func:`~torchrl.envs.utils.check_env_specs` to run.
Examples:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("HalfCheetah-v4")
>>> # registers the info dict reader
>>> env.auto_register_info_dict()
GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu)
>>> env.rollout(3)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 6]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)
"""
from torchrl.envs import check_env_specs, TensorDictPrimer, TransformedEnv
if self.info_dict_reader:
raise RuntimeError("The environment already has an info-dict reader.")
self.set_info_dict_reader(
ignore_private=ignore_private, info_dict_reader=info_dict_reader
)
try:
check_env_specs(self)
return self
except (AssertionError, RuntimeError) as err:
patterns = [
"The keys of the specs and data do not match",
"The sets of keys in the tensordicts to stack are exclusive",
]
for pattern in patterns:
if re.search(pattern, str(err)):
result = TransformedEnv(
self, TensorDictPrimer(self.info_dict_reader[0].info_spec)
)
check_env_specs(result)
return result
raise err
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
)
@property
def info_dict_reader(self):
return self._info_dict_reader