基于扩散模型生成图片的算法DDPM于2020年被提出。2021年OpenAI发表的论文《Improved Denoising Diffusion Probabilistic Models》,对DDPM算法进行改进。
Improved DDPM
改进
噪声Schedule采用余弦函数
原始DDPM算法,使用公式xt=αˉtx0+1−αˉtϵx_t=sqrt{bar{alpha}_t}x_0+sqrt{1-bar{alpha}_t}epsilon计算第tt步正向扩散后带噪声的图片xtx_t,公式中的αˉt=∏i=1tαibar{alpha}_t=prod_{i=1}^{t}{alpha_i},βt=1−αtbeta_t=1-alpha_t,即βt=1−αˉtαˉt−1beta_t=1-frac{bar{alpha}_t}{bar{alpha}_{t-1}},βtbeta_t表示每步噪声的大小,原始DDPM算法,令βtbeta_t随tt线性增长,从β1=10−4beta_1=10^{-4}增长到βT=0.02beta_T=0.02,改进的DDPM算法,令αˉtbar{alpha}_t随tt的变化采用余弦函数:
f(t)=cos(t/T+s1+s⋅π2)2αˉt=f(t)f(0)βt=clip(1−αˉtαˉt−1,0.999)begin{aligned}
f(t)&=cos{left(frac{t/T+s}{1+s}cdotfrac{pi}{2}right)^2}
bar{alpha}_t&=frac{f(t)}{f(0)}
beta_t&=text{clip}(1-frac{bar{alpha}_t}{bar{alpha}_{t-1}},0.999)
end{aligned}
两种噪声Schedule下,αˉtbar{alpha}_t随tt的变化曲线如图1所示,相比线性函数,余弦函数的αˉtbar{alpha}_t下降相对较平缓,因而βtbeta_t相对较小,加噪相对较慢,不会过快地对原始图片加入过多的噪声。
Improved DDPM的代码开源,代码地址是:github.com/openai/impr…,其深度学习框架采用PyTorch。训练和采样分别执行以下脚本:
# 训练脚本
python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
# 采样脚本
python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS
其中,MODEL_FLAGS、DIFFUSION_FLAGS、TRAIN_FLAGS分别表示模型结构(U-Net)、扩散过程和训练的配置,而基线模型(DDPM)的配置如下:
MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"
如果采用余弦函数作为噪声Schedule,可以将DIFFUSION_FLAGS中的noise_schedule设置为cosine,而相应的计算噪声βtbeta_t的代码在improved_diffusion/gaussian_diffusion.py中,如下:
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
# 入参alpha_bar即公式推导中的f(t)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
对方差进行学习
原始DDPM算法,满足高斯分布的概率密度函数pθ(xt−1∣xt)p_theta(x_{t-1}|x_t)中的方差Σθ(xt,t)Sigma_theta(x_t,t)被固定为常量σt2Isigma_t^2mathbf{I},其中σt2sigma_t^2直接取值βtbeta_t。改进的DDPM算法通过模型对Σθ(xt,t)Sigma_theta(x_t,t)也进行了学习,这样可以使用更少的步数、且生成更高质量的图片。
具体如何学习Σθ(xt,t)Sigma_theta(x_t,t)呢?DDPM的论文已推导Σθ(xt,t)Sigma_theta(x_t,t)的取值在βtbeta_t和β~ttilde{beta}_t之间,而β~t:=1−αˉt−11−αˉtβttilde{beta}_t:=frac{1-bar{alpha}_{t-1}}{1-bar{alpha}_t}beta_t,图2表示了β~t/βttilde{beta}_t/beta_t和tt之间的关系,从中可见,除了t=0t=0外,其他tt取值下,βtbeta_t和β~ttilde{beta}_t近似相等,所以原始DDPM算法将方差直接取值βtbeta_t,而改进的DDPM算法设计了中间向量vv,由模型预测vv,并将Σθ(xt,t)Sigma_theta(x_t,t)表示为:
Σθ(xt,t)=exp(vlogβt+(1−v)logβ~t)Sigma_theta(x_t,t)=exp(vlog{beta_t}+(1-v)log{tilde{beta}_t})
原始DDPM算法的损失函数为:
Lsimple:=Et∼[1,T],x0∼q(x0),ϵ∼N(0,I)[∥ϵ−ϵθ(xt,t)∥2]L_{text{simple}}:=E_{tsim[1,T],x_0sim q(x_0),epsilonsimmathcal{N}(0,mathbf{I})}[parallelepsilon-epsilon_theta(x_t,t)parallel^2]
其中并不包含Σθ(xt,t)Sigma_theta(x_t,t),因此改进的DDPM算法设计了新的损失函数为:
Lhybrid=Lsimple+λLvlbL_text{hybrid}=L_text{simple}+lambda L_text{vlb}
而LvlbL_text{vlb}的定义如下:
Lvlb:=L0+L1+⋯+LT−1+LTwhere L0:=−logpθ(x0∣x1)Lt−1:=DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)) for 2≤t≤TLT:=DKL(q(xT∣x0)∥p(xT))begin{aligned}
L_{text{vlb}}&:=L_0+L_1+cdots+L_{T-1}+L_T
text{where }L_0&:=-log{p_theta(x_0|x_1)}
L_{t-1}&:=D_{text{KL}}(q(x_{t-1}|x_{t},x_0)parallel p_theta(x_{t-1}|x_{t}))text{ for } 2le tle T
L_T&:=D_{text{KL}}(q(x_T|x_0)parallel p(x_T))
end{aligned}
损失函数中λlambda被设置为0.001,以减少LvlbL_text{vlb}对LsimpleL_text{simple}的影响。另外梯度更新时,LvlbL_text{vlb}部分不更新涉及μθ(xt,t)mu_theta(x_t,t)的参数,只更新涉及Σθ(xt,t)Sigma_theta(x_t,t)的参数。
如果需要对方差进行学习,可以将训练脚本参数MODEL_FLAGS中的learn_sigma设置为True,这样,模型结构(U-Net)的输出维度增加,增加的部分作为vv,相关代码在improved_diffusion/script_util.py的create_model方法中,如下:
return UNetModel(
in_channels=3,
model_channels=num_channels,
# 模型结构(U-Net)的输出维度增加,增加的部分作为v
out_channels=(3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
)
同时,反向扩散生成图片时,由模型预测方差的代码在improved_diffusion/gaussian_diffusion.py的p_mean_variance方法中,如下:
# 模型预测
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
# 按列拆分模型输出,后半部分作为v
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
# 按公式exp(v·log(β_t)+(1-v)·log(β_t))计算方差
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
而代码如何调整损失函数的计算,在下一节介绍。
训练时采用Importance Sampling
论文进一步发现,将损失函数替换为LvlbL_text{vlb}或LhybridL_text{hybrid}后,损失函数取值随着训练迭代变化的曲线比较波动,不易收敛,如图3所示。LvlbL_text{vlb}包含多项,每项对应一个步数tt,且每项的取值量纲差别较大,如图4所示,而原始DDPM算法训练时随机采样步数tt,因此论文认为每次训练迭代随机采样步数tt、并进而计算量纲差别大的LtL_t导致LvlbL_text{vlb}或LhybridL_text{hybrid}的波动。论文通过训练时采用Importance Sampling来解决上述波动问题。Importance Sampling中,LvlbL_text{vlb}可表示为以下公式:
Lvlb=Et∼pt[Ltpt],where pt∝E[Lt2] and ∑pt=1L_text{vlb}=E_{tsim p_t}left[frac{L_t}{p_t}right],text{where }p_tproptosqrt{Eleft[L_t^2right]}text{ and }sum{p_t}=1
E[Lt2]Eleft[L_t^2right]无法提前求解,且在训练过程中会变化,因此,论文对LvlbL_text{vlb}的每一项LtL_t保留最新的10个取值,并在训练过程中动态更新。训练初期,仍是随机采样步数tt,直至所有的LtL_t均有10个取值,再采用Importance Sampling。从图3可以看出,经过Importance Sampling后的LvlbL_text{vlb}随着训练迭代变化的曲线比较平滑,且损失最小。
如果需要使用Importance Sampling后的LvlbL_text{vlb}作为损失函数,可以将训练脚本参数DIFFUSION_FLAGS中的use_kl设置为True、TRAIN_FLAGS中的schedule_sampler设置为loss-second-moment。Importance Sampling的相关代码在improved_diffusion/resample.py的LossSecondMomentResampler类中,如下:
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
# 训练初期,仍是随机采样步数
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
# 根据L_t的历史值计算p_t
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
使用LvlbL_text{vlb}作为损失函数的相关代码在improved_diffusion/gaussian_diffusion.py的_vb_terms_bpd方法中,如下:
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
效果
针对上述改进,论文在ImageNet 64×64和CIFAR-10这两个数据集上分别进行消融实验以验证各改进的有效性,如图5和图6所示。
其中,有效性指标使用了NLL和FID。NLL(Negative Log Likelihood)等价于损失函数LvlbL_text{vlb},NLL越小,说明生成图像与真实图像的分布越接近。FID(Fréchet Inception Distance)是另一种用于图像生成质量评估的指标,它可以评估生成图像与真实图像之间的相似度。FID指标的计算方法是使用Inception-v3模型对生成图像和真实图像进行特征提取,并计算两个特征分布之间的Fréchet距离。FID越小,说明生成图像与真实图像越相似。从实验结果上看,噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用Importance Sampling后的LvlbL_text{vlb},NLL最低,但FID较高,而噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用LhybridL_text{hybrid},在NLL、FID上都能取得较小的值。
另外,论文还和其他基于似然预估的模型进行了对比实验,如图7所示。优化后的DDPM虽然在NLL和FID上还不是SOTA,但相对也是较优的效果,仅次于基于Transformer的网络结构。
加速
DDPM在生成图片时需要从完全噪声开始执行多步降噪操作,而每步操作均需要将当前步带噪声的图片作为输入由模型预测噪声,导致生成图片需要较多的步骤和计算量。论文也采用了《Denoising Diffusion Implicit Models》提出的采样方法——DDIM,减少步数。