diff --git a/open_spiel/python/algorithms/nash_averaging.py b/open_spiel/python/algorithms/nash_averaging.py index 39b2d4b578..36ad24fffb 100644 --- a/open_spiel/python/algorithms/nash_averaging.py +++ b/open_spiel/python/algorithms/nash_averaging.py @@ -82,8 +82,11 @@ def nash_averaging(game, eps=0.0, a_v_a=True): # game does not have to be symmetric m, n = p_mat[0].shape - a_mat = np.block([[np.zeros(shape=(m, m)), p_mat[0]], - [-p_mat[0].T, np.zeros(shape=(n, n))]]) + min_payoffs = np.min(p_mat[0], axis=1).reshape((m, 1)) + max_payoffs = np.max(p_mat[0], axis=1).reshape((m, 1)) + std_p_mat = (p_mat[0] - min_payoffs)/(max_payoffs-min_payoffs) + a_mat = np.block([[np.zeros(shape=(m, m)), std_p_mat], + [-std_p_mat.T, np.zeros(shape=(n, n))]]) maxent_nash = np.array(_max_entropy_symmetric_nash(a_mat, eps=eps)) pa, pe = maxent_nash[:m], maxent_nash[m:] - return (pa, pe), (p_mat[0].dot(pe), -p_mat[0].T.dot(pa)) + return (pa, pe), (std_p_mat.dot(pe), -std_p_mat.T.dot(pa))