-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Hungarian algorithm for the linear assignment problem. #1083
Add Hungarian algorithm for the linear assignment problem. #1083
Conversation
bbd8386
to
fa0af8f
Compare
thanks for starting this @carlosgmartin ! At a high level this looks good to me, but it would be important to have an example of this method in the gallery (https://optax.readthedocs.io/en/latest/gallery.html) to showcase its usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for integrating this
Here are a few comments
fa0af8f
to
2558ef9
Compare
2558ef9
to
0f1df6a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again! A few minor remaining comments
0f1df6a
to
b45bf72
Compare
b45bf72
to
2c41bad
Compare
@vroulet Updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, looks good to me
06ce57a
into
google-deepmind:main
Perhaps it's worth mentioning somewhere in the docs that this can be used to compute the Wasserstein distance between empirical distributions: Codeimport argparse
import jax
import optax
from jax import numpy as jnp, random
from matplotlib import collections, pyplot as plt, rcParams
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=0)
p.add_argument("--points", type=int, default=200)
p.add_argument("--power", type=float, default=2.0)
p.add_argument("--markersize", type=float, default=4.0)
p.add_argument("--linewidth", type=float, default=1.0)
return p.parse_args()
def get_wasserstein_distance(x, y, power=2):
displacements = x[:, None] - y[None, :]
pow_distances = (jnp.abs(displacements) ** power).sum(-1)
i, j = optax.assignment.hungarian_algorithm(pow_distances)
distance = pow_distances[i, j].mean() ** (1 / power)
return distance, (i, j)
def estimate_wasserstein_distance(key, sample_fn_1, sample_fn_2, n_samples):
keys = random.split(key)
x = jax.vmap(sample_fn_1)(random.split(keys[0], n_samples))
y = jax.vmap(sample_fn_2)(random.split(keys[1], n_samples))
distance, (i, j) = get_wasserstein_distance(x, y)
return distance
def main():
args = parse_args()
key = random.key(args.seed)
keys = random.split(key)
x = random.normal(keys[0], (args.points, 2))
y = random.normal(keys[1], (args.points, 2)) + jnp.array([0.2, 0.0])
distance, (i, j) = get_wasserstein_distance(x, y, args.power)
fig, ax = plt.subplots(constrained_layout=True)
ax.scatter(*x.T, s=args.markersize**2, label="facility", edgecolor="none")
ax.scatter(*y.T, s=args.markersize**2, label="client", edgecolor="none")
data = jnp.stack((x[i], y[j]), 1)
lc = collections.LineCollection(
list(data),
linewidth=args.linewidth,
color="lightgrey",
zorder=-10,
label="assignment",
)
ax.add_collection(lc)
ax.set(title=f"Wasserstein {args.power}-distance: {distance:g}")
ax.legend()
rcParams["savefig.dpi"] = 300
plt.show()
if __name__ == "__main__":
main() This in turn can be used to estimate the Wasserstein distance between arbitrary distributions, by taking the Wasserstein distance between large batches of samples from each. |
Yes, that's clearly worth it :) |
@vroulet Right, there are faster specialized algorithms for approximating the Wasserstein distance (see OTT-JAX). Perhaps I can append the example above to the LAP gallery entry? |
Not a Contribution, but how does this compare to |
@JTT94 I ran a comparison. Interestingly, it looks like that implementation is both faster and yields a smaller jaxpr than the current optax implementation. Example:
This was on a CPU (Apple M3 Max). I encourage someone to try the comparison on a GPU. @fabianp @vroulet Thoughts? If this implementation is indeed better, I could create a PR that replaces the current optax implementation with it. |
Hello @carlosgmartin, |
#954