Skip to content

Commit

Permalink
regret matching and meta regret matching agents for matrix games.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501340036
Change-Id: I1f236ef8af91afced3aad185d3280e8fbacc1bc0
  • Loading branch information
Elnaz Davoodi authored and lanctot committed Jan 16, 2023
1 parent 835173e commit c284ab2
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Meta-regret matching with self-play agents."""
from typing import List

from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax

from open_spiel.python.examples.meta_cfr.matrix_games import utils

FLAGS = flags.FLAGS


def opponent_best_response_strategy(utility):
opponent_action = jnp.argmin(utility, axis=-1)
opponent_strategy = jax.nn.one_hot(opponent_action, FLAGS.num_actions)
return opponent_strategy


def _mlp_forwards(mlp_hidden_sizes: List[int]) -> hk.Transformed:
"""Returns a haiku transformation of the MLP model to be used in optimizer.
Args:
mlp_hidden_sizes: List containing size of linear layers.
Returns:
Haiku transformation of the RNN network.
"""
def forward_fn(inputs):
mlp = hk.nets.MLP(mlp_hidden_sizes, activation=jax.nn.relu, name="mlp")
return mlp(inputs)
return hk.transform(forward_fn)


class OptimizerModel:
"""Optimizer model."""

def __init__(self, learning_rate):
self.learning_rate = learning_rate

self.model = _mlp_forwards([64, 16, FLAGS.num_actions])

self._net_init = self.model.init
self.net_apply = self.model.apply

self.opt_update, self.net_params, self.opt_state = None, None, None

def lr_scheduler(self, init_value):
schedule_fn = optax.polynomial_schedule(
init_value=init_value, end_value=0.05, power=1., transition_steps=50)
return schedule_fn

def get_optimizer_model(self):
schedule_fn = self.lr_scheduler(self.learning_rate)
opt_init, self.opt_update = optax.chain(
optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn),
optax.scale(-self.learning_rate))
rng = jax.random.PRNGKey(10)
dummy_input = np.random.normal(
loc=0, scale=10., size=(FLAGS.batch_size, 1, FLAGS.num_actions))
self.net_params = self._net_init(rng, dummy_input)
self.opt_state = opt_init(self.net_params)


class MetaSelfplayAgent:
"""Meta player."""

def __init__(self, repeats, training_epochs, data_loader):
self.repeats = repeats
self.training_epochs = training_epochs
self.net_apply = None
self.net_params = None
self.regret_sum = None
self.step = 0
self.data_loader = data_loader

def train(self):
self.training_optimizer()
self.regret_sum = jnp.zeros(shape=[FLAGS.batch_size, 1, FLAGS.num_actions])

def initial_policy(self):
x = self.net_apply(self.net_params, None, self.regret_sum)
self.last_policy = jax.nn.softmax(x)
self.step += 1
return self.last_policy

def next_policy(self, last_values):
value = jnp.matmul(self.last_policy, last_values)
curren_regret = jnp.transpose(last_values, [0, 2, 1]) - value
self.regret_sum += curren_regret

x = self.net_apply(self.net_params, None, self.regret_sum / (self.step + 1))
self.last_policy = jax.nn.softmax(x)
self.step += 1
return self.last_policy

def training_optimizer(self):
"""Training optimizer."""

optimizer = OptimizerModel(0.01)
optimizer.get_optimizer_model()

for _ in range(FLAGS.num_batches):
batch_payoff = next(self.data_loader)
# for _ in range(self.repeats):
grads = jax.grad(
utils.meta_loss,
has_aux=False)(optimizer.net_params, optimizer.net_apply,
batch_payoff, self.training_epochs)

updates, optimizer.opt_state = optimizer.opt_update(
grads, optimizer.opt_state)
optimizer.net_params = optax.apply_updates(optimizer.net_params, updates)

self.net_apply = optimizer.net_apply
self.net_params = optimizer.net_params
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Regret matching."""
from absl import flags
import jax
import jax.numpy as jnp
import numpy as np

FLAGS = flags.FLAGS


class RegretMatchingAgent:
"""Regret matching agent."""

def __init__(self, num_actions, data_loader):
self.num_actions = num_actions
# self.regret_sum = jax.numpy.array(np.zeros(self.num_actions))
self.regret_sum = jax.numpy.array(
np.zeros(shape=[FLAGS.batch_size, 1, self.num_actions]))
self.data_loader = data_loader

def train(self):
pass

def initial_policy(self):
self.last_policy = self.regret_matching_policy(self.regret_sum)
return self.last_policy

def next_policy(self, last_values):
value = jnp.matmul(self.last_policy, last_values)
last_values = jnp.transpose(last_values, [0, 2, 1])
current_regrets = last_values - value
self.regret_sum += current_regrets
self.last_policy = self.regret_matching_policy(self.regret_sum)
return self.last_policy

def regret_matching_policy(self, regret_sum):
"""Regret matching policy."""

strategy = np.copy(regret_sum)
strategy[strategy < 0] = 0
strategy_sum = np.sum(strategy, axis=-1)
for i in range(FLAGS.batch_size):
if strategy_sum[i] > 0:
strategy[i] /= strategy_sum[i]
else:
strategy[i] = np.repeat(1 / self.num_actions, self.num_actions)
return strategy

0 comments on commit c284ab2

Please sign in to comment.