神奇的知识蒸馏:代码实战与温度对softmax的影响

释放双眼,带上耳机,听听看~!
本文介绍了知识蒸馏的实战代码和不同温度对softmax的影响,分享了两个好用的知识蒸馏代码库,以及通过代码实现不同温度下softmax的可视化。

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

   
上期给大家介绍了知识蒸馏的核心原理,时间有点久不知道大家还记不记得。一句话概括就是——将教师模型的知识通过soft targets传递给轻量化的学生模型,从而提升学生模型性能,减少计算需求。还没看过或者忘记的小伙伴赶紧来看看叭➡从教师到学生:神奇的“知识蒸馏”之旅——原理详解篇。明白了原理之后,咱们今天就来实战一下,看看教师模型、学生模型是怎样用代码构建的,学习如何用知识蒸馏来提高学生模型的性能,让大家对蒸馏有一个更加直观的感受。除此之外,上期还有一些额外的知识蒸馏知识点没讲完,这次也一口气介绍完叭。废话不多说啦,如果你也对此感兴趣,想动手实现知识蒸馏看看效果,让我们一起愉快的学习叭🎈🎈🎈

神奇的知识蒸馏:代码实战与温度对softmax的影响

1.知识蒸馏代码实战

   
在介绍代码之前呢,给大家分享两个好用的知识蒸馏代码库:

   
第一个开源库包括剪枝、蒸馏、神经架构搜索和量化;第二个是大神发表的RepDistiller,里面有12种用pytorch实现的流行知识蒸馏算法。都对知识蒸馏的学习很有帮助滴🌈🌈🌈

1.1不同温度下softmax可视化

   
通过上期的学习,咱们知道了蒸馏温度T越高,soft targets就越soft,所以温度是至关重要滴,那首先咱们就来学着画一下不同温度对于softmax的影响叭

  1. 导入工具包:🎈🎈🎈
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

%config InlineBackend.figure_format = 'retina'
  1. 输入各类别的logits:🎈🎈🎈
logits = np.array([-5,2,7,9])

   
4个类别的logit,你可以理解成是神经网络最后一层的线性分类层输出的4个类别的logit,它们有正有负有大有小。

  1. 普通softmax( T=1 ):🎈🎈🎈
    softmax_1=np.exp(logits) / sum(np.exp(logits))
    softmax_1
    plt.plot(softmax_1,label='softmax_1')
    plt.legend()
    plt.show()

   
普通的softmax是蒸馏温度等于1,softmax_1=np.exp(logits) / sum(np.exp(logits))代表着把e−5+e2+e7+e9e^{-5}+e^2+e^7+e^9作为分母,e−5、e2、e7、e9e^{-5}、e^2、e^7、e^9分别作为分子算出来的各个数值,其代表了每一个softmax的后验概率。此时画出的图如下

神奇的知识蒸馏:代码实战与温度对softmax的影响

  1. 知识蒸馏softmax( T=3 ):🎈🎈🎈
plt.plot(softmax_1,label='T=1')

T=3
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3,label='T=3')

T=5
softmax_5 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_5,label='T=5')

T=10
softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_10,label='T=10')

T=100
softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_100,label='T=100')

plt.xticks(np.arange(4), ['Cat', 'Dog','Donkey','Horse'])
plt.legend()
plt.show()

   
分别尝试温度T=3,T=5,T=10,T=100,画出它们的图如下所示,可以发现T越大,soft targets越soft,贫富差距就越小;T越小,两极分化就越大。所以T的选取很重要,若是过小的话就和没有蒸馏是一样的,过大又会陷入平均主义。

神奇的知识蒸馏:代码实战与温度对softmax的影响

1.2载入数据集

   
下面就以MNIST数据集为例,利用pytorch从头训练教师网络、从头训练学生网络,并用知识蒸馏训练学生网络比较性能

  1. 导入工具包:🎈🎈🎈
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transform
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
#设置随机种子,便于复现
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True
  1. 载入MNIST数据集:🎈🎈🎈
#载入数据集
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

#载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

#生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

   
先导入工具包,一般的代码是要放在云gpu上的(最好是);然后载入训练集和测试集,生成训练集的DataLoader和测试集的DataLoader。

1.3构建并训练教师模型

构建教师模型:🎈🎈🎈

class TeacherModel(nn.Module):
    def __init__(self, in_channels=1,num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,num_classes)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

   
构造一个教师网络,这个教师网络有三层隐含层,每一层都加了dropout,防止过拟合。第一层是把输入的MNIST中784个像素映射到1200个神经元,第二层是把1200个神经元映射成1200个神经元,第三层是把1200个神经元映射成10个类别。

从头训练教师模型:🎈🎈🎈

model = TeacherModel()
model = model.to(device)
summary(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 6
for epoch in range(epochs):
    model.train()
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}t Accuracy:{:.4f}'.format(epoch+1, acc))
teacher_model = model

   
首先指定一个交叉熵分类损失函数CrossEntropyLoss,指定优化器和学习率,开始训练6轮,每一次训练都是先前向再反向,每一轮之后再在测试集上评估模型的性能。运行之后看到如下所示的结果,准确率为0.9762(PS:其实这些代码都是很简单的基础知识,在前面的学习中也详解讲过啦,这里就不再细说辽)🌞🌞🌞

神奇的知识蒸馏:代码实战与温度对softmax的影响

