Skip to content

Commit

Permalink
For exploitability descent, changes indices to come from outside of t…
Browse files Browse the repository at this point in the history
…he graph.

PiperOrigin-RevId: 271998981
Change-Id: Ia46f7fd6871061ee3ac63859e45ac2bf48f3a3c5
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Sep 30, 2019
1 parent 61075f4 commit 3ada59c
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions open_spiel/python/algorithms/exploitability_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def loss(self, policy_values):
loss_per_state = -tf.reduce_sum(policy_values * advantage, axis=-1)
return nash_conv, tf.reduce_sum(loss_per_state * cf_reach_probabilities)

def minibatch_loss(self, policy_values, minibatch_size=128):
def minibatch_loss(self, policy_values, indices):
"""Returns the exploitability descent loss given a policy for a subset."""

evaluate_policy = _create_policy_evaluator(self.tabular_policy,
Expand All @@ -120,11 +120,6 @@ def minibatch_loss(self, policy_values, minibatch_size=128):
advantage = q_values - tf.stop_gradient(baseline)

# We now randomly select a minibatch from the data to propagate our loss on.
# This is done with replacement.
indices = tf.random.uniform((minibatch_size,),
maxval=tf.gather(tf.shape(policy_values), 0),
dtype=tf.dtypes.int32,
name="random_indices")
policy_values = tf.gather(policy_values, indices)
advantage = tf.gather(advantage, indices)
cf_reach_probabilities = tf.gather(cf_reach_probabilities, indices)
Expand Down

0 comments on commit 3ada59c

Please sign in to comment.