Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hungarian algorithm for the linear assignment problem. #1083

Conversation

carlosgmartin
Copy link
Contributor

@fabianp
Copy link
Member

fabianp commented Oct 2, 2024

thanks for starting this @carlosgmartin !

At a high level this looks good to me, but it would be important to have an example of this method in the gallery (https://optax.readthedocs.io/en/latest/gallery.html) to showcase its usage

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating this
Here are a few comments

optax/__init__.py Outdated Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Outdated Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Outdated Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Show resolved Hide resolved
optax/assignment/_hungarian_algorithm_test.py Outdated Show resolved Hide resolved
optax/assignment/_hungarian_algorithm_test.py Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the linear_assignment_problem branch from fa0af8f to 2558ef9 Compare October 3, 2024 18:56
@carlosgmartin
Copy link
Contributor Author

@fabianp @vroulet Done.

@carlosgmartin carlosgmartin force-pushed the linear_assignment_problem branch from 2558ef9 to 0f1df6a Compare October 3, 2024 18:58
Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again! A few minor remaining comments

optax/assignment/_hungarian_algorithm_test.py Show resolved Hide resolved
optax/assignment/_hungarian_algorithm.py Outdated Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the linear_assignment_problem branch from 0f1df6a to b45bf72 Compare October 4, 2024 22:51
@carlosgmartin carlosgmartin force-pushed the linear_assignment_problem branch from b45bf72 to 2c41bad Compare October 4, 2024 22:53
@carlosgmartin
Copy link
Contributor Author

@vroulet Updated.

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, looks good to me

@copybara-service copybara-service bot merged commit 06ce57a into google-deepmind:main Oct 8, 2024
7 of 8 checks passed
@carlosgmartin carlosgmartin deleted the linear_assignment_problem branch October 8, 2024 22:02
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Oct 11, 2024

Perhaps it's worth mentioning somewhere in the docs that this can be used to compute the Wasserstein distance between empirical distributions:

Code
import argparse

import jax
import optax
from jax import numpy as jnp, random
from matplotlib import collections, pyplot as plt, rcParams


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--points", type=int, default=200)
    p.add_argument("--power", type=float, default=2.0)
    p.add_argument("--markersize", type=float, default=4.0)
    p.add_argument("--linewidth", type=float, default=1.0)
    return p.parse_args()


def get_wasserstein_distance(x, y, power=2):
    displacements = x[:, None] - y[None, :]
    pow_distances = (jnp.abs(displacements) ** power).sum(-1)
    i, j = optax.assignment.hungarian_algorithm(pow_distances)
    distance = pow_distances[i, j].mean() ** (1 / power)
    return distance, (i, j)


def estimate_wasserstein_distance(key, sample_fn_1, sample_fn_2, n_samples):
    keys = random.split(key)
    x = jax.vmap(sample_fn_1)(random.split(keys[0], n_samples))
    y = jax.vmap(sample_fn_2)(random.split(keys[1], n_samples))
    distance, (i, j) = get_wasserstein_distance(x, y)
    return distance


def main():
    args = parse_args()

    key = random.key(args.seed)

    keys = random.split(key)
    x = random.normal(keys[0], (args.points, 2))
    y = random.normal(keys[1], (args.points, 2)) + jnp.array([0.2, 0.0])

    distance, (i, j) = get_wasserstein_distance(x, y, args.power)

    fig, ax = plt.subplots(constrained_layout=True)

    ax.scatter(*x.T, s=args.markersize**2, label="facility", edgecolor="none")
    ax.scatter(*y.T, s=args.markersize**2, label="client", edgecolor="none")

    data = jnp.stack((x[i], y[j]), 1)
    lc = collections.LineCollection(
        list(data),
        linewidth=args.linewidth,
        color="lightgrey",
        zorder=-10,
        label="assignment",
    )
    ax.add_collection(lc)

    ax.set(title=f"Wasserstein {args.power}-distance: {distance:g}")

    ax.legend()

    rcParams["savefig.dpi"] = 300
    plt.show()


if __name__ == "__main__":
    main()

This in turn can be used to estimate the Wasserstein distance between arbitrary distributions, by taking the Wasserstein distance between large batches of samples from each.

@vroulet
Copy link
Collaborator

vroulet commented Oct 11, 2024

Yes, that's clearly worth it :)
Ideally we would like your effort to be adopted so such comments are good.
Also, I was wondering, your implementation is neat, but aren't there other implementations potentially more tpu friendly? (not very important, but one may mention it if it's the case? I'm not an expert on that, packages in optimal transport may have some clever tricks).

@carlosgmartin
Copy link
Contributor Author

@vroulet Right, there are faster specialized algorithms for approximating the Wasserstein distance (see OTT-JAX).

Perhaps I can append the example above to the LAP gallery entry?

@JTT94
Copy link

JTT94 commented Nov 19, 2024

Not a Contribution, but how does this compare to hungarian_tpu_matcher? from
https://github.com/google-research/scenic/blob/main/scenic/model_lib/matchers/hungarian_jax.py

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Nov 20, 2024

@JTT94 I ran a comparison. Interestingly, it looks like that implementation is both faster and yields a smaller jaxpr than the current optax implementation.

Example:

running on a batch of cost matrices of size (100000, 20, 21)
lax.map runtime for optax version: 47.6521
lax.map runtime for scenic version: 7.84279
vmap runtime for optax version: 37.7756
vmap runtime for scenic version: 17.987

This was on a CPU (Apple M3 Max). I encourage someone to try the comparison on a GPU.

@fabianp @vroulet Thoughts? If this implementation is indeed better, I could create a PR that replaces the current optax implementation with it.

@vroulet
Copy link
Collaborator

vroulet commented Nov 20, 2024

Hello @carlosgmartin,
Yes that would be great. Your implementation can be kept for test if you want. I think your implementation is the standard one so it's good to have but ideally we want the fastest implementation on accelerators.

@carlosgmartin
Copy link
Contributor Author

@vroulet Submitted a PR: #1140.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants