pip install -r requirements.txt
- 下载 CIFAR-10 数据集 到
datasets/
文件夹,保持目录结构为datasets/cifar-10-batches-py
- 训练模型:
python ddpm/train.py
- 采样图像:
python ddpm/sample.py
- 计算IS与FID:
python ddpm/metrics.py
DDPM 是一个基于马尔可夫链的生成模型。如上图所示,它通过一个前向过程(Forward Process)逐步向数据中添加高斯噪声,最终得到纯噪声,然后通过一个反向过程(Reverse Process)从噪声中逐步恢复出数据。
前向加噪的过程很简单,而我们需要根据前向过程的公式,推导得出模型反向去噪过程的公式,作为我们训练模型的目标。最后训练的时候,我们只需要反向去噪过程的公式就可以训练模型。
在前向过程中,我们需要向
其中,
对应于原文公式(2)(文章中使用分布的形式来表示):
令
其中,
对应于原文公式(4):
当
在去噪过程中,我们需要根据当前时刻的
其中,
这也就是原文的公式(6)~(7)(将
其中,
尽管 DDPM 的推导过程比较复杂,但实际上我们推理时想要采样得到图片,所要用到的公式只有逆向去噪过程中得到
我们训练模型的目标就是让模型预测的噪声
我们使用 CIFAR-10 数据集进行训练。首先下载数据到 datasets/
文件夹,并保持目录结构为 datasets/cifar-10-batches-py
。
接着,我们定义一个数据转换器来处理数据,对于训练集,我们进行随机水平翻转,并缩放到
train_data_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 将数据缩放到[0, 1]范围
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 将数据缩放到[-1, 1]范围
])
test_data_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
接着我们使用 torchvision.datasets.CIFAR10
来加载数据,并创建 DataLoader
来处理数据,详见 ddpm/dataloader.py
。
完成以上步骤后,我们可以写一个函数来可视化数据。将转换后的数据,从
def show_tensor_image(image):
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2), # 将数据从[-1, 1]缩放到[0, 1]范围
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # 将通道顺序从CHW改为HWC
transforms.Lambda(lambda t: t * 255.), # 将数据缩放到[0, 255]范围
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), # 将数据转换为uint8类型
transforms.ToPILImage(), # 将数据转换为PIL图像格式
])
# 如果图像是批次数据,则取第一个图像
if len(image.shape) == 4:
image = image[0, :, :, :]
return reverse_transforms(image)
我们可以通过 python ddpm/dataloader.py
来检查数据是否正确加载。
DDPM 模型需要的输入包括噪声图像
首先,我们定义一个时间嵌入层,它负责将时间信息注入到特征中,将时间步
其中,
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
# 将维度分成两半,分别用于sin和cos
half_dim = self.dim // 2
# 计算不同频率的指数衰减
embeddings = math.log(10000) / (half_dim - 1)
# 生成频率序列
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# 将时间步与频率序列相乘
embeddings = time[:, None] * embeddings[None, :]
# 拼接sin和cos得到最终的嵌入向量
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
接着,我们定义一个 U-Net 的基本模块 Block,包含时间嵌入、上/下采样功能。第一次卷积扩展通道数,然后加入时间嵌入,接着进行第二次卷积,融合特征信息,最后进行上/下采样。这里我们采用简化版的 U-Net,没有使用原论文中带有注意力机制的模型。
class Block(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_channels)
if up:
self.conv1 = nn.Conv2d(2 * in_channels, out_channels, kernel_size=3, padding=1) # 由于 U-Net 的残差连接,上采样时会 concat 之前的特征,输入通道数需要翻倍
self.transform = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.transform = nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_channels)
self.bnorm2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x, t):
# 第一次卷积
h = self.bnorm1(self.relu(self.conv1(x)))
# 时间嵌入
time_emb = self.relu(self.time_mlp(t))
# 将时间信息注入特征图
h = h + time_emb[..., None, None]
# 第二次卷积
h = self.bnorm2(self.relu(self.conv2(h)))
# 上采样或下采样
return self.transform(h)
最后,我们将多个 Block 组合起来,形成一个 U-Net 模型。模型首先会按照通道数 3->64->128->256->512->1024 的变化顺序进行下采样,此时通道数逐渐增加,特征图尺寸逐渐减小。然后按照通道数 1024->512->256->128->64->3 的变化顺序(这里是输出通道数,输入通道数因为残差连接会翻倍)进行上采样,此时通道数逐渐减少,特征图尺寸逐渐增加。每一层都会加入时间步信息,最后输出与输入图像尺寸相同的预测噪声。
class SimpleUnet(nn.Module):
def __init__(self):
super().__init__()
image_channels = 3
down_channels = (64, 128, 256, 512, 1024)
up_channels = (1024, 512, 256, 128, 64)
out_dim = 3
time_emb_dim = 32
# 时间嵌入层
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
)
# 输入层、下采样层、上采样层和输出层
self.input = nn.Conv2d(image_channels, down_channels[0], kernel_size=3, padding=1)
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], time_emb_dim) for i in range(len(down_channels) - 1)])
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) for i in range(len(up_channels) - 1)])
self.output = nn.Conv2d(up_channels[-1], out_dim, kernel_size=3, padding=1)
def forward(self, x, time_step):
# 时间步嵌入
t = self.time_embed(time_step)
# 初始卷积
x = self.input(x)
# UNet前向传播:先下采样收集特征,再上采样恢复分辨率
residual_stack = []
for down in self.downs:
x = down(x, t)
residual_stack.append(x)
for up in self.ups:
residual_x = residual_stack.pop()
x = torch.cat((x, residual_x), dim=1)
x = up(x, t)
return self.output(x)
通过执行 python ddpm/unet.py
我们可以打印 U-Net 模型每一层输入和输出的形状,检查模型结构是否正确:
Input shape: torch.Size([1, 3, 32, 32])
Time embedding shape: torch.Size([1, 32])
After input conv shape: torch.Size([1, 64, 32, 32])
Downsampling process:
Down block 1 output shape: torch.Size([1, 128, 16, 16])
Down block 2 output shape: torch.Size([1, 256, 8, 8])
Down block 3 output shape: torch.Size([1, 512, 4, 4])
Down block 4 output shape: torch.Size([1, 1024, 2, 2])
Upsampling process:
Concatenated input shape before up block 1: torch.Size([1, 2048, 2, 2])
Up block 1 output shape: torch.Size([1, 512, 4, 4])
Concatenated input shape before up block 2: torch.Size([1, 1024, 4, 4])
Up block 2 output shape: torch.Size([1, 256, 8, 8])
Concatenated input shape before up block 3: torch.Size([1, 512, 8, 8])
Up block 3 output shape: torch.Size([1, 128, 16, 16])
Concatenated input shape before up block 4: torch.Size([1, 256, 16, 16])
Up block 4 output shape: torch.Size([1, 64, 32, 32])
Final output shape: torch.Size([1, 3, 32, 32])
首先我们需要定义一个噪声调度器,用于控制加噪过程,生成不同时间步的噪声图像。根据上面给出的公式,我们可以用代码对其进行实现。在前向过程中,我们需要定义变量 register_buffer
来定义变量,这样这些变量就会自动与模型参数一起保存和加载。
class NoiseScheduler(nn.Module):
def __init__(self, beta_start=0.0001, beta_end=0.02, num_steps=1000):
super().__init__()
self.beta_start = beta_start
self.beta_end = beta_end
self.num_steps = num_steps
# β_t: 线性噪声调度
self.register_buffer('betas', torch.linspace(beta_start, beta_end, num_steps))
# α_t = 1 - β_t
self.register_buffer('alphas', 1.0 - self.betas)
# α_bar_t = ∏(1-β_i) from i=1 to t
self.register_buffer('alpha_bar', torch.cumprod(self.alphas, dim=0))
# sqrt(α_bar_t)
self.register_buffer('sqrt_alpha_bar', torch.sqrt(self.alpha_bar))
# sqrt(1-α_bar_t)
self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1.0 - self.alpha_bar))
由于我们是对一个 batch 的图像进行训练,而且还需要将这些变量与图像张量进行运算,而目前我们定义的变量都是一维张量,所以需要对公式中的变量的维度进行调整,以适应不同张量的维度。因此,我们在 NoiseScheduler
类中定义了 get
方法,用于获取指定时间步的变量值并调整形状,其中 var
为变量张量,t
为时间步,x_shape
为目标形状。
def get(self, var, t, x_shape):
# 从变量张量中收集指定时间步的值
out = var[t]
# 调整形状为[batch_size, 1, 1, 1],以便进行广播
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
然后我们便可以实现加噪过程,根据公式: get
方法获取。
def add_noise(self, x, t):
# 获取时间步t对应的sqrt(α_bar_t)
sqrt_alpha_bar = self.get(self.sqrt_alpha_bar, t, x.shape)
# 获取时间步t对应的sqrt(1-α_bar_t)
sqrt_one_minus_alpha_bar = self.get(self.sqrt_one_minus_alpha_bar, t, x.shape)
# 从标准正态分布采样噪声 ε ~ N(0,I)
noise = torch.randn_like(x)
# 实现前向扩散过程: x_t = sqrt(α_bar_t) * x_0 + sqrt(1-α_bar_t) * ε
return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * noise, noise
通过执行 python ddpm/diffusion.py
我们可以绘制图像逐步加噪的过程。
最后,我们可以实现完整的训练流程了。其步骤为:
- 随机采样时间步
t
- 对图像添加噪声,获得带噪声的图像和噪声
- 使用模型预测噪声
- 计算预测噪声和真实噪声之间的MSE损失
- 反向传播和优化
通过执行 python ddpm/train.py
我们可以训练模型,并保存模型参数,在 RTX-4090 上训练 200 个 epoch 需要 2 小时左右。
采样过程的思路为,从标准正态分布中采样初始噪声,然后逐步去噪,从 NoiseScheduler
类中定义这些系数:
# α_bar_(t-1)
self.register_buffer('alpha_bar_prev', torch.cat([torch.tensor([1.0]), self.alpha_bar[:-1]], dim=0))
# 1/sqrt(α_t)
self.register_buffer('sqrt_recip_alphas', torch.sqrt(1.0 / self.alphas))
# 1/sqrt(α_bar_t)
self.register_buffer('sqrt_recip_alphas_bar', torch.sqrt(1.0 / self.alpha_bar))
# sqrt(1/α_bar_t - 1)
self.register_buffer('sqrt_recipm1_alphas_bar', torch.sqrt(1.0 / self.alpha_bar - 1))
# 后验分布方差 σ_t^2
self.register_buffer('posterior_var', self.betas * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar))
# 后验分布均值系数1: β_t * sqrt(α_bar_(t-1))/(1-α_bar_t)
self.register_buffer('posterior_mean_coef1', self.betas * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar))
# 后验分布均值系数2: (1-α_bar_(t-1)) * sqrt(α_t)/(1-α_bar_t)
self.register_buffer('posterior_mean_coef2', (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_bar))
采样过程的具体流程为(依照原文公式):
- 从标准正态分布采样初始噪声
$x_T \sim \mathcal{N}(0,I)$ - 从
$t=T$ 到$t=0$ ,不断迭代循环,执行以下步骤: - 根据当前时间步
$t$ 和当前的样本图片$x_t$ ,通过模型计算预测噪声$\epsilon_\theta(x_t,t)$ - 计算
$x_0$ 的预测值:$x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}x_t - \sqrt{\frac{1}{\bar{\alpha}_t}-1}\epsilon _\theta(x_t,t)$ - 计算后验分布均值:
$\mu_\theta(x_t,t) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 + \frac{\sqrt{\alpha _t}(1-\bar{\alpha} _{t-1})}{1-\bar{\alpha}_t}x_t$ - 计算后验分布方差的对数值:
$\log(\sigma_t^2) = \log(\tilde{\beta}_t) = \log(\frac{\beta_t(1-\bar{\alpha} _{t-1})}{1-\bar{\alpha}_t})$ - 如果当前时间步
$t>0$ ,则从后验分布中采样:$x_{t-1} = \mu_\theta(x_t,t) + \sigma_t\epsilon, \epsilon \sim \mathcal{N}(0,I)$ - 如果当前时间步
$t=0$ ,则直接使用均值作为生成结果:$x_0 = \mu_\theta(x_t,t)$
def sample(model, scheduler, num_samples, size, device="cpu"):
model.eval()
with torch.no_grad():
# 从标准正态分布采样初始噪声 x_T ~ N(0,I)
x_t = torch.randn(num_samples, *size).to(device)
# 逐步去噪,从t=T到t=0
for t in tqdm(reversed(range(scheduler.num_steps)), desc="Sampling"):
# 构造时间步batch
t_batch = torch.tensor([t] * num_samples).to(device)
# 获取采样需要的系数
sqrt_recip_alpha_bar = scheduler.get(scheduler.sqrt_recip_alphas_bar, t_batch, x_t.shape)
sqrt_recipm1_alpha_bar = scheduler.get(scheduler.sqrt_recipm1_alphas_bar, t_batch, x_t.shape)
posterior_mean_coef1 = scheduler.get(scheduler.posterior_mean_coef1, t_batch, x_t.shape)
posterior_mean_coef2 = scheduler.get(scheduler.posterior_mean_coef2, t_batch, x_t.shape)
# 预测噪声 ε_θ(x_t,t)
predicted_noise = model(x_t, t_batch)
# 计算x_0的预测值: x_0 = 1/sqrt(α_bar_t) * x_t - sqrt(1/α_bar_t-1) * ε_θ(x_t,t)
_x_0 = sqrt_recip_alpha_bar * x_t - sqrt_recipm1_alpha_bar * predicted_noise
# 计算后验分布均值 μ_θ(x_t,t)
model_mean = posterior_mean_coef1 * _x_0 + posterior_mean_coef2 * x_t
# 计算后验分布方差的对数值 log(σ_t^2)
model_log_var = scheduler.get(torch.log(torch.cat([scheduler.posterior_var[1:2], scheduler.betas[1:]])), t_batch, x_t.shape)
if t > 0:
# t>0时从后验分布采样: x_t-1 = μ_θ(x_t,t) + σ_t * z, z~N(0,I)
noise = torch.randn_like(x_t).to(device)
x_t = model_mean + torch.exp(0.5 * model_log_var) * noise
else:
# t=0时直接使用均值作为生成结果
x_t = model_mean
# 将最终结果裁剪到[-1,1]范围
x_0 = torch.clamp(x_t, -1.0, 1.0)
return x_0
通过执行 python ddpm/sample.py
我们可以加载训练好的模型生成图像,并保存为图片。
Inception Score (IS) 和 Fréchet Inception Distance (FID) 是评估生成图像质量的两个重要指标,我们使用 IS 分数和 FID 分数来评估生成图像的质量。
首先加载预训练好的 Inception 模型,并获取真实图像和生成图像的特征。
class InceptionStatistics:
def __init__(self, device='cuda'):
self.device = device
# 加载预训练的Inception v3模型
self.model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
self.model.fc = nn.Identity() # 移除最后的全连接层
self.model = self.model.to(device)
self.model.eval()
# 设置图像预处理
self.preprocess = transforms.Compose([
transforms.Resize(299),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@torch.no_grad()
def get_features(self, images):
"""获取Inception特征"""
features = []
probs = []
# 将图像处理为299x299大小
images = self.preprocess(images)
# 批量处理图像
dataset = TensorDataset(images)
dataloader = DataLoader(dataset, batch_size=32)
for (batch,) in tqdm(dataloader):
batch = batch.to(self.device)
# 获取特征和logits
feature = self.model(batch)
prob = F.softmax(feature, dim=1)
features.append(feature.cpu().numpy())
probs.append(prob.cpu().numpy())
features = np.concatenate(features, axis=0)
probs = np.concatenate(probs, axis=0)
return features, probs
IS 分数通过预训练的 Inception v3 网络评估生成图像的质量和多样性。计算公式为:
其中:
-
$p_g$ 是生成器的分布 -
$p(y|x)$ 是 Inception 模型对图像 x 的类别预测概率 -
$p(y)$ 是所有生成图像的平均类别分布 - KL 是 KL 散度,用于衡量两个概率分布之间的差距,计算公式为
$KL(p|q) = \sum_{i=1}^{n} p(i) \log \frac{p(i)}{q(i)}$
IS 分数越高说明:
- 每张生成图像的类别预测越清晰(质量好)
- 不同图像的类别分布越分散(多样性高)
具体步骤为:
- 将所有图像分成 batch
- 对每组计算:
- 计算边缘分布
$p(y)$ ,即对当前 batch 的$p(y|x)$ 取平均 - 计算 KL 散度
- 取指数
- 计算边缘分布
- 返回所有组得分的均值和标准差
def calculate_inception_score(probs, splits=10):
# 存储每个split的IS分数
scores = []
# 计算每个split的大小
split_size = probs.shape[0] // splits
# 对每个split进行计算
for i in tqdm(range(splits)):
# 获取当前split的概率分布
part = probs[i * split_size:(i + 1) * split_size]
# 计算KL散度: KL(p(y|x) || p(y))
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
# 对每个样本的KL散度求平均
kl = np.mean(np.sum(kl, axis=1))
# 计算exp(KL)并添加到scores列表
scores.append(np.exp(kl))
# 返回所有split的IS分数的均值和标准差
return np.mean(scores), np.std(scores)
FID 分数通过比较真实图像和生成图像在 Inception 特征空间的分布来评估生成质量。计算公式为:
其中:
-
$\mu_r, \mu_g$ 分别是真实图像和生成图像特征的均值 -
$\Sigma_r, \Sigma_g$ 分别是真实图像和生成图像特征的协方差矩阵 -
$Tr$ 表示矩阵的迹
FID 分数越低说明生成图像的特征分布越接近真实图像分布,生成质量越好。
具体步骤为:
- 分别对真实图像和生成图像:
- 通过 Inception 模型提取特征
- 计算特征的均值向量和协方差矩阵
- 计算均值向量之间的欧氏距离
- 计算协方差矩阵的平方根项
- 计算最终的 FID 分数
def calculate_fid(real_features, fake_features):
# 计算真实图像和生成图像特征的均值向量和协方差矩阵
mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
# 计算均值向量之间的欧几里得距离的平方
ssdiff = np.sum((mu1 - mu2) ** 2)
# 计算协方差矩阵的平方根项:(Σ_r Σ_f)^(1/2)
covmean = linalg.sqrtm(sigma1.dot(sigma2)) # 耗时较长
# 如果结果包含复数,取其实部
if np.iscomplexobj(covmean):
covmean = covmean.real
# 计算最终的FID分数
fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
return fid
通过执行 python ddpm/metrics.py
我们可以计算 IS 分数和 FID 分数,我们使用训练好的 diffusion 模型生成 10000 张图片,并用 CIFAR-10 数据集作为真实图像数据集,来计算 IS 分数和 FID 分数。
虽然不能通过单张图片来计算 IS 和 FID 分数,但是我们可以直观地看一下 IS、FID 分数不同的两个模型所生成的图片有什么样的效果。 以下是本项目所使用的简化版 U-Net 模型(epochs=2000, IS=1.12, FID=41.63),以及 DDPM 原文所使用的带有 Attention 结构的 U-Net 模型(epochs=2000, IS=1.17, FID=14.10)在本项目的训练框架下,分别训练 200 和 2000 个 epochs 后生成的图像。
可以看出,随着训练的进行,以及模型性能的提升,具有更高的 IS 分数和更低的 FID 分数的模型产生图像能够更好地分辨出具体类别,细节也更接近真实 CIFAR-10 数据集中的图像。