Skip to content

Commit

Permalink
Fixed DDIMsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
kpandey008 committed Aug 2, 2022
1 parent 2aef3ab commit 3b45564
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
9 changes: 8 additions & 1 deletion main/eval/ddpm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys

p = os.path.join(os.path.abspath("."), 'main')
p = os.path.join(os.path.abspath("."), "main")
sys.path.insert(1, p)

import copy
Expand Down Expand Up @@ -75,8 +75,15 @@ def sample(config):
config.evaluation.chkpt_path,
online_network=online_ddpm,
target_network=target_ddpm,
vae=None,
conditional=False,
pred_steps=n_steps,
eval_mode="sample",
resample_strategy=config.evaluation.resample_strategy,
sample_method=config.evaluation.sample_method,
sample_from=config.evaluation.sample_from,
data_norm=config.data.norm,
strict=False,
)

# Create predict dataset of latents
Expand Down
1 change: 1 addition & 0 deletions main/eval/ddpm/sample_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def sample_cond(config):
data_norm=config_ddpm.data.norm,
temp=config_ddpm.evaluation.temp,
z_cond=config_ddpm.evaluation.z_cond,
strict=True,
)

# Create predict dataset of latents
Expand Down
1 change: 0 additions & 1 deletion main/models/diffusion/spaced_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def get_ddim_mean_cov(

def ddim_sample(self, x_t, cond=None, z_vae=None, checkpoints=[], eta=0.0):
# The sampling process goes here!
print(f"Eta: {eta}")
x = x_t
B, *_ = x_t.shape
sample_dict = {}
Expand Down
12 changes: 10 additions & 2 deletions main/models/diffusion/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,22 @@ def __init__(
# Spaced Diffusion (for spaced re-sampling)
self.spaced_diffusion = None

def forward(self, x, cond=None, z=None, n_steps=None, checkpoints=[]):
def forward(
self,
x,
cond=None,
z=None,
n_steps=None,
checkpoints=[],
resample_type="uniform",
):
sample_nw = (
self.target_network if self.sample_from == "target" else self.online_network
)
# For spaced resampling
if self.resample_strategy == "spaced":
num_steps = n_steps if n_steps is not None else self.online_network.T
indices = space_timesteps(sample_nw.T, num_steps)
indices = space_timesteps(sample_nw.T, num_steps, type=resample_type)
if self.spaced_diffusion is None:
self.spaced_diffusion = SpacedDiffusion(sample_nw, indices).to(x.device)

Expand Down
21 changes: 14 additions & 7 deletions main/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def configure_device(device):
return device


def space_timesteps(num_timesteps, desired_count):
def space_timesteps(num_timesteps, desired_count, type="uniform"):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
Expand All @@ -44,12 +44,19 @@ def space_timesteps(num_timesteps, desired_count):
process to divide up.
:return: a set of diffusion steps from the original process to use.
"""
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return range(0, num_timesteps, i)
raise ValueError(
f"cannot create exactly {desired_count} steps with an integer stride"
)
if type == "uniform":
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return range(0, num_timesteps, i)
raise ValueError(
f"cannot create exactly {desired_count} steps with an integer stride"
)
elif type == "quad":
seq = np.linspace(0, np.sqrt(num_timesteps * 0.8), desired_count) ** 2
seq = [int(s) for s in list(seq)]
return seq
else:
raise NotImplementedError


def get_dataset(name, root, image_size, norm=True, flip=False, **kwargs):
Expand Down

0 comments on commit 3b45564

Please sign in to comment.