forked from google-deepmind/open_spiel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
regret matching and meta regret matching agents for matrix games.
PiperOrigin-RevId: 501340036 Change-Id: I1f236ef8af91afced3aad185d3280e8fbacc1bc0
- Loading branch information
Showing
2 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
118 changes: 118 additions & 0 deletions
118
open_spiel/python/examples/meta_cfr/matrix_games/meta_selfplay_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Meta-regret matching with self-play agents.""" | ||
from typing import List | ||
|
||
from absl import flags | ||
import haiku as hk | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
|
||
from open_spiel.python.examples.meta_cfr.matrix_games import utils | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def opponent_best_response_strategy(utility): | ||
opponent_action = jnp.argmin(utility, axis=-1) | ||
opponent_strategy = jax.nn.one_hot(opponent_action, FLAGS.num_actions) | ||
return opponent_strategy | ||
|
||
|
||
def _mlp_forwards(mlp_hidden_sizes: List[int]) -> hk.Transformed: | ||
"""Returns a haiku transformation of the MLP model to be used in optimizer. | ||
Args: | ||
mlp_hidden_sizes: List containing size of linear layers. | ||
Returns: | ||
Haiku transformation of the RNN network. | ||
""" | ||
def forward_fn(inputs): | ||
mlp = hk.nets.MLP(mlp_hidden_sizes, activation=jax.nn.relu, name="mlp") | ||
return mlp(inputs) | ||
return hk.transform(forward_fn) | ||
|
||
|
||
class OptimizerModel: | ||
"""Optimizer model.""" | ||
|
||
def __init__(self, learning_rate): | ||
self.learning_rate = learning_rate | ||
|
||
self.model = _mlp_forwards([64, 16, FLAGS.num_actions]) | ||
|
||
self._net_init = self.model.init | ||
self.net_apply = self.model.apply | ||
|
||
self.opt_update, self.net_params, self.opt_state = None, None, None | ||
|
||
def lr_scheduler(self, init_value): | ||
schedule_fn = optax.polynomial_schedule( | ||
init_value=init_value, end_value=0.05, power=1., transition_steps=50) | ||
return schedule_fn | ||
|
||
def get_optimizer_model(self): | ||
schedule_fn = self.lr_scheduler(self.learning_rate) | ||
opt_init, self.opt_update = optax.chain( | ||
optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn), | ||
optax.scale(-self.learning_rate)) | ||
rng = jax.random.PRNGKey(10) | ||
dummy_input = np.random.normal( | ||
loc=0, scale=10., size=(FLAGS.batch_size, 1, FLAGS.num_actions)) | ||
self.net_params = self._net_init(rng, dummy_input) | ||
self.opt_state = opt_init(self.net_params) | ||
|
||
|
||
class MetaSelfplayAgent: | ||
"""Meta player.""" | ||
|
||
def __init__(self, repeats, training_epochs, data_loader): | ||
self.repeats = repeats | ||
self.training_epochs = training_epochs | ||
self.net_apply = None | ||
self.net_params = None | ||
self.regret_sum = None | ||
self.step = 0 | ||
self.data_loader = data_loader | ||
|
||
def train(self): | ||
self.training_optimizer() | ||
self.regret_sum = jnp.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions]) | ||
|
||
def initial_policy(self): | ||
x = self.net_apply(self.net_params, None, self.regret_sum) | ||
self.last_policy = jax.nn.softmax(x) | ||
self.step += 1 | ||
return self.last_policy | ||
|
||
def next_policy(self, last_values): | ||
value = jnp.matmul(self.last_policy, last_values) | ||
curren_regret = jnp.transpose(last_values, [0, 2, 1]) - value | ||
self.regret_sum += curren_regret | ||
|
||
x = self.net_apply(self.net_params, None, self.regret_sum / (self.step + 1)) | ||
self.last_policy = jax.nn.softmax(x) | ||
self.step += 1 | ||
return self.last_policy | ||
|
||
def training_optimizer(self): | ||
"""Training optimizer.""" | ||
|
||
optimizer = OptimizerModel(0.01) | ||
optimizer.get_optimizer_model() | ||
|
||
for _ in range(FLAGS.num_batches): | ||
batch_payoff = next(self.data_loader) | ||
# for _ in range(self.repeats): | ||
grads = jax.grad( | ||
utils.meta_loss, | ||
has_aux=False)(optimizer.net_params, optimizer.net_apply, | ||
batch_payoff, self.training_epochs) | ||
|
||
updates, optimizer.opt_state = optimizer.opt_update( | ||
grads, optimizer.opt_state) | ||
optimizer.net_params = optax.apply_updates(optimizer.net_params, updates) | ||
|
||
self.net_apply = optimizer.net_apply | ||
self.net_params = optimizer.net_params |
46 changes: 46 additions & 0 deletions
46
open_spiel/python/examples/meta_cfr/matrix_games/regret_matching_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Regret matching.""" | ||
from absl import flags | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
class RegretMatchingAgent: | ||
"""Regret matching agent.""" | ||
|
||
def __init__(self, num_actions, data_loader): | ||
self.num_actions = num_actions | ||
# self.regret_sum = jax.numpy.array(np.zeros(self.num_actions)) | ||
self.regret_sum = jax.numpy.array( | ||
np.zeros(shape=[FLAGS.batch_size, 1, self.num_actions])) | ||
self.data_loader = data_loader | ||
|
||
def train(self): | ||
pass | ||
|
||
def initial_policy(self): | ||
self.last_policy = self.regret_matching_policy(self.regret_sum) | ||
return self.last_policy | ||
|
||
def next_policy(self, last_values): | ||
value = jnp.matmul(self.last_policy, last_values) | ||
last_values = jnp.transpose(last_values, [0, 2, 1]) | ||
current_regrets = last_values - value | ||
self.regret_sum += current_regrets | ||
self.last_policy = self.regret_matching_policy(self.regret_sum) | ||
return self.last_policy | ||
|
||
def regret_matching_policy(self, regret_sum): | ||
"""Regret matching policy.""" | ||
|
||
strategy = np.copy(regret_sum) | ||
strategy[strategy < 0] = 0 | ||
strategy_sum = np.sum(strategy, axis=-1) | ||
for i in range(FLAGS.batch_size): | ||
if strategy_sum[i] > 0: | ||
strategy[i] /= strategy_sum[i] | ||
else: | ||
strategy[i] = np.repeat(1 / self.num_actions, self.num_actions) | ||
return strategy |