自注意力生成对抗网络(SAGAN)- 图像生成中的远距离依赖性建模

释放双眼,带上耳机,听听看~!
SAGAN是一种自注意力生成对抗网络,允许对图像生成进行注意力驱动的远距离依赖性建模。通过引入自注意力和谱归一化等稳定策略,SAGAN在ImageNet数据集上的表现更好,提高了Inception得分并降低了Fréchet Inception距离。

简介

提出了自注意力生成对抗网络(SAGAN),该网络允许对图像生成进行注意力驱动的远距离依赖性建模。传统卷积 GAN 仅根据低分辨率特征图中的空间局部点生成高分辨率细节。在 SAGAN 中,可以使用来自所有特征位置的信息来生成细节。此外,判别器可以检查与图像的远处的特征是否一致。此外,最近的研究表明,约束生成器会影响GAN性能。利用这一见解,将谱归一化应用于 GAN 生成器,并发现这稳定了训练。所提出的 SAGAN 比之前表现更好的工作,在具有挑战性的 ImageNet 数据集上,将发布的 Inception 得分从 36.8 提高到 52.52,并将 Fréchet Inception 距离从 27.62 降低到 18.65。注意力层的可视化显示,生成器利用了与对象形状相对应的邻域,而不是固定形状的局部区域。

引入自注意力

自注意力生成对抗网络(SAGAN)- 图像生成中的远距离依赖性建模

在生成器和判别器进行卷积的过程中,由表的实验表明,将自注意力加在中高维的特征图上,取得的效果比较好。过程:通过 1×1 的卷积映射到三个空间,

自注意力生成对抗网络(SAGAN)- 图像生成中的远距离依赖性建模

损失

自注意力生成对抗网络(SAGAN)- 图像生成中的远距离依赖性建模

稳定的策略

  1. 在生成器和判别器中使用谱归一化

  2. 生成器和判别器的学习率设置不一样,分别是 0.0001 和 0.0004

  3. 亚当

法典

谱归一化

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(01), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(01), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

自注意力

class Self_Attn(nn.Module):
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, widthheight = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # [B, N, C//8]
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # [B, C//8, N]
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # [B, N, N] 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # [B, C, N]

        out = torch.bmm(proj_value,attention.permute(0,2,1))  # [B, C, N] 
        out = out.view(m_batchsize,C,width,height)    # [B, C, W, H]
        
        out = self.gamma*out + x
        return out,attention

生成器

class Generator(nn.Module):
    def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
        super(Generator, self).__init__()
        self.imsize = image_size
        layer1 = []
        layer2 = []
        layer3 = []
        last = []

        repeat_num = int(np.log2(self.imsize)) - 3 # 3
        mult = 2 ** repeat_num   # 8 
        layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
        layer1.append(nn.BatchNorm2d(conv_dim * mult))
        layer1.append(nn.ReLU())

        curr_dim = conv_dim * mult # 64 * 8

        layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
        layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
        layer2.append(nn.ReLU())

        curr_dim = int(curr_dim / 2# 64 * 4

        layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
        layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
        layer3.append(nn.ReLU())

        if self.imsize == 64:
            layer4 = []
            curr_dim = int(curr_dim / 2# 64 * 2
            layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
            layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
            layer4.append(nn.ReLU())
            self.l4 = nn.Sequential(*layer4)
            curr_dim = int(curr_dim / 2# 64

        self.l1 = nn.Sequential(*layer1)
        self.l2 = nn.Sequential(*layer2)
        self.l3 = nn.Sequential(*layer3)

        last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
        last.append(nn.Tanh())
        self.last = nn.Sequential(*last)

        self.attn1 = Self_Attn( 128'relu')
        self.attn2 = Self_Attn( 64,  'relu')

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 11#[N, 100, 1, 1]
        out=self.l1(z)       # [N, 64x8, 4, 4]
        out=self.l2(out)      # [N, 64x4, 8, 8]
        out=self.l3(out)      # [N, 64x2, 16, 16]
        out,p1 = self.attn1(out)
        out=self.l4(out)      # [N, 64, 32, 32]
        out,p2 = self.attn2(out)
        out=self.last(out)      # [N, 3, 64, 64]

        return out, p1, p2

判别器

class Discriminator(nn.Module):
    def __init__(self, batch_size=64, image_size=64, conv_dim=64):
        super(Discriminator, self).__init__()
        self.imsize = image_size
        layer1 = []
        layer2 = []
        layer3 = []
        last = []

        layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
        layer1.append(nn.LeakyReLU(0.1))

        curr_dim = conv_dim

        layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
        layer2.append(nn.LeakyReLU(0.1))
        curr_dim = curr_dim * 2

        layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
        layer3.append(nn.LeakyReLU(0.1))
        curr_dim = curr_dim * 2

        if self.imsize == 64:
            layer4 = []
            layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
            layer4.append(nn.LeakyReLU(0.1))
            self.l4 = nn.Sequential(*layer4)
            curr_dim = curr_dim*2
        self.l1 = nn.Sequential(*layer1)
        self.l2 = nn.Sequential(*layer2)
        self.l3 = nn.Sequential(*layer3)

        last.append(nn.Conv2d(curr_dim, 1, 4))
        self.last = nn.Sequential(*last)

        self.attn1 = Self_Attn(256'relu')
        self.attn2 = Self_Attn(512'relu')

    def forward(self, x):    # [B, 3, 64, 64]
        out = self.l1(x)    # [B, 64, 32, 32]
        out = self.l2(out)    # [B, 128, 16, 16]
        out = self.l3(out)    # [B, 256, 8, 8]
        out,p1 = self.attn1(out)
        out=self.l4(out)    # [B, 512, 4, 4]
        out,p2 = self.attn2(out)
        out=self.last(out)    # [B, 1, 1, 1]

        return out.squeeze(), p1, p2

损失

d_out_real,dr1,dr2 = self.D(real_images)
d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
fake_images,gf1,gf2 = self.G(z)
d_out_fake,df1,df2 = self.D(fake_images)
d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
d_loss = d_loss_real + d_loss_fake

z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
fake_images,_,_ = self.G(z)
g_out_fake,_,_ = self.D(fake_images)
g_loss_fake = - g_out_fake.mean()

参考链接:

github.com/heykeetae/S…

arxiv.org/abs/1805.08…

ONE MORE THING

咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

清华大学开源多模态对话模型VisualGLM-6B能解读表情包

2023-12-21 19:05:14

AI教程

ShuffleNet V2 论文的个人理解分析及实用指导思想

2023-12-21 19:09:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索