深度学习分类模型部署到移动端的方法

释放双眼,带上耳机,听听看~!
本文介绍了如何将深度学习分类模型部署到移动端,包括模型转换、目标检测和使用pytorch等内容。

本文 已参与「新人创作礼」活动,一起开启掘金创作之路。

想必大家做深度学习分类的时候经常会遇到一个疑问,分类的模型最后需要怎么用,今天就来简单的介绍下这个问题

1、先预览

深度学习分类模型部署到移动端的方法
其实也就是分析每一帧照片,来判断最大可能属于哪个类别,跟目标检测很像,目标检测判断一张照片里有哪几种物品,而分类仅仅是判断这个照片属于哪个类别

2、官网地址

深度学习分类模型部署到移动端的方法

深度学习分类模型部署到移动端的方法
我们可以从官网的这里找到pytroch部署到移动端的例子,当然,现在这个还不是很成熟,感觉自己训练的数据集和网络,可能会多少有点问题,建议使用官方提供的网络模型来训练

3、流程

深度学习分类模型部署到移动端的方法

4、加载库

implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

5、代码分析

1、加载图片

bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

2、加载模型


module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));

3、准备Tensor

final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

4、运行model

final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

5、得到分数集合

final float[] scores = outputTensor.getDataAsFloatArray();

6、获取物体标签

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
  if (scores[i] > maxScore) {
    maxScore = scores[i];
    maxScoreIdx = i;
  }
}

String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

7、注意事项

这里的Demo加载的是需要pt模型数据,而我们Pytroch模型得到的数据都是pth的,如果直接使用,肯定是不行的,因此需要先转换后再使用,转换的方法如下所示

import torch
import torch.utils.data.distributed

# pytorch环境中
from models.base_model import BaseModel

model_pth = 'ghostnet.pth' #模型的参数文件
mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = BaseModel(name='resnet', num_classes=5).to(device)

model.load_state_dict(torch.load('ghostnet.pth', map_location=device), strict=False)

model.eval() # 模型设为评估模式

# 1张3通道224*224的图片
input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式

mobile = torch.jit.trace(model, input_tensor) # 模型转化
mobile.save(mobile_pt) # 保存文件

以上的例子,在官网中均有介绍,建议读者去官网下载好代码例子后跑通,在使用自己的数据集,目前感觉这个方法还是不太成熟,如果读者想做目标检测,比如YOLO的相关算法部署,小编在这里建议读者使用腾讯的NCNN这个推理框架,还是很好用的,我屡试不爽。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

Stable Diffusion v1-5模型权重下载及使用教程

2023-12-14 11:14:14

AI教程

深度学习模型中的不变风险最小化陷阱

2023-12-14 11:25:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索