Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

释放双眼,带上耳机,听听看~!
2021年OpenAI发表的论文《Improved Denoising Diffusion Probabilistic Models》对DDPM算法进行了改进,采用余弦函数的噪声Schedule,代码开源且使用PyTorch框架。本文介绍了其训练和采样脚本,并讨论了改进算法相对于原始DDPM的优势。

基于扩散模型生成图片的算法DDPM于2020年被提出。2021年OpenAI发表的论文《Improved Denoising Diffusion Probabilistic Models》,对DDPM算法进行改进。

Improved DDPM

改进

噪声Schedule采用余弦函数

原始DDPM算法,使用公式xt=αˉtx0+1−αˉtϵx_t=sqrt{bar{alpha}_t}x_0+sqrt{1-bar{alpha}_t}epsilon计算第tt步正向扩散后带噪声的图片xtx_t,公式中的αˉt=∏i=1tαibar{alpha}_t=prod_{i=1}^{t}{alpha_i}βt=1−αtbeta_t=1-alpha_t,即βt=1−αˉtαˉt−1beta_t=1-frac{bar{alpha}_t}{bar{alpha}_{t-1}}βtbeta_t表示每步噪声的大小,原始DDPM算法,令βtbeta_ttt线性增长,从β1=10−4beta_1=10^{-4}增长到βT=0.02beta_T=0.02,改进的DDPM算法,令αˉtbar{alpha}_ttt的变化采用余弦函数:

f(t)=cos⁡(t/T+s1+s⋅π2)2αˉt=f(t)f(0)βt=clip(1−αˉtαˉt−1,0.999)begin{aligned}
f(t)&=cos{left(frac{t/T+s}{1+s}cdotfrac{pi}{2}right)^2}
bar{alpha}_t&=frac{f(t)}{f(0)}
beta_t&=text{clip}(1-frac{bar{alpha}_t}{bar{alpha}_{t-1}},0.999)
end{aligned}

两种噪声Schedule下,αˉtbar{alpha}_ttt的变化曲线如图1所示,相比线性函数,余弦函数的αˉtbar{alpha}_t下降相对较平缓,因而βtbeta_t相对较小,加噪相对较慢,不会过快地对原始图片加入过多的噪声。

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

Improved DDPM的代码开源,代码地址是:github.com/openai/impr…,其深度学习框架采用PyTorch。训练和采样分别执行以下脚本:

# 训练脚本
python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

# 采样脚本
python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

其中,MODEL_FLAGS、DIFFUSION_FLAGS、TRAIN_FLAGS分别表示模型结构(U-Net)、扩散过程和训练的配置,而基线模型(DDPM)的配置如下:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

如果采用余弦函数作为噪声Schedule,可以将DIFFUSION_FLAGS中的noise_schedule设置为cosine,而相应的计算噪声βtbeta_t的代码在improved_diffusion/gaussian_diffusion.py中,如下:

def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")

# 入参alpha_bar即公式推导中的f(t) 
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)        

对方差进行学习

原始DDPM算法,满足高斯分布的概率密度函数pθ(xt−1∣xt)p_theta(x_{t-1}|x_t)中的方差Σθ(xt,t)Sigma_theta(x_t,t)被固定为常量σt2Isigma_t^2mathbf{I},其中σt2sigma_t^2直接取值βtbeta_t。改进的DDPM算法通过模型对Σθ(xt,t)Sigma_theta(x_t,t)也进行了学习,这样可以使用更少的步数、且生成更高质量的图片。

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

具体如何学习Σθ(xt,t)Sigma_theta(x_t,t)呢?DDPM的论文已推导Σθ(xt,t)Sigma_theta(x_t,t)的取值在βtbeta_tβ~ttilde{beta}_t之间,而β~t:=1−αˉt−11−αˉtβttilde{beta}_t:=frac{1-bar{alpha}_{t-1}}{1-bar{alpha}_t}beta_t,图2表示了β~t/βttilde{beta}_t/beta_ttt之间的关系,从中可见,除了t=0t=0外,其他tt取值下,βtbeta_tβ~ttilde{beta}_t近似相等,所以原始DDPM算法将方差直接取值βtbeta_t,而改进的DDPM算法设计了中间向量vv,由模型预测vv,并将Σθ(xt,t)Sigma_theta(x_t,t)表示为:

