3.5 torchvision 经典dataset学习
前面已经学习了Dataset,DataLoader,以及常用的函数,通常足以应对大多数需求,但距离熟练编写自己的Dataset可能还有一段距离。 为了让大家能轻松掌握各种情况下的dataset编写,本小节对torchvision中提供的几个常见dataset进行分析,观察它们的代码共性,总结编写dataset的经验。
X-MNIST
由于MNIST数据使用广泛,在多领域均可基于这个小数据集进行初步的研发与验证,因此基于MNIST数据格式的各类X-MNIST数据层出不穷,在mnist.py文件中也提供了多个X-MNIST的编写,这里需要大家体会类继承。
示例表明FashionMNIST、KMNIST两个dataset仅需要修改数据url(mirrors、resources)和类别名称(classes),其余的函数均可复用MNIST中写好的功能,这一点体现了面向对象编程的优点。
来看dataset的 getitem,十分简洁,因为已经把图片和标签处理好,存在self.data和self.targets中使用了:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], int(self.targets[index])
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
代码参阅:D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\mnist.py
cifar-10
cifar-10是除MNIST之外使用最多的公开数据集,同样,让我们直接关注其 Dataset
实现的关键部分
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
核心代码还是这一行: img, target = self.data[index], self.targets[index]
接下来,去分析data和self.targets是如何从磁盘上获取的?通过代码搜索可以看到它们来自这里(D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\cifar.py CIFAR10 类的 init函数):
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
这一段的作用于MNIST的_load_data(), 我们的_get_img_info()一样,就是读取数据信息。
总结:
getitem函数中十分简洁,逻辑简单
初始化时需完成数据信息的采集,存储到变量中,供getitem使用
VOC
之前讨论的数据集主要用于教学目的,比较复杂的目标检测数据是否具有较高的编写难度?答案是,一点也不,仍旧可以用我们分析出来的逻辑进行编写。
下面来看第一个大规模应用的目标检测数据集——PASCAL VOC,
D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\voc.py的
VOCDetection类的getitem函数:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
更简洁了,与我们的案例中的getitem一样一样的,那么images和annotations从哪里来?相信大家已经知道答案了,那就是初始化的时候根据数据格式、数据组织结构,从磁盘中读取。
COCO
说到目标检测就不得不提COCO数据集,COCO数据集是微软提出的大规模视觉数据集,主要用于目标检测,它从数据量、类别量都远超VOC,对于深度学习模型的落地应用起到了推动作用。
对于CV那么重要的COCO,它的dataset难吗?答案是,不难。反而更简单了,整个类仅40多行。
getitem函数连注释都显得是多余的:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
id = self.ids[index]
image = self._load_image(id)
target = self._load_target(id)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
其实,这一切得益于COCO的应用过于广泛,因此有了针对COCO数据集的轮子——pycocotools,它非常好用,建议使用COCO数据集的话,一定要花几天时间熟悉pycocotools。pycocotools把getitem需要的东西都准备好了,因此这个类只需要40多行代码。
小结
本章从数据模块中两个核心——Dataset&Dataloader出发,剖析pytorch是如何从硬盘中读取数据、组装数据和处理数据的。在数据处理流程中深入介绍数据预处理、数据增强模块transforms,并通过notebook的形式展示了常用的transforms方法使用,最后归纳总结torchvision中常见的dataset,为大家将来应对五花八门的任务时都能写出dataset代码。 下一章将介绍模型模块。