PyTorch 介绍
PyTorch是由Feacbook开源,基于Torch二次开发的Python机器学习库,用于自然语言处理等应用程序。PyTorch既可以看作加入了GPU支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络。
PyTorch 环境搭建
1. 安装开发工具和python
anaconda下载,安装教程自行百度~~
2. 创建项目
-
选择 Conda 创建新环境
-
选择 python 版本
-
创建完成后,可以看到当前环境
3. 安装PyTorch环境
-
我们可以去PyTorch官方:https://pytorch.org/get-started/locally/ 进行选择,然后执行官方命令安装。
-
也可以通过pip直接安装PyTorch
# cpu用户 pip install torch==2.0.0 # gpu用户 # 只支持英伟达的显卡 # 需要提前安装 cuda ,我这里是11.7,所有下面指定cu117 # 如果已经安装了cuda不知道什么版本,可以通过 nvidia-smi 命令查看 pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
在终端执行
-
验证PyTorch是否安装成功
在python控制台执行
import torch
,没有报错说明安装成功 -
验证PyTorch是否支持GPU
在python控制台执行
torch.cuda.is_available()
,返回 True 说明可以运行在显卡上
4. 熟悉PyTorch都有什么
我们可以通过 dir
和 help
来熟悉
-
通过
dir
看torch都有哪些工具包 -
通过
help
看torch.cuda.is_available的用法和说明
PyTorch 学习
一、学习 Dataset
Dataset是什么东西,有什么作用
我们将使用此 数据集 进行演示(也可以用自己的~~)
-
由于是图片数据集,我们需要提前安装
opencv
pip install opencv-python==4.6.0.66
-
实现我们Dataset
需要重写三个方法
__init__
类被创建会执行(类似Java构造器),我们需要在这里初始化数据__getitem__
根据索引查询,返回 label 和 image__len__
返回 数据集长度from torch.utils.data import Dataset from torch.utils.data.dataset import T_co import os import cv2 as cv # 读取我们 label 文件的第一行内容 def read_label(path): file = open(path, "r", encoding='utf-8') label = file.readline() file.close() return label class MyDataset(Dataset): def __init__(self, train_path): # 给对象赋值,让其他方法也可以获取到 self.train_path = train_path self.image_path = os.path.join(train_path, 'image') self.label_path = os.path.join(train_path, 'label') # 拿到文件夹下所有图片名 self.image_path_list = os.listdir(self.image_path) def __getitem__(self, index) -> T_co: # 读取图片 image_name = self.image_path_list[index] image_path = os.path.join(self.image_path, image_name) img = cv.imread(image_path) # 读取图片对应的label label_name = 'txt'.join(image_name.rsplit(image_name.split('.')[-1], 1)) label_path = os.path.join(self.label_path, label_name) label = read_label(label_path) return img, label def __len__(self): # 返回数据集长度 return len(self.image_path_list) # 测试 创建 MyDataset 对象 my_dataset = MyDataset("dataset/train") # 拿到下标100的 image 和 label data_index = 100 img, label = my_dataset[data_index] # 展示出来 我们这里用到了 __len__ cv.imshow(label + ' (' + str(data_index) + '/' + str(len(my_dataset)) + ')', img) cv.waitKey(0)
效果如下:
二、下期预告
1. TensorBoard 的使用
2. Transforms 的使用
会尽快更新~~~