Skip to content

Commit

Permalink
utils and evaluations for regret matching, and meta regret matching a…
Browse files Browse the repository at this point in the history
…gents.

PiperOrigin-RevId: 501340246
Change-Id: I97c0271b99afe8b6d3a472b4d013afececad1fd6
  • Loading branch information
Elnaz Davoodi authored and lanctot committed Jan 16, 2023
1 parent c284ab2 commit c5d8f2e
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
98 changes: 98 additions & 0 deletions open_spiel/python/examples/meta_cfr/matrix_games/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Evaluation."""

from absl import flags
import jax
import jax.numpy as jnp
import numpy as np

FLAGS = flags.FLAGS


@jax.jit
def compute_best_response_strategy(utility):
actions_count = utility.shape[-1]
opponent_action = jnp.argmin(utility, axis=-1)
opponent_strategy = jax.nn.one_hot(opponent_action, actions_count)
return opponent_strategy


@jax.jit
def compute_values_against_best_response(strategy, payoff):
utility = jnp.matmul(strategy, payoff)
br_strategy = compute_best_response_strategy(utility)
return jnp.matmul(payoff, jnp.transpose(br_strategy))


def evaluate_against_best_response(agent, payoff_batch, steps_count):
"""Evaluation against best response agent.
Args:
agent: Agent model.
payoff_batch: Payoff matrix.
steps_count: Number of steps.
"""
current_policy = agent.initial_policy()
values = jax.vmap(compute_values_against_best_response)(current_policy,
payoff_batch)
for step in range(steps_count):
current_policy = agent.next_policy(values)
values = jax.vmap(compute_values_against_best_response)(current_policy,
payoff_batch)
values = jnp.transpose(values, [0, 1, 2])
value = jnp.matmul(current_policy, values)

for i in range(value.shape[0]):
print(step, np.mean(np.asarray(value[i])))


def compute_regrets(payoff_batch, strategy_x, strategy_y):
values_y = -jnp.matmul(strategy_x, payoff_batch)
values_x = jnp.transpose(
jnp.matmul(payoff_batch, jnp.transpose(strategy_y, [0, 2, 1])), [0, 2, 1])
value_x = jnp.matmul(
jnp.matmul(strategy_x, payoff_batch),
jnp.transpose(strategy_y, [0, 2, 1]))
value_y = -value_x
regrets_x = values_x - value_x
regrets_y = values_y - value_y
return regrets_x, regrets_y


def evaluate_in_selfplay(agent_x, agent_y, payoff_batch, steps_count):
"""Evalute in selfplay.
Args:
agent_x: First agent.
agent_y: Second agent.
payoff_batch: Payoff matrix.
steps_count: Number of steps.
"""
payoff_batch_size = payoff_batch.shape[0]

regret_sum_x = np.zeros(shape=[payoff_batch_size, 1, FLAGS.num_actions])
regret_sum_y = np.zeros(shape=[payoff_batch_size, 1, FLAGS.num_actions])
strategy_x = agent_x.initial_policy()
strategy_y = agent_y.initial_policy()

regrets_x, regrets_y = compute_regrets(payoff_batch, strategy_x, strategy_y)
regret_sum_x += regrets_x
regret_sum_y += regrets_y
for s in range(steps_count):
values_y = -jnp.matmul(strategy_x, payoff_batch)
values_x = jnp.transpose(
jnp.matmul(payoff_batch, jnp.transpose(strategy_y, [0, 2, 1])),
[0, 2, 1])

values_x = jnp.transpose(values_x, [0, 2, 1])
values_y = jnp.transpose(values_y, [0, 2, 1])
strategy_x = agent_x.next_policy(values_x)
strategy_y = agent_y.next_policy(values_y)

regrets_x, regrets_y = compute_regrets(payoff_batch, strategy_x, strategy_y)
regret_sum_x += regrets_x
regret_sum_y += regrets_y
print(
jnp.mean(
jnp.max(
jnp.concatenate([regret_sum_x, regret_sum_y], axis=2),
axis=[1, 2]) / (s + 1)))
68 changes: 68 additions & 0 deletions open_spiel/python/examples/meta_cfr/matrix_games/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Utility functions for meta learning for regret minimization."""

from absl import flags
import jax
import jax.numpy as jnp
import numpy as np

FLAGS = flags.FLAGS


def meta_loss(opt_params, net_apply, payoff, steps):

"""Returns the meta learning loss value.
Args:
opt_params: Optimizer parameters.
net_apply: Apply function.
payoff: Payoff matrix.
steps: Number of steps.
Returns:
Accumulated loss value over number of steps.
"""
regret_sum_x = np.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions])
regret_sum_y = np.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions])
total_loss = 0
step = 0

@jax.jit
def scan_body(carry, x):
nonlocal regret_sum_x
nonlocal regret_sum_y
regret_sum_x, regret_sum_y, current_step, total_loss = carry
x = net_apply(opt_params, None, regret_sum_x / (current_step + 1))
y = net_apply(opt_params, None, regret_sum_y / (current_step + 1))

strategy_x = jax.nn.softmax(x)
strategy_y = jnp.transpose(jax.nn.softmax(y), [0, 2, 1])

values_x = jnp.matmul(payoff, strategy_y) # val_x = payoff * st_y
values_y = -jnp.matmul(strategy_x, payoff) # val_y = -1 * payoff * st_x

value_x = jnp.matmul(jnp.matmul(strategy_x, payoff), strategy_y)
value_y = -value_x

curren_regret_x = values_x - value_x
curren_regret_y = values_y - value_y
curren_regret_x = jnp.transpose(curren_regret_x, [0, 2, 1])

regret_sum_x += curren_regret_x
regret_sum_y += curren_regret_y

current_loss = jnp.mean(jnp.max(
jax.numpy.concatenate([curren_regret_x, curren_regret_y], axis=2),
axis=[1, 2]), axis=-1)
total_loss += current_loss
current_step += 1
return (regret_sum_x, regret_sum_y, current_step, total_loss), None

(regret_sum_x, regret_sum_y, step, total_loss), _ = jax.lax.scan(
scan_body,
(regret_sum_x, regret_sum_y, step, total_loss),
None,
length=steps,
)

return total_loss

0 comments on commit c5d8f2e

Please sign in to comment.