Skip to content

Commit

Permalink
restructuring neural models + addition of OT-FM and GENOT (ott-jax#468)
Browse files Browse the repository at this point in the history
* draft of BaseSolver and UnbalancedMixin

* draft of BaseSolver and UnbalancedMixin

* [ci skip] continue flow matching implementation

* [ci skip] continue flow matching implementation

* [ci skip] add neural networks

* [ci skip] add test

* [ci skip] resolve import errors

* [ci skip] MRO not working

* [ci skip] basic test for flow matching passes

* [ci skip] add tests for FM with conditions and conditional OT with FM

* [ci skip] add genot outline

* [ci skip] restructure genot

* [ci skip] restructure genot

* [ci skip] fix transport

* [ci skip] flow matching tests passing

* [ci skip] add more tests genot

* [ci skip] add more tests genot

* [ci skip] add TimeSampler

* [ci skip] add docs for TimeSampler and Flow

* [ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.Array

* [ci skip] change init arguments of GENOT and add docstrings to GENOT

* [ci skip] split nets into base_models and models

* [ci skip] add references

* add tests for learning the rescaling factors

* [ci skip] partially fix rescaling factor learning

* [ci skip] fix rescaling factor learning

* [ci skip] all tests passing but k_samples_per_x in genot

* k_samples_per_x working in GENOT

* [ci skip] changed dataloaders to numpy and dict return

* [ci skip] changed dataloaders to numpy and dict return

* revert jax.Array to jnp.ndarray

* move dataloader from tests to module

* add docstrings to neurcal networks

* [ci skip] adapt type of scale_cost and cost_fn

* [ci skip] clean code

* [ci skip] fix genot tests

* [ci skip] fix otfm tests

* [ci skip] fix otfm tests

* add scale cost to otfm

* incorporate feedback partially

* resolve circular import errors

* resolve a few pre-commit errors

* resolve pre-commit errors

* resolve pre-commit errors

* fix rng bug

* Update pre-commit

* fix import error

* Run linter

* replace rng jnp.ndarray type by jax.array

* replace rng jnp.ndarray type by jax.array

* fix import error

* [ci skip] start to incorporate feedback

* restructure neural module

* fix import errors

* incorporate feedback partially

* make time encoder a layer

* make conditions Optional and minor feedback

* revert faulty jax.array / jnp.ndarray conversions

* make formatting in neural nets nicer

* add description to Velocity Field

* replace time sampler class by function

* add citations

* add more references

* rename keys_model to rng

* fix tests regarding time sampling

* fix typo in tests

* rename neural_vector_field to velocity_field everywhere

* fix OTFlowMatching.transport

* fix rescaling networks

* Update src/ott/neural/flows/flows.py

Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com>

* Update src/ott/neural/flows/flows.py

Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com>

* test for scale_cost

* update test for scale_cost

* fix bug for scale_cost

* fix bug for scale_cost

* jit solve_ode in genot

* incorporate changes partially

* [ci skip] intermediate save

* [ci skip] neural base solver update

* make resamlpemixin a class

* incorporate more changes

* move noise sampling to flows

* fix bug in passing rngs in otfm

* introduce otmatcher in otfm

* [ci skip] split GENOT into GENOTLin and GENOTQuad

* remove dictionaries in OTFM and GENOT classes

* change logic in match_latent_to_data in genot

* change data loaders / data sets

* finish data loader refactoring

* Update linter

* fix bug in _resample_data`

* incorporate more changes

* add docs

* incorporate more changes

* problem with custom type

* fix scale cost bug

* fix bugs

* fux bug in unbalancedness/rescalingMlp

* unify unbalancedness step in GENOT

* change OTDataSet and OTFlowMatching to 4 data loaderes

* Fix bug in the `ConditionalOTDataset`

* Polish docs in the `flows.py`

* Update `OTFM`

* Fix small bugs in `OTFM`

* Polish layers

* Fix typo in citation

* More polish for the docs

* remove print statements and unbalancednesshandler

* remove tests

* make genot training loops more similar to otfm training loop

* adapt tests to the extent possible

* Add weights to sampling

* Start cleaning matchers

* Add conditional sampling + resampling

* Add initial quad matcher

* Improve typing

* Remove `base_solver.py`

* Add TODO

* Update datasets, fix OTFM tests

* Start cleaning GENOT

* Update GENOT

* Remove old GENOTLin/GENOTQuad

* Remove axis swapping

* Remove old todo

* Fix OTFM tests

* Remove `MLPBlock` and `RescalingMLP`

* Add forgotten license

* Remove `__post_init__` from `VF`

* Move cyclical time encoder

* Move more stuff to `utils`

* Remove `samplers.py`

* Rename `cond_dim` -> `condition_dim`

* Nicer formatting

* Fix bug when sampling from the target

* Fix another bug when sampling from the data

* Add initial test for GW

* Remove old GENOT tests

* Remove old dataloaders

* Add more todos

* add docs to dataloader

* expose args in GENOT

* add docs and adapt data_match_fn

* fix linting

* fix data loading and add genot fused tests

* genot tests passing

* adapt docs

* adapt docs

* add error message

* clean docs

* comprise genot tests

* change reference for GENOT

* add missing docstring

* Modify behaviour of `ConditionalLoader`

* Update docstring

* Clean GENOT docs

* Improve VF

* Simplify GENOT test

* Better metadata wrapper in tests

* Fix condition in GENOT test

* Add quad cond dl

* Add conf fused DL

* Polish docs

* Remove conditional loader

* Fix link in the docs

* Improve VF

* Fix GENOT test

* Polish docs

* Remove `uniform_marginals` argument

* Fix undefined variable

* Update `GENOT.transport` docs

* Add `diffrax` to `conf.py`

* Restructure files

* Fix neural init tests import

* Update `docs/`

* Update Monge Gap

* Update MetaOT and NeuralDual

* Update ICNN inits

* Fix links to neural in the docs

* Check for condition dim in VF

* Don't use activation fn in the last layer of VF

* Update assertions

* Try skipping OTFM/GENOT tests temporarily

* Be extra verbose when intalling packages

* Remove `torch` dependency

* Remove `torch` from tests in `pyproject.toml`

* [ci skip] Update docstrings

---------

Co-authored-by: lucaeyring <luca.eyring@googlemail.com>
Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com>
Co-authored-by: Dominik Klein <dominik.klein@helmoltz-munich.de>
  • Loading branch information
5 people authored Apr 3, 2024
1 parent 94b57fa commit 713b9fc
Show file tree
Hide file tree
Showing 101 changed files with 2,730 additions and 949 deletions.
46 changes: 25 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,8 @@ default_stages:
- push
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/google/yapf
rev: v0.40.0
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.10.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: detect-private-key
- id: check-ast
Expand All @@ -37,13 +20,34 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: v0.0.285
rev: v0.2.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.1
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.12.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
- repo: https://github.com/rstcheck/rstcheck
rev: v6.1.2
rev: v6.2.0
hooks:
- id: rstcheck
additional_dependencies: [tomli]
Expand Down
9 changes: 5 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import logging
from datetime import datetime

import ott
from sphinx.util import logging as sphinx_logging

import ott

# -- Project information -----------------------------------------------------
needs_sphinx = "4.0"

Expand Down Expand Up @@ -62,13 +63,13 @@
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"jaxopt": ("https://jaxopt.github.io/stable", None),
"lineax": ("https://docs.kidger.site/lineax/", None),
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None),
"optax": ("https://optax.readthedocs.io/en/latest/", None),
"diffrax": ("https://docs.kidger.site/diffrax/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"pot": ("https://pythonot.github.io/", None),
"jaxopt": ("https://jaxopt.github.io/stable", None),
"optax": ("https://optax.readthedocs.io/en/latest/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
}

Expand Down
15 changes: 15 additions & 0 deletions docs/neural/datasets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ott.neural.datasets
===================
.. module:: ott.neural.datasets
.. currentmodule:: ott.neural

The :mod:`ott.neural.datasets` contains datasets and needed for solving
(conditional) neural optimal transport problems.

Datasets
--------
.. autosummary::
:toctree: _autosummary

datasets.OTData
datasets.OTDataset
30 changes: 3 additions & 27 deletions docs/neural/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
ott.neural
==========
.. module:: ott.neural
.. currentmodule:: ott.neural

In contrast to most methods presented in :mod:`ott.solvers`, which output
vectors or matrices, the goal of the :mod:`ott.neural` module is to parameterize
Expand All @@ -13,29 +12,6 @@ and solvers to estimate such neural networks.
.. toctree::
:maxdepth: 2

solvers

Models
------
.. autosummary::
:toctree: _autosummary

models.ICNN
models.MLP
models.MetaInitializer

Losses
------
.. autosummary::
:toctree: _autosummary

losses.monge_gap
losses.monge_gap_from_samples

Layers
------
.. autosummary::
:toctree: _autosummary

layers.PositiveDense
layers.PosDefPotentials
datasets
methods
networks
37 changes: 37 additions & 0 deletions docs/neural/methods.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
ott.neural.methods
==================
.. module:: ott.neural.methods
.. currentmodule:: ott.neural.methods

Monge Gap
---------
.. autosummary::
:toctree: _autosummary

monge_gap.monge_gap
monge_gap.monge_gap_from_samples
monge_gap.MongeGapEstimator

Neural Dual
-----------
.. autosummary::
:toctree: _autosummary

neuraldual.W2NeuralDual

ott.neural.methods.flows
========================
.. module:: ott.neural.methods.flows
.. currentmodule:: ott.neural.methods.flows

Flows
-----
.. autosummary::
:toctree: _autosummary

otfm.OTFlowMatching
genot.GENOT
dynamics.BaseFlow
dynamics.StraightFlow
dynamics.ConstantNoiseFlow
dynamics.BrownianBridge
33 changes: 33 additions & 0 deletions docs/neural/networks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
ott.neural.networks
===================
.. module:: ott.neural.networks
.. currentmodule:: ott.neural.networks

Networks
--------
.. autosummary::
:toctree: _autosummary

icnn.ICNN
velocity_field.VelocityField
potentials.BasePotential
potentials.PotentialMLP
potentials.PotentialTrainState


ott.neural.networks.layers
==========================
.. module:: ott.neural.networks.layers
.. currentmodule:: ott.neural.networks.layers

Layers
------
.. autosummary::
:toctree: _autosummary

conjugate.FenchelConjugateSolver
conjugate.FenchelConjugateLBFGS
conjugate.ConjugateResults
posdef.PositiveDense
posdef.PosDefPotentials
time_encoder.cyclical_time_encoder
28 changes: 0 additions & 28 deletions docs/neural/solvers.rst

This file was deleted.

50 changes: 50 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,56 @@ @misc{huguet:2023
year = {2023},
}

@misc{eyring:23,
author = {Eyring, Luca and Klein, Dominik and Uscidda, Théo and Palla, Giovanni and Kilbertus, Niki and Akata, Zeynep and Theis, Fabian},
doi = {10.48550/arXiv.2311.15100},
eprint = {2311.15100},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title = {Unbalancedness in Neural Monge Maps Improves Unpaired Domain Translation},
year = {2023},
}

@misc{klein_uscidda:23,
author = {Klein, Dominik and Uscidda, Théo and Theis, Fabian and Cuturi, Marco},
doi = {10.48550/arXiv.2310.09254},
eprint = {2310.09254},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title = {Entropic (Gromov) Wasserstein Flow Matching with GENOT},
year = {2023},
}

@misc{lipman:22,
author = {Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt},
doi = {10.48550/arXiv.2210.02747},
eprint = {2210.02747},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title = {Flow matching for generative modeling},
year = {2022},
}

@misc{tong:23,
author = {Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua},
doi = {10.48550/arXiv.2302.00482},
eprint = {2302.00482},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title = {Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport},
year = {2023},
}

@misc{pooladian:23,
author = {Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky},
doi = {10.48550/arXiv.2304.14772},
eprint = {2304.14772},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title = {Multisample flow matching: Straightening flows with minibatch couplings},
year = {2023},
}

@article{iacono:17,
author = {Iacono, Roberto and Boyd, John P.},
url = {https://doi.org/10.1007/s10444-017-9530-3},
Expand Down
11 changes: 11 additions & 0 deletions docs/solvers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ Wasserstein Solver
:toctree: _autosummary

was_solver.WassersteinSolver

Utilities
---------
.. autosummary::
:toctree: _autosummary

utils.match_linear
utils.match_quadratic
utils.sample_joint
utils.sample_conditional
utils.uniform_sampler
1 change: 1 addition & 0 deletions docs/spelling/misc.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Eulerian
Utils
alg
arg
args
Expand Down
3 changes: 3 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ McCann
Monge
Moreau
SGD
Schrödinger
Schur
Seidel
Sinkhorn
Expand All @@ -46,6 +47,7 @@ chromatin
collinear
covariance
covariances
dataclass
dataloaders
dataset
datasets
Expand Down Expand Up @@ -110,6 +112,7 @@ preprocess
preprocessing
proteome
prox
pytree
quantile
quantiles
quantizes
Expand Down
Loading

0 comments on commit 713b9fc

Please sign in to comment.