当前位置:首页> AI教程> 深度学习中的扩散模型及不同schedule原理解析

深度学习中的扩散模型及不同schedule原理解析

释放双眼,带上耳机,听听看~!
本文深度解析了深度学习中的扩散模型以及不同schedule原理,包括线性schedule和余弦schedule的代码分析和对应关系,适合对深度学习感兴趣的读者阅读。

本文适用人群:

  • 知道什么是扩散模型

  • 了解不用schedule原理

碎碎念

这不能算重复造轮子

事情起因是这样的,我今天想换一个schedule试一下。我是做文本的,在文本的schedule中,我觉得应该试一下sqrt schedule。但是我很确定我现在用的代码里没有这个schedule。

深度学习中的扩散模型及不同schedule原理解析

我知道这个代码的出处,然后我想,这是2022年的版本了,现在去看看原作者仓库是不是可以抄到新代码。所以我就跑github看了。

深度学习中的扩散模型及不同schedule原理解析

很不幸,作者只是加了一个sigmoid_beta_schedule

然后我转念一想,提出sqrt schedule的论文中也有 cosine schedule 和 linear schedule,所以我为什么不用人家源码呢?

很遗憾,人家的代码是自己写的,和我抄的代码根本不是一套体系的。所以只好根据人家实现方法,再按照我原来的代码格式重写一个了。

深度学习中的扩散模型及不同schedule原理解析

代码分析

linear

两端代码的对应方式如下:

Diffusion LM:

    if schedule_name == "linear":
        scale = 1000 / timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(beta_start, beta_end, timesteps, dtype=np.float64)

DDPM:

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

这俩代码的对应关系简单明了,就是numpy的版本改pytorch版。

cosine

DDPM:

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

Diffusion LM:

    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
        
    def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)

这个diffusion lm 的写法就稍微有点绕了。

我们先分析DDPM的写法:

  1. def cosine_beta_schedule(timesteps, s=0.008)::这是函数的定义,它接受两个参数,timesteps 表示时间步数,s 是一个可选参数,用于调整计算中的偏差。

  2. steps = timesteps + 1:这一行计算了 steps 变量,它表示生成的权重序列的长度,通常比输入的时间步数多1,因为权重是在时间步之间计算的。

  3. x = torch.linspace(0, timesteps, steps, dtype=torch.float64):这一行创建了一个名为 x 的张量,其中包含从0到 timesteps 的等间隔值,数量为 steps 个。dtype=torch.float64 指定了张量的数据类型为双精度浮点数。

  4. alphas_cumprod 的计算:这一行计算了一组权重值 alphas_cumprod。具体计算步骤如下:

    • (x / timesteps):首先,将 x 中的每个值除以 timesteps,得到一个范围在0到1之间的归一化值。
    • ((x / timesteps) + s) / (1 + s):然后,将这些归一化值加上参数 s,并除以 (1 + s),这样可以将范围调整到 (s / (1 + s)) 到 1 之间。
    • math.pi * 0.5:接着,将上述值乘以 π/2,将范围映射到0到π/2之间。
    • torch.cos(...):然后,计算这些值的余弦值,并取平方。这将使得这些值在0到1之间波动,具有余弦形状。
    • alphas_cumprod / alphas_cumprod[0]:最后,将这些余弦值除以第一个值,以确保它们从1开始,并随时间步数而变化。
  5. betas 的计算:这一行计算了一组 betas 值,它们表示权重之间的差异。具体计算步骤如下:

    • alphas_cumprod[1:]:首先,选择 alphas_cumprod 中除第一个元素之外的所有元素。
    • alphas_cumprod[:-1]:然后,选择 alphas_cumprod 中除最后一个元素之外的所有元素。
    • 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]):接着,计算这两个选择的元素之间的差异,并将其赋值给 betas
  6. return torch.clip(betas, 0, 0.999):最后,返回 betas 张量,但经过截断,将小于0的值截断为0,将大于0.999的值截断为0.999,以确保权重在0到0.999之间,通常用于加权操作。

接下来我们看DDPM的写法:

因为写的比较绕,所以把这两个函数合并一下。修改之后是对的嗷,看一下除了输出一个是numpy数组一个是torch tensor,别的没区别。

接下来分析代码:

深度学习中的扩散模型及不同schedule原理解析

深度学习中的扩散模型及不同schedule原理解析

标注了一下对应关系,torch.linspace其实就是对应的for循环。计算cosine值的函数也是固定。

x / timesteps对应的就是 t2

计算alphas_cumprod的步骤就是alpha_bar(t)

min()的过程就是对应最后那个torch.clip()的截断操作。

这里DDPM版本更胜一筹的地方在于s作为一个超参数可以自己调节。但是diffusion lm中是写死的。

写!

深度学习中的扩散模型及不同schedule原理解析

def sqrt_beta_schedule(timesteps):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = 1 - np.sqrt(x / timesteps + 0.0001)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

一次成功。今天抄代码又成功了。

深度学习中的扩散模型及不同schedule原理解析

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

Adobe AI新功能Generative Fill实测:PS革命性改变者!

2023-11-22 18:36:14

AI教程

ChatGPT进入多模态时代,Midjourney图像识别功能解析

2023-11-22 18:43:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索