Skip to content

Official implement for ParamReL: Learning Parameter Space Representation via Progressively Encoding Bayesian Flow Networks

Notifications You must be signed in to change notification settings

amasawa/ParamReL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 

Repository files navigation

ParamReL: Learning Parameter Space Representation via Progressively Encoding Bayesian Flow Networks

This is the official code release forParamReL by Zhangkai Wu, Xuhui Fan, Jin Li, Zhilin Zhao, Hui Chen, Longbing Cao.

Setup

pip install accelerate matplotlib omegaconf rich neptune

Train

For Discrete Data

# train BFN based model
accelerate launch --num_processes 2 --multi_gpu trainBFN.py config_file=BFNconfigs/mnist_discrete.yaml
accelerate launch trainBFN.py config_file=BFNconfigs/mnist_discrete.yaml
# train infoBFN based model
accelerate launch --num_processes 2 --multi_gpu trainInfoBFN.py config_file=infoBFNconfigs/mnist_infoBFN.yaml
accelerate launch trainInfoBFN.py config_file=infoBFNconfigs/mnist_infoBFN.yaml

For Continous Data

# train BFN based model on Cifar10
accelerate launch trainBFN.py config_file=BFNconfigs/cifar10_continuous_256bins.yaml

# train infoBFN based model on Cifar10
accelerate launch trainInfoBFN.py config_file=infoBFNconfigs/cifar10_continuous_256bins.yaml

# train infoBFN based model on Cifar10
accelerate launch trainInfoBFN.py config_file=infoBFNconfigs/celeba_continuous_256bins.yaml
accelerate launch --num_processes 2 --multi_gpu trainInfoBFN.py config_file=infoBFNconfigs/celeba_continuous_256bins.yaml


accelerate launch --num_processes 2 --multi_gpu trainInfoBFN.py config_file=infoBFNconfigs/celeba_continuous_256bins.yaml

Representation Learning Test

Sample and CleanFID

# extract raw data
accelerate launch extract.py config_file=infoBFNconfigs/extract.yaml
# sampling from trained model
accelerate launch gen.py config_file=./infoBFNconfigs/celeba_continuous_Nobin.yaml
# cleanfid
python fid.py 


python sample.py seed=1 config_file=./infoBFNconfigs/celeba_continuous_256bins.yaml load_model=./nciCkpts/celeba1.pt samples_shape="[4, 64, 64, 3]" n_steps=100 a_dim=32 save_file=./celebaCon100.pt

python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./celebadis10000.pt')).savefig('celebadis10000.png')"

latent quality

Disentanglement

Latent Traversal

About

Official implement for ParamReL: Learning Parameter Space Representation via Progressively Encoding Bayesian Flow Networks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published