Skip to content

Commit

Permalink
correct issues about symmetric games in psro_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
rezunli96 committed Mar 9, 2023
1 parent c269d4a commit b5b6dc8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
8 changes: 5 additions & 3 deletions open_spiel/python/algorithms/psro_v2/abstract_meta_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def iteration(self, seed=None):
def update_meta_strategies(self):
self._meta_strategy_probabilities = self._meta_strategy_method(self)
if self.symmetric_game:
self._meta_strategy_probabilities = [self._meta_strategy_probabilities[0]]
self._meta_strategy_probabilities = [
self._meta_strategy_probabilities[0]]

def update_agents(self):
return NotImplementedError("update_agents not implemented.")
Expand Down Expand Up @@ -233,14 +234,15 @@ def get_meta_strategies(self):
def get_meta_game(self):
"""Returns the meta game matrix."""
meta_games = self._meta_games
if self.symmetric_game:
meta_games = self._game_num_players * meta_games
return [np.copy(a) for a in meta_games]

def get_policies(self):
"""Returns the players' policies."""
policies = self._policies
if self.symmetric_game:
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
policies = self._game_num_players * policies
return policies

Expand Down
32 changes: 27 additions & 5 deletions open_spiel/python/algorithms/psro_v2/psro_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,18 @@ def __init__(self,
**kwargs)

def _initialize_policy(self, initial_policies):
self._policies = [[] for k in range(self._num_players)]
self._new_policies = [([initial_policies[k]] if initial_policies else
[policy.UniformRandomPolicy(self._game)])
for k in range(self._num_players)]
if self.symmetric_game:
self._policies = [[]]
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
self._new_policies = [([initial_policies[0]] if initial_policies else
[policy.UniformRandomPolicy(self._game)])]
else:
self._policies = [[] for _ in range(self._num_players)]
self._new_policies = [([initial_policies[k]] if initial_policies else
[policy.UniformRandomPolicy(self._game)])
for k in range(self._num_players)]

def _initialize_game_state(self):
effective_payoff_size = self._game_num_players
Expand Down Expand Up @@ -211,14 +219,18 @@ def update_meta_strategies(self):
meta-probabilities.
"""
if self.symmetric_game:
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
self._policies = self._policies * self._game_num_players

self._meta_strategy_probabilities, self._non_marginalized_probabilities = (
self._meta_strategy_method(solver=self, return_joint=True))

if self.symmetric_game:
self._policies = [self._policies[0]]
self._meta_strategy_probabilities = [self._meta_strategy_probabilities[0]]
self._meta_strategy_probabilities = [
self._meta_strategy_probabilities[0]]

def get_policies_and_strategies(self):
"""Returns current policy sampler, policies and meta-strategies of the game.
Expand Down Expand Up @@ -330,6 +342,9 @@ def update_agents(self):
training_parameters[current_player].append(new_parameter)

if self.symmetric_game:
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
self._policies = self._game_num_players * self._policies
self._num_players = self._game_num_players
training_parameters = [training_parameters[0]]
Expand Down Expand Up @@ -366,6 +381,9 @@ def update_empirical_gamestate(self, seed=None):
# Switch to considering the game as a symmetric game where players have
# the same policies & new policies. This allows the empirical gamestate
# update to function normally.
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
self._policies = self._game_num_players * self._policies
self._new_policies = self._game_num_players * self._new_policies
self._num_players = self._game_num_players
Expand Down Expand Up @@ -428,6 +446,7 @@ def update_empirical_gamestate(self, seed=None):
# TODO(author4): This update uses ~2**(n_players-1) * sims_per_entry
# samples to estimate each payoff table entry. This should be
# brought to sims_per_entry to coincide with expected behavior.

utility_estimates = self.sample_episodes(estimated_policies,
self._sims_per_entry)

Expand Down Expand Up @@ -471,6 +490,9 @@ def get_policies(self):
policies = self._policies
if self.symmetric_game:
# For compatibility reasons, return list of expected length.
# Notice that the following line returns N references to the same policy
# This might not be correct for certain applications.
# E.g., a DQN BR oracle with player_id information
policies = self._game_num_players * self._policies
return policies

Expand Down

0 comments on commit b5b6dc8

Please sign in to comment.