Unofficial PyTorch Implementation of Novel View Synthesis with Diffusion Models.
As the JAX code given by the authors are not runnable, we fixed the original code to runnable JAX code, while following the authors intend described in the paper. Please compare Author's code and our fixed version to see the changes we made.
The PyTorch implementation is in xunet.py. Feel free to put up an issue if the implementation is not consistent with the original JAX code.
Visit SRN repository and download chairs_train.zip
and cars_train.zip
and extract the downloaded files in /data/
. Here we use 90% of the training data for training and
10% as the validation set.
We include pickle file that contains available view-png files per object.
python train.py
To continue training,
python train.py --transfer ./results/shapenet_SRN_car/1235678
python sample.py --model trained_model.pt --target ./data/SRN/cars_train/a4d535e1b1d3c153ff23af07d9064736
We set the diffusion steps to be 256 during sampling procedure, which takes around 10 minutes per view.
We trained SRN Car dataset for 101K steps for 120 hours. We have tested using 8 x RTX3090 with batch size of 128 and image size of 64 x 64. Due to the memory constraints, we were not able to test the original authors' configuration of image size 128 x 128.
Add trained model- Add evaluation code.
- Get similar performance as reported.
- EMA decay not implemented.