CLIP模型预训练和推理方法详解

释放双眼,带上耳机,听听看~!
本文详细介绍了OpenAI发表的《Learning Transferable Visual Models From Natural Language Supervision》论文中的CLIP模型预训练和推理方法,包括数据集构建、预训练方法以及对比学习和文本预测方法的性能比较。

前言

随着ChatGPT的横空出现,AIGC迅速成为当下最热门的技术领域,但是构建ChatGPT的底层相关研究——NLP、多模态、大模型已陆续发展多年。作为一位非算法的研发工程师,本着持续学习、保持进步的初心,也计划对相关论文进行系统的梳理、阅读,计划包含以下部分:

  • Transformer;
  • BERT;
  • GPT系列论文;
  • CLIP;
  • Diffusion;

本文是对CLIP论文的阅读笔记。

《Learning Transferable Visual Models From Natural Language Supervision》是OpenAI于2021年发表的一篇论文,其中发布了CLIP模型,该模型基于大规模的图片、自然语言文本(即论文标题中的Natural Language Supervision)对数据集进行预训练,挖掘图片和文本之间的相关性,所得的模型具有极强的迁移学习能力(即论文标题中的Transferable Visual Models),可直接用于各种图片分类任务,并相对于已有监督学习模型,取得较高的准确度。

介绍

论文首先提到,基于大规模数据集预训练大模型近几年来在NLP领域有着飞速发展,这些生成式模型在各种NLP任务上表现出了Zero-Shot的迁移能力。因此,作者也想把这一思路应用到计算机视觉领域。
首先,作者构造了一个新的计算机视觉领域大规模数据集,其中包含了4亿条图片和自然语言文本对。然后,作者基于这个数据集,采用对比学习,预训练了CLIP(Contrastive Language-Image Pre-training)模型,该模型包含了文本编码器和图片编码器两部分,将文本和图片映射到同一个向量空间中,而模型训练目标就是最大化正确图片和文本对之间的相关性。最后,在推理阶段,作者通过提示工程将各分类任务类别转化为文本,计算图片和文本之间的相关性,选择最相关文本所对应类别作为类别预测从而实现Zero-Shot。作者进行了大量实验,充分论证了CLIP在Zero-Shot上的有效性。

方法

数据集构建

作者并没有采用计算机视觉领域常用的标注数据集,比如ImageNet,而是从互联网上爬取图片和自然语言文本对。具体爬取流程:先从英文维基百科收集至少出现100次的单词作为Query数据集(共50万个),再对每个Query收集最多2万个图片、自然语言文本对,最终共收集了4亿个图片、自然语言文本对,该数据集量级远大于千万级的ImageNet。

预训练方法

CLIP模型预训练和推理方法详解

CLIP模型预训练和推理如图1所示。由于描述图片的自然语言文本多样,根据图片预测文本的方法较难,因此,论文采用对比学习进行预训练,即图1左侧所示,对于一批给定的NN个图片和文本,预测全组合后的N×NN times N个图片、文本对的相关性。通过对比学习,预训练图片编码器和文本编码器,将图片和文本转化为Embedding向量,最大化真实的NN个图片、文本对中图片和文本Embedding向量的余弦相似度,最小化这个批次中其他不正确的N2−NN^2-N个图片、文本对中图片和文本Embedding向量的余弦相似度。

CLIP模型预训练和推理方法详解

论文比较了对比学习和文本预测这两种方法的性能,如图2所示,其中横坐标表示训练处理的图片数量,纵坐标表示在ImageNet数据集以Zero-Shot方式进行推理的准确度,绿色和黄色曲线分别表示对比学习和文本预测随着处理图片数量的增加、准确度的提升,从中可以看出,当准确度相同时,文本预测图片数量是对比学习的4倍,也就是说对比学习的学习效率是文本预测的4倍。
图1左侧的预训练可以通过以下的伪代码描述:

# image_encoder - ResNet or Vision Transformer
# text_encoder 	- CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] 		- minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t 			- learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss   = (loss_i + loss_t)/2

