PyTorch中的 Dataset、DataLoader 和 enumerate()
PyTorch:关于Dataset,DataLoader 和 enumerate()
本博文主要参考了 Pytorch中DataLoader的使用方法详解 和 pytorch:关于enumerate,Dataset和Dataloader 两篇文章进行总结和归纳。
DataLoader 隶属 PyTorch 中 torch.utils.data 下的一个类,任何继承 torch.utils.data.Data 类的子类均需要重载__getitem__()及__len__()两个函数,且子类在__init__()函数产生的数据路径,将作为 DataLoader 参数 DataSets 的实参。该类将自定义的 Dataset 根据 batch size 大小、是否 shuffle 等封装成一个 Batch Size 大小的 Tensor,用于后面的训练。
Dataset 类构建
在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。这里的 Dateset 可以指整个数据集,也可以是训练集,测试集等。
class Dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]
正常情况下,该数据集是要继承 Pytorch 中 Dataset 类的,但实际操作中,即使不继承,数据集类构建后仍可以用 Dataloader() 加载的。
在dataset类中,len(self)返回数据集中数据的总个数,getitem(self,item)表示每次返回第 item 条(个)数据。
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个 item
③__getitem__:返回一条(个)训练样本的数据,并将其转换成 tensor
在 dataset 实例化时一般要传入数据集的路径,一般在__init__() 函数中指定数据集路径等相关信息(可以通过相关路径读取包含图像名称、标签等相关信息的 json 或者 csv 等类型的文件);通过__getitem__(self,item) 得到对应的图像并将进行 transform 转换(缩放、裁剪、转换成 tensor 等操作),最终以 tensor 的形式返回。
DataLoader 使用
在构建 Dataset 类后,即可使用 DataLoader 加载。DataLoader 中常用参数如下:
- dataset:需要载入的数据集,如前面构造的 dataset 类。
- batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个 batch 进行训练。
- shuffle:是否在打乱数据集样本顺序。True 为打乱,False 反之。
- num_workers:这个参数决定了有几个进程来处理 data loading。0 意味着所有的数据都会被 load 进主进程。(默认num_workers=0,在 Windows 系统下需要设置为 0)
- drop_last:是否舍去最后一个batch的数据(很多情况下数据总数 N 与 batch size 不整除,导致最后一个 batch 不为 batch size)。True 为舍去,False 反之。
注意:使用 DataLoader 读取数据时,为了加快效率,所以使用了多个线程,即 num_workers 不为0,在 windows 系统下报如下的错误。
RuntimeError: Couldn’t open shared file mapping: <torch_16716_3565374679>, error code: <1455>
DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
参照 DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support() 教程中提到,在 https://github.com/pytorch/pytorch/pull/5585 中给出了一些官方解释,应该是 Windows下的一些线程文件读写的问题。
在 Windows 上,FileMapping 对象应必须在所有相关进程都关闭后,才能释放。启用多线程处理时,子进程将创建 FileMapping,然后主进程将打开它。 之后当子进程将尝试释放它的时候,因为父进程还在引用,所以它的引用计数不为零,无法释放。 但是当前代码没有提供在可能的情况下再次关闭它的机会。这个版本官方说 num_workers=1 是可以用的,更多的线程还在解决,不过现在即便是用 2 个子进程也已经可以了。
加载数据的过程
pytorch 中加载数据的顺序是:
- 创建一个 dataset 对象
- 创建一个 dataloader 对象
- 循环 dataloader 对象,将 data, label 拿到模型中去训练
enumerate() 函数
在对 Dataloader 进行读取时,通常使用 enumerate() 函数,enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。调用 enumerate(dataloader) 时每次都会读出一个 batch_size 大小的数据。例如,数据集中总共包含 245 张图像,train_loader = dataloader(dataset, batch_size=32, drop_last=True) 被实例化时,经过以下代码后输出的 count 为 224(正好等于32*7),而多出来的 245-224=21 张图像不够一个 batch 因此被 drop 掉了。下面展示了如何从 dataloader 中通过 enumerate() 返回一个batch_size的数据。
for k, images, target in enumerate(dataloader):
其中,k代表下标值,images, target 代表可遍历的数据对象。因为 enumerate(dataloader) 一次会返回一个 batch 的数据,所以返回的 images 为 batch_size 长度的list,target 也为 batch_size 长度的 list。
通常,dataloader 里包含很多个数据对象,那么我们应该怎么保证 batch 就是我们所需要的数据呢?通过 Dataset 的定义可以实现我们需要的数据。Dataset 是用来定义数据从哪里读取,以及如何读取的问题,通过重写 Dataset 抽象类的__getitem__()函数。enumerate(dataloader) 得到的数据就是 __getitem__() 函数返回的数据,只不过 enumerate(dataloader) 一次会得到 batch_size 个不同 item 的数据组成的 list。
def __getitem__(self, item):
images = self.data[item]
target = self.label[item]
return images, target
返回 item 对应的数据,就是 enumerate(dataloader) 得到的数据的一部分。
def __len__(self):
return len(self.data)
返回 dataset 中总的数据个数,用于控制返回多少个 batch 的数据,enumerate(dataloader) 一次会返回 batch_size 大小的 list。
Reference
Pytorch中DataLoader的使用方法详解
pytorch:关于enumerate,Dataset和Dataloader
DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support()