扩散模型的退化过程

释放双眼,带上耳机,听听看~!
该文章介绍了使用PyTorch和扩散模型处理MNIST数据集的退化过程。通过控制输入噪声量来模拟内容损坏的程度,加入噪声后可视化输入数据进行对比。
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda
dataset = torchvision.datasets.MNIST(root=r"D:Pycharm_ProjectdataMNIST", train=True, download=True, 
                                     transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([3, 5, 1, 7, 0, 6, 8, 6])

扩散模型的退化过程

扩散模型的退化过程

# # 通过引入一个参数来控制输入的噪声量(即内容损坏的程度)
# def corrupt(x,amount):
#     "根据amount为输入的x添加噪声,这就是退化过程"
#     noise = torch.randn_like(x)
#     amount=amount.view(-1,1,1,1) # 整理形状以保证广播机制不出错
#     return x*(1-amount)+noise*amount
def corrupt(x, amount):
    """根据amount为输入x加入噪声,这就是退化过程"""
    noise = torch.rand_like(x) # rand_like函数返回一个与x形状相同的张量,其中的值服从0-1均匀分布,randn_like返回的是标准正态分布
    amount = amount.view(-1, 1, 1, 1) 
    return x*(1-amount) + noise*amount 
# 对输出结果可视化
# 绘制输入数据
fig, axs = plt.subplots(2, 1, figsize=(7, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# 加入噪声
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# 绘制加噪版本的图像
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');


扩散模型的退化过程

扩散模型的训练

class BasicUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        # 激活函数
        self.act = nn.SiLU()
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            # 通过运算层与激活函数
            x = self.act(l(x))
            if i < 2:
                # 排列供残差连接使用的数据
                h.append(x) 
                # 连接下采样
                x = self.downscale(x) 
              
        for i, l in enumerate(self.up_layers):
            if i > 0: 
                # 连接上采样
                x = self.upscale(x) 
                # 得到之前排列好的供残差连接使用的数据
                x += h.pop()
            x = self.act(l(x)) 
            
        return x

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape

torch.Size([8, 1, 28, 28])
# 查看网络参数个数
sum(p.numel() for p in net.parameters())
309057
# 训练参数
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
n_epochs = 3

# 创建网络
net = BasicUNet()
net.to(device)

loss_fn = nn.MSELoss()

# 优化器
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

losses = []

# 开始训练
for epoch in range(n_epochs):
    for x,y in train_dataloader:
        # 得到数据并准备开始退化
        x = x.to(device)
        # 随机噪声
        noise_amount = torch.rand(x.shape[0], device=device)
        # 退化
        noised_x=corrupt(x,noise_amount)

        # 得到预测结果
        pred = net(noised_x)

        # 计算损失
        loss = loss_fn(pred, x)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 记录损失
        losses.append(loss.item())
    
    # 输出损失的平均值
    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f'Epoch {epoch+1}/{n_epochs}: average loss {avg_loss:.5f}')
Epoch 1/3: average loss 0.02758
Epoch 2/3: average loss 0.02110
Epoch 3/3: average loss 0.01915
plt.plot(losses)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

扩散模型的退化过程

对比输入数据、退化数据、预测数据

# 我们也可以尝试通过抓取一批数据来得到不同程度的损坏数据,然后将他们输入模型以获得预测

# 可视化模型在“带噪”输入上的表现
# 生成一批数据
x, y = next(iter(train_dataloader))
x=x[:8].to(device)

# 在(0,1)范围内选择退化量
amount = torch.linspace(0, 1, x.shape[0], device=device)
noised_x = corrupt(x, amount)

# 得到预测结果
with torch.no_grad():
    pred = net(noised_x).detach().cpu()

# 绘制结果
fig, axs = plt.subplots(3, 1, figsize=(12,7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x.cpu())[0].clip(0,1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x.cpu())[0].clip(0,1), cmap='Greys')
axs[2].set_title('Network output')
axs[2].imshow(torchvision.utils.make_grid(pred)[0].clip(0,1), cmap='Greys')
<matplotlib.image.AxesImage at 0x206ffb066a0>

扩散模型的退化过程

扩散模型的采样过程

采样过程的方案:从完全随机的噪声开始,先检查一下模型的预测结果,然后只朝着预测方向移动一小部分。

可以理解为:将带有噪声的图像输入到模型中,得到一个预测输出,如果当前输出结果稍微好一点,那么将这次的预测输出重新作为输入再次输入到模型

def sample_with_step(x, n_steps):
    step_history = [x.detach().cpu()]
    pred_output_history = []

    for i in range(n_steps):
        with torch.no_grad(): 
            pred = net(x)  # 预测去噪结果
        # 将模型输出保存下来    
        pred_output_history.append(pred.detach().cpu())
        # 朝着预测方向移动的因子(移动多少)
        mix_factor = 1/(n_steps - i) 
        x = x*(1-mix_factor) + pred*mix_factor # 移动过程
        step_history.append(x.detach().cpu()) 
    
    return x, step_history, pred_output_history


n_steps = 5
# 完全随机的值开始
x = torch.rand(8, 1, 28, 28).to(device)
x, step_history, pred_output_history = sample_with_step(x, n_steps)

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')

扩散模型的退化过程

也可以将采样过程拆解成更多步,以获得质量更高的图像

n_steps = 40
# 完全随机的值开始
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
    noise_amount = torch.ones((x.shape[0],)).to(device)*(1-(i/n_steps))
    with torch.no_grad():
        pred = net(x)
    mix_factor = 1/(n_steps - i)
    x = x*(1-mix_factor) + pred*mix_factor
fig, axs = plt.subplots(1, 1, figsize=(12, 12))
axs.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0, 1), cmap='Greys')
<matplotlib.image.AxesImage at 0x2069bb03e50>

