Skip to content

Commit

Permalink
Bugfix for PSRO on symmetric, simultaneous moves games.
Browse files Browse the repository at this point in the history
The fix was operated by changing the RL Oracle sampling logic to support simultaneous moves games, by correcting the symmetrization logic of PSRO, and by switching from pyspiel.load_game_as_turn_based to pyspiel.load_game to keep the games' symmetries.
The aggregation step doesn't yet support simultaneous moves, and thus exploitability computation has been removed in those cases.

Fixes: google-deepmind#175
PiperOrigin-RevId: 303162118
Change-Id: I38738009012442be699e5f4ca355f7e78b2ef715
  • Loading branch information
DeepMind Technologies Ltd authored and tewalds committed Mar 30, 2020
1 parent 35c9f3b commit a679985
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 47 deletions.
7 changes: 7 additions & 0 deletions open_spiel/python/algorithms/policy_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def action_probabilities(self, state, player_id=None):
supplied state.
"""
state_key = self._state_key(state, player_id=player_id)
if state.is_simultaneous_node():
# Policy aggregator doesn't yet support simultaneous moves nodes.
# The below lines are one step towards that direction.
result = []
for player_pol in self._policies:
result.append(player_pol[state_key])
return result
if player_id is None:
player_id = state.current_player()
return self._policies[player_id][state_key]
Expand Down
30 changes: 16 additions & 14 deletions open_spiel/python/algorithms/psro_v2/psro_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,22 +290,13 @@ def update_agents(self):
The resulting policies are appended to self._new_policies.
"""

used_policies, used_indexes = self._training_strategy_selector(
self, self._number_policies_selected)

if self.symmetric_game:
self._policies = self._game_num_players * self._policies
self._num_players = self._game_num_players

sample_strategy, total_policies, probabilities_of_playing_policies = self.get_policies_and_strategies(
)

if self.symmetric_game:
# In a symmetric game, only one population is kept. The below lines
# therefore make PSRO consider only the first player during training,
# since both players are identical.
self._policies = [self._policies[0]]
self._num_players = 1
(sample_strategy,
total_policies,
probabilities_of_playing_policies) = self.get_policies_and_strategies()

# Contains the training parameters of all trained oracles.
# This is a list (Size num_players) of list (Size num_new_policies[player]),
Expand All @@ -317,7 +308,6 @@ def update_agents(self):
currently_used_policies = used_policies[current_player]
current_indexes = used_indexes[current_player]
else:
# Rectifying training will not work for joint strategies.
currently_used_policies = [
joint_policy[current_player] for joint_policy in used_policies
]
Expand All @@ -340,6 +330,11 @@ def update_agents(self):
}
training_parameters[current_player].append(new_parameter)

if self.symmetric_game:
self._policies = self._game_num_players * self._policies
self._num_players = self._game_num_players
training_parameters = [training_parameters[0]]

# List of List of new policies (One list per player)
self._new_policies = self._oracle(
self._game,
Expand All @@ -348,6 +343,13 @@ def update_agents(self):
using_joint_strategies=self._rectify_training or
not self.sample_from_marginals)

if self.symmetric_game:
# In a symmetric game, only one population is kept. The below lines
# therefore make PSRO consider only the first player during training,
# since both players are identical.
self._policies = [self._policies[0]]
self._num_players = 1

