3.1 torch.utils.data.Dataset
数据交互模块——Dataset
虽然说pytorch数据模块的核心是DataLoader,但是对于使用者而言,改动最多的、与源数据最接近的是Dataset, 本小节就详细分析Dataset的作用,并通过三个案例学习如何编写自定义Dataset来读取自己的数据集。
Dataset的功能
pytorch提供的torch.utils.data.Dataset类是一个抽象基类An abstract class representing a Dataset
,供用户继承,编写自己的dataset,实现对数据的读取。在Dataset类的编写中必须要实现的两个函数是__getitem__和__len__(由于markdown语法问题,后续双下划线就省略了)。
getitem:需要实现读取一个样本的功能。通常是传入索引(index,可以是序号或key),然后实现从磁盘中读取数据,并进行预处理(包括online的数据增强),然后返回一个样本的数据。数据可以是包括模型需要的输入、标签,也可以是其他元信息,例如图片的路径。getitem返回的数据会在dataloader中组装成一个batch。即,通常情况下是在dataloader中调用Dataset的getitem函数获取一个样本。
len:返回数据集的大小,数据集的大小也是个最要的信息,它在dataloader中也会用到。如果这个函数返回的是0,dataloader会报错:"ValueError: num_samples should be a positive integer value, but got num_samples=0"
这个报错相信大家经常会遇到,这通常是文件路径没写对,导致你的dataset找不到数据,数据个数为0。
了解Dataset类的概念,下面通过一幅示意图,来理解Dataset与DataLoader的关系。
dataset负责与磁盘打交道,将磁盘上的数据读取并预处理好,提供给DataLoader,而DataLoader只需要关心如何组装成批数据,以及如何采样。采样的体现是出现在传入getitem函数的索引,这里采样的规则可以通过sampler由用户自定义,可以方便地实现均衡采样、随机采样、有偏采样、渐进式采样等,这个留在DataLoader中会详细展开。
此处,先分析Dataset如何与磁盘构建联系。
从2.2节的例子中可以看到,我们为COVID19Dataset定义了一个_get_img_info函数,该函数就是用来建立磁盘关系的,在这个函数中收集并处理样本的路径信息、标签信息,存储到一个list中,供getitem函数使用。getitem函数只需要拿到序号,就可获得图片的路径信息、标签信息,接着进行图片预处理,最后返回一个样本信息。
希望大家体会_get_img_info函数的作用,对于各种不同的数据形式,都可以用这个模板实现Dataset的构建,只需将数据信息(路径、标签等)收集并存储至列表中,供__getitem__
函数使用”。
三个Dataset案例
相信大家在做自己的任务时,遇到的第一个问题就是,怎么把自己的数据放到github的模型上跑起来。很多朋友通常会把自己的数据调整为与现成项目数据一模一样的数据形式,然后执行相关代码。这样虽然快捷,但缺少灵活性。
为了让大家能掌握各类数据形式的读取,这里构建三个不同的数据形式进行编写Dataset。
第一个:2.2中的类型。数据的划分及标签在txt中。
第二个:数据的划分及标签在文件夹中体现
第三个:数据的划分及标签在csv中
详情请结合 配套代码,深刻理解_get_img_info及Dataset做了什么。
代码输出主要有两部分,
第一部分是两种dataset的getitem输出。
第二部分是结合DataLoader进行数据加载。
先看第一部分,输出的是 PIL对象及图像标签,这里可以进入getitem函数看到采用了
img = Image.open(path_img).convert('L')
对图片进行了读取,得到了PIL对象,由于transform为None,不对图像进行任何预处理,因此getitem函数返回的图像是PIL对象。
2 (
, 1) 2 ( , 1)
第二部分是结合DataLoader的使用,这种形式更贴近真实场景,在这里为Dataset设置了一些transform,有图像的缩放,ToTensor, normalize三个方法。因此,getitem返回的图像变为了张量的形式,并且在DataLoader中组装成了batchsize的形式。大家可以尝试修改缩放的大小来观察输出,也可以注释normalize来观察它们的作用。
0 torch.Size([2, 1, 4, 4]) tensor([[[[-0.0431, -0.1216, -0.0980, -0.1373], [-0.0667, -0.2000, -0.0824, -0.2392], [-0.1137, 0.0353, 0.1843, -0.2078], [ 0.0510, 0.3255, 0.3490, -0.0510]]],
[[[-0.3569, -0.2863, -0.3333, -0.4118], [ 0.0196, -0.3098, -0.2941, 0.1059], [-0.2392, -0.1294, 0.0510, -0.2314], [-0.1059, 0.4118, 0.4667, 0.0275]]]]) torch.Size([2]) tensor([1, 0])
关于transform的系列方法以及工作原理,将在本章后半部分讲解数据增强部分再详细展开。
小结
本小结介绍了torch.utils.data.Dataset类的结构及工作原理,并通过三个案例实践,加深大家对自行编写Dataset的认识,关于Dataset的编写,torchvision也有很多常用公开数据集的Dataset模板,建议大家学习,本章后半部分也会挑选几个Dataset进行分析。下一小节将介绍DataLoader类的使用。
补充学习建议
- IDE的debug: 下一小节的代码将采用debug模式进行逐步分析,建议大家提前熟悉pycharm等IDE的debug功能。
- python的迭代器:相信很多初学者对代码中的“next(iter(train_set))”不太了解,这里建议大家了解iter概念、next概念、迭代器概念、以及双下划线函数概念。