Skip to content

Simple reimplementation of Denoising Diffusion Probabilistic Model (https://arxiv.org/abs/2006.11239) paper in PyTorch

Notifications You must be signed in to change notification settings

aleksandrinvictor/minDDPM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Denoising Diffusion Probabilistic Model

This repo contains simple reimplementation of Denoising Diffusion Probabilistic Model paper: Ho et al

Data

Following [1] we used Fashion MNIST dataset.

Setup

Pip

pip install -r requirements.txt

Docker

  1. docker build -t ddpm .
  2. docker run -it ddpm

Training

python src/train.py

There is a couple of settings you may want to specify:

  • --batch_size - set depending on your gpu memory available
  • --num_epochs - num epoch to train the model
  • --diffusion_timesteps - how many diffusion steps to make

Inference

Load checkpoint (check Releases).

Run the following code:

from src.diffusion import GaussianDiffusion, linear_beta_schedule
from src.unet import Unet
import torch

import matplotlib.pyplot as plt

unet = Unet(channels=1, dim_mults=(1, 2, 4), dim=28)
checkpoint = torch.load("<checkpoint-path>")
unet.load_state_dict(checkpoint["model_state_dict"])
unet.to("cuda:1")

timesteps = 300
diffusion = GaussianDiffusion(noise_schedule=linear_beta_schedule, timesteps=timesteps)

result = diffusion.sample(model=unet, image_size=28, batch_size=64, channels=1)

image_index = 8
image = (res[-1][image_index] + 1) * 0.5
plt.imshow(image.reshape(28, 28, 1), cmap="gray")

You can also use DDIM sampling [4]. To do that, modify the example in the following way:

from src.diffusion import SamplingMethod

result = diffusion.sample(
    model=unet,
    image_size=28,
    batch_size=64,
    channels=1,
    sampling_method=SamplingMethod.DDIM
)

Results

Fashion MNIST dataset samples

Generated samples

References

[1] The Annotated Diffusion Model.

[2] Denoising Diffusion Probabilistic Model, in Pytorch.

[3] Denoising Diffusion Probabilistic Models (DDPM).

[4] Denoising Diffusion Implicit Models (DDIM).

About

Simple reimplementation of Denoising Diffusion Probabilistic Model (https://arxiv.org/abs/2006.11239) paper in PyTorch

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published