伪代码中:

  • I[n, h, w, c],表示每批图片集合,4维矩阵,n为图片数,h为图片高度,w为图片宽度,c为图像通道数(如RGB三通道);
  • T[n, l],表示每批文本集合,2维矩阵,n为文本数,l为文本长度;
  • image_encoder,表示图片编码器,将I[n, h, w, c]编码输出得到I_f[n, d_i],2维矩阵,即将每张图片编码输出为d_i维的向量;
  • text_encoder,表示文本编码器,将T[n,l]编码输出得到T_f[n, d_t],2维矩阵,即将每个文本编码输出为d_t维的向量;
  • np.dot(I_f, W_i),表示通过I_f[n, d_i]与W_i[d_i, d_e]的点积运算将每个图片从d_i维的向量空间线性投影到d_e维的向量空间,再使用l2_normalize函数按行对每个图片的向量进行归一化,从而得到当前批次每个图片在多模态空间下的Embedding向量I_e[n, d_e];
  • np.dot(T_f, W_t),表示通过T_f[n, d_t]与W_t[d_t, d_e]的点积运算将每个文本从d_t维的向量空间线性投影到d_e维的向量空间,再使用l2_normalize函数按行对每个文本的向量进行归一化,从而得到当前批次每个文本在多模态空间下的Embedding向量T_e[n, d_e];
  • np.dot(I_e, T_e.T) * np.exp(t),表示通过I_e与T_e转置的点积运算得到当前批次图片和文本Embedding向量两两之间的余弦相似度矩阵logits[n, n],并通过ete^t这一可学习的参数控制相似度的大小;
  • 最后采用对称损失函数;
  • labels = np.arange(n),表示当前批次中相同序号的图片和文本成对,所以输出[0,1,…,n-1]的标量;
  • 伪代码最后三行使用logits和labels计算损失函数。损失函数使用的是对比学习中常用的InfoNCE损失函数。

对于上述损失函数,我的理解是先按列通过Softmax函数将相似度转化为概率,即对于一个文本jj,计算其与各个图片ii配对的概率:
pi,j=exp⁡(⟨Ii,Tj⟩)∑k=1Nexp⁡(⟨Ik,Tj⟩)p_{i,j}=frac{exp{(langle I_i,T_jrangle)}}{sum_{k=1}^{N}{exp(langle I_k,T_jrangle)}}
其中,⟨⟩langle rangle表示去取两个向量的余弦相似度,计算示意如图3所示。

CLIP模型预训练和推理方法详解

然后按列对于每个文本jj,取序号和其相等的图片jj的配对(即真实配对)概率pj,jp_{j,j},类似负对数似然损失函数,即最小化以下损失函数:

L1=−1N∑j=1Nlog⁡pj,jmathcal{L}_1=-frac{1}{N}sum_{j=1}^{N}{log{p_{j,j}}}

即伪代码中的loss_i = cross_entropy_loss(logits, labels, axis=0)。
再按行通过Softmax函数将相似度转化为概率,即对于一个图片ii,计算其与各个文本jj配对的概率:

pi,j=exp⁡(⟨Ij,Ti⟩)∑k=1Nexp⁡(⟨Ij,Tk⟩)p_{i,j}=frac{exp{(langle I_j,T_irangle)}}{sum_{k=1}^{N}{exp(langle I_j,T_krangle)}}

计算示意如图4所示。

CLIP模型预训练和推理方法详解

然后再按行对于每个图片ii,取序号和其相等的文本ii的配对(即真实配对)概率pi,ip_{i,i},同样类似负对数似然损失函数,最小化以下损失函数:

L2=−1N∑i=1Nlog⁡pi,imathcal{L}_2=-frac{1}{N}sum_{i=1}^{N}{log{p_{i,i}}}

即伪代码中的loss_t = cross_entropy_loss(logits, labels, axis=1)。最后将上述两个损失函数取均值得到最终的损失函数,即伪代码中的loss = (loss_i + loss_t)/2。

模型选择

对于图片编码器,论文使用了两种模型架构,分别是计算机视觉领域常用的深度残差网络(ResNet)和最近于2021年由Google提出的基于Transformer的Vision Transformer(ViT)。

CLIP模型预训练和推理方法详解

