-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Hungarian solver from optax #598
Changes from 1 commit
f823c12
39e56cb
07ceaac
4fd262d
8ac3ea5
74a3ed7
61f13eb
ab0139b
b2d6973
02193ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,4 +20,5 @@ | |
sinkhorn_divergence, | ||
sliced, | ||
soft_sort, | ||
unreg, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import NamedTuple, Optional, Tuple | ||
|
||
import jax.experimental.sparse as jesp | ||
import jax.numpy as jnp | ||
|
||
from optax import assignment | ||
|
||
from ott.geometry import costs, geometry, pointcloud | ||
|
||
__all__ = ["hungarian"] | ||
|
||
|
||
class HungarianOutput(NamedTuple): | ||
michalk8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r"""Output of the Hungarian solver. | ||
|
||
Args: | ||
geom: geometry object | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Geometry object. |
||
paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs | ||
of indices, for which the optimal transport assigns mass. Namely, for each | ||
index :math:`0 <= k < n`, if one has | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
:math:`i := \text{paired_indices}[0, k]` and | ||
:math:`j := \text{paired_indices}[1, k]`, then point :math:`i` in | ||
the first geometry sends mass to point :math:`j` in the second. | ||
""" | ||
geom: geometry.Geometry | ||
paired_indices: Optional[jnp.ndarray] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would remove the optional and default none. |
||
|
||
@property | ||
def matrix(self) -> jesp.BCOO: | ||
"""``[n, n]`` transport matrix in sparse format, with ``n`` NNZ entries.""" | ||
n, _ = self.geom.shape | ||
unit_mass = jnp.ones((n,)) / n | ||
indices = self.paired_indices.swapaxes(0, 1) | ||
return jesp.BCOO((unit_mass, indices), shape=(n, n)) | ||
|
||
|
||
def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]: | ||
"""Solve assignment problem using Hungarian as implemented in optax. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :mod:`optax` There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also please add this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't it there already?
|
||
|
||
Args: | ||
geom: (square) geometry object. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would say
|
||
|
||
Returns: | ||
The value of the unregularized OT problem, along with an output | ||
object listing relevant information on outputs. | ||
""" | ||
n, m = geom.shape | ||
assert n == m, f"Hungarian can only match same # of points, got {n} and {m}." | ||
i, j = assignment.hungarian_algorithm(geom.cost_matrix) | ||
|
||
hungarian_out = HungarianOutput(geom=geom, paired_indices=jnp.stack((i, j))) | ||
return jnp.sum(geom.cost_matrix[i, j]) / n, hungarian_out | ||
|
||
|
||
def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float: | ||
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would remove the
or similar. |
||
|
||
Uses :func:`~ott.tools.unreg.hungarian` to solve the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use :func:`hungarian`. |
||
:term:`Kantorovich problem` between two point clouds of the same | ||
size to compute a :term:`Wasserstein distance` estimator. | ||
|
||
Note: | ||
At the moment, only supports point clouds of the same size to be easily | ||
cast as an optimal matching problem. | ||
|
||
Args: | ||
x: ``[n,d]`` point cloud | ||
y: ``[n,d]`` point cloud of the same size | ||
p: order of the Wasserstein distance, non-negative float. | ||
|
||
Returns: | ||
The p-Wasserstein distance between these point clouds.hungarian | ||
""" | ||
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p)) | ||
cost, _ = hungarian(geom) | ||
return cost ** 1. / p |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Tuple | ||
|
||
import pytest | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from ott.geometry import costs, pointcloud | ||
from ott.solvers import linear | ||
from ott.tools import unreg | ||
|
||
|
||
class TestHungarian: | ||
|
||
@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None]) | ||
def test_matches_sink(self, rng: jax.Array, cost_fn: Optional[costs.CostFn]): | ||
n, m, dim = 12, 12, 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would remove |
||
rng1, rng2 = jax.random.split(rng, 2) | ||
x, y = gen_data(rng1, n, m, dim) | ||
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005) | ||
cost_hung, out_hung = unreg.hungarian(geom) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
out_sink = linear.solve(geom) | ||
np.testing.assert_allclose( | ||
out_sink.primal_cost, cost_hung, rtol=1e-3, atol=1e-3 | ||
) | ||
np.testing.assert_allclose( | ||
out_sink.matrix, out_hung.matrix.todense(), rtol=1e-3, atol=1e-3 | ||
) | ||
|
||
@pytest.mark.parametrize("p", [1.3, 2.3]) | ||
def test_wass(self, rng: jax.Array, p: float): | ||
n, m, dim = 12, 12, 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would remove |
||
rng1, rng2 = jax.random.split(rng, 2) | ||
x, y = gen_data(rng1, n, m, dim) | ||
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p=p)) | ||
cost_hung, _ = unreg.hungarian(geom) | ||
w_p = unreg.wassdis_p(x, y, p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
np.testing.assert_allclose(w_p, cost_hung ** 1. / p, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def gen_data(rng: jax.Array, n: int, m: int, | ||
dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
rngs = jax.random.split(rng, 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Split into 2 keys please. |
||
x = jax.random.uniform(rngs[0], (n, dim)) | ||
y = jax.random.uniform(rngs[1], (m, dim)) | ||
return x, y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all = ["HungarianOutput", "hungarian", "wassdis_p"]