-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New neural OT solver ENOT #503
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #503 +/- ##
==========================================
- Coverage 91.30% 89.09% -2.21%
==========================================
Files 69 70 +1
Lines 7242 7427 +185
Branches 1018 1051 +33
==========================================
+ Hits 6612 6617 +5
- Misses 479 659 +180
Partials 151 151
|
Hi @nazarblch , thanks a lot for the contribution! Will give it a review in the upcoming days! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @nazarblch !! This is just a few very basic comments. I will check the paper + code, and Michal will also start a review
Hi @nazarblch , I started looking into this PR, could you please adapt to the new structure as introduced in introduced #468 ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nazarblch thanks a lot for the contribution! Still need to finish looking at _loss_fn
and _euclidean_loss_fn
's implementation + the notebook, but left other comments.
|
||
return loss, (dual_loss, amor_loss, w_dist) | ||
|
||
def _euclidean_loss_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly a naive question, but why is DotCost
treated differently here - shouldn't it be equivalent (at least in theory the def _loss_fn
with the DotCost
? Or is this more numerically stable impl.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory, when we use DotCost
, we optimize the dual potentials with negative sign and add squared norms in the end to the Wasserstein distance (ref. expression 19 in ENOT paper). I have merged _loss_fn
and _euclidean_loss_fn
using If operators for the sign of dual potentials and adding the norms to w_dist
logging value.
This is a short confirmation that I have read the comments and am working on them. |
conjugate_cost = jnp.dot if corr else cost_fn | ||
|
||
def g_cost_conjugate(x: jnp.ndarray) -> jnp.ndarray: | ||
y_hat = jax.lax.stop_gradient(transport(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than capturing the transport
and conjugate_cost
, would suggest:
def g_cost_conjugate(x: jnp.ndarray) -> jnp.ndarray:
if is_bidirectional and not corr:
y_hat = cost_fn.twist_operator(x, grad_f(x), False)
else:
y_hat = grad_f(x)
y_hat = jax.lax.stop_gradient(y_hat)
return -g(y_hat) + (jnp.dot(x, y_hat) if corr else cost_fn(x, y_hat))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have inserted the code above. However, there doesn't seem to be much difference.
@michalk8, I have corrected the code and updated the repository. Also we uploaded a new version of the paper into arxiv with additional image to image translation experiments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your contribution @nazarblch , I just briefly polished some of the docs, LGTM!
This pull request includes a new method (solver) for Neural OT training, called ENOT. It solves the dual OT problem in d-dimensional Euclidean spaces with the use of specific expectile regularization on the Kantorovich potentials.
The method is similar to W2NeuralDual (already implemented in the repository) but extends its functionality with the ability to use different cost functions, and eliminates the need for computationally intensive fine-tuning with conjugate optimizers.
Our paper arXiv:2403.03777 provides a detailed description of the method and its experimental evaluation. ENOT outperforms previous state-of-the-art approaches on the Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime).
We have also included in this pull request a detailed tutorial, demonstrating how our method works. Please find the source code with ENOT implementation and the tutorial with applications on some 2D datasets by the following links:
https://github.com/nazarblch/ott/blob/dev/src/ott/neural/solvers/expectile_neural_dual.py
https://github.com/nazarblch/ott/blob/dev/docs/tutorials/ENOT.ipynb