ViT模型结构如图5所示,其将图片按块拆分成多个patch,每个patch类似Transformer中的一个token,通过对patch及其位置进行Embedding计算转化为向量输入Transformer的Encoder。另外,在Transformer的Encoder的输入中,ViT会额外增加一个分类token,用于汇总所有patch的信息,将该token对应的最终Encoder的输出作为图片表征再输入到分类头中用于计算图片的分类。
对于文本编码器,论文使用了Transformer,其编码器层数为12,维度为512,多头注意力头数为8,模型参数量为6300万。对于文本,论文采用BPE(Byte Pair Encoding)编码,编码后的字典量级为49152。论文限制文本序列长度最大为76,并在文本序列的开头和结尾加上固定的token——“[SOS]”和“[EOS]”,并将“[EOS]”对应的最终Encoder的输出作为文本表征线性投影到多模态Embedding向量空间。

训练

对于图片编码器,论文分别训练了5个ResNet模型和3个ViT模型。
对于RestNet模型,论文训练了ResNet-50、ResNet-101和另外3个EfficientNet,其模型规模分别是ResNet-50的4倍、16倍和64倍,被分别称为RN50x4、RN50x16和RN50x64。
对于ViT模型,论文训练了ViT-B/32、ViT-B/16和ViT-L/14。
论文在这里介绍了训练的一些细节,比如所有模型训练32个epoch,采用Adam优化器等,不再一一列出。
最大的RestNet模型——RN50x64,在592个V100 GPU实例上需要训练18天,而最大的ViT模型——ViT-L/14,在256个V100 GPU实例上需要训练12天。
对于ViT-L/14,论文还使用更高分辨率(336×336)数据集额外进行了一个epoch的训练,并将训练所得的模型称为ViT-L/14@336px。由于ViT-L/14@336px的效果最好,因此论文指出在后面实验部分所指的CLIP模型均为ViT-L/14@336px。

实验

Zero-Shot

Zero-Shot,在计算机视觉领域一般表示模型能泛化、识别未见过的类别。论文在这里使用该术语表示模型识别未见过的数据集,即对于计算机视觉领域的各分类任务,不再使用任务本身的监督数据集训练模型,而是使用CLIP模型直接在分类任务的测试数据集上进行推理。

Zero-Shot实现

Zero-Shot的实现如图1右侧所示,使用分类任务中的所有类别名称按照模版构造文本,比如类别名称“plane”、“dog”,按照模版“A photo of a {label}.”构造文本“A photo of a plane.”、“A photo of a dog.”。将构造的多个文本输入文本编码器得到文本Embedding向量。将需预测分类的图片输入图片编码器得到图片Embedding向量。对各文本和图片Embedding向量计算余弦相似度,并通过Softmax函数转化为概率。概率最高的文本所包含的类别即对图片类别的预测。

和Visual N-Grams的对比

在CLIP之前,实现Zero-Shot的模型是Visual N-Grams。论文对比了CLIP和Visual N-Grams,如图6所示,CLIP在ImageNet及其他两个数据集的准确度上大幅提升,特别是在ImageNet上,准确度从11.5%提升至76.5%。

CLIP模型预训练和推理方法详解

论文解释了CLIP相比Visual N-Grams在准确度上提升明显的原因,一是训练数据集扩大了10倍,模型规模扩大了100倍,更多数据和更大模型带来模型性能提升,二是Transformer的引入。

提示构造和组合

论文在“Zero-Shot实现”部分介绍了通过类似“A photo of a {label}._”的模版构造文本(论文称之为提示),而不是直接采用类别名称作为文本。这么做的原因一是一词多义,如果文本编码器的输入只是类别名称会由于缺乏上下文导致含义不明确,比如Oxford-IIT中的“boxer”在该数据集下表示为一种宠物,但如果没有上下文,可能会被理解为拳击手。另外一个原因是,预训练数据集中图片文本极少是一个词,通过提示能够避免训练和预测时文本的差异,_从而提升Zero-Shot性能,论文指出,相比只使用类别名称,使用提示在ImageNet上的准确度能够提升1.3%。
另外,论文指出可以针对不同的任务设计相应的提示模版,例如对于宠物数据集Oxford-IIIT,可以使用以下模版:“A photo of a {label}, a type of pet.”,而对于卫星照片,可以使用以下模版:“a satellite photo of a {label}.”。
论文还提到可以同时使用多个提示模版,比如对于类别名称使用“A photo of a big {label}.”和“A photo of a small {label}”分别构造两个提示,提示通过文本编码器输出两个Embedding向量后,在向量空间进行合并,再和图片Embedding向量计算相似度。论文在ImageNet上同时使用了80个提示模版,准确度相比只使用一个模版提升3.5%。