Σθ(xt,t)=exp⁡(vlog⁡βt+(1−v)log⁡β~t)Sigma_theta(x_t,t)=exp(vlog{beta_t}+(1-v)log{tilde{beta}_t})

原始DDPM算法的损失函数为:

Lsimple:=Et∼[1,T],x0∼q(x0),ϵ∼N(0,I)[∥ϵ−ϵθ(xt,t)∥2]L_{text{simple}}:=E_{tsim[1,T],x_0sim q(x_0),epsilonsimmathcal{N}(0,mathbf{I})}[parallelepsilon-epsilon_theta(x_t,t)parallel^2]

其中并不包含Σθ(xt,t)Sigma_theta(x_t,t),因此改进的DDPM算法设计了新的损失函数为:

Lhybrid=Lsimple+λLvlbL_text{hybrid}=L_text{simple}+lambda L_text{vlb}

LvlbL_text{vlb}的定义如下:

Lvlb:=L0+L1+⋯+LT−1+LTwhere L0:=−log⁡pθ(x0∣x1)Lt−1:=DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)) for 2≤t≤TLT:=DKL(q(xT∣x0)∥p(xT))begin{aligned}
L_{text{vlb}}&:=L_0+L_1+cdots+L_{T-1}+L_T
text{where }L_0&:=-log{p_theta(x_0|x_1)}
L_{t-1}&:=D_{text{KL}}(q(x_{t-1}|x_{t},x_0)parallel p_theta(x_{t-1}|x_{t}))text{ for } 2le tle T
L_T&:=D_{text{KL}}(q(x_T|x_0)parallel p(x_T))
end{aligned}

损失函数中λlambda被设置为0.001,以减少LvlbL_text{vlb}LsimpleL_text{simple}的影响。另外梯度更新时,LvlbL_text{vlb}部分不更新涉及μθ(xt,t)mu_theta(x_t,t)的参数,只更新涉及Σθ(xt,t)Sigma_theta(x_t,t)的参数。
如果需要对方差进行学习,可以将训练脚本参数MODEL_FLAGS中的learn_sigma设置为True,这样,模型结构(U-Net)的输出维度增加,增加的部分作为vv,相关代码在improved_diffusion/script_util.py的create_model方法中,如下:

    return UNetModel(
        in_channels=3,
        model_channels=num_channels,
        # 模型结构(U-Net)的输出维度增加,增加的部分作为v
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
    )

同时,反向扩散生成图片时,由模型预测方差的代码在improved_diffusion/gaussian_diffusion.py的p_mean_variance方法中,如下:

        # 模型预测
        model_output = model(x, self._scale_timesteps(t), **model_kwargs)

        if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
            assert model_output.shape == (B, C * 2, *x.shape[2:])
            # 按列拆分模型输出,后半部分作为v
            model_output, model_var_values = th.split(model_output, C, dim=1)
            if self.model_var_type == ModelVarType.LEARNED:
                model_log_variance = model_var_values
                model_variance = th.exp(model_log_variance)
            else:
                # 按公式exp(v·log(β_t)+(1-v)·log(β_t))计算方差 
                min_log = _extract_into_tensor(
                    self.posterior_log_variance_clipped, t, x.shape
                )
                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
                # The model_var_values is [-1, 1] for [min_var, max_var].
                frac = (model_var_values + 1) / 2
                model_log_variance = frac * max_log + (1 - frac) * min_log
                model_variance = th.exp(model_log_variance)

而代码如何调整损失函数的计算,在下一节介绍。

训练时采用Importance Sampling

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

论文进一步发现,将损失函数替换为LvlbL_text{vlb}LhybridL_text{hybrid}后,损失函数取值随着训练迭代变化的曲线比较波动,不易收敛,如图3所示。LvlbL_text{vlb}包含多项,每项对应一个步数tt,且每项的取值量纲差别较大,如图4所示,而原始DDPM算法训练时随机采样步数tt,因此论文认为每次训练迭代随机采样步数tt、并进而计算量纲差别大的LtL_t导致LvlbL_text{vlb}LhybridL_text{hybrid}的波动。论文通过训练时采用Importance Sampling来解决上述波动问题。Importance Sampling中,LvlbL_text{vlb}可表示为以下公式:

