携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第28天,点击查看活动详情
前言
在之前的文章中,我们介绍了如何去自定义去完成关于ResNet这样的网络结构,VGGNet这样的网络结构,MobileNet这样的网络结构,以及Inception这样不同的四大类结构。实际上,在Pytorch中提供了非常多的已经定义好的模型,这些模型也是目前来说比较标准的网络结构,我们经常会利用这些标准的网络结构去作为我们的预训练的模型,这样就可以节省很多的工作,就不需要自己去自定义模型结构。
今天,我们通过调用Pytorch提供的标准网络ResNet18来完成Cifar10模型的训练。
-
1.1 调用Pytorch提供的标准网络
相比于之前自定义的网络结构,使用Pytorch提供的标准网络的代码量是比较少的,如果不需要对网络结构进行自己定义或者进行模型压缩、裁剪等操作的时候,推荐大家使用Pytorch提供的标准网络结构。
import torch.nn as nn
from torchvision import models
class resnet18(nn.Module):
def __init__(self):
super(resnet18, self).__init__()
self.model = models.resnet18(pretrained=True)
# 这里主要用来解决cifar10,需要修改类别数
self.num_features = self.model.fc.in_features
self.model.fc = nn.Linear(self.num_features, 10)
def forward(self, x):
out = self.model(x)
return out
def pytorch_resnet18():
return resnet18()
注:cifar10数据训练的代码参考我之前的文章Pytorch——Cifar10图像分类中的训练模型的代码,只需要修改一下net即可。
在进行cifar10的数据训练,可以看到在第一个epoch之后,准确率到了26%,并且整个网络是处于收敛过程中的,如果需要使用其它的网络结构的时候,也可以利用这个模板来调用其他的模型。