Skip to content

Commit

Permalink
Cleaned up a few things in the linear_quadratic model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 500709376
Change-Id: I3a551b30834870517896088bcd8d8edfdd060b67
  • Loading branch information
DeepMind Technologies Ltd authored and lanctot committed Jan 16, 2023
1 parent e75bdf1 commit ad08b8f
Showing 1 changed file with 75 additions and 52 deletions.
127 changes: 75 additions & 52 deletions open_spiel/python/mfg/games/linear_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
_CROSS_Q = 0.01
_KAPPA = 0.5
_TERMINAL_COST = 1.0
_DELTA_T = 1.0 # 3.0/_HORIZON
_DELTA_T = 1.0
_N_ACTIONS_PER_SIDE = 3
_SPATIAL_BIAS = 0

Expand All @@ -52,7 +52,7 @@
"cross_q": _CROSS_Q,
"kappa": _KAPPA,
"terminal_cost": _TERMINAL_COST,
"spatial_bias": _SPATIAL_BIAS
"spatial_bias": _SPATIAL_BIAS,
}

_GAME_TYPE = pyspiel.GameType(
Expand All @@ -69,28 +69,29 @@
provides_information_state_tensor=False,
provides_observation_string=True,
provides_observation_tensor=True,
parameter_specification=_DEFAULT_PARAMS)
parameter_specification=_DEFAULT_PARAMS,
)