Zero-Shot性能

为了分析CLIP Zero-Shot的性能,论文这里选择ResNet-50进行Linear Probe方式的监督学习作为基线模型,即将ResNet-50最后几层所学到的特征输入到一个分类器中,然后针对特定数据集进行微调(仅微调分类器的参数)。在27个任务的准确度上,Zero-Shot CLIP和Linear Probe RestNet-50的对比如图7所示,其中,绿色表示Zero-Shot CLIP相对Linear Probe RestNet-50,准确度提升及提升幅度,蓝色表示Zero-Shot CLIP相对Linear Probe RestNet-50,准确度下降及下降幅度。

CLIP模型预训练和推理方法详解

在16个任务(超过半数)的准确度上,Zero-Shot CLIP超过监督学习的ResNet-50。
论文对于各任务的准确度结果进行了分析,这里不一一列出,只介绍一下论文对CLIP准确度相对不高的任务的分析。CLIP在一些专业、复杂和抽象的任务上表现不佳,比如卫星图像识别(EuroSAT和RESISC45)、淋巴肿瘤探测(PatchCamelyon)、物体个数检测(CLEVRCounts)、交通信号识别(GTSRB)、车距识别(KITTI Distance)等。类似淋巴肿瘤探测(PatchCamelyon)这样的专业任务即使对于人类也需要一些先验知识,因此论文对CLIP进行了监督学习微调,即Few-Shot。具体方法是采用Linear Probe方式,对CLIP和其他模型(目前在ImageNet上表现最好的模型BiT-M、SimCLRv2和ResNet-50)增加分类器头并对分类器头进行监督学习微调,微调时选取了27个任务中的20个任务(这20个任务能保证每个类别至少包含16个训练样本),对每个类别选取一定数量的训练样本进行微调。结果如图8所示,其中横轴是每个类别选取的训练样本的数量,纵轴是模型在20个任务上的平均准确度,整体上CLIP在各训练样本数量下的准确度在所有模型中都是最高的。

CLIP模型预训练和推理方法详解

但从图8可以看出,对于CLIP,每个类别只使用1个或2个训练样本进行监督学习微调,准确度低于Zero-Shot,每个类别使用4个训练样本进行监督学习微调,准确度才和Zero-Shot持平。直觉上Few-Shot一般都要好于Zero-Shot,而对于上述在1个或2个训练样本上相反的结果,论文给出的解释是因为增加了分类器头,CLIP需要足够多训练样本的监督学习才能基于训练样本进行间接推理。
论文进一步实验,对于27个任务,每个任务在进行Few-Shot时,其每个类别需要多少个训练样本,准确度才能和Zero-Shot持平,结果如图9所示,其中,一半的数据集每个类别只需要5个以内的训练样本,中位数是5.4个训练样本,平均数是20.8个训练样本,而在ImageNet上,Zero-Shot和16-Shot的准确度持平。

CLIP模型预训练和推理方法详解

论文还实验,对于27个任务,如果使用所有训练样本进行Linear Probe方式的微调,大部分情况下,CLIP模型在准确度上Few-Shot优于Zero-Shot。
最后论文还使用ResNet-50、ResNet-101、RN50x4、RN50x16、RN50x64论证了模型越大,Zero-Shot效果越好,如图10所示。

CLIP模型预训练和推理方法详解

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

深度学习在飞机故障检测中的应用

2023-11-24 9:51:00

AI教程

使用PyTorch解决多分类问题:构建、训练和评估深度学习模型

2023-11-24 9:55:00

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索