扩散模型的退化过程

UNet2DModel

UNet2DModel与DDPM对比:

  • UNet2DModel比BasicUNet更先进。
  • 退化过程的处理方式不同。
  • 训练目标不同,包括预测噪声而不是去噪图像。
  • UNet2DModel模型通过调节时间步来调节噪声量, 其中t作为一个额外参数传入前向过程中。

Diffusers库中的UNet2DModel模型比BasicUNet模型有如下改进:

  • GroupNorm层对每个模块的输入进行了组标准化(group normalization)。
  • Dropout层能使训练更平滑。
  • 每个块有多个ResNet层(如果layers_per_block未设置为1)。
  • 引入了注意力机制(通常仅用于输入分辨率较低的blocks)。
  • 可以对时间步进行调节。
  • 具有可学习参数的上采样模块和下采样模块。
net = UNet2DModel(
    sample_size=28,           # 目标图像的分辨率
    in_channels=1,            
    out_channels=1,           
    layers_per_block=2,       # 每一个UNet块中的ResNet层数
    block_out_channels=(32, 64, 64), 
    down_block_types=( 
        "DownBlock2D",        # 下采样模块
        "AttnDownBlock2D",    # 带有空域维度的self-att的ResNet下采样模块
        "AttnDownBlock2D",
    ), 
    up_block_types=(
        "AttnUpBlock2D", 
        "AttnUpBlock2D",      # 带有空域维度的self-att的ResNet上采样模块
        "UpBlock2D",          # 上采样模块
      ),
)

sum([p.numel() for p in net.parameters()])
# UNet2DModel模型大约有170万个参数,而BasicUNet模型只有30多万个参数。
1707009
# 训练数据加载器
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3

net.to(device)

loss_fn = nn.MSELoss()

opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

losses = []

# 开始训练
for epoch in range(n_epochs):

    for x, y in train_dataloader:
        # 得到数据并准备退化
        x = x.to(device)
        # 随机噪声
        noise_amount = torch.rand(x.shape[0]).to(device) 
        # 退化过程
        noisy_x = corrupt(x, noise_amount) 
        
        # 得到预测结果
        pred = net(noisy_x, 0).sample
        
        # 计算损失值
        loss = loss_fn(pred, x) 
        
        # 反向传播并更新参数
        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())
    
    # 输出损失的均值
    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

Finished epoch 0. Average loss for this epoch: 0.018955
Finished epoch 1. Average loss for this epoch: 0.012771
Finished epoch 2. Average loss for this epoch: 0.011652
fig, axs = plt.subplots(1, 2, figsize=(8, 3))

axs[0].plot(losses)
axs[0].set_ylim(0, 0.1)
axs[0].set_title('Loss over time')

n_steps = 100
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
    noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps))
    with torch.no_grad():
        pred = net(x, 0).sample
    mix_factor = 1/(n_steps - i)
    x = x*(1-mix_factor) + pred*mix_factor

axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Generated Samples');

扩散模型的退化过程

扩散模型的退化过程示例

退化过程:

在某个时间步给定x(t−1)x_(t-1) ,可以得到一个噪声稍微增加的 x(t)x_(t)q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x1:T∣x0)=∏t=1Tq(xt∣xt−1)q(x_t | x_{t-1}) = mathcal{N}(x_t; sqrt{1 – beta_t} x_{t-1}, beta_t I) q(x_{1:T} | x_0) = prod^T_{t=1} q(x_t | x_{t-1})

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# 对一批图片加噪
fig, axs = plt.subplots(3, 1, figsize=(10, 6))
xb, yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb * 2. - 1.
print('X shape', xb.shape)

# 展示干净的原始输入图片
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap='Greys')
axs[0].set_title('Clean X')

# 使用调度器加噪
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print('Noisy X shape', noisy_xb.shape)

# 展示“带噪”版本
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1),  cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),  cmap='Greys')
axs[2].set_title('Noisy X');

X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

扩散模型的退化过程

  • 模型会预测退化过程中使用的噪声。
  • 预测噪声这个目标会使权重更倾向于预测得到更低的噪声量。

扩展知识

时间步可以转换为embedding,在多个地方被输入模型。

输入纯噪声,在模型预测的基础上使用足够多的小步,不断迭代,每次去除一点点噪声。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

知识图谱构建与落地实践

2023-11-24 0:05:14

AI教程

生成式AI如何提高开发者生产力?影响及使用技巧一览

2023-11-24 1:05:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索