3.5 torchvision 经典dataset学习

前面已经学习了Dataset,DataLoader,以及常用的函数,通常足以应对大多数需求,但距离熟练编写自己的Dataset可能还有一段距离。 为了让大家能轻松掌握各种情况下的dataset编写,本小节对torchvision中提供的几个常见dataset进行分析,观察它们的代码共性,总结编写dataset的经验。

X-MNIST

由于MNIST数据使用广泛,在多领域均可基于这个小数据集进行初步的研发与验证,因此基于MNIST数据格式的各类X-MNIST数据层出不穷,在mnist.py文件中也提供了多个X-MNIST的编写,这里需要大家体会类继承

示例表明FashionMNIST、KMNIST两个dataset仅需要修改数据url(mirrors、resources)和类别名称(classes),其余的函数均可复用MNIST中写好的功能,这一点体现了面向对象编程的优点。

来看dataset的 getitem,十分简洁,因为已经把图片和标签处理好,存在self.data和self.targets中使用了:

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        img, target = self.data[index], int(self.targets[index])
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

代码参阅:D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\mnist.py

cifar-10

cifar-10是除MNIST之外使用最多的公开数据集,同样,让我们直接关注其 Dataset 实现的关键部分

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

核心代码还是这一行: img, target = self.data[index], self.targets[index]

接下来,去分析data和self.targets是如何从磁盘上获取的?通过代码搜索可以看到它们来自这里(D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\cifar.py CIFAR10 类的 init函数):

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

这一段的作用于MNIST的_load_data(), 我们的_get_img_info()一样,就是读取数据信息。

总结:

  • getitem函数中十分简洁,逻辑简单

  • 初始化时需完成数据信息的采集,存储到变量中,供getitem使用

VOC

之前讨论的数据集主要用于教学目的,比较复杂的目标检测数据是否具有较高的编写难度?答案是,一点也不,仍旧可以用我们分析出来的逻辑进行编写。

下面来看第一个大规模应用的目标检测数据集——PASCAL VOC

D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\voc.py的

VOCDetection类的getitem函数:

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert("RGB")
        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

更简洁了,与我们的案例中的getitem一样一样的,那么images和annotations从哪里来?相信大家已经知道答案了,那就是初始化的时候根据数据格式、数据组织结构,从磁盘中读取。

COCO

说到目标检测就不得不提COCO数据集,COCO数据集是微软提出的大规模视觉数据集,主要用于目标检测,它从数据量、类别量都远超VOC,对于深度学习模型的落地应用起到了推动作用。

对于CV那么重要的COCO,它的dataset难吗?答案是,不难。反而更简单了,整个类仅40多行。

getitem函数连注释都显得是多余的:

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

其实,这一切得益于COCO的应用过于广泛,因此有了针对COCO数据集的轮子——pycocotools,它非常好用,建议使用COCO数据集的话,一定要花几天时间熟悉pycocotools。pycocotools把getitem需要的东西都准备好了,因此这个类只需要40多行代码。

小结

本章从数据模块中两个核心——Dataset&Dataloader出发,剖析pytorch是如何从硬盘中读取数据、组装数据和处理数据的。在数据处理流程中深入介绍数据预处理、数据增强模块transforms,并通过notebook的形式展示了常用的transforms方法使用,最后归纳总结torchvision中常见的dataset,为大家将来应对五花八门的任务时都能写出dataset代码。 下一章将介绍模型模块。

Copyright © TingsongYu 2021 all right reserved,powered by Gitbook文件修订时间: 2024年04月26日21:48:10

results matching ""

    No results matching ""