当前位置:首页> AI教程> 卷积块注意力模块(CBAM)在卷积神经网络中的应用示例

卷积块注意力模块(CBAM)在卷积神经网络中的应用示例

释放双眼,带上耳机,听听看~!
本文介绍了如何在卷积神经网络中使用卷积块注意力模块(CBAM),并提供了对Alexnet添加CBAM模块的示例,以增强网络的表现能力和性能。

前言

  在2018年提出了了卷积块注意力模块(CBAM),提出了一种简单而有效的前馈卷积神经网络注意力模块,它可以对一个中间特征图在CA&SA(通道&空间)上单独进行推理,然后将注意力结果图与输入特征图融合处理。在本文中为大家带来一个简单的小示例,以Alexnet为CNN基础模板,对Alexnet添加CBAM模块,希望对大家有帮助。

理论基础

  我们对卷积网络中添加注意力机制模块希望达到对核心特征进行多注意,对非核心特征过滤进而达到网络的表现能力增强的效果。由于卷积的运算是通过混合了通道和空间信息进行提取信息,那么我们的注意力模块对通道和空间进行“注意”,达到目的。
卷积块注意力模块(CBAM)在卷积神经网络中的应用示例

通道(CA)模块

  如下图所示,通道子模块利用共享网络的最大池输出和平均池输出;空间子模块利用类似的两个输出,沿通道轴汇集并将它们转发到卷积层。
卷积块注意力模块(CBAM)在卷积神经网络中的应用示例


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        print("in_planes=", in_planes, "in_planes // ratio = ", in_planes // ratio)
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
                                nn.ReLU(),
                                nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))

        out = avg_out + max_out
        return self.sigmoid(out)

空间(SA)模块

  SA模块利用特征的空间间关系生成空间注意图。与通道注意不同,空间注意所关注的“在哪里”是一个信息部分,与通道注意相呼应。为了计算空间注意力,我们首先沿通道轴应用平均池和最大池操作,这样可以有效地突出显示信息域。
卷积块注意力模块(CBAM)在卷积神经网络中的应用示例

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

实操简介

  综上所述,我们已经知道了CA&SA模块的应用位置,我们在本文中采用的卷积模板为较为经典的Alexnet网络,大家可自行参考【基础实操】借用torch自带网络进行训练自己的图像数据中介绍到的训练方法,这里我们将延续该篇中的方式,仅对Alexnet部分进行修改匹配。

  在Alexnet中定义ca&sa模块,对self.features中的最后两层注释,便于由卷积层接CA&SA模块。具体的添加如下所示:

class AlexNet(nn.Module):

    def __init__(self, num_classes: int = 1000, planes=256, downsample=None) -> None:
        super(AlexNet, self).__init__()

        self.ca = ChannelAttention(in_planes=256)
        self.sa = SpatialAttention()
        self.downsample = downsample
        self.bn2 = nn.BatchNorm2d(planes)

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=3, stride=2),

        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),

            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.features(x)
        out = self.bn2(out)
        out = self.ca(out) * out
        out = self.sa(out) * out
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

结语

  使用CBMA模块可以保持较小的计算开销,能够得到显著的性能提升,这是我们在学习和工程上愿意看到的结果,也是推动项目进展的有利方法。由于本次写的匆忙,本文中还有阐述不清楚的地方,望各位路过的大神,不吝啬赐教!

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

实战多模态视频检索项目分享

2023-12-16 12:17:14

AI教程

ControlNet的论文原文解读和代码实现

2023-12-16 12:28:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索