forked from google-deepmind/open_spiel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
utils and evaluations for regret matching, and meta regret matching a…
…gents. PiperOrigin-RevId: 501340246 Change-Id: I97c0271b99afe8b6d3a472b4d013afececad1fd6
- Loading branch information
Showing
2 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
98 changes: 98 additions & 0 deletions
98
open_spiel/python/examples/meta_cfr/matrix_games/evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""Evaluation.""" | ||
|
||
from absl import flags | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
@jax.jit | ||
def compute_best_response_strategy(utility): | ||
actions_count = utility.shape[-1] | ||
opponent_action = jnp.argmin(utility, axis=-1) | ||
opponent_strategy = jax.nn.one_hot(opponent_action, actions_count) | ||
return opponent_strategy | ||
|
||
|
||
@jax.jit | ||
def compute_values_against_best_response(strategy, payoff): | ||
utility = jnp.matmul(strategy, payoff) | ||
br_strategy = compute_best_response_strategy(utility) | ||
return jnp.matmul(payoff, jnp.transpose(br_strategy)) | ||
|
||
|
||
def evaluate_against_best_response(agent, payoff_batch, steps_count): | ||
"""Evaluation against best response agent. | ||
Args: | ||
agent: Agent model. | ||
payoff_batch: Payoff matrix. | ||
steps_count: Number of steps. | ||
""" | ||
current_policy = agent.initial_policy() | ||
values = jax.vmap(compute_values_against_best_response)(current_policy, | ||
payoff_batch) | ||
for step in range(steps_count): | ||
current_policy = agent.next_policy(values) | ||
values = jax.vmap(compute_values_against_best_response)(current_policy, | ||
payoff_batch) | ||
values = jnp.transpose(values, [0, 1, 2]) | ||
value = jnp.matmul(current_policy, values) | ||
|
||
for i in range(value.shape[0]): | ||
print(step, np.mean(np.asarray(value[i]))) | ||
|
||
|
||
def compute_regrets(payoff_batch, strategy_x, strategy_y): | ||
values_y = -jnp.matmul(strategy_x, payoff_batch) | ||
values_x = jnp.transpose( | ||
jnp.matmul(payoff_batch, jnp.transpose(strategy_y, [0, 2, 1])), [0, 2, 1]) | ||
value_x = jnp.matmul( | ||
jnp.matmul(strategy_x, payoff_batch), | ||
jnp.transpose(strategy_y, [0, 2, 1])) | ||
value_y = -value_x | ||
regrets_x = values_x - value_x | ||
regrets_y = values_y - value_y | ||
return regrets_x, regrets_y | ||
|
||
|
||
def evaluate_in_selfplay(agent_x, agent_y, payoff_batch, steps_count): | ||
"""Evalute in selfplay. | ||
Args: | ||
agent_x: First agent. | ||
agent_y: Second agent. | ||
payoff_batch: Payoff matrix. | ||
steps_count: Number of steps. | ||
""" | ||
payoff_batch_size = payoff_batch.shape[0] | ||
|
||
regret_sum_x = np.zeros(shape=[payoff_batch_size, 1, FLAGS.num_actions]) | ||
regret_sum_y = np.zeros(shape=[payoff_batch_size, 1, FLAGS.num_actions]) | ||
strategy_x = agent_x.initial_policy() | ||
strategy_y = agent_y.initial_policy() | ||
|
||
regrets_x, regrets_y = compute_regrets(payoff_batch, strategy_x, strategy_y) | ||
regret_sum_x += regrets_x | ||
regret_sum_y += regrets_y | ||
for s in range(steps_count): | ||
values_y = -jnp.matmul(strategy_x, payoff_batch) | ||
values_x = jnp.transpose( | ||
jnp.matmul(payoff_batch, jnp.transpose(strategy_y, [0, 2, 1])), | ||
[0, 2, 1]) | ||
|
||
values_x = jnp.transpose(values_x, [0, 2, 1]) | ||
values_y = jnp.transpose(values_y, [0, 2, 1]) | ||
strategy_x = agent_x.next_policy(values_x) | ||
strategy_y = agent_y.next_policy(values_y) | ||
|
||
regrets_x, regrets_y = compute_regrets(payoff_batch, strategy_x, strategy_y) | ||
regret_sum_x += regrets_x | ||
regret_sum_y += regrets_y | ||
print( | ||
jnp.mean( | ||
jnp.max( | ||
jnp.concatenate([regret_sum_x, regret_sum_y], axis=2), | ||
axis=[1, 2]) / (s + 1))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Utility functions for meta learning for regret minimization.""" | ||
|
||
from absl import flags | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def meta_loss(opt_params, net_apply, payoff, steps): | ||
|
||
"""Returns the meta learning loss value. | ||
Args: | ||
opt_params: Optimizer parameters. | ||
net_apply: Apply function. | ||
payoff: Payoff matrix. | ||
steps: Number of steps. | ||
Returns: | ||
Accumulated loss value over number of steps. | ||
""" | ||
regret_sum_x = np.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions]) | ||
regret_sum_y = np.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions]) | ||
total_loss = 0 | ||
step = 0 | ||
|
||
@jax.jit | ||
def scan_body(carry, x): | ||
nonlocal regret_sum_x | ||
nonlocal regret_sum_y | ||
regret_sum_x, regret_sum_y, current_step, total_loss = carry | ||
x = net_apply(opt_params, None, regret_sum_x / (current_step + 1)) | ||
y = net_apply(opt_params, None, regret_sum_y / (current_step + 1)) | ||
|
||
strategy_x = jax.nn.softmax(x) | ||
strategy_y = jnp.transpose(jax.nn.softmax(y), [0, 2, 1]) | ||
|
||
values_x = jnp.matmul(payoff, strategy_y) # val_x = payoff * st_y | ||
values_y = -jnp.matmul(strategy_x, payoff) # val_y = -1 * payoff * st_x | ||
|
||
value_x = jnp.matmul(jnp.matmul(strategy_x, payoff), strategy_y) | ||
value_y = -value_x | ||
|
||
curren_regret_x = values_x - value_x | ||
curren_regret_y = values_y - value_y | ||
curren_regret_x = jnp.transpose(curren_regret_x, [0, 2, 1]) | ||
|
||
regret_sum_x += curren_regret_x | ||
regret_sum_y += curren_regret_y | ||
|
||
current_loss = jnp.mean(jnp.max( | ||
jax.numpy.concatenate([curren_regret_x, curren_regret_y], axis=2), | ||
axis=[1, 2]), axis=-1) | ||
total_loss += current_loss | ||
current_step += 1 | ||
return (regret_sum_x, regret_sum_y, current_step, total_loss), None | ||
|
||
(regret_sum_x, regret_sum_y, step, total_loss), _ = jax.lax.scan( | ||
scan_body, | ||
(regret_sum_x, regret_sum_y, step, total_loss), | ||
None, | ||
length=steps, | ||
) | ||
|
||
return total_loss |