def update_empirical_gamestate(self, seed=None):
"""Given new agents in _new_policies, update meta_games through simulations.
Expand Down
36 changes: 22 additions & 14 deletions open_spiel/python/algorithms/psro_v2/rl_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,28 @@ def sample_episode(self, unused_time_step, agents, is_evaluation=False):
time_step = self._env.reset()
cumulative_rewards = 0.0
while not time_step.last():
player_id = time_step.observations["current_player"]

# is_evaluation is a boolean that, when False, lets policies train. The
# setting of PSRO requires that all policies be static aside from those
# being trained by the oracle. is_evaluation could be used to prevent
# policies from training, yet we have opted for adding a frozen attribute
# that prevents policies from training, for all values of is_evaluation.
# Since all policies returned by the oracle are frozen before being
# returned, only currently-trained policies can effectively learn.
agent_output = agents[player_id].step(
time_step, is_evaluation=is_evaluation)
action_list = [agent_output.action]
time_step = self._env.step(action_list)
cumulative_rewards += np.array(time_step.rewards)
if time_step.is_simultaneous_move():
action_list = []
for agent in agents:
output = agent.step(time_step, is_evaluation=is_evaluation)
action_list.append(output.action)
time_step = self._env.step(action_list)
cumulative_rewards += np.array(time_step.rewards)
else:
player_id = time_step.observations["current_player"]

# is_evaluation is a boolean that, when False, lets policies train. The
# setting of PSRO requires that all policies be static aside from those
# being trained by the oracle. is_evaluation could be used to prevent
# policies from training, yet we have opted for adding frozen attributes
# that prevents policies from training, for all values of is_evaluation.
# Since all policies returned by the oracle are frozen before being
# returned, only currently-trained policies can effectively learn.
agent_output = agents[player_id].step(
time_step, is_evaluation=is_evaluation)
action_list = [agent_output.action]
time_step = self._env.step(action_list)
cumulative_rewards += np.array(time_step.rewards)

if not is_evaluation:
for agent in agents:
Expand Down
7 changes: 0 additions & 7 deletions open_spiel/python/algorithms/psro_v2/rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def action_probabilities(self, state, player_id=None):
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)

cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)

cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)

step_type = rl_environment.StepType.LAST if state.is_terminal(
) else rl_environment.StepType.MID

Expand All @@ -98,7 +92,6 @@ def action_probabilities(self, state, player_id=None):
step_type=rl_environment.StepType.FIRST)
# pylint: enable=protected-access

# pylint: enable=protected-access
p = self._policy.step(time_step, is_evaluation=True).probs
prob_dict = {action: p[action] for action in legal_actions}
return prob_dict
Expand Down
25 changes: 13 additions & 12 deletions open_spiel/python/examples/psro_v2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,19 @@ def gpsro_looper(env, oracle, agents):
print("Meta game : {}".format(meta_game))
print("Probabilities : {}".format(meta_probabilities))

aggregator = policy_aggregator.PolicyAggregator(env.game)
aggr_policies = aggregator.aggregate(
range(FLAGS.n_players), policies, meta_probabilities)
# The following lines only work for sequential games for the moment.
if env.game.get_type().dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL:
aggregator = policy_aggregator.PolicyAggregator(env.game)
aggr_policies = aggregator.aggregate(
range(FLAGS.n_players), policies, meta_probabilities)

exploitabilities, expl_per_player = exploitability.nash_conv(
env.game, aggr_policies, return_only_nash_conv=False)
exploitabilities, expl_per_player = exploitability.nash_conv(
env.game, aggr_policies, return_only_nash_conv=False)

_ = print_policy_analysis(policies, env.game, FLAGS.verbose)
if FLAGS.verbose:
print("Exploitabilities : {}".format(exploitabilities))
print("Exploitabilities per player : {}".format(expl_per_player))
_ = print_policy_analysis(policies, env.game, FLAGS.verbose)
if FLAGS.verbose:
print("Exploitabilities : {}".format(exploitabilities))
print("Exploitabilities per player : {}".format(expl_per_player))


def main(argv):
Expand All @@ -292,9 +294,8 @@ def main(argv):

np.random.seed(FLAGS.seed)

game = pyspiel.load_game_as_turn_based(FLAGS.game_name,
{"players": pyspiel.GameParameter(
FLAGS.n_players)})
game = pyspiel.load_game(FLAGS.game_name,
{"players": pyspiel.GameParameter(FLAGS.n_players)})
env = rl_environment.Environment(game)

# Initialize oracle and agents
Expand Down

0 comments on commit a679985

Please sign in to comment.