This repository has been archived by the owner on Oct 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy pathbsuite.py
215 lines (184 loc) · 7.9 KB
/
bsuite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions to load bsuite environments."""
from typing import Any, Mapping, Tuple
from bsuite import sweep
from bsuite.environments import base
from bsuite.experiments.bandit import bandit
from bsuite.experiments.bandit_noise import bandit_noise
from bsuite.experiments.bandit_scale import bandit_scale
from bsuite.experiments.cartpole import cartpole
from bsuite.experiments.cartpole_noise import cartpole_noise
from bsuite.experiments.cartpole_scale import cartpole_scale
from bsuite.experiments.cartpole_swingup import cartpole_swingup
from bsuite.experiments.catch import catch
from bsuite.experiments.catch_noise import catch_noise
from bsuite.experiments.catch_scale import catch_scale
from bsuite.experiments.deep_sea import deep_sea
from bsuite.experiments.deep_sea_stochastic import deep_sea_stochastic
from bsuite.experiments.discounting_chain import discounting_chain
from bsuite.experiments.memory_len import memory_len
from bsuite.experiments.memory_size import memory_size
from bsuite.experiments.mnist import mnist
from bsuite.experiments.mnist_noise import mnist_noise
from bsuite.experiments.mnist_scale import mnist_scale
from bsuite.experiments.mountain_car import mountain_car
from bsuite.experiments.mountain_car_noise import mountain_car_noise
from bsuite.experiments.mountain_car_scale import mountain_car_scale
from bsuite.experiments.umbrella_distract import umbrella_distract
from bsuite.experiments.umbrella_length import umbrella_length
from bsuite.logging import csv_logging
from bsuite.logging import sqlite_logging
from bsuite.logging import terminal_logging
import dm_env
import termcolor
# Internal imports.
# Mapping from experiment name to environment constructor or load function.
# Each constructor or load function accepts keyword arguments as defined in
# each experiment's sweep.py file.
EXPERIMENT_NAME_TO_ENVIRONMENT = dict(
bandit=bandit.load,
bandit_noise=bandit_noise.load,
bandit_scale=bandit_scale.load,
cartpole=cartpole.load,
cartpole_noise=cartpole_noise.load,
cartpole_scale=cartpole_scale.load,
cartpole_swingup=cartpole_swingup.CartpoleSwingup,
catch=catch.load,
catch_noise=catch_noise.load,
catch_scale=catch_scale.load,
deep_sea=deep_sea.load,
deep_sea_stochastic=deep_sea_stochastic.load,
discounting_chain=discounting_chain.load,
memory_len=memory_len.load,
memory_size=memory_size.load,
mnist=mnist.load,
mnist_noise=mnist_noise.load,
mnist_scale=mnist_scale.load,
mountain_car=mountain_car.load,
mountain_car_noise=mountain_car_noise.load,
mountain_car_scale=mountain_car_scale.load,
umbrella_distract=umbrella_distract.load,
umbrella_length=umbrella_length.load,
)
def unpack_bsuite_id(bsuite_id: str) -> Tuple[str, int]:
"""Returns the experiment name and setting index given a bsuite_id."""
parts = bsuite_id.split(sweep.SEPARATOR)
assert len(parts) == 2
experiment_name = parts[0]
setting_index = int(parts[1])
return experiment_name, setting_index
def load(
experiment_name: str,
kwargs: Mapping[str, Any],
) -> base.Environment:
"""Returns a bsuite environment given an experiment name and settings."""
return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs)
def load_from_id(bsuite_id: str) -> base.Environment:
"""Returns a bsuite environment given a bsuite_id."""
kwargs = sweep.SETTINGS[bsuite_id]
experiment_name, _ = unpack_bsuite_id(bsuite_id)
env = load(experiment_name, kwargs)
termcolor.cprint(
f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold'])
return env
def load_and_record(bsuite_id: str,
save_path: str,
logging_mode: str = 'csv',
overwrite: bool = False) -> dm_env.Environment:
"""Returns a bsuite environment wrapped with either CSV or SQLite logging."""
if logging_mode == 'csv':
return load_and_record_to_csv(bsuite_id, save_path, overwrite)
elif logging_mode == 'sqlite':
if not save_path.endswith('.db'):
save_path += '.db'
if overwrite:
print('WARNING: overwrite option is ignored for SQLite logging.')
return load_and_record_to_sqlite(bsuite_id, save_path)
elif logging_mode == 'terminal':
return load_and_record_to_terminal(bsuite_id)
else:
raise ValueError((f'Unrecognised logging_mode "{logging_mode}". '
'Must be "csv", "sqlite", or "terminal".'))
def load_and_record_to_sqlite(bsuite_id: str,
db_path: str) -> dm_env.Environment:
"""Returns a bsuite environment that saves results to an SQLite database.
The returned environment will automatically save the results required for
the analysis notebook when stepping through the environment.
To load the results, specify the file path in the provided notebook, or to
manually inspect the results use:
```python
from bsuite.logging import sqlite_load
results_df, sweep_vars = sqlite_load.load_bsuite('/path/to/database.db')
```
Args:
bsuite_id: The bsuite id identifying the environment to return. For example,
"catch/0" or "deep_sea/3".
db_path: Path to the database file for this set of results. The file will be
created if it does not already exist. When generating results using
multiple different processes, specify the *same* db_path for every
bsuite_id.
Returns:
A bsuite environment determined by the bsuite_id.
"""
raw_env = load_from_id(bsuite_id)
experiment_name, setting_index = unpack_bsuite_id(bsuite_id)
termcolor.cprint(
f'Logging results to SQLite database in {db_path}.',
color='yellow',
attrs=['bold'])
return sqlite_logging.wrap_environment(
env=raw_env,
db_path=db_path,
experiment_name=experiment_name,
setting_index=setting_index,
)
def load_and_record_to_csv(bsuite_id: str,
results_dir: str,
overwrite: bool = False) -> dm_env.Environment:
"""Returns a bsuite environment that saves results to CSV.
To load the results, specify the file path in the provided notebook, or to
manually inspect the results use:
```python
from bsuite.logging import csv_load
results_df, sweep_vars = csv_load.load_bsuite(results_dir)
```
Args:
bsuite_id: The bsuite id identifying the environment to return. For example,
"catch/0" or "deep_sea/3".
results_dir: Path to the directory to store the resultant CSV files. Note
that this logger will generate a separate CSV file for each bsuite_id.
overwrite: Whether to overwrite existing CSV files if found.
Returns:
A bsuite environment determined by the bsuite_id.
"""
raw_env = load_from_id(bsuite_id)
termcolor.cprint(
f'Logging results to CSV file for each bsuite_id in {results_dir}.',
color='yellow',
attrs=['bold'])
return csv_logging.wrap_environment(
env=raw_env,
bsuite_id=bsuite_id,
results_dir=results_dir,
overwrite=overwrite,
)
def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment:
"""Returns a bsuite environment that logs to terminal."""
raw_env = load_from_id(bsuite_id)
termcolor.cprint(
'Logging results to terminal.', color='yellow', attrs=['bold'])
return terminal_logging.wrap_environment(raw_env)