简介
提出了自注意力生成对抗网络(SAGAN),该网络允许对图像生成进行注意力驱动的远距离依赖性建模。传统卷积 GAN 仅根据低分辨率特征图中的空间局部点生成高分辨率细节。在 SAGAN 中,可以使用来自所有特征位置的信息来生成细节。此外,判别器可以检查与图像的远处的特征是否一致。此外,最近的研究表明,约束生成器会影响GAN性能。利用这一见解,将谱归一化应用于 GAN 生成器,并发现这稳定了训练。所提出的 SAGAN 比之前表现更好的工作,在具有挑战性的 ImageNet 数据集上,将发布的 Inception 得分从 36.8 提高到 52.52,并将 Fréchet Inception 距离从 27.62 降低到 18.65。注意力层的可视化显示,生成器利用了与对象形状相对应的邻域,而不是固定形状的局部区域。
引入自注意力
在生成器和判别器进行卷积的过程中,由表的实验表明,将自注意力加在中高维的特征图上,取得的效果比较好。过程:通过 1×1 的卷积映射到三个空间,
损失
稳定的策略
-
在生成器和判别器中使用谱归一化
-
生成器和判别器的学习率设置不一样,分别是 0.0001 和 0.0004
-
亚当
法典
谱归一化
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
自注意力
class Self_Attn(nn.Module):
def __init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # [B, N, C//8]
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # [B, C//8, N]
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # [B, N, N]
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # [B, C, N]
out = torch.bmm(proj_value,attention.permute(0,2,1)) # [B, C, N]
out = out.view(m_batchsize,C,width,height) # [B, C, W, H]
out = self.gamma*out + x
return out,attention
生成器
class Generator(nn.Module):
def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
super(Generator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []
repeat_num = int(np.log2(self.imsize)) - 3 # 3
mult = 2 ** repeat_num # 8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
layer1.append(nn.BatchNorm2d(conv_dim * mult))
layer1.append(nn.ReLU())
curr_dim = conv_dim * mult # 64 * 8
layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())
curr_dim = int(curr_dim / 2) # 64 * 4
layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())
if self.imsize == 64:
layer4 = []
curr_dim = int(curr_dim / 2) # 64 * 2
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer4.append(nn.ReLU())
self.l4 = nn.Sequential(*layer4)
curr_dim = int(curr_dim / 2) # 64
self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)
last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
self.last = nn.Sequential(*last)
self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1) #[N, 100, 1, 1]
out=self.l1(z) # [N, 64x8, 4, 4]
out=self.l2(out) # [N, 64x4, 8, 8]
out=self.l3(out) # [N, 64x2, 16, 16]
out,p1 = self.attn1(out)
out=self.l4(out) # [N, 64, 32, 32]
out,p2 = self.attn2(out)
out=self.last(out) # [N, 3, 64, 64]
return out, p1, p2
判别器
class Discriminator(nn.Module):
def __init__(self, batch_size=64, image_size=64, conv_dim=64):
super(Discriminator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []
layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
layer1.append(nn.LeakyReLU(0.1))
curr_dim = conv_dim
layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer2.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2
layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer3.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2
if self.imsize == 64:
layer4 = []
layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer4.append(nn.LeakyReLU(0.1))
self.l4 = nn.Sequential(*layer4)
curr_dim = curr_dim*2
self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)
last.append(nn.Conv2d(curr_dim, 1, 4))
self.last = nn.Sequential(*last)
self.attn1 = Self_Attn(256, 'relu')
self.attn2 = Self_Attn(512, 'relu')
def forward(self, x): # [B, 3, 64, 64]
out = self.l1(x) # [B, 64, 32, 32]
out = self.l2(out) # [B, 128, 16, 16]
out = self.l3(out) # [B, 256, 8, 8]
out,p1 = self.attn1(out)
out=self.l4(out) # [B, 512, 4, 4]
out,p2 = self.attn2(out)
out=self.last(out) # [B, 1, 1, 1]
return out.squeeze(), p1, p2
损失
d_out_real,dr1,dr2 = self.D(real_images)
d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
fake_images,gf1,gf2 = self.G(z)
d_out_fake,df1,df2 = self.D(fake_images)
d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
d_loss = d_loss_real + d_loss_fake
z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
fake_images,_,_ = self.G(z)
g_out_fake,_,_ = self.D(fake_images)
g_loss_fake = - g_out_fake.mean()
参考链接:
ONE MORE THING
咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)