Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New neural OT solver ENOT #503

Merged
merged 26 commits into from
Jun 28, 2024
Merged

New neural OT solver ENOT #503

merged 26 commits into from
Jun 28, 2024

Conversation

nazarblch
Copy link
Contributor

This pull request includes a new method (solver) for Neural OT training, called ENOT. It solves the dual OT problem in d-dimensional Euclidean spaces with the use of specific expectile regularization on the Kantorovich potentials.

The method is similar to W2NeuralDual (already implemented in the repository) but extends its functionality with the ability to use different cost functions, and eliminates the need for computationally intensive fine-tuning with conjugate optimizers.

Our paper arXiv:2403.03777 provides a detailed description of the method and its experimental evaluation. ENOT outperforms previous state-of-the-art approaches on the Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime).

We have also included in this pull request a detailed tutorial, demonstrating how our method works. Please find the source code with ENOT implementation and the tutorial with applications on some 2D datasets by the following links:
https://github.com/nazarblch/ott/blob/dev/src/ott/neural/solvers/expectile_neural_dual.py
https://github.com/nazarblch/ott/blob/dev/docs/tutorials/ENOT.ipynb

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

codecov bot commented Mar 19, 2024

Codecov Report

Attention: Patch coverage is 3.22581% with 180 lines in your changes missing coverage. Please review.

Project coverage is 89.09%. Comparing base (2147cbe) to head (c23102a).
Report is 36 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/neural/methods/expectile_neural_dual.py 0.00% 177 Missing ⚠️
src/ott/neural/networks/potentials.py 66.66% 3 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #503      +/-   ##
==========================================
- Coverage   91.30%   89.09%   -2.21%     
==========================================
  Files          69       70       +1     
  Lines        7242     7427     +185     
  Branches     1018     1051      +33     
==========================================
+ Hits         6612     6617       +5     
- Misses        479      659     +180     
  Partials      151      151              
Files with missing lines Coverage Δ
src/ott/neural/networks/potentials.py 86.95% <66.66%> (-3.21%) ⬇️
src/ott/neural/methods/expectile_neural_dual.py 0.00% <0.00%> (ø)

@michalk8
Copy link
Collaborator

Hi @nazarblch , thanks a lot for the contribution! Will give it a review in the upcoming days!

@michalk8 michalk8 self-requested a review March 20, 2024 18:31
@michalk8 michalk8 added the enhancement New feature or request label Mar 20, 2024
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @nazarblch !! This is just a few very basic comments. I will check the paper + code, and Michal will also start a review

src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

michalk8 commented Apr 9, 2024

Hi @nazarblch , I started looking into this PR, could you please adapt to the new structure as introduced in introduced #468 ?

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nazarblch thanks a lot for the contribution! Still need to finish looking at _loss_fn and _euclidean_loss_fn's implementation + the notebook, but left other comments.

docs/spelling/misc.txt Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/solvers/expectile_neural_dual.py Outdated Show resolved Hide resolved

return loss, (dual_loss, amor_loss, w_dist)

def _euclidean_loss_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly a naive question, but why is DotCost treated differently here - shouldn't it be equivalent (at least in theory the def _loss_fn with the DotCost? Or is this more numerically stable impl.?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, when we use DotCost, we optimize the dual potentials with negative sign and add squared norms in the end to the Wasserstein distance (ref. expression 19 in ENOT paper). I have merged _loss_fn and _euclidean_loss_fn using If operators for the sign of dual potentials and adding the norms to w_dist logging value.

@nazarblch
Copy link
Contributor Author

@nazarblch thanks a lot for the contribution! Still need to finish looking at _loss_fn and _euclidean_loss_fn's implementation + the notebook, but left other comments.

This is a short confirmation that I have read the comments and am working on them.

@michalk8 michalk8 self-requested a review May 28, 2024 13:15
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
conjugate_cost = jnp.dot if corr else cost_fn

def g_cost_conjugate(x: jnp.ndarray) -> jnp.ndarray:
y_hat = jax.lax.stop_gradient(transport(x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than capturing the transport and conjugate_cost, would suggest:

def g_cost_conjugate(x: jnp.ndarray) -> jnp.ndarray:
  if is_bidirectional and not corr:
    y_hat = cost_fn.twist_operator(x, grad_f(x), False)
  else:
    y_hat = grad_f(x)
  y_hat = jax.lax.stop_gradient(y_hat)

  return -g(y_hat) + (jnp.dot(x, y_hat) if corr else cost_fn(x, y_hat))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have inserted the code above. However, there doesn't seem to be much difference.

src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
src/ott/neural/methods/expectile_neural_dual.py Outdated Show resolved Hide resolved
@nazarblch
Copy link
Contributor Author

@michalk8, I have corrected the code and updated the repository. Also we uploaded a new version of the paper into arxiv with additional image to image translation experiments.

@michalk8 michalk8 self-requested a review June 27, 2024 15:10
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your contribution @nazarblch , I just briefly polished some of the docs, LGTM!

@michalk8 michalk8 merged commit 8aa6810 into ott-jax:main Jun 28, 2024
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants