两个引子:
最近学习RNN(循环神经网络),在进行Pytorch的代码实现时,发现文本数据的读取与图像数据的读取有较大的区别,通过上网查阅资料与文档并简要阅读原码,对Pytorch中Dataset、DataLoader、Sampler三个类进行了浅要分析。
另外,如何组织数据是一个很重要的问题,在SGD(随机梯度下降)的过程中,batch的大小对训练的速率有较大的影响,因此如何对数据集进行采样,也是一个重要的问题。
三者间的关系
废话不多说,先用一张图来解释他们间的关系
Pytorch官网文档对于这三个类相关介绍的第一句话就是:
At the heart of PyTorch data loading utility is the PyTorch 源码解读之 torch.utils.data:解析数据处理全流程 – 知乎 (zhihu.com)
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 – 知乎 (zhihu.com) torch.utils.data — PyTorch 1.13 documentation