Lvlb=Et∼pt[Ltpt],where pt∝E[Lt2] and ∑pt=1L_text{vlb}=E_{tsim p_t}left[frac{L_t}{p_t}right],text{where }p_tproptosqrt{Eleft[L_t^2right]}text{ and }sum{p_t}=1

E[Lt2]Eleft[L_t^2right]无法提前求解,且在训练过程中会变化,因此,论文对LvlbL_text{vlb}的每一项LtL_t保留最新的10个取值,并在训练过程中动态更新。训练初期,仍是随机采样步数tt,直至所有的LtL_t均有10个取值,再采用Importance Sampling。从图3可以看出,经过Importance Sampling后的LvlbL_text{vlb}随着训练迭代变化的曲线比较平滑,且损失最小。
如果需要使用Importance Sampling后的LvlbL_text{vlb}作为损失函数,可以将训练脚本参数DIFFUSION_FLAGS中的use_kl设置为True、TRAIN_FLAGS中的schedule_sampler设置为loss-second-moment。Importance Sampling的相关代码在improved_diffusion/resample.py的LossSecondMomentResampler类中,如下:

class LossSecondMomentResampler(LossAwareSampler):
    def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
        self.diffusion = diffusion
        self.history_per_term = history_per_term
        self.uniform_prob = uniform_prob
        self._loss_history = np.zeros(
            [diffusion.num_timesteps, history_per_term], dtype=np.float64
        )
        self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)

    def weights(self):
        # 训练初期,仍是随机采样步数
        if not self._warmed_up():
            return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
        # 根据L_t的历史值计算p_t
        weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
        weights /= np.sum(weights)
        weights *= 1 - self.uniform_prob
        weights += self.uniform_prob / len(weights)
        return weights

    def update_with_all_losses(self, ts, losses):
        for t, loss in zip(ts, losses):
            if self._loss_counts[t] == self.history_per_term:
                # Shift out the oldest loss term.
                self._loss_history[t, :-1] = self._loss_history[t, 1:]
                self._loss_history[t, -1] = loss
            else:
                self._loss_history[t, self._loss_counts[t]] = loss
                self._loss_counts[t] += 1

    def _warmed_up(self):
        return (self._loss_counts == self.history_per_term).all()

使用LvlbL_text{vlb}作为损失函数的相关代码在improved_diffusion/gaussian_diffusion.py的_vb_terms_bpd方法中,如下:

    def _vb_terms_bpd(
        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
    ):
        """
        Get a term for the variational lower-bound.

        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.

        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(
            true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
        )
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
        )
        assert decoder_nll.shape == x_start.shape
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = th.where((t == 0), decoder_nll, kl)
        return {"output": output, "pred_xstart": out["pred_xstart"]}

效果

针对上述改进,论文在ImageNet 64×64和CIFAR-10这两个数据集上分别进行消融实验以验证各改进的有效性,如图5和图6所示。

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

其中,有效性指标使用了NLL和FID。NLL(Negative Log Likelihood)等价于损失函数LvlbL_text{vlb},NLL越小,说明生成图像与真实图像的分布越接近。FID(Fréchet Inception Distance)是另一种用于图像生成质量评估的指标,它可以评估生成图像与真实图像之间的相似度。FID指标的计算方法是使用Inception-v3模型对生成图像和真实图像进行特征提取,并计算两个特征分布之间的Fréchet距离。FID越小,说明生成图像与真实图像越相似。从实验结果上看,噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用Importance Sampling后的LvlbL_text{vlb},NLL最低,但FID较高,而噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用LhybridL_text{hybrid},在NLL、FID上都能取得较小的值。

Improved Denoising Diffusion Probabilistic Models (DDPM) and Its Application in Image Generation

另外,论文还和其他基于似然预估的模型进行了对比实验,如图7所示。优化后的DDPM虽然在NLL和FID上还不是SOTA,但相对也是较优的效果,仅次于基于Transformer的网络结构。

加速

DDPM在生成图片时需要从完全噪声开始执行多步降噪操作,而每步操作均需要将当前步带噪声的图片作为输入由模型预测噪声,导致生成图片需要较多的步骤和计算量。论文也采用了《Denoising Diffusion Implicit Models》提出的采样方法——DDIM,减少步数。

参考文献

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

掘力计划第23期:AIGC的应用和创新技术沙龙精彩回顾

2023-12-20 18:57:14

AI教程

如何用labelme快速制作人脸关键点数据集

2023-12-20 19:01:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索