1.4构建并训练学生模型

构建学生模型:🎈🎈🎈

class StudentModel(nn.Module):
    def __init__(self, in_channels=1,num_classes=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)
        
    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

   
构建的学生模型就要小得多啦,它的每一层只有20个神经元,构建方法和上面的一样。

从头训练学生模型:🎈🎈🎈

model = StudentModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
    model.train()
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}t Accuracy:{:.4f}'.format(epoch+1, acc))
student_model_scratch = model

   
从头训练学生模型和上面训练教师模型也是一样的,最后运行得到的结果如下,准确率只有0.8986,所以我们要用知识蒸馏来训练学生模型,提高它的性能。

神奇的知识蒸馏:代码实战与温度对softmax的影响

1.5知识蒸馏训练学生模型

知识蒸馏训练学生模型:🎈🎈🎈

#准备预训练好的教师模型
teacher_model.eval()

#准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()

#蒸馏温度
temp = 7
#hard_loss
hard_loss = nn.CrossEntropyLoss()
#hard_loss 权重
alpha = 0.3

# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.paramaters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)
        
        #学生模型预测
        student_preds = model(data)
        #计算hard_loss
        student_loss = hard_loss(student_preds, targets)
        
        #计算蒸馏后的预测结果及soft_loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        #将hard_loss和soft_loss加权求和
        loss = alpha * student_loss + (1-alpha) * ditillation_loss
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}t Accuracy:{:.4f}'.format(epoch+1, acc))

   
蒸馏温度选为7,hard_loss是一个普通的分类交叉熵损失函数,而soft_loss是一个KL散度(差不多也是交叉熵损失函数)。训练时也是先前向后反向,前向时先获取教师网络的预测结果,对教师网络预测结果进行蒸馏和softmax,然后把学生网络温度为temp和教师网络温度为temp时分别算出来softmax,一起作为soft_loss,算出来一个总的损失函数loss = alpha * student_loss + (1-alpha) * ditillation_loss。其余反向传播、评估性能等步骤和刚刚是一样的。运行得到的结果如下,可以看到准确率相比没蒸馏前有了提升(虽然提升不大,但这只是一个小demo,具体的还要进行调参优化)🌟🌟🌟

神奇的知识蒸馏:代码实战与温度对softmax的影响

   
其实,我们并不能用最后的分数来衡量它知识蒸馏是好是坏,因为知识蒸馏并不只是能涨点,并不只是能压缩模型提高性能,它还有很多潜在的好处,比如我们可以用海量的无监督的大数据集,可以防止过拟合,可以实现知识从大模型到小模型的迁移,这才是关于知识蒸馏我们要把握的点✨✨✨

2.知识蒸馏的补充知识

知识蒸馏为什么work???color{blue}{知识蒸馏为什么work???}

   
学完了知识蒸馏的原理与代码实现之后,小伙伴们有没有认真想一想知识蒸馏为什么会有用呢🧐🧐🧐比较让人信服的一个机理解释是:

神奇的知识蒸馏:代码实战与温度对softmax的影响

   
如上图,绿色是教师网络的求解空间,因为教师网络比较大嘛,所以它的表达能力和拟合能力比较强;学生网络是比较小的蓝色区域,它的表达能力比较差,求解空间比较小。训练教师网络之后,假如教师网络收敛到了红圈里面,如果我们单独训练学生网络(不蒸馏,直接用原来的数据集和标签),那么学生网络会收敛到黄色区域。毫无疑问,此时的学生网络和教师网络是有一定距离的,如果单纯的用hard label来训练学生网络,它是没法达到教师网络的水平滴;但我们加上知识蒸馏(橙色区域)之后,教师网络就会引导这个黄圈,告诉它怎么去收敛,那么它最终会收敛到这个橙圈里,而橙圈是教师网络的一个子集,它离原生的学生网络的收敛空间更接近,离教师网络越近,效果就越好😁😁😁

知识蒸馏与迁移学习???color{blue}{知识蒸馏与迁移学习???}

   
学完知识蒸馏后,小伙伴们有没有感觉和迁移学习很像,毕竟知识蒸馏是从教师模型迁移到学生模型上的,那它俩到底是一个什么关系腻🧐🧐🧐其实,知识蒸馏和迁移学习是没关系滴,它俩的概念是正交的,迁移学习指的是把一个领域训练的模型,让其泛化到另一个领域,比如说用X胸片的数据集去训练一个原本识别猫狗的模型,然后猫狗模型就慢慢学会去分辨x光胸片的各种病,这种把猫狗域迁移到了医疗域属于迁移学习(侧重于领域的迁移);而知识蒸馏是把一个模型的知识迁移到另一个模型上,通常是大模型迁移到小模型(侧重于模型的迁移)。所以这俩是可以交叉的,可以用知识蒸馏实现迁移学习……也可以完全没有任何关系

ending

   
看到这里相信盆友们都对如何用代码实现知识蒸馏有了一个全面深入的了解啦,小伙伴们学废了没呀👀很开心能把学到的知识以文章的形式分享给大家🌴🌴🌴如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

神奇的知识蒸馏:代码实战与温度对softmax的影响

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

4个强大的生成AI Python库,助您提升人工智能应用

2023-11-20 9:34:14

AI教程

文档块索引与嵌入技术

2023-11-20 9:47:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索