深度学习蒸馏两种蒸馏方式详解

释放双眼,带上耳机,听听看~!
本文将详细解释深度学习中的蒸馏机制,包括基础蒸馏、自适应蒸馏等方法,并提供实际案例和代码实现,帮助读者更好地理解和应用蒸馏机制。

关键词:轻量化 、 蒸馏 、 基础蒸馏 、 负责蒸馏 、 自适应蒸馏 、自适应蒸馏

导读

  深度学习中的蒸馏机制是一种有效的模型压缩方法,可以将大型神经网络压缩为小型神经网络,同时可以最大限度的保持原始网络的性能。在实际应用中,蒸馏机制可以帮助我们在资源受限的设备上部署更加高效的神经网络模型,从而在保证准确性的前提下提升计算效率。本文将会深入探讨蒸馏机制的工作原理、应用场景和实现方法,并为读者提供实际案例和代码实现,帮助大家更好地理解和应用蒸馏机制。

前言

  在本文中的实验数据集为花朵数据集,该数据集类别数经过筛选总共为5类,数据集划分为训练集与测试集,训练:测试 = 500 : 150。

深度学习蒸馏两种蒸馏方式详解

深度学习蒸馏两种蒸馏方式详解
   在网络的选择方面,我们选择高精度的大模型为Vgg11,小模型为轻量级的
shufflenet_v2_x0_5。这两个模型在torchvisionmodels中都可以找到,有现成的,不需要我们自己手撸代码。这两个权重的大小分别为:

        深度学习蒸馏两种蒸馏方式详解

实操蒸馏

   常见的蒸馏有四大类:基础蒸馏、负责蒸馏、激励蒸馏和自适应蒸馏,下面我们会对这四大类蒸馏进行实操演示。

基础模块

   在实操之前,我们有一些代码是公共的,这里包括数据集的获取,网络的定义及调用,可以构建一个【init.py】文件进行定义基础公共代码部分:

import torch
import torchvision
from torch.utils.data import DataLoader
from shufflenetv2 import shufflenet_v2_x0_5
from vgg import vgg11
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_Data():
    train_root = 'D:/Data_ALL/datas/train/'
    test_root = 'D:/Data_ALL/datas/test/'
    # 将文件夹的内容载入dataset
    train_dataset = torchvision.datasets.ImageFolder(root=train_root, transform=torchvision.transforms.ToTensor())
    test_dataset = torchvision.datasets.ImageFolder(root=test_root, transform=torchvision.transforms.ToTensor())

    train_dataloader = DataLoader(train_dataset, batch_size=12,shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=8)

    train_num = len(train_dataset)
    test_num = len(test_dataset)
    return train_dataloader, test_dataloader, train_num, test_num

def get_net():
    model_t = vgg11(pretrained=False, num_classes=5).to(device)
    model_s = shufflenet_v2_x0_5(pretrained=False, num_classes=5).to(device)
    return model_t, model_s


def Soomth_temp(init_temp, global_step, decay_steps=0.07):

    temp = init_temp - (global_step * decay_steps)
    if temp < 1:
        temp = 1
    return temp

测试模块

   同时再构建【demo.py】函数对核心代码部分的main部分进行设定好损失函数以及优化器和读取init.py数据。


def test(dataloader):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for test_image, test_label in dataloader:
            test_image, test_label = test_image.to(device), test_label.to(device)
            test_pred = model_s(test_image)
            test_loss += criterion(test_pred, test_label).item()
            correct += (test_pred.argmax(1) == test_label).type(torch.float).sum().item()

    test_loss /= test_num
    correct /= size
    print(f"Test Error: n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} n")


if __name__ == "__main__":

    # 设置网络的部分
    model_t, model_s = get_net()
    # 设置数据集的部分
    train_dataloader, test_dataloader, train_num, test_num = get_Data()
    # 定义温度参数T
    T = 10
    # 定义交叉熵损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = optim.Adam(net_s.parameters(), lr=0.001)
    # 设置超参数的部分
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.94)    # LR变动的学习率
    epochs = 100
    for epoch in range(epochs):
        train(train_dataloader)
        test(test_dataloader)

   需要注意的是,在训练结束后,我们需要使用小型模型进行预测,而不是使用带有温度参数的“软目标”进行预测。

基础蒸馏

   基础蒸馏是最简单的蒸馏机制,它通过让小型模型学习大型模型的预测结果来传递知识。在这种方法中,大型模型的输出被视为“软目标”,并被用作小型模型的训练目标。这种方法的优点是简单易用,但是它通常需要使用较高的温度参数,导致模型的精度不如其他方法。

from torch import optim, nn
from init import *
import torch.nn.functional as F

# 定义训练函数
def train(dataloader):
    model_s.train()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model_s(inputs)

        # 对大型模型输出进行softmax转换并缩放温度参数
        teacher_outputs = model_t(inputs).detach()
        teacher_outputs = nn.functional.softmax(teacher_outputs / T, dim=1)

        # 计算损失函数,其中使用“软目标”作为辅助损失
        loss = criterion(outputs, labels) + nn.KLDivLoss()(nn.functional.log_softmax(outputs / T, dim=1),
                                                           teacher_outputs) * T * T

        # 反向传播更新参数
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   在训练函数中,我们首先将大型模型的输出进行softmax转换,并使用温度参数对其进行缩放,从而得到“软目标”的分布。然后,我们使用“软目标”作为辅助损失来帮助小型模型进行训练。最后,我们将两个损失函数相加,得到最终的损失函数,并使用反向传播更新参数。在训练过程中,我们使用KL散度来衡量小型模型输出分布与“软目标”分布之间的距离,从而提高小型模型的泛化能力。

