Implementation of paper Multimodal Adversarially Learned Inference with Factorized Discriminators arxiv
create a new environment and install following packages
conda install python=3.8 pytorch=1.7 cudatoolkit=10.2 torchvision=0.8 torchaudio=0.7 tensorboard \
scipy matplotlib nltk=3.6 gensim=3.8 scikit-image=0.18 -c pytorch
python train_mmali_multimnist.py --max_iter 250000 \
--style_dim 10 \
--latent_dim 20 \
--lambda_unimodal 0.1 \
--n_modalities 4 \
--name jsd2_mod4
python train_mmali_mnist_svhn.py --max_iter 250000 \
--style_dim 10 \
--latent_dim 20
--use_all \
--lambda_unimodal 0.1 \
--lambda_x_rec 0.05 \
--lambda_s_rec 0.05 \
--lambda_c_rec 0.05 \
--joint_rec \
--name exp_mnist_svhn
dataset will be downloaded to /tmp/data and the results will be saved to /tmp/exps
- Download CUB dataset
- Download the preprocessed char-CNN-RNN text embeddings for birds (From Joint-GAN)
- run
python make_cub_img_ft.py
to extract image embeddings using ResNet101
python train_mmali_cap_img.py --max_iter 250000 \
--style_dim 32 \
--latent_dim 64
--lambda_unimodal 1.0 \
--lambda_x_rec 1.0 \
--name exp_cap_img
Use pretrained model to decode sentence features