PyTorch DataSet 和 Dataloader 加载步骤

机器学习与生成对抗网络

共 7617字,需浏览 16分钟

 ·

2022-01-01 09:28


 戳我,查看GAN的系列专辑~!

等你着陆!【GAN生成对抗网络】知识星球!

来源:知乎—端庄的汤汤  侵删

地址:https://zhuanlan.zhihu.com/p/381224748
在训练模型读取数据的时候一开始总是搞不清楚这两个接口的原理,这次就好好的整理整理,把整个过程一步一步梳理清楚,细节拉满。(内容来自深度之眼和自己的debug)
开局两张图,内容全靠编!!!

from torch.utils.data import Dataset

class MyDataset__(Dataset): def __init(self, *args, **kargs): pass def __len__(self): pass def __getitem__(self, idx):        pass
在写自己的dataset的时候这三个魔术方法是必须要写的,其余方法看自己需求。Dataset写法多样,此处就不表了。我在下面这个链接里写了要用到的魔术方法的解释。
https://zhuanlan.zhihu.com/p/381229502
实例化对象md = MyDataset(...)之后,len(md)会返回实例化dataset对象md的数据长度,md[idx]就可以索引该对象的数据,这就是两个魔术方法的非常方便的应用,和python的内置函数就联系起来了。强调:是因为魔术方法的改写,自己建立的dataset才能使用len()函数还有数组索引。示例代码结果如下。
DataLoader