负责蒸馏

   负责蒸馏是一种基于网络层次结构的蒸馏方法。它的基本思想是让小型模型学习大型模型中的“特定”层次信息,如中间层的特征图。这些信息可以被看作是大型模型的“中间目标”,并被用作小型模型的训练目标。这种方法的优点是可以更精确地传递知识,从而获得更高的模型精度。

def train(dataloader):
    alpha = 0.5
    model_s.train()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)

        # 对大型模型输出进行softmax转换并缩放温度参数
        teacher_outputs = model_t(inputs).detach()
        teacher_outputs = nn.functional.softmax(teacher_outputs / T, dim=1)

        # 计算损失函数,其中使用“软目标”作为辅助损失
        loss = (1 - alpha) * criterion(outputs, labels) + alpha * nn.KLDivLoss()(
            nn.functional.log_softmax(outputs / T, dim=1), teacher_outputs) * T * T

        # 反向传播更新参数
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   我们使用VGG11作为大型模型,ShuffleNet V2 x0.5作为小型模型。在训练函数中,我们首先对大型模型的输出进行softmax转换,并使用温度参数对其进行缩放,从而得到“软目标”的分布。然后,我们对小型模型的输出进行相同的操作。接着,我们通过对所有层的输出计算KL散度来定义辅助损失函数,并将所有辅助损失函数相加,得到总辅助损失。最后,我们将交叉熵损失函数和总辅助损失加权相加,得到最终的损失函数,并使用反向传播更新参数。在训练过程中,我们使用KL散度来衡量小型模型输出分布与“软目标”分布之间的距离,并使用权重参数。

激励蒸馏

   激励蒸馏是一种基于激活函数的蒸馏方法,它通过让小型模型学习大型模型中的激活函数来传递知识。具体而言,我们可以通过让小型模型学习大型模型中的激活函数输出,来获得更加精确的模型压缩效果。这种方法的优点是可以更精确地控制模型的软硬程度,从而获得更好的模型压缩效果。

alpha = 0.5
def train(dataloader):
    model_s.train()
    model_t.eval()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)

        # 获取大模型的中间层激活
        with torch.no_grad():
            teacher_outputs = model_t.features(inputs)  # 大家可以自行寻找

        # 获取小模型的中间层激活
        student_outputs = model_s.features(inputs)  # 大家可以自行寻找

        # 定义辅助损失函数
        loss = 0
        for i in range(len(teacher_outputs)):
            # 对大模型和小模型的中间层激活进行softmax转换并缩放温度参数
            teacher_activation = nn.functional.softmax(teacher_outputs[i] / T, dim=1)
            student_activation = nn.functional.softmax(student_outputs[i] / T, dim=1)
            # 计算中间层激活的KL散度作为辅助损失
            loss += nn.KLDivLoss()(student_activation, teacher_activation.detach()) * T * T

        # 计算主要损失函数
        loss += (1 - alpha) * criterion(outputs, labels)

        # 反向传播更新参数
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

自适应蒸馏

   自适应蒸馏是一种基于动态温度控制的蒸馏方法,它通过动态调整温度参数来实现更好的模型压缩效果。具体而言,我们可以根据小型模型的训练进展来自适应地调整温度参数,从而达到更好的模型压缩效果。这种方法的优点是可以自适应地调整模型的软硬程度,从而获得更好的压缩效果。

alpha = 0.5
def train(dataloader):
    model_s.train()
    model_t.eval()
    for batch, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_s(inputs)

        # 获取大模型的中间层激活
        with torch.no_grad():
            teacher_outputs = model_t.features(inputs)

        # 获取小模型的中间层激活
        student_outputs = model_s.features(inputs)

        # 定义辅助损失函数
        loss = 0
        for i in range(len(teacher_outputs)):
            # 计算大模型和小模型中间层激活的L2范数,作为动态的温度参数
            T = torch.norm(teacher_outputs[i] - student_outputs[i], p=2) / torch.numel(teacher_outputs[i])

            # 对大模型和小模型的中间层激活进行softmax转换并缩放温度参数
            teacher_activation = nn.functional.softmax(teacher_outputs[i] / T, dim=1)
            student_activation = nn.functional.softmax(student_outputs[i] / T, dim=1)

            # 计算中间层激活的KL散度作为辅助损失
            loss += nn.KLDivLoss()(student_activation, teacher_activation.detach()) * T * T

        # 计算主要损失函数
        loss += criterion(outputs, labels)

        # 反向传播更新参数
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss = loss.item()
            print(f"loss: {loss:>7f}")

   在上述示例中,我们使用了 PyTorch 的 vgg11shufflenet_v2_x0_5 模型,并通过 features 属性获取了这些模型的中间层激活。我们通过计算大模型和小模型中间层激活的L2范数来动态地确定温度参数T。

结语

   蒸馏机制是一种有效的深度学习模型压缩技术,它可以将一个复杂的模型压缩成一个简单的模型,同时保持高精度。在实际应用中,蒸馏机制已经被广泛应用于各种场景,例如移动端深度学习模型的压缩、模型部署以及高效模型训练等方面。不同的蒸馏机制适用于不同的应用场景,需要根据具体的应用需求进行选择

*本文正在参加 人工智能创作者扶持计划 ” *

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

BEiT-3:Image as a Foreign Language - 多模态模型性能数据炸裂解读

2023-12-16 10:56:14

AI教程

世界杯首场大冷门!沙特2:1反超阿根廷背后的科技与狠活

2023-12-16 11:06:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索