Skip to content

Commit

Permalink
Internal change to algorithms.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 477154449
Change-Id: I544aa226f7ca8fccc889402c7ffd53c7fd3c27fc
  • Loading branch information
sgirgin authored and lanctot committed Sep 29, 2022
1 parent 88bef5f commit 2b8c5da
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 65 deletions.
1 change: 1 addition & 0 deletions docs/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ External sampling Monte Carlo CFR | Tabular | [Lanctot et
Fixed Strategy Iteration CFR (FSICFR) | Tabular | [Neller &amp; Hnath '11](https://cupola.gettysburg.edu/csfac/2/) | <font color="orange"><b>~</b></font>
Mean-field Ficticious Play for MFG | Tabular | [Perrin et. al. '20](https://arxiv.org/abs/2007.03458) | <font color="orange"><b>~</b></font>
Online Mirror Descent for MFG | Tabular | [Perolat et. al. '21](https://arxiv.org/abs/2103.00623) | <font color="orange"><b>~</b></font>
Munchausen Online Mirror Descent for MFG | Tabular | [Lauriere et. al. '22](https://arxiv.org/pdf/2203.11973) | <font color="orange"><b>~</b></font>
Outcome sampling Monte Carlo CFR | Tabular | [Lanctot et al. '09](http://mlanctot.info/files/papers/nips09mccfr.pdf), [Lanctot '13](http://mlanctot.info/files/papers/PhD_Thesis_MarcLanctot.pdf) | ![](_static/green_circ10.png "green circle")
Policy Iteration | Tabular | [Sutton &amp; Barto '18](http://incompleteideas.net/book/the-book-2nd.html) | ![](_static/green_circ10.png "green circle")
Q-learning | Tabular | [Sutton &amp; Barto '18](http://incompleteideas.net/book/the-book-2nd.html) | ![](_static/green_circ10.png "green circle")
Expand Down
135 changes: 70 additions & 65 deletions open_spiel/python/mfg/algorithms/mirror_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mirror Descent (https://arxiv.org/pdf/2103.00623.pdf)."""
from typing import Optional

from typing import Dict, List, Optional

import numpy as np

from open_spiel.python import policy as policy_std
from open_spiel.python import policy as policy_lib
from open_spiel.python.mfg import value
from open_spiel.python.mfg.algorithms import distribution
import pyspiel
Expand All @@ -29,48 +30,54 @@ def softmax_projection(logits):
return [l / norm_exp for l in exp_l]


class ProjectedPolicy(policy_std.Policy):
class ProjectedPolicy(policy_lib.Policy):
"""Project values on the policy simplex."""

def __init__(self, game, player_ids,
cumulative_state_value: value.ValueFunction):
def __init__(
self,
game: pyspiel.Game,
player_ids: List[int],
state_value: value.ValueFunction,
):
"""Initializes the projected policy.
Args:
game: The game to analyze.
player_ids: list of player ids for which this policy applies; each should
be in the range 0..game.num_players()-1.
cumulative_state_value: The cumulative state value to project.
state_value: The (cumulative) state value to project.
"""
super(ProjectedPolicy, self).__init__(game, player_ids)
self._cumulative_state_value = cumulative_state_value
self._state_value = state_value

def cumulative_value(self, state, action=None):
def value(self, state: pyspiel.State, action: Optional[int] = None) -> float:
if action is None:
return self._cumulative_state_value(
return self._state_value(
state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID))
else:
new_state = state.child(action)
return state.rewards()[0] + self._cumulative_state_value(
return state.rewards()[0] + self._state_value(
new_state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID))

def action_probabilities(self, state, player_id=None):
action_logit = [(a, self.cumulative_value(state, action=a))
for a in state.legal_actions()]
def action_probabilities(self,
state: pyspiel.State,
player_id: Optional[int] = None) -> Dict[int, float]:
del player_id
action_logit = [
(a, self.value(state, action=a)) for a in state.legal_actions()
]
action, logit = zip(*action_logit)
prob = softmax_projection(logit)
action_prob = zip(action, prob)
return dict(action_prob)
return dict(zip(action, softmax_projection(logit)))


class MirrorDescent(object):
"""The mirror descent algorithm."""

def __init__(self,
game,
game: pyspiel.Game,
state_value: Optional[value.ValueFunction] = None,
lr=0.01,
root_state=None):
lr: float = 0.01,
root_state: Optional[pyspiel.State] = None):
"""Initializes mirror descent.
Args:
Expand All @@ -85,7 +92,7 @@ def __init__(self,
self._root_states = game.new_initial_states()
else:
self._root_states = [root_state]
self._policy = policy_std.UniformRandomPolicy(game)
self._policy = policy_lib.UniformRandomPolicy(game)
self._distribution = distribution.DistributionPolicy(game, self._policy)
self._md_step = 0
self._lr = lr
Expand All @@ -94,69 +101,67 @@ def __init__(self,
state_value if state_value else value.TabularValueFunction(game))
self._cumulative_state_value = value.TabularValueFunction(game)

def eval_state(self, state, learning_rate):
"""Evaluate the value of a state and update the cumulative sum."""
state_str = state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID)
if self._state_value.has(state_str):
return self._state_value(state_str)
elif state.is_terminal():
self._state_value.set_value(
state_str,
state.rewards()[state.mean_field_population()])
self._cumulative_state_value.add_value(
state_str, learning_rate * self._state_value(state_str))
return self._state_value(state_str)
elif state.current_player() == pyspiel.PlayerId.CHANCE:
self._state_value.set_value(state_str, 0.0)
def get_state_value(self, state: pyspiel.State,
learning_rate: float) -> float:
"""Returns the value of the state."""
if state.is_terminal():
return state.rewards()[state.mean_field_population()]

if state.current_player() == pyspiel.PlayerId.CHANCE:
v = 0.0
for action, prob in state.chance_outcomes():
new_state = state.child(action)
self._state_value.add_value(
state_str, prob * self.eval_state(new_state, learning_rate))
self._cumulative_state_value.add_value(
state_str, learning_rate * self._state_value(state_str))
return self._state_value(state_str)
elif state.current_player() == pyspiel.PlayerId.MEAN_FIELD:
v += prob * self.eval_state(new_state, learning_rate)
return v

if state.current_player() == pyspiel.PlayerId.MEAN_FIELD:
dist_to_register = state.distribution_support()
dist = [
self._distribution.value_str(str_state, 0.0)
for str_state in dist_to_register
]
new_state = state.clone()
new_state.update_distribution(dist)
self._state_value.set_value(
state_str,
state.rewards()[state.mean_field_population()] +
self.eval_state(new_state, learning_rate))
self._cumulative_state_value.add_value(
state_str, learning_rate * self._state_value(state_str))
return self._state_value(state_str)
else:
assert int(state.current_player()) >= 0, "The player id should be >= 0"
v = 0.0
for action, prob in self._policy.action_probabilities(state).items():
new_state = state.child(action)
v += prob * self.eval_state(new_state, learning_rate)
self._state_value.set_value(
state_str,
state.rewards()[state.mean_field_population()] + v)
self._cumulative_state_value.add_value(
state_str, learning_rate * self._state_value(state_str))
return self._state_value(state_str)
return (state.rewards()[state.mean_field_population()] +
self.eval_state(new_state, learning_rate))

def iteration(self, learning_rate=None):
"""an iteration of Mirror Descent."""
assert int(state.current_player()) >= 0, "The player id should be >= 0"
v = 0.0
for action, prob in self._policy.action_probabilities(state).items():
new_state = state.child(action)
v += prob * self.eval_state(new_state, learning_rate)
return state.rewards()[state.mean_field_population()] + v

def eval_state(self, state: pyspiel.State, learning_rate: float) -> float:
"""Evaluate the value of a state and update the cumulative sum."""
state_str = state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID)
# Return the already calculated value if present.
if self._state_value.has(state_str):
return self._state_value(state_str)
# Otherwise, calculate the value of the state.
v = self.get_state_value(state, learning_rate)
self._state_value.set_value(state_str, v)
# Update the cumulative value of the state.
self._cumulative_state_value.add_value(state_str, learning_rate * v)
return v

def get_projected_policy(self) -> policy_lib.Policy:
"""Returns the projected policy."""
return ProjectedPolicy(self._game, list(range(self._game.num_players())),
self._cumulative_state_value)

def iteration(self, learning_rate: Optional[float] = None):
"""An iteration of Mirror Descent."""
self._md_step += 1
# TODO(sertan): Fix me.
self._state_value = value.TabularValueFunction(self._game)
for state in self._root_states:
self.eval_state(state, learning_rate if learning_rate else self._lr)
self._policy = ProjectedPolicy(self._game,
list(range(self._game.num_players())),
self._cumulative_state_value)
self._policy = self.get_projected_policy()
self._distribution = distribution.DistributionPolicy(
self._game, self._policy)

def get_policy(self):
def get_policy(self) -> policy_lib.Policy:
return self._policy

@property
Expand Down
86 changes: 86 additions & 0 deletions open_spiel/python/mfg/algorithms/munchausen_mirror_descent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Munchausen Online Mirror Descent."""

from typing import Dict, List, Optional

import numpy as np

from open_spiel.python import policy as policy_lib
from open_spiel.python.mfg import value
from open_spiel.python.mfg.algorithms import mirror_descent
import pyspiel


class ProjectedPolicyMunchausen(mirror_descent.ProjectedPolicy):
"""Project values on the policy simplex."""

def __init__(
self,
game: pyspiel.Game,
player_ids: List[int],
state_value: value.ValueFunction,
learning_rate: float,
policy: policy_lib.Policy,
):
"""Initializes the projected policy.
Args:
game: The game to analyze.
player_ids: list of player ids for which this policy applies; each should
be in the range 0..game.num_players()-1.
state_value: The state value to project.
learning_rate: The learning rate.
policy: The policy to project.
"""
super().__init__(game, player_ids, state_value)
self._learning_rate = learning_rate
self._policy = policy

def action_probabilities(self,
state: pyspiel.State,
player_id: Optional[int] = None) -> Dict[int, float]:
del player_id
action_logit = [
(a, self._learning_rate * self.value(state, action=a) + np.log(p))
for a, p in self._policy.action_probabilities(state).items()
]
action, logit = zip(*action_logit)
return dict(zip(action, mirror_descent.softmax_projection(logit)))


class MunchausenMirrorDescent(mirror_descent.MirrorDescent):
"""Munchausen Online Mirror Descent algorithm.
This algorithm is equivalent to the online mirror descent algorithm but
instead of summing value functions, it directly computes the cumulative
Q-function using a penalty with respect to the previous policy.
"""

def eval_state(self, state: pyspiel.State, learning_rate: float):
"""Evaluate the value of a state."""
state_str = state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID)
# Return the already calculated value if present.
if self._state_value.has(state_str):
return self._state_value(state_str)
# Otherwise, calculate the value of the state.
v = self.get_state_value(state, learning_rate)
self._state_value.set_value(state_str, v)
return v

def get_projected_policy(self) -> policy_lib.Policy:
"""Returns the projected policy."""
return ProjectedPolicyMunchausen(self._game,
list(range(self._game.num_players())),
self._state_value, self._lr, self._policy)
44 changes: 44 additions & 0 deletions open_spiel/python/mfg/algorithms/munchausen_mirror_descent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Munchausen Online Mirror Descent."""

from absl.testing import absltest
from absl.testing import parameterized

from open_spiel.python.mfg import value
from open_spiel.python.mfg.algorithms import munchausen_mirror_descent
from open_spiel.python.mfg.algorithms import nash_conv
from open_spiel.python.mfg.games import crowd_modelling # pylint: disable=unused-import
import pyspiel


class MunchausenMirrorDescentTest(parameterized.TestCase):

@parameterized.named_parameters(('python', 'python_mfg_crowd_modelling'),
('cpp', 'mfg_crowd_modelling'))
def test_run(self, name):
"""Checks if the algorithm works."""
game = pyspiel.load_game(name)
md = munchausen_mirror_descent.MunchausenMirrorDescent(
game, value.TabularValueFunction(game))
for _ in range(10):
md.iteration()
md_policy = md.get_policy()
nash_conv_md = nash_conv.NashConv(game, md_policy)

self.assertAlmostEqual(nash_conv_md.nash_conv(), 2.27366, places=5)


if __name__ == '__main__':
absltest.main()

0 comments on commit 2b8c5da

Please sign in to comment.