Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hungarian solver from optax #598

Merged
merged 10 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add Hungarian solver
  • Loading branch information
marcocuturi committed Nov 18, 2024
commit f823c1232189ea50d7503fe8b2682f789472aa7c
9 changes: 9 additions & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ differentiable approximations to ranks and quantile functions :cite:`cuturi:19`,
and various tools to study Gaussians with the 2-Wasserstein metric
:cite:`gelbrich:90,delon:20`, etc.

Unregularized Optimal Transport
-------------------------------
.. autosummary::
:toctree: _autosummary

unreg.hungarian
unreg.HungarianOutput


Segmented Sinkhorn
------------------
.. autosummary::
Expand Down
27 changes: 27 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,33 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*aux_data)


@jtu.register_pytree_node_class
class PNorm(TICost):
r""":math:`p`-norm between vectors.

Uses custom implementation of `norm` to avoid `NaN` values when
differentiating the norm of :math:`x-x`.

Args:
p: Power of the p-norm in :math:`[1, +\infty)`.
"""

def __init__(self, p: float):
super().__init__()
self.p = p

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return mu.norm(z, self.p) / self.p

def tree_flatten(self): # noqa: D102
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)


@jtu.register_pytree_node_class
class RegTICost(TICost):
r"""Regularized translation-invariant cost.
Expand Down
24 changes: 15 additions & 9 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,25 @@ class Geometry:
cost_matrix: Cost matrix of shape ``[n, m]``.
kernel_matrix: Kernel matrix of shape ``[n, m]``.
epsilon: Regularization parameter. If ``None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None``, this defaults
to the value computed in :attr:`mean_cost_matrix` / 20. If passed as a
``relative_epsilon = True`` or ``relative_epsilon = None`` or
``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std``
, this value defaults to a multiple of :attr:`std_cost_matrix`
(or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple
is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```.
If passed as a
``float``, then the regularizer that is ultimately used is either that
``float`` value (if ``relative_epsilon = False`` or ``None``) or that
``float`` times the :attr:`mean_cost_matrix`
(if ``relative_epsilon = True``). Look for
``float`` times the :attr:`std_cost_matrix` (if
``relative_epsilon = True`` or ``relative_epsilon = `std```) or
:attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for
:class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a
scheduler.
relative_epsilon: when `False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When `True`, ``epsilon``
refers to a fraction of the :attr:`std_cost_matrix`, which is computed
adaptively from data. Can also be set to ``mean`` or ``std`` to use mean
of cost matrix if necessary.
relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When :obj:`True` or set
to a string, ``epsilon`` refers to a fraction of the
:attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed
adaptively from data, depending on whether it is set to ``mean`` or
``std``.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can
be given to rescale the cost such that ``cost_matrix /= scale_cost``.
Expand Down
1 change: 1 addition & 0 deletions src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
sinkhorn_divergence,
sliced,
soft_sort,
unreg,
)
89 changes: 89 additions & 0 deletions src/ott/tools/unreg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple, Optional, Tuple

import jax.experimental.sparse as jesp
import jax.numpy as jnp

from optax import assignment

from ott.geometry import costs, geometry, pointcloud

__all__ = ["hungarian"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all = ["HungarianOutput", "hungarian", "wassdis_p"]



class HungarianOutput(NamedTuple):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
r"""Output of the Hungarian solver.

Args:
geom: geometry object
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geometry object.

paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs
of indices, for which the optimal transport assigns mass. Namely, for each
index :math:`0 <= k < n`, if one has
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use \leq

:math:`i := \text{paired_indices}[0, k]` and
:math:`j := \text{paired_indices}[1, k]`, then point :math:`i` in
the first geometry sends mass to point :math:`j` in the second.
"""
geom: geometry.Geometry
paired_indices: Optional[jnp.ndarray] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove the optional and default none.


@property
def matrix(self) -> jesp.BCOO:
"""``[n, n]`` transport matrix in sparse format, with ``n`` NNZ entries."""
n, _ = self.geom.shape
unit_mass = jnp.ones((n,)) / n
indices = self.paired_indices.swapaxes(0, 1)
return jesp.BCOO((unit_mass, indices), shape=(n, n))


def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]:
"""Solve assignment problem using Hungarian as implemented in optax.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:mod:`optax`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please add this to intersphinx_mapping in conf.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it there already?

intersphinx_mapping = {
    "python": ("https://docs.python.org/3", None),
    "numpy": ("https://numpy.org/doc/stable/", None),
    "jax": ("https://jax.readthedocs.io/en/latest/", None),
    "jaxopt": ("https://jaxopt.github.io/stable", None),
    "lineax": ("https://docs.kidger.site/lineax/", None),
    "flax": ("https://flax.readthedocs.io/en/latest/", None),
    "optax": ("https://optax.readthedocs.io/en/latest/", None),
    "diffrax": ("https://docs.kidger.site/diffrax/", None),
    "scipy": ("https://docs.scipy.org/doc/scipy/", None),
    "pot": ("https://pythonot.github.io/", None),
    "matplotlib": ("https://matplotlib.org/stable/", None),
}


Args:
geom: (square) geometry object.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would say

Geometry of shape ``[n, n]``.


Returns:
The value of the unregularized OT problem, along with an output
object listing relevant information on outputs.
"""
n, m = geom.shape
assert n == m, f"Hungarian can only match same # of points, got {n} and {m}."
i, j = assignment.hungarian_algorithm(geom.cost_matrix)

hungarian_out = HungarianOutput(geom=geom, paired_indices=jnp.stack((i, j)))
return jnp.sum(geom.cost_matrix[i, j]) / n, hungarian_out


def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float:
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove the convenience wrapper part and rephrase as

Compute the :term:`Wasserstein distance` using the Hungarian algoritm.

or similar.


Uses :func:`~ott.tools.unreg.hungarian` to solve the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use :func:`hungarian`.

:term:`Kantorovich problem` between two point clouds of the same
size to compute a :term:`Wasserstein distance` estimator.

Note:
At the moment, only supports point clouds of the same size to be easily
cast as an optimal matching problem.

Args:
x: ``[n,d]`` point cloud
y: ``[n,d]`` point cloud of the same size
p: order of the Wasserstein distance, non-negative float.

Returns:
The p-Wasserstein distance between these point clouds.hungarian
"""
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p))
cost, _ = hungarian(geom)
return cost ** 1. / p
60 changes: 60 additions & 0 deletions tests/tools/unreg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import pytest

import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.tools import unreg


class TestHungarian:

@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None])
def test_matches_sink(self, rng: jax.Array, cost_fn: Optional[costs.CostFn]):
n, m, dim = 12, 12, 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove m.

rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005)
cost_hung, out_hung = unreg.hungarian(geom)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add jax.jit.

out_sink = linear.solve(geom)
np.testing.assert_allclose(
out_sink.primal_cost, cost_hung, rtol=1e-3, atol=1e-3
)
np.testing.assert_allclose(
out_sink.matrix, out_hung.matrix.todense(), rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("p", [1.3, 2.3])
def test_wass(self, rng: jax.Array, p: float):
n, m, dim = 12, 12, 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove m

rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p=p))
cost_hung, _ = unreg.hungarian(geom)
w_p = unreg.wassdis_p(x, y, p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add jax.jit.

np.testing.assert_allclose(w_p, cost_hung ** 1. / p, rtol=1e-3, atol=1e-3)


def gen_data(rng: jax.Array, n: int, m: int,
dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
rngs = jax.random.split(rng, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split into 2 keys please.

x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
return x, y