这个类有那么多参数,左边这几个是常用的。dataset=train_data,来自上边黄色代码图片。num_workers代表多进程读取数据,windows下设置为0,因为pytorch多进程是fork,windows不是这种方式,所以不能使用多进程,Linux可以,一般设置为4或8都见过。shuffle在训练的时候为True,提高模型泛化性,验证或者推理一般就False,没必要多一步操作。
其他的参数在官方dataloader.py里面都有介绍,就不多说了,太抽象也不好理解。现在就一步一步debug下去,碰到什么就搞清楚什么,一路趟平过去。下图看到初始状态,还有参数,batchsize = 16。
创建Dataloader对象进来之后在def __init__(self,...)初始化方法中,有这么一个判断,传进来的dataset是不是迭代器,如果是迭代器不支持目前pytorch的采样器sampler等。目前传进来的dataset都不是迭代器。后面还有其他的判断,不影响主线,就不表了,感兴趣的朋友可以自己看看。初始化方法中后面的代码也不多,我就搬上来看一下逻辑。
首先这个第一行else就是接上图的如果dataset不是迭代器,if isinstance(...):。正常的我们传入的dataset,设置这么一个实例化属性,后面的那个变量_DatasetKind.Map==0。
然后我们看一下这个类,这个类也在dataloader.py中,32行开始的。可以看到这个类有两个参数Map = 0,当这个参数等于0的时候,会采用_MapDatasetFetcher的采样方法,进去看一下这个类到底是个什么玩意儿。
可以看到这是一个在fetch.py中的类,这里面有一个fetch方法,定义了收集数据的方法,这个一会儿估计会用到,先放到这里,标记为①。
现在我们继续回到dataloader.py中,刚才到216行代码,我把初始化方法中下面的内容都粘贴过来,不多,看看逻辑是什么样的。
到这个地方247行,就是初始化方法的全部内容了。从上面197~216行代码中看一下,如果自己指定了sampler即采样器,那么shuffle=True就不行了,如果指定了batchsampler,也是有一些参数会冲突,先往后看,一会儿再看看这个采样器是什么,batch_size不会不设置的,就不看了,也是异常处理语句。
来到上面这两张图片,dataloader.py里面的217~247,还是初始化方法里面的(def __init__(self,...))。第218行代码,如果sampler==None,这个条件一般是成立的,很少自己指定采样器,训练的时候shuffle==True,就是if shuffle:了,此时采用随机采样器,验证或者推理的时候shuffle==False,就采用顺序采样器,这个晚会儿再去看。
228行代码,指定了batchsize并且batch_sampler没有指定的情况下(这就是通常的情况),定义了一个batchsampler的采样器类,看参数传入的就是上面没指定sampler时,系统传入的或着随机采样器或者顺序采样器。下面都是定义实例属性了,把刚才的判断都用属性引用一下。不慌,接着看下一步的debug。
上面有三张图片,按照我的代码debug,shuffle=True直接进入随机采样器的,但是顺序采样器就在随机采样器的上面,就一起看看。debug就会进入sampler.py文件,先看一下第一张图SequentialSampler类。这个类很简单就是三个内置方法。首先初始化会传入一个dataset进去也就是58行的datasource。
dataset本质是什么呢?dataset是我们自己创建的MyDataset类的实例对象,系统给他分配了一块儿内存,这块儿内存里有实例属性,实例方法啥的。举个例子,实例属性有一个变量存储了一个列表,这个列表里存储的是所有图片的路径。def __getitem__(self, idx): 这个方法呢就可以索引这个列表读取图片。当然这个方法里面的过程是需要自己写的。比如上面说了有个列表是存储了图片的路径,我们可以对该列表进行索引,取出一个路径,然后读取图片,然后处理图片,最后返回图片。这些代码都是存储在内存中的。
正如最开始我们写的那样,getitem这个魔术方法让dataset实例对象这块内存可以像数组那样进行索引dataset[idx],可以使用len(dataset)函数直接返回长度。所以dataset可以理解为一块儿内存。
创建顺序采样器实例时,把dataset传入到SequentialSampler类,就是创建了一个dataset长度的迭代器,__iter__这个魔术方法最开始也提到了,可以返过去看一下。所以这个顺序采样器的本质我们就完全了解了,其实就是个迭代器,可以理解为跟range(n)没有什么太大的区别。
现在看随机采样器RandomSampler(Sampler)类。别的不用看,直接看__iter__()方法,有个if,如果要替代传进来的dataset,那么就返回一个指定了样本数量的列表。这个列表就是从[0, n-1]每次取一个值,取了指定的num_samples个,一般我们不会选这个,那返回的就是另一个。
其中,n是dataset的大小,返回的是一个[0, n-1]的乱序的列表的迭代器。这个列表中包含0~n-1的每一个值,但是是乱序的。这一步到现在也非常的明了了。看回dataloader.py的224行,得到sampler是个迭代器,迭代器里面是乱序的[0~n-1]的数值,继续往下看。
刚才说了dataloader.py的228行就是我们遇见的通常情况,所以debug就会进入到230行,然后创建实例对象,现在看一下这个类,是怎么对sampler迭代器进行操作的,返回的又是什么。
首先是初始化方法,创建一些实例属性,然后还是__iter/len__(self)这两个方法,这个batch采样器也很明了了,就是一个生成器。
过程就是:先创建一个列表,然后从我们存储有乱序的索引的迭代器sampler中取值,每取batchsize个就返回batch这个列表(里面存储的是数字),然后停在这里,等待下一次next()方法调用,下一次从yield处开始执行,先把batch这个列表置空,然后重新取值,知道最后取完,然后判断batch的长度到达设置的batchsize没有。如果有就返回,没有就完成循环了。然后执行下面的判断,drop_last是否为True,看看是舍弃还是继续返回。
就这样BatchSampler类的原理也了解了,就这,非常的简单。
然后继续debug,现在到了dataloader.py的232行了。
debug这些属性的时候都会跳到另一个魔术方法里,就下面这个
这个方法可以看一下这里面的解释python 中__setattr__, __getattr__,__getattribute__, __call_使用方法。重写_setattr__方法,意味着每次对实例属性进行赋值都会调用该方法。我们在debug的过程实例已经创建了,这个方法的目的是控制括号里的那5个参数不能更改。如果在创建实例之后,又重新对这几个实例属性赋值,当然或者对其他实例属性赋值,都会调用这个方法,只不过调用之后在方法内的逻辑是,如果是括号内的这几个实例属性,就会报错。看下错误,意思就是xx属性不应该在实例对象初始化之后再被设置。
现在就来到了238行collate_fn,然后再239行进行判断的时候就来到了下图295行这个方法。
上面的代码可以看到,self.batch_sampler是有值的,右值为BatchSampler类的实例对象,所以返回True。所以collate_fn = _utils.collate.default_collate,这个方法是怎么运行的一会儿再说。
现在整个dataset和dataloader的原理和内容基本都了解的差不多了,那下面就在循环代码中,看看它整个的过程是怎么一步一步取值的。下图是一个比较简单的常规训练代码,建立DataLoader的实例,起名字叫train_loader,进去看一下,跳到第二张图。
类中实现了__iter__方法,实例对象就是个迭代器。对括号内self的理解参见第一个卡片链接,几个特殊方法那个。在迭代过程中确认是单进程还是多进程加载数据,上面提到过了,在windows上是不支持多进程加载数据的,所以进去看一下单进程是怎么加载数据的。
这是一个比较简单的类,看下代码,首先看一下初始化方法,定义了一个fetcher的实例属性,看下右值是不是比较熟悉,在标记为①处的地方,已经说过这个类还有方法了。
在它实例方法里面有一个index的变量,这个右值在初始化方法中没有,应该是继承父类的,我们进去其父类看一下,这个右值在下面说明。
从上面三个框可以看到,刚才说没有的那个右值self._next_index()来自train_loader的_index_sampler,这是一个实例方法,返回BatchSampler类的实例对象(还记得这个对象返回的是什么吗?这是一个迭代器,每次返回一个batch乱序索引),看下图。
这个就是单进程类里面的东西,初始化完成之后,就返回这么一个实例对象。
for … in… 这个语法有两个功能。一是获得一个迭代器,即调用了__iter__()方法。第二个功能是循环调用__next__()方法。刚才是实现了第一个功能,获得一个迭代器,现在继续单步调试下去就是调用__next__()方法。对于迭代器来讲,这就是取数据的过程。
Iter()与 __iter__ 用于产生 iterator(迭代器),__iter__ 迭代器协议,凡是实现__iter__协议的对象,皆是迭代器对象。(next()也得实现,不然没法产生数据)。
Iter()迭代器工厂函数,凡是有定义有__iter__()函数,或者支持序列访问协议,也就是定义有__getitem__()函数的对象 皆可以通过 iter()工厂函数 产生迭代器(iterable)对象。
原文链接:https://blog.csdn.net/weixin_36670529/article/details/106641754
刚才说过了,这一步会调用对象的__next__()方法,现在的对象是_SingleProcessDataLoaderIter类的实例对象,这个方法是父类的,直接继承了。我们看到data = self._next_data(),这个是实例方法。
如图所示,进入了这个类的这个方法。上面刚刚说过了,self._next_index()是父类的一个实例方法,可以往上翻看一下,经过一系列调用,到了BatchSampler类,返回的是一个batchsize的乱序的索引列表。
从上图debug的index值可以看到具体列表,我的样本是199个,batchsize=16,所以列表内的值是无序的16个,不超过198的数值。
现在只是索引取完了,我们要的是数据和标签,所以下面到了真正的按照乱序索引取数据的过程。
到了fetch.py的这个类的fetch方法,这个方法上面提到过2次,if条件为真,所以直接执行下面代码,传进来的形参是刚才那个index(16个无序的不超过198的数字列表),然后就是一个列表生成式,索引的形式读取dataset返回的数据,这也是为什么上面一直强调,自定义的dataset类必须写__getiitem__()方法的原因。再强调一遍,只有写了这个魔术方法,才能和python的内置函数功能对应起来,实例对象才能够以索引的方式取数据。下一步debug就会直接调到我们自定义的dataset类的__getiitem__()方法读取数据。如下图所示。
在fetch方法中,得到那16个索引的dataset数据之后,data是一个含有image和label数据的长度为16的列表,每对image和label数据构成一个元组。return时对data有一个self.collate_fn(data)的操作,看一下这个操作什么样子的。
在dataloader的初始化方法中,就有这个函数,上面也提到过,这个函数一直作为参数传递,直到现在终于用到这个函数了,看下到底是个啥玩意儿。
红框的路径正是上面collate_fn的右值,collate_fn是在dataloader.py中的,dataloader.py是在torch.utils.data中的,所以这个右值就是是一个相对路径,并且它不是一个实例方法,它是一个函数。
这个函数的形参是个batch,其实就是我们刚才的data,data是一个含有image和label数据的长度为16的列表,每对image和label数据构成一个元组,下图可以看到一个示例。
default_collate这个函数代码逻辑很简单,就是各种判断,如果都不是,最后就raise个类型错误。显然,elem变量是个tuple类型数据,里面含有两个元素,如上所示一个tensor,一个数字,elem_type必然是tuple了。所以一路判断下来,进入了下图这个判断里面,下面注释着这个判断是检查batch或者说我们的data中的元素是不是尺寸一致的。因为我们每个元素都是tuple,每个tuple里面都是2个元素,每个tuple里面的两个元素的各自形式都是一致的。所以进入了这个判断。
https://blog.csdn.net/csdn15698845876/article/details/73411541
将batch这个列表变成迭代器,然后,取出一个数据,即一个elem,elem_size = 2,就是一个元组有几个元素嘛,有两个。然后判断是否所有elem的size都是2。然后解包再打包,这个过程看上面链接里的内容去理解。就是把batch这个列表解包,形成了16个元组,每个元组有2个元素,分别是image和label数据,然后再zip。元组每个有2个元素,所以zip完的迭代器如果转换为list形式就是2个元组,每个元组里分别有16个image和label元素。
zip函数返回的是一个迭代器,使用for这个语法取的时候,分别取出来两个参数,然后又送入到当前的default_collate函数进行判断。
第一个肯定是个tensor了,get_workerinfo()返回None,进入到return,对batch这组数据在0维进行stack操作,(batch是个元组,里面有16个tensor,在第0维进行stack相当于是增加了一维,这一维度的个数是元组的长度,也就是数据的个数)返回4维的tensor(n*c*h*w),这也是为什么返回的数据第一维是n(batch_size大小)。然后zip后的第二组数据也是元组,取出一个elem就是int,所以进入到如下代码段中。
讲label数据转换为tensor返回到调用zip后迭代器的那个列表中,然后进行return,return的是一个列表,里面有两个tensor元素。每个tensor元素都有16个数据。
然后进行内存加速操作,这个是什么锁页内存,可以对数据处理进行加速,然后返回data(十个列表),返回到哪儿了呢?
其实是返回到_SingleProcessDataLoaderIter的__next__(self)方法中,S继承了父类_BaseDataLoaderIter类的这个方法。然后再进行判断,此时判断条件不成立,我们的数据不是_DatasetKind.Iterable,而是_DatasetKind.Map,这个上面已经说过了。然后再返回data,此时返回的data就是转了一圈回来了。
终于传回来了,此时data就是for语句中的那个data,data里面有两组数据,将其解包,图像数据给inputs,标签数据给label,此时数据维度也正确了,然后进行后续的前向计算和反向传播等操作。



猜您喜欢:

超110篇!CVPR 2021最全GAN论文汇总梳理!

超100篇!CVPR 2020最全GAN论文梳理汇总!

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成


附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享


《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》


浏览 12
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报