forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdreamer.py
90 lines (75 loc) · 3.31 KB
/
dreamer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.data.tensor_specs import Composite
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based import ModelBasedEnvBase
from torchrl.envs.transforms.transforms import Transform
class DreamerEnv(ModelBasedEnvBase):
"""Dreamer simulation environment."""
def __init__(
self,
world_model: TensorDictModule,
prior_shape: Tuple[int, ...],
belief_shape: Tuple[int, ...],
obs_decoder: TensorDictModule = None,
device: DEVICE_TYPING = "cpu",
batch_size: Optional[torch.Size] = None,
):
super(DreamerEnv, self).__init__(
world_model, device=device, batch_size=batch_size
)
self.obs_decoder = obs_decoder
self.prior_shape = prior_shape
self.belief_shape = belief_shape
def set_specs_from_env(self, env: EnvBase):
"""Sets the specs of the environment from the specs of the given environment."""
super().set_specs_from_env(env)
self.action_spec = self.action_spec.to(self.device)
self.state_spec = Composite(
state=self.observation_spec["state"],
belief=self.observation_spec["belief"],
shape=env.batch_size,
)
def _reset(self, tensordict=None, **kwargs) -> TensorDict:
batch_size = tensordict.batch_size if tensordict is not None else []
device = tensordict.device if tensordict is not None else self.device
if tensordict is None:
td = self.state_spec.rand(shape=batch_size)
# why don't we reuse actions taken at those steps?
td.set("action", self.action_spec.rand(shape=batch_size))
td[("next", "reward")] = self.reward_spec.rand(shape=batch_size)
td.update(self.observation_spec.rand(shape=batch_size))
if device is not None:
td = td.to(device, non_blocking=True)
if torch.cuda.is_available() and device.type == "cpu":
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
torch.mps.synchronize()
else:
td = tensordict.clone()
return td
def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict:
if self.obs_decoder is None:
raise ValueError("No observation decoder provided")
if compute_latents:
tensordict = self.world_model(tensordict)
return self.obs_decoder(tensordict)
class DreamerDecoder(Transform):
"""A transform to record the decoded observations in Dreamer.
Examples:
>>> model_based_env = DreamerEnv(...)
>>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
"""
def _call(self, tensordict):
return self.parent.base_env.obs_decoder(tensordict)
def _reset(self, tensordict, tensordict_reset):
return self._call(tensordict_reset)
def transform_observation_spec(self, observation_spec):
return observation_spec