Skip to content

Extension of the Monge Gap to learn conditional optimal transport maps

License

Notifications You must be signed in to change notification settings

AI4SCR/conditional-monge

Repository files navigation

Conditional Monge Gap

CI License: MIT

An extension of the Monge Gap, an approach to estimate transport maps conditionally on arbitrary context vectors. It is based on a two-step training procedure combining an encoder-decoder architecture with an OT estimator. The model is applied to 4i and scRNA-seq datasets.

Installation from PyPI

You can install this package as follows

pip install cmonge

Development setup & installation

The package environment is managed by poetry. The code was tested in Python 3.10.

pip install poetry
git clone git@github.com:AI4SCR/conditional-monge.git
cd cmonge
poetry install -v

If the installation was successful you can run the tests using pytest

poetry shell # activate env
pytest

Data

The preprocessed version of the Sciplex3 and 4i datasets can be downloaded here.

Example usage

You can find example config in configs/conditional-monge-sciplex.yml. To train an autoencoder model:

from cmonge.datasets.conditional_loader import ConditionalDataModule
from cmonge.trainers.ae_trainer import AETrainerModule
from cmonge.utils import load_config


config_path = Path("configs/conditional-monge-sciplex.yml")
config = load_config(config_path)
config.data.ae = True

datamodule = ConditionalDataModule(config.data, config.condition)
ae_trainer = AETrainerModule(config.ae)

ae_trainer.train(datamodule)
ae_trainer.evaluate(datamodule)

To train a conditional monge model:

from cmonge.datasets.conditional_loader import ConditionalDataModule
from cmonge.trainers.conditional_monge_trainer import ConditionalMongeTrainer
from cmonge.utils import load_config

config_path = Path("configs/conditional-monge-sciplex.yml")
logger_path = Path("logs")
config = load_config(config_path)

datamodule = ConditionalDataModule(config.data, config.condition)
trainer = ConditionalMongeTrainer(jobid=1, logger_path=logger_path, config=config.model, datamodule=datamodule)

trainer.train(datamodule)
trainer.evaluate(datamodule)

Older checkpoints loading

If you want to load model weights of older checkpoints (cmonge-{moa, rdkit}-ood or cmonge-{moa, rdkit}-homogeneous), make sure you are on the tag cmonge_checkpoint_loading.

git checkout cmonge_checkpoint_loading

Citation

If you use the package, please cite:

@inproceedings{
  harsanyi2024learning,
  title={Learning Drug Perturbations via Conditional Map Estimators},
  author={Benedek Harsanyi and Marianna Rapsomaniki and Jannis Born},
  booktitle={ICLR 2024 Workshop on Machine Learning for Genomics Explorations},
  year={2024},
  url={https://openreview.net/forum?id=FE7lRuwmfI}
}

About

Extension of the Monge Gap to learn conditional optimal transport maps

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages