Skip to content

Commit

Permalink
Add CFR with fully reaching probability algorithms.
Browse files Browse the repository at this point in the history
  • Loading branch information
JialianLee committed Mar 22, 2020
1 parent f811002 commit fc8bd86
Show file tree
Hide file tree
Showing 8 changed files with 1,607 additions and 4 deletions.
89 changes: 87 additions & 2 deletions open_spiel/python/algorithms/cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import numpy as np

from open_spiel.python import policy
from open_spiel.python.algorithms import best_response as pyspiel_best_response
import pyspiel


Expand Down Expand Up @@ -178,6 +179,8 @@ def __init__(self, game, alternating_updates, linear_averaging,
self._linear_averaging = linear_averaging
self._alternating_updates = alternating_updates
self._regret_matching_plus = regret_matching_plus
self.nodes_touched = 0
self.cumulative_rewards = np.zeros(self._num_players)

def _initialize_info_state_nodes(self, state):
"""Initializes info_state_nodes.
Expand Down Expand Up @@ -276,7 +279,7 @@ def _compute_counterfactual_regret_for_player(self, state, policies,
state_value += action_prob * self._compute_counterfactual_regret_for_player(
new_state, policies, new_reach_probabilities, player)
return state_value

current_player = state.current_player()
info_state = state.information_state_string(current_player)

Expand All @@ -287,7 +290,10 @@ def _compute_counterfactual_regret_for_player(self, state, policies,
# occurring, so the returned value is not impacting the parent node value.
if all(reach_probabilities[:-1] == 0):
return np.zeros(self._num_players)
# if reach_probabilities[1 - current_player] == 0:
# return np.zeros(self._num_players)

self.nodes_touched += 1
state_value = np.zeros(self._num_players)

# The utilities of the children states are computed recursively. As the
Expand Down Expand Up @@ -351,6 +357,75 @@ def _get_infostate_policy(self, info_state_str):
action: prob_vec[action] for action in info_state_node.legal_actions
}

def _compute_best_response_policy(self, player):
# current_policy = policy.PolicyFromCallable(
# self._game,
# lambda state: self._get_infostate_policy(state.information_state_string(
# )))
current_policy = self.average_policy()
if player is not None:
self.br = pyspiel_best_response.BestResponsePolicy(self._game, player, current_policy,
self._root_node)
else:
self.br = {}
for p in range(self._num_players):
self.br.update(pyspiel_best_response.BestResponsePolicy(self_game, p, current_policy,
self._root_node))

def _calculate_util(self, state, policies, reach_probabilities):
if state.is_terminal():
return np.asarray(state.returns())

if state.is_chance_node():
state_value = 0.0
for action, action_prob in state.chance_outcomes():
assert action_prob > 0
new_state = state.child(action)
new_reach_probabilities = reach_probabilities.copy()
new_reach_probabilities[-1] *= action_prob
state_value += self._calculate_util(
new_state, policies, new_reach_probabilities)
return state_value

current_player = state.current_player()
info_state = state.information_state_string(current_player)

# No need to continue on this history branch as no update will be performed
# for any player.
# The value we return here is not used in practice. If the conditional
# statement is True, then the last taken action has probability 0 of
# occurring, so the returned value is not impacting the parent node value.
if all(reach_probabilities[:-1] == 0):
return np.zeros(self._num_players)
# if reach_probabilities[1 - current_player] == 0:
# return np.zeros(self._num_players)

self.nodes_touched += 1
state_value = np.zeros(self._num_players)

# The utilities of the children states are computed recursively. As the
# regrets are added to the information state regrets for each state in that
# information state, the recursive call can only be made once per child
# state. Therefore, the utilities are cached.

# children_utilities = {}

info_state_node = self._info_state_nodes[info_state]
if policies is None:
info_state_policy = self._get_infostate_policy(info_state)
else:
info_state_policy = policies[current_player](info_state)
for action in state.legal_actions():
action_prob = info_state_policy.get(action, 0.)
new_state = state.child(action)
new_reach_probabilities = reach_probabilities.copy()
new_reach_probabilities[current_player] *= action_prob
state_value += action_prob * self._calculate_util(
new_state,
policies=policies,
reach_probabilities=new_reach_probabilities)
return state_value


def _regret_matching(cumulative_regrets, legal_actions):
"""Returns an info state policy by applying regret-matching.
Expand Down Expand Up @@ -443,7 +518,17 @@ def evaluate_and_update_policy(self):
if self._regret_matching_plus:
_apply_regret_matching_plus_reset(self._info_state_nodes)
_update_current_policy(self._current_policy, self._info_state_nodes)

self.cumulative_rewards += self._calculate_util(self._root_node,
policies=None,
reach_probabilities=np.ones(self._game.num_players() + 1))

def get_regret(self):
reg = 0.0
for player in range(self._num_players):
self._compute_best_response_policy(player)
reg += self.br.value(self._root_node)
reg -= sum(self.cumulative_rewards) / self._iteration
return reg

class CFRPlusSolver(_CFRSolver):
"""CFR+ implementation.
Expand Down
Loading

0 comments on commit fc8bd86

Please sign in to comment.