class MFGLinearQuadraticGame(pyspiel.Game):
"""A Mean-Field Linear QUadratic game.
A game starts by an initial chance node that select the initial state
of the player in the MFG.
Then the game sequentially alternates between:
- An action selection node (Where the player Id >= 0)
- A chance node (the player id is pyspiel.PlayerId.CHANCE)
- A Mean Field node (the player id is pyspiel.PlayerId.MEAN_FIELD)
"""A Mean-Field Linear Quadratic game.
For now, only single-population setting is covered. A game starts by an
initial chance node that selects the initial state of the player in the MFG.
Then the game sequentially alternates between:
- An action selection node (where the player id is >= 0)
- A chance node (the player id is pyspiel.PlayerId.CHANCE)
- A Mean Field node (the player id is pyspiel.PlayerId.MEAN_FIELD)
"""

# pylint:disable=dangerous-default-value
def __init__(self, params: Mapping[str, Any] = _DEFAULT_PARAMS):
self.size = params.get("size", _SIZE)
self.horizon = params.get("horizon", _HORIZON)
self.dt = params.get("dt", _DELTA_T)
self.n_actions_per_side = params.get("n_actions_per_side",
_N_ACTIONS_PER_SIDE)
self.n_actions_per_side = params.get(
"n_actions_per_side", _N_ACTIONS_PER_SIDE
)
self.volatility = params.get("volatility", _VOLATILITY)
self.mean_revert = params.get("mean_revert", _MEAN_REVERT)
self.cross_q = params.get("cross_q", _CROSS_Q)
Expand All @@ -105,7 +106,8 @@ def __init__(self, params: Mapping[str, Any] = _DEFAULT_PARAMS):
min_utility=-np.inf,
max_utility=+np.inf,
utility_sum=0.0,
max_game_length=self.horizon)
max_game_length=self.horizon,
)
super().__init__(_GAME_TYPE, game_info, params)

def new_initial_state(self):
Expand All @@ -114,8 +116,9 @@ def new_initial_state(self):

def make_py_observer(self, iig_obs_type=None, params=None):
"""Returns an object used for observing game state."""
if ((iig_obs_type is None) or
(iig_obs_type.public_info and not iig_obs_type.perfect_recall)):
if (iig_obs_type is None) or (
iig_obs_type.public_info and not iig_obs_type.perfect_recall
):
return Observer(params, self)
return IIGObserverForPublicInfoGame(iig_obs_type, params)

Expand Down Expand Up @@ -151,13 +154,13 @@ def __init__(self, game):

# Represents the current probability distribution over game states.
# Initialized with a uniform distribution.
self._distribution = [1. / self.size for i in range(self.size)]
self._distribution = [1.0 / self.size for i in range(self.size)]

def to_string(self):
return self.state_to_str(self.x, self.tick)

def state_to_str(self, x, tick, player_id=pyspiel.PlayerId.DEFAULT_PLAYER_ID):
"""A string that uniquely identify a triplet x, t, player_id."""
"""A string that uniquely identifies a triplet x, t, player_id."""
if self.x is None:
return "initial"

Expand All @@ -168,7 +171,8 @@ def state_to_str(self, x, tick, player_id=pyspiel.PlayerId.DEFAULT_PLAYER_ID):
elif self._player_id == pyspiel.PlayerId.CHANCE:
return "({}, {})_a_mu".format(x, tick)
raise ValueError(
"player_id is not mean field, chance or default player id.")
"player_id is not mean field, chance or default player id."
)

# OpenSpiel (PySpiel) API functions are below. This is the standard set that
# should be implemented by every perfect-information sequential-move game.
Expand All @@ -181,29 +185,38 @@ def _legal_actions(self, player):
"""Returns a list of legal actions for player and MFG nodes."""
if player == pyspiel.PlayerId.MEAN_FIELD:
return []
if (player == pyspiel.PlayerId.DEFAULT_PLAYER_ID and
player == self.current_player()):
if (
player == pyspiel.PlayerId.DEFAULT_PLAYER_ID
and player == self.current_player()
):
return list(range(self.n_actions))
raise ValueError(f"Unexpected player {player}. "
"Expected a mean field or current player 0.")
raise ValueError(
f"Unexpected player {player}. "
"Expected a mean field or current player 0."
)

def _apply_action(self, action):
"""Applies the specified action to the state."""
if self._player_id == pyspiel.PlayerId.MEAN_FIELD:
raise ValueError(
"_apply_action should not be called at a MEAN_FIELD state.")
"_apply_action should not be called at a MEAN_FIELD state."
)
self.return_value = self._rewards()

assert self._player_id == pyspiel.PlayerId.DEFAULT_PLAYER_ID or self._player_id == pyspiel.PlayerId.CHANCE
assert (
self._player_id == pyspiel.PlayerId.DEFAULT_PLAYER_ID
or self._player_id == pyspiel.PlayerId.CHANCE
)

if self.x is None:
self.x = action
self._player_id = pyspiel.PlayerId.DEFAULT_PLAYER_ID
return

if action < 0 or action >= self.n_actions:
raise ValueError("The action is between 0 and {} at any node".format(
self.n_actions))
raise ValueError(
"The action is between 0 and {} at any node".format(self.n_actions)
)

move = self.action_to_move(action)
if self._player_id == pyspiel.PlayerId.CHANCE:
Expand All @@ -212,7 +225,7 @@ def _apply_action(self, action):
self._player_id = pyspiel.PlayerId.MEAN_FIELD
self.tick += 1
elif self._player_id == pyspiel.PlayerId.DEFAULT_PLAYER_ID:
dist_mean = (self.distribution_average() - self.x)
dist_mean = self.distribution_average() - self.x
full_move = move
full_move += self.mean_revert * dist_mean
full_move *= self.dt
Expand Down Expand Up @@ -240,12 +253,14 @@ def chance_outcomes(self):

a = np.array(self.actions_to_position())
gaussian_vals = scipy.stats.norm.cdf(
a + 0.5, scale=self.volatility) - scipy.stats.norm.cdf(
a - 0.5, scale=self.volatility)
gaussian_vals[0] += scipy.stats.norm.cdf(
a[0] - 0.5, scale=self.volatility) - 0.0
a + 0.5, scale=self.volatility
) - scipy.stats.norm.cdf(a - 0.5, scale=self.volatility)
gaussian_vals[0] += (
scipy.stats.norm.cdf(a[0] - 0.5, scale=self.volatility) - 0.0
)
gaussian_vals[-1] += 1.0 - scipy.stats.norm.cdf(
a[-1] + 0.5, scale=self.volatility)
a[-1] + 0.5, scale=self.volatility
)
return [
(act, p) for act, p in zip(list(range(self.n_actions)), gaussian_vals)
]
Expand Down Expand Up @@ -273,7 +288,8 @@ def update_distribution(self, distribution):
"""
if self._player_id != pyspiel.PlayerId.MEAN_FIELD:
raise ValueError(
"update_distribution should only be called at a MEAN_FIELD state.")
"update_distribution should only be called at a MEAN_FIELD state."
)
self._distribution = distribution.copy()
self._player_id = pyspiel.PlayerId.DEFAULT_PLAYER_ID

