Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
restructuring neural models + addition of OT-FM and GENOT (ott-jax#468)
* draft of BaseSolver and UnbalancedMixin * draft of BaseSolver and UnbalancedMixin * [ci skip] continue flow matching implementation * [ci skip] continue flow matching implementation * [ci skip] add neural networks * [ci skip] add test * [ci skip] resolve import errors * [ci skip] MRO not working * [ci skip] basic test for flow matching passes * [ci skip] add tests for FM with conditions and conditional OT with FM * [ci skip] add genot outline * [ci skip] restructure genot * [ci skip] restructure genot * [ci skip] fix transport * [ci skip] flow matching tests passing * [ci skip] add more tests genot * [ci skip] add more tests genot * [ci skip] add TimeSampler * [ci skip] add docs for TimeSampler and Flow * [ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.Array * [ci skip] change init arguments of GENOT and add docstrings to GENOT * [ci skip] split nets into base_models and models * [ci skip] add references * add tests for learning the rescaling factors * [ci skip] partially fix rescaling factor learning * [ci skip] fix rescaling factor learning * [ci skip] all tests passing but k_samples_per_x in genot * k_samples_per_x working in GENOT * [ci skip] changed dataloaders to numpy and dict return * [ci skip] changed dataloaders to numpy and dict return * revert jax.Array to jnp.ndarray * move dataloader from tests to module * add docstrings to neurcal networks * [ci skip] adapt type of scale_cost and cost_fn * [ci skip] clean code * [ci skip] fix genot tests * [ci skip] fix otfm tests * [ci skip] fix otfm tests * add scale cost to otfm * incorporate feedback partially * resolve circular import errors * resolve a few pre-commit errors * resolve pre-commit errors * resolve pre-commit errors * fix rng bug * Update pre-commit * fix import error * Run linter * replace rng jnp.ndarray type by jax.array * replace rng jnp.ndarray type by jax.array * fix import error * [ci skip] start to incorporate feedback * restructure neural module * fix import errors * incorporate feedback partially * make time encoder a layer * make conditions Optional and minor feedback * revert faulty jax.array / jnp.ndarray conversions * make formatting in neural nets nicer * add description to Velocity Field * replace time sampler class by function * add citations * add more references * rename keys_model to rng * fix tests regarding time sampling * fix typo in tests * rename neural_vector_field to velocity_field everywhere * fix OTFlowMatching.transport * fix rescaling networks * Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com> * Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com> * test for scale_cost * update test for scale_cost * fix bug for scale_cost * fix bug for scale_cost * jit solve_ode in genot * incorporate changes partially * [ci skip] intermediate save * [ci skip] neural base solver update * make resamlpemixin a class * incorporate more changes * move noise sampling to flows * fix bug in passing rngs in otfm * introduce otmatcher in otfm * [ci skip] split GENOT into GENOTLin and GENOTQuad * remove dictionaries in OTFM and GENOT classes * change logic in match_latent_to_data in genot * change data loaders / data sets * finish data loader refactoring * Update linter * fix bug in _resample_data` * incorporate more changes * add docs * incorporate more changes * problem with custom type * fix scale cost bug * fix bugs * fux bug in unbalancedness/rescalingMlp * unify unbalancedness step in GENOT * change OTDataSet and OTFlowMatching to 4 data loaderes * Fix bug in the `ConditionalOTDataset` * Polish docs in the `flows.py` * Update `OTFM` * Fix small bugs in `OTFM` * Polish layers * Fix typo in citation * More polish for the docs * remove print statements and unbalancednesshandler * remove tests * make genot training loops more similar to otfm training loop * adapt tests to the extent possible * Add weights to sampling * Start cleaning matchers * Add conditional sampling + resampling * Add initial quad matcher * Improve typing * Remove `base_solver.py` * Add TODO * Update datasets, fix OTFM tests * Start cleaning GENOT * Update GENOT * Remove old GENOTLin/GENOTQuad * Remove axis swapping * Remove old todo * Fix OTFM tests * Remove `MLPBlock` and `RescalingMLP` * Add forgotten license * Remove `__post_init__` from `VF` * Move cyclical time encoder * Move more stuff to `utils` * Remove `samplers.py` * Rename `cond_dim` -> `condition_dim` * Nicer formatting * Fix bug when sampling from the target * Fix another bug when sampling from the data * Add initial test for GW * Remove old GENOT tests * Remove old dataloaders * Add more todos * add docs to dataloader * expose args in GENOT * add docs and adapt data_match_fn * fix linting * fix data loading and add genot fused tests * genot tests passing * adapt docs * adapt docs * add error message * clean docs * comprise genot tests * change reference for GENOT * add missing docstring * Modify behaviour of `ConditionalLoader` * Update docstring * Clean GENOT docs * Improve VF * Simplify GENOT test * Better metadata wrapper in tests * Fix condition in GENOT test * Add quad cond dl * Add conf fused DL * Polish docs * Remove conditional loader * Fix link in the docs * Improve VF * Fix GENOT test * Polish docs * Remove `uniform_marginals` argument * Fix undefined variable * Update `GENOT.transport` docs * Add `diffrax` to `conf.py` * Restructure files * Fix neural init tests import * Update `docs/` * Update Monge Gap * Update MetaOT and NeuralDual * Update ICNN inits * Fix links to neural in the docs * Check for condition dim in VF * Don't use activation fn in the last layer of VF * Update assertions * Try skipping OTFM/GENOT tests temporarily * Be extra verbose when intalling packages * Remove `torch` dependency * Remove `torch` from tests in `pyproject.toml` * [ci skip] Update docstrings --------- Co-authored-by: lucaeyring <luca.eyring@googlemail.com> Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com> Co-authored-by: nvesseron <96598529+nvesseron@users.noreply.github.com> Co-authored-by: Dominik Klein <dominik.klein@helmoltz-munich.de>
- Loading branch information