forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
228 lines (176 loc) · 7.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Union
import pkg_resources
from torch.autograd.grad_mode import _DecoratorContextManager
from torchrl.data.tensordict.tensordict import TensorDictBase
AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set}
class classproperty(property):
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
def step_tensordict(
tensordict: TensorDictBase,
next_tensordict: TensorDictBase = None,
keep_other: bool = True,
exclude_reward: bool = True,
exclude_done: bool = True,
exclude_action: bool = True,
) -> TensorDictBase:
"""
Given a tensordict retrieved after a step, returns another tensordict with all the 'next_' prefixes are removed,
i.e. all the `'next_some_other_string'` keys will be renamed onto `'some_other_string'` keys.
Args:
tensordict (TensorDictBase): tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional): destination tensordict
keep_other (bool, optional): if True, all keys that do not start with `'next_'` will be kept.
Default is True.
exclude_reward (bool, optional): if True, the `"reward"` key will be discarded
from the resulting tensordict.
Default is True.
exclude_done (bool, optional): if True, the `"done"` key will be discarded
from the resulting tensordict.
Default is True.
exclude_action (bool, optional): if True, the `"action"` key will be discarded
from the resulting tensordict.
Default is True.
Returns:
A new tensordict (or next_tensordict) with the "next_*" keys renamed without the "next_" prefix.
Examples:
This funtion allows for this kind of loop to be used:
>>> td_out = []
>>> env = make_env()
>>> policy = make_policy()
>>> td = env.reset()
>>> for i in range(max_steps):
>>> td = env.step(td)
>>> next_td = step_tensordict(td)
>>> assert next_td is not td # make sure that keys are not overwritten
>>> td_out.append(td)
>>> td = next_td
>>> td_out = torch.stack(td_out, 0)
>>> print(td_out) # should contain keys 'observation', 'next_observation', 'action', 'reward', 'done' or similar
"""
other_keys = []
prohibited = set()
if exclude_done:
prohibited.add("done")
else:
other_keys.append("done")
if exclude_reward:
prohibited.add("reward")
else:
other_keys.append("reward")
if exclude_action:
prohibited.add("action")
else:
other_keys.append("action")
keys = [key for key in tensordict.keys() if key.startswith("next_")]
if len(keys) == 0:
raise RuntimeError(
"There was no key starting with 'next_' in the provided TensorDict: ",
tensordict,
)
new_keys = [key[5:] for key in keys]
prohibited = prohibited.union(keys).union(new_keys)
if keep_other:
other_keys = [key for key in tensordict.keys() if key not in prohibited]
select_tensordict = tensordict.select(*other_keys, *keys)
for new_key, key in zip(new_keys, keys):
select_tensordict.rename_key(key, new_key, safe=True)
if next_tensordict is not None:
return next_tensordict.update(select_tensordict)
else:
return select_tensordict
def get_available_libraries():
"""
Returns:
all the supported libraries
"""
return SUPPORTED_LIBRARIES
def _check_gym():
"""
Returns:
True if the gym library is installed
"""
return "gym" in AVAILABLE_LIBRARIES
def _check_gym_atari():
"""
Returns:
True if the gym library is installed and atari envs can be found.
"""
if not _check_gym():
return False
return "atari-py" in AVAILABLE_LIBRARIES
def _check_mario():
"""
Returns:
True if the "gym-super-mario-bros" library is installed.
"""
return "gym-super-mario-bros" in AVAILABLE_LIBRARIES
def _check_dmcontrol():
"""
Returns:
True if the "dm-control" library is installed.
"""
return "dm-control" in AVAILABLE_LIBRARIES
def _check_dmlab():
"""
Returns:
True if the "deepmind-lab" library is installed.
"""
return "deepmind-lab" in AVAILABLE_LIBRARIES
SUPPORTED_LIBRARIES = {
"gym": _check_gym(), # OpenAI
"gym[atari]": _check_gym_atari(), #
"vizdoom": None, # 1.2k, https://github.com/mwydmuch/ViZDoom
"ml-agents": None,
# 11.5k, unity, https://github.com/Unity-Technologies/ml-agents
"pysc2": None, # 7.3k, DM, https://github.com/deepmind/pysc2
"deepmind_lab": _check_dmlab(),
# 6.5k DM, https://github.com/deepmind/lab, https://github.com/deepmind/lab/tree/master/python/pip_package
"serpent.ai": None, # 6k, https://github.com/SerpentAI/SerpentAI
"gfootball": None, # 2.8k G, https://github.com/google-research/football
"dm_control": _check_dmcontrol(),
# 2.3k DM, https://github.com/deepmind/dm_control
"habitat": None,
# 1.2k FB, https://github.com/facebookresearch/habitat-sim
"meta-world": None, # 500, https://github.com/rlworkgroup/metaworld
"minerl": None, # 300, https://github.com/minerllabs/minerl
"multi-agent-emergence-environments": None,
# 1.2k, OpenAI, https://github.com/openai/multi-agent-emergence-environments
"openspiel": None, # 2.8k, DM, https://github.com/deepmind/open_spiel
"procgen": None, # 500, OpenAI, https://github.com/openai/procgen
"pybullet": None, # 641, https://github.com/benelot/pybullet-gym
"realworld_rl_suite": None,
# 250, G, https://github.com/google-research/realworldrl_suite
"rlcard": None, # 1.4k, https://github.com/datamllab/rlcard
"screeps": None, # 2.3k https://github.com/screeps/screeps
"gym-super-mario-bros": _check_mario(),
}
EXPLORATION_MODE = None
class set_exploration_mode(_DecoratorContextManager):
"""
Sets the exploration mode of all ProbabilisticTDModules to the desired mode.
Args:
mode (str): mode to use when the policy is being called.
Examples:
>>> policy = Actor(action_spec, module=network, default_interaction_mode="mode")
>>> env.rollout(policy=policy, max_steps=100) # rollout with the "mode" interaction mode
>>> with set_exploration_mode("random"):
>>> env.rollout(policy=policy, max_steps=100) # rollout with the "random" interaction mode
"""
def __init__(self, mode: str = "mode"):
super().__init__()
self.mode = mode
def __enter__(self) -> None:
global EXPLORATION_MODE
self.prev = EXPLORATION_MODE
EXPLORATION_MODE = self.mode
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global EXPLORATION_MODE
EXPLORATION_MODE = self.prev
def exploration_mode() -> Union[str, None]:
"""Returns the exploration mode currently set."""
return EXPLORATION_MODE