Expand Down Expand Up @@ -301,25 +317,32 @@ def eta_t(self):
T = self.horizon
t = self.t

R = (K + q)**2 + (kappa - q**2)
R = (K + q) ** 2 + (kappa - q**2)
deltap = -(K + q) + math.sqrt(R)
deltam = -(K + q) - math.sqrt(R)
numerator = -(kappa - q**2) * (math.exp(
(deltap - deltam) * (T - t)) - 1) - c * (
deltap * math.exp((deltap - deltam) * (T - t)) - deltam)
denominator = (deltam * math.exp(
(deltap - deltam) * (T - t)) - deltap) - c * (
math.exp((deltap - deltam) * (T - t)) - 1)
numerator = -(kappa - q**2) * (
math.exp((deltap - deltam) * (T - t)) - 1
) - c * (deltap * math.exp((deltap - deltam) * (T - t)) - deltam)
denominator = (
deltam * math.exp((deltap - deltam) * (T - t)) - deltap
) - c * (math.exp((deltap - deltam) * (T - t)) - 1)
return numerator / denominator

def _rewards(self):
"""Reward for the player for this state."""
if self._player_id == pyspiel.PlayerId.DEFAULT_PLAYER_ID:
dist_mean = (self.distribution_average() - self.x)
dist_mean = self.distribution_average() - self.x

move = self.action_to_move(self._last_action)
action_reward = self.dt / 2 * (-move**2 + 2 * self.cross_q * move *
dist_mean - self.kappa * dist_mean**2)
action_reward = (
self.dt
/ 2
* (
-(move**2)
+ 2 * self.cross_q * move * dist_mean
- self.kappa * dist_mean**2
)
)

if self.is_terminal():
terminal_reward = -self.terminal_cost * dist_mean**2 / 2.0
Expand All @@ -330,8 +353,7 @@ def _rewards(self):

def rewards(self) -> List[float]:
"""Rewards for all players."""
# For now, only single-population (single-player) mean field games
# are supported.
# For now, only single-population mean field games are supported.
return [self._rewards()]

def _returns(self):
Expand All @@ -340,14 +362,14 @@ def _returns(self):

def returns(self) -> List[float]:
"""Returns for all players."""
# For now, only single-population (single-player) mean field games
# are supported.
# For now, only single-population mean field games are supported.
return [self._returns()]

def __str__(self):
"""A string that uniquely identify the current state."""
return self.state_to_str(
x=self.x, tick=self.tick, player_id=self._player_id)
x=self.x, tick=self.tick, player_id=self._player_id
)


class Observer:
Expand All @@ -363,7 +385,7 @@ def __init__(self, params, game):
self.dict = {
"x": self.tensor[0],
"t": self.tensor[1],
"observation": self.tensor
"observation": self.tensor,
}

def set_from(self, state, player: int):
Expand All @@ -378,7 +400,8 @@ def set_from(self, state, player: int):
if state.x is not None:
if not 0 <= state.x < self.size:
raise ValueError(
f"Expected {state} x position to be in [0, {self.size})")
f"Expected {state} x position to be in [0, {self.size})"
)
self.dict["x"] = np.array([state.x])
if not 0 <= state.t <= self.horizon:
raise ValueError(f"Expected {state} time to be in [0, {self.horizon}]")
Expand Down

0 comments on commit ad08b8f

Please sign in to comment.