forked from google-deepmind/open_spiel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rl_environment.py
480 lines (394 loc) · 17.2 KB
/
rl_environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reinforcement Learning (RL) Environment for Open Spiel.
This module wraps Open Spiel Python interface providing an RL-friendly API. It
covers both turn-based and simultaneous move games. Interactions between agents
and the underlying game occur mostly through the `reset` and `step` methods,
which return a `TimeStep` structure (see its docstrings for more info).
The following example illustrates the interaction dynamics. Consider a 2-player
Kuhn Poker (turn-based game). Agents have access to the `observations` (a dict)
field from `TimeSpec`, containing the following members:
* `info_state`: list containing the game information state for each player. The
size of the list always correspond to the number of players. E.g.:
[[0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0]].
* `legal_actions`: list containing legal action ID lists (one for each player).
E.g.: [[0, 1], [0]], which corresponds to actions 0 and 1 being valid for
player 0 (the 1st player) and action 0 being valid for player 1 (2nd player).
* `current_player`: zero-based integer representing the player to make a move.
At each `step` call, the environment expects a singleton list with the action
(as it's a turn-based game), e.g.: [1]. This (zero-based) action must correspond
to the player specified at `current_player`. The game (which is at decision
node) will process the action and take as many steps necessary to cover chance
nodes, halting at a new decision or final node. Finally, a new `TimeStep`is
returned to the agent.
Simultaneous-move games follow analogous dynamics. The only differences is the
environment expects a list of actions, one per player. Note the `current_player`
field is "irrelevant" here, admitting a constant value defined in spiel.h, which
defaults to -2 (module level constant `SIMULTANEOUS_PLAYER_ID`).
See open_spiel/python/examples/rl_example.py for example usages.
"""
import collections
import enum
from absl import logging
import numpy as np
import pyspiel
SIMULTANEOUS_PLAYER_ID = pyspiel.PlayerId.SIMULTANEOUS
class TimeStep(
collections.namedtuple(
"TimeStep", ["observations", "rewards", "discounts", "step_type"])):
"""Returned with every call to `step` and `reset`.
A `TimeStep` contains the data emitted by a game at each step of interaction.
A `TimeStep` holds an `observation` (list of dicts, one per player),
associated lists of `rewards`, `discounts` and a `step_type`.
The first `TimeStep` in a sequence will have `StepType.FIRST`. The final
`TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will
have `StepType.MID.
Attributes:
observations: a list of dicts containing observations per player.
rewards: A list of scalars (one per player), or `None` if `step_type` is
`StepType.FIRST`, i.e. at the start of a sequence.
discounts: A list of discount values in the range `[0, 1]` (one per player),
or `None` if `step_type` is `StepType.FIRST`.
step_type: A `StepType` enum value.
"""
__slots__ = ()
def first(self):
return self.step_type == StepType.FIRST
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
def is_simultaneous_move(self):
return self.observations["current_player"] == SIMULTANEOUS_PLAYER_ID
def current_player(self):
return self.observations["current_player"]
class StepType(enum.Enum):
"""Defines the status of a `TimeStep` within a sequence."""
FIRST = 0 # Denotes the first `TimeStep` in a sequence.
MID = 1 # Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
LAST = 2 # Denotes the last `TimeStep` in a sequence.
def first(self):
return self is StepType.FIRST
def mid(self):
return self is StepType.MID
def last(self):
return self is StepType.LAST
# Global pyspiel members
def registered_games():
return pyspiel.registered_games()
class ChanceEventSampler(object):
"""Default sampler for external chance events."""
def __init__(self, seed=None):
self.seed(seed)
def seed(self, seed=None):
self._rng = np.random.RandomState(seed)
def __call__(self, state):
"""Sample a chance event in the given state."""
actions, probs = zip(*state.chance_outcomes())
return self._rng.choice(actions, p=probs)
class ObservationType(enum.Enum):
"""Defines what kind of observation to use."""
OBSERVATION = 0 # Use observation_tensor
INFORMATION_STATE = 1 # Use information_state_tensor
class Environment(object):
"""Open Spiel reinforcement learning environment class."""
def __init__(self,
game,
discount=1.0,
chance_event_sampler=None,
observation_type=None,
include_full_state=False,
mfg_distribution=None,
mfg_population=None,
enable_legality_check=False,
**kwargs):
"""Constructor.
Args:
game: [string, pyspiel.Game] Open Spiel game name or game instance.
discount: float, discount used in non-initial steps. Defaults to 1.0.
chance_event_sampler: optional object with `sample_external_events` method
to sample chance events.
observation_type: what kind of observation to use. If not specified, will
default to INFORMATION_STATE unless the game doesn't provide it.
include_full_state: whether or not to include the full serialized
OpenSpiel state in the observations (sometimes useful for debugging).
mfg_distribution: the distribution over states if the game is a mean field
game.
mfg_population: The Mean Field Game population to consider.
enable_legality_check: Check the legality of the move before stepping.
**kwargs: dict, additional settings passed to the Open Spiel game.
"""
self._chance_event_sampler = chance_event_sampler or ChanceEventSampler()
self._include_full_state = include_full_state
self._mfg_distribution = mfg_distribution
self._mfg_population = mfg_population
self._enable_legality_check = enable_legality_check
if isinstance(game, str):
if kwargs:
game_settings = {key: val for (key, val) in kwargs.items()}
logging.info("Using game settings: %s", game_settings)
self._game = pyspiel.load_game(game, game_settings)
else:
logging.info("Using game string: %s", game)
self._game = pyspiel.load_game(game)
else: # pyspiel.Game or API-compatible object.
logging.info("Using game instance: %s", game.get_type().short_name)
self._game = game
self._num_players = self._game.num_players()
self._state = None
self._should_reset = True
# Discount returned at non-initial steps.
self._discounts = [discount] * self._num_players
# Determine what observation type to use.
if observation_type is None:
if self._game.get_type().provides_information_state_tensor:
observation_type = ObservationType.INFORMATION_STATE
else:
observation_type = ObservationType.OBSERVATION
# Check the requested observation type is supported.
if observation_type == ObservationType.OBSERVATION:
if not self._game.get_type().provides_observation_tensor:
raise ValueError(f"observation_tensor not supported by {game}")
elif observation_type == ObservationType.INFORMATION_STATE:
if not self._game.get_type().provides_information_state_tensor:
raise ValueError(f"information_state_tensor not supported by {game}")
self._use_observation = (observation_type == ObservationType.OBSERVATION)
if self._game.get_type().dynamics == pyspiel.GameType.Dynamics.MEAN_FIELD:
assert mfg_distribution is not None
assert mfg_population is not None
assert 0 <= mfg_population < self._num_players
def seed(self, seed=None):
self._chance_event_sampler.seed(seed)
def get_time_step(self):
"""Returns a `TimeStep` without updating the environment.
Returns:
A `TimeStep` namedtuple containing:
observation: list of dicts containing one observations per player, each
corresponding to `observation_spec()`.
reward: list of rewards at this timestep, or None if step_type is
`StepType.FIRST`.
discount: list of discounts in the range [0, 1], or None if step_type is
`StepType.FIRST`.
step_type: A `StepType` value.
"""
observations = {
"info_state": [],
"legal_actions": [],
"current_player": [],
"serialized_state": []
}
rewards = []
step_type = StepType.LAST if self._state.is_terminal() else StepType.MID
self._should_reset = step_type == StepType.LAST
cur_rewards = self._state.rewards()
for player_id in range(self.num_players):
rewards.append(cur_rewards[player_id])
observations["info_state"].append(
self._state.observation_tensor(player_id) if self._use_observation
else self._state.information_state_tensor(player_id))
observations["legal_actions"].append(self._state.legal_actions(player_id))
observations["current_player"] = self._state.current_player()
discounts = self._discounts
if step_type == StepType.LAST:
# When the game is in a terminal state set the discount to 0.
discounts = [0. for _ in discounts]
if self._include_full_state:
observations["serialized_state"] = pyspiel.serialize_game_and_state(
self._game, self._state)
# For gym environments
if hasattr(self._state, "last_info"):
observations["info"] = self._state.last_info
return TimeStep(
observations=observations,
rewards=rewards,
discounts=discounts,
step_type=step_type)
def _check_legality(self, actions):
if self.is_turn_based:
legal_actions = self._state.legal_actions()
if actions[0] not in legal_actions:
raise RuntimeError(f"step() called on illegal action {actions[0]}")
else:
for p in range(len(actions)):
legal_actions = self._state.legal_actions(p)
if legal_actions and actions[p] not in legal_actions:
raise RuntimeError(f"step() by player {p} called on illegal " +
f"action: {actions[p]}")
def step(self, actions):
"""Updates the environment according to `actions` and returns a `TimeStep`.
If the environment returned a `TimeStep` with `StepType.LAST` at the
previous step, this call to `step` will start a new sequence and `actions`
will be ignored.
This method will also start a new sequence if called after the environment
has been constructed and `reset` has not been called. Again, in this case
`actions` will be ignored.
Args:
actions: a list containing one action per player, following specifications
defined in `action_spec()`.
Returns:
A `TimeStep` namedtuple containing:
observation: list of dicts containing one observations per player, each
corresponding to `observation_spec()`.
reward: list of rewards at this timestep, or None if step_type is
`StepType.FIRST`.
discount: list of discounts in the range [0, 1], or None if step_type is
`StepType.FIRST`.
step_type: A `StepType` value.
"""
assert len(actions) == self.num_actions_per_step, (
"Invalid number of actions! Expected {}".format(
self.num_actions_per_step))
if self._should_reset:
return self.reset()
if self._enable_legality_check:
self._check_legality(actions)
if self.is_turn_based:
self._state.apply_action(actions[0])
else:
self._state.apply_actions(actions)
self._sample_external_events()
return self.get_time_step()
def reset(self):
"""Starts a new sequence and returns the first `TimeStep` of this sequence.
Returns:
A `TimeStep` namedtuple containing:
observations: list of dicts containing one observations per player, each
corresponding to `observation_spec()`.
rewards: list of rewards at this timestep, or None if step_type is
`StepType.FIRST`.
discounts: list of discounts in the range [0, 1], or None if step_type
is `StepType.FIRST`.
step_type: A `StepType` value.
"""
self._should_reset = False
if self._game.get_type(
).dynamics == pyspiel.GameType.Dynamics.MEAN_FIELD and self._num_players > 1:
self._state = self._game.new_initial_state_for_population(
self._mfg_population)
else:
self._state = self._game.new_initial_state()
self._sample_external_events()
observations = {
"info_state": [],
"legal_actions": [],
"current_player": [],
"serialized_state": []
}
for player_id in range(self.num_players):
observations["info_state"].append(
self._state.observation_tensor(player_id) if self._use_observation
else self._state.information_state_tensor(player_id))
observations["legal_actions"].append(self._state.legal_actions(player_id))
observations["current_player"] = self._state.current_player()
if self._include_full_state:
observations["serialized_state"] = pyspiel.serialize_game_and_state(
self._game, self._state)
return TimeStep(
observations=observations,
rewards=None,
discounts=None,
step_type=StepType.FIRST)
def _sample_external_events(self):
"""Sample chance events until we get to a decision node."""
while self._state.is_chance_node() or (self._state.current_player()
== pyspiel.PlayerId.MEAN_FIELD):
if self._state.is_chance_node():
outcome = self._chance_event_sampler(self._state)
self._state.apply_action(outcome)
if self._state.current_player() == pyspiel.PlayerId.MEAN_FIELD:
dist_to_register = self._state.distribution_support()
dist = [
self._mfg_distribution.value_str(str_state, default_value=0.0)
for str_state in dist_to_register
]
self._state.update_distribution(dist)
def observation_spec(self):
"""Defines the observation per player provided by the environment.
Each dict member will contain its expected structure and shape. E.g.: for
Kuhn Poker {"info_state": (6,), "legal_actions": (2,), "current_player": (),
"serialized_state": ()}
Returns:
A specification dict describing the observation fields and shapes.
"""
return dict(
info_state=tuple([
self._game.observation_tensor_size() if self._use_observation else
self._game.information_state_tensor_size()
]),
legal_actions=(self._game.num_distinct_actions(),),
current_player=(),
serialized_state=(),
)
def action_spec(self):
"""Defines per player action specifications.
Specifications include action boundaries and their data type.
E.g.: for Kuhn Poker {"num_actions": 2, "min": 0, "max":1, "dtype": int}
Returns:
A specification dict containing per player action properties.
"""
return dict(
num_actions=self._game.num_distinct_actions(),
min=0,
max=self._game.num_distinct_actions() - 1,
dtype=int,
)
# Environment properties
@property
def use_observation(self):
"""Returns whether the environment is using the game's observation.
If false, it is using the game's information state.
"""
return self._use_observation
# Game properties
@property
def name(self):
return self._game.get_type().short_name
@property
def num_players(self):
return self._game.num_players()
@property
def num_actions_per_step(self):
return 1 if self.is_turn_based else self.num_players
# New RL calls for more advanced use cases (e.g. search + RL).
@property
def is_turn_based(self):
return ((self._game.get_type().dynamics
== pyspiel.GameType.Dynamics.SEQUENTIAL) or
(self._game.get_type().dynamics
== pyspiel.GameType.Dynamics.MEAN_FIELD))
@property
def max_game_length(self):
return self._game.max_game_length()
@property
def is_chance_node(self):
return self._state.is_chance_node()
@property
def game(self):
return self._game
def set_state(self, new_state):
"""Updates the game state."""
assert new_state.get_game() == self.game, (
"State must have been created by the same game.")
self._state = new_state
@property
def get_state(self):
return self._state
@property
def mfg_distribution(self):
return self._mfg_distribution
def update_mfg_distribution(self, mfg_distribution):
"""Updates the distribution over the states of the mean field game."""
assert (
self._game.get_type().dynamics == pyspiel.GameType.Dynamics.MEAN_FIELD)
self._mfg_distribution = mfg_distribution