This repo contains a PyTorch implementation for the paper Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes.
We provide official code repo for JAX implementation in riemannian-diffusion-mixture.
- Simple design of the generative process as a mixture of Riemannian bridge processes, which does not require heat kernel estimation as previous denoising approach.
- Geometrical interpretation for the mixture process as the weighted mean of tangent directions on manifolds
- Scales to higher dimensions with significantly faster training compared to previous diffusion models.
Create an environment with Python 3.9.0, and Pytorch 2.0.0. Install requirements with the following command:
pip install -r requirements.txt
conda install -c conda-forge cartopy python-kaleido
Following manifolds are supported in this repo:
- Euclidean
- Hypersphere
- Torus
- Hyperboloid
- Triangular mesh
- Special orthogonal group
To implement new manifolds, add python files that define the geometry of the manifold in /geomstats/geometry
.
Please refer to geomstats/geometry for examples.
This repo supports experiments on the following datasets:
- Protein datasets:
General
,Glycine
,Proline
, andPre-Pro
, andRNA
. - High-dimensional tori
Please refer to riemannian-diffusion-mixture for running expreiments on earth and climate science datasets
, triangular mesh datasets
, and hyperboloid datasets
.
For experiment on Protein datasets, create .tsv file in /data/top500
directory with the following command:
cd data/top500
bash batch_download.sh -f list_file.txt -p
python get_torsion_angle.py
For experiment on RNA dataset, create .tsv file in /data/rna
directory with the following command:
cd data/rna
bash batch_download.sh -f list_file.txt -p
python get_torsion_angles.py
The configurations are provided in the config/
directory in YAML
format.
CUDA_VISIBLE_DEVICES=0 python main.py -m \
experiment=<exp> \
seed=0,1,2,3,4 \
n_jobs=5 \
where <exp>
is one of the experiments in config/experiment/*.yaml
For example,
CUDA_VISIBLE_DEVICES=0 python main.py -m \
experiment=rna \
seed=0,1,2,3,4 \
n_jobs=5 \
To run experiments on high-dimensional tori, use experiment=htori
with n=$DIM
where $DIM
denotes the dimesion of the tori.
If you found the provided code with our paper useful in your work, we kindly request that you cite our work.
@inproceedings{jo2024riemannian,
author = {Jaehyeong Jo and
Sung Ju Hwang},
title = {Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes},
booktitle = {International Conference on Machine Learning},
year = {2024},
}
Our code builds upon geomstats. We thank Riemannian Score-Based Generative Modelling and Riemmanian Flow Matching for their works.