PyTorch | 如何控制dataloader的随机shuffle

共 3402字,需浏览 7分钟

 ·

2022-05-27 10:26



前言 在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch。本文作者给出了一个简单方便的实现思路,附详解代码。

作者:魏鸿鑫@知乎
编辑:CV技术指南
原文:https://zhuanlan.zhihu.com/p/515697362


问题背景


在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch, 或者说,在一个epoch中按照固定的顺序进行多次的样本循环iteration。


现有Sampler


默认的 RandomSampler 在生成iteration的时候会重新做一次random shuffle,所以无法直接实现这个需求。

  def __iter__(self) -> Iterator[int]:
      n = len(self.data_source)
      if self.generator is None:
          seed = int(torch.empty((), dtype=torch.int64).random_().item())
          generator = torch.Generator()
          generator.manual_seed(seed)
      else:
          generator = self.generator

      if self.replacement:
          for _ in range(self.num_samples // 32):
              yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
          yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
      else:
          for _ in range(self.num_samples // n):
              yield from torch.randperm(n, generator=generator).tolist()
          yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

上面的代码是RandomSampler中最重要的__iter__函数,我们可以看到每次调用这个函数或者新的iter时会得到一个新的随机顺序的iteration。

再看看另一个常用的sampler,也就是 SequentialSampler。我们在test的时候经常会设置shuffle=false,这时候就相当于使用了SequentialSampler:

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """

    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

在代码中可以看到,这个sampler就是简单地创造并返回一个range序列,无法对其进行shuffle操作。


解决方案


结合上面两个现有的sampler,我们可以简单地自定义一个新的sampler来实现我们的需求。也就是说,我们希望能够手动控制何时进行shuffle操作,在没有shuffle时我们希望sampler按照前面的顺序返回iteration。

下面是我的实现:

class MySequentialSampler(SequentialSampler):
    def __init__(self, data_source, num_data=None):
        self.data_source = data_source
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        if num_data is None:
            self.num_data = len(self.my_list)
        else:
            self.num_data = num_data
            self.my_list = self.my_list[:num_data]

    def __iter__(self):
        return iter(self.my_list)

    def __len__(self):
        return self.num_data

    def shuffle(self):
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        self.my_list = self.my_list[:self.num_data]

这个实现非常简单而且使用方便。在默认情况下基本等同于SequentialSampler (去掉init函数中的shuffle即完全一致)。当我们需要重新shuffle序列的时候,只需要调用shuffle函数即可,比如:dataloader.sampler.shuffle(). 通过这个自定义sampler,我们就可以实现在指定的时候进行shuffle操作,而不是固定在每个iteration结束时进行shuffle。


ps: 理论上也可以直接通过对dataset进行shuffle,但这样操作的缺点是会改变对应的index,另外一般我们在train或者test函数中不会获取到dataset,而只能从loader进行操作(dataloader.dataset一般只能获取到length)。因此,修改sampler可以说是对原训练方法流程最少的方式。



猜您喜欢:

 戳我,查看GAN的系列专辑~!
一顿午饭外卖,成为CV视觉前沿弄潮儿!
CVPR 2022 | 25+方向、最新50篇GAN论文
 ICCV 2021 | 35个主题GAN论文汇总
超110篇!CVPR 2021最全GAN论文梳理
超100篇!CVPR 2020最全GAN论文梳理


拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成


附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享


《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》


浏览 55
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报