Pytorch 数据流中常见Trick总结

机器学习与生成对抗网络

共 3727字,需浏览 8分钟

 ·

2021-12-23 19:50


 戳我,查看GAN的系列专辑~!

等你着陆!【GAN生成对抗网络】知识星球!

转载zlhroughlove@知乎 编辑极市平台 分享,侵权

来源https://zhuanlan.zhihu.com/p/441317369

前言

在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~

一、Dataset的定义

from torch.utils.data import Dataset, DataLoader, RandomSampler

对于不同类型的建模任务,模型的输入各不相同。自然语言,多模态,点击率预估,往往这些场景输入模型的数据并不是来自于单一文件,而且可能无法全部存入内存。Dataset需要整合项目的数据,对于单条样本涉及到的数据做一个提取与归纳。不但如此,项目可能还涉及到多种模型,任务的训练。Dataset需要为不同的模型以及训练任务提供不同的单条样本输入,作为一个数据生成器,把后续模型训练任务需要的所有基础数据,标签全返回了。所以往往我们可以定义一个BaseDataset类,继承torch.utils.data.Dataset,这个类可以初始化一些文件路径,配置等。后面不同的模型训练任务定义相应的Dataset类继承BaseDataset。

Dataset通用的结构为:

class BaseDataset(Dataset):

    def __init__(self, config):
        self.config = config
        if os.path.isfile(config.file_path) is False:
            raise ValueError(f"Input file path {config.file_path} not found")
        logger.info(f"Creating features from dataset file at {config.file_path}")
        # 一次性全读进内存
        self.data = joblib.load(config.file_path)
        self.nums = len(self.data)

    def __len__(self):
        return self.nums

    def __getitem__(self, i) -> Dict[str, tensor]:
        sample_i = self.data[i]
        return {"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}

如果无法全部读取进内存需要再__getitem__方法内构建数据,做自然语言则可以吧tokenizer初始化到该类中,在__getitem__方法内完成tokenizer。改方法的输出推荐做成字典形式。

对于不同的训练任务可以通过以下方法返回响应的数据生成器

def build_dataset(task_type, features, **kwargs):
    assert task_type in ['task1''task2'], 'task mismatch'

    if task_type == 'task1':
        dataset = task1Dataset(features))
    else:
        dataset = task2Dataset(features)

    return dataset

有时模型的训练任务需要做数据增强,对比学习,构造多种的预训练任务输入。Dataset的职能边界是提供一套基础的单样本数据输入生成器。如果是MLM任务,可以在Dataset内生成maskposition以及label。如果是在batch内的对比学习则应该在DataLoader生产batch数据后再进行。

二、DataLoader的定义

DataLoader的作用是对Dataset进行多进程高效地构建每个训练批次的数据。传入的数据可以认为是长度为batch大小的多个__getitem__ 方法返回的字典list。DataLoader的职能边界是根据Dataset提供的单条样本数据有选择的构建一个batch的模型输入数据。

其通常的结构为对Train,Valid,Test分别建立:

train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=None, # 一般不用设置
                              num_workers=4)

首先对于sampler 还有一种定义方式:

sampler = torch.utils.data.distributed.DistributedSampler(dataset)

至于batch内数据是否需要做shuffle也需要根据损失函数确定(对比学习慎用)

DataLoader会自动合并__getitem__ 方法返回的字典内每个key内每个tensor,在tensor的第0维度新增一个batch大小的维度。如果该方法返回的每条样本长度不同无法拼接,batchsize>1就会报错。但是又一些任务在还没有确定后续的批样本对应的任务时,Dataset可能返回的字典里每个key可能就是长度不同的tensor,甚至是list,这时候需要使用collate_fn参数告诉DataLoader如何取样。我们可以定义自己的函数来准确地实现想要的功能。

如果__getitem__方法返回的是tuple((list, list)) 可以使用:

def merge_sample(x):
    return zip(*x)

train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=merge_sample,
                              num_workers=4)

拼接数据,后续再做进一步处理。(此时list内数据还是不等长,无法转为tensor)

如果__getitem_方法返回的是Dict[str,tensor],自定义的collate_fn方法内需要实现:List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence过程也可以在自定义方法内实现。(总之collate_fn中不但可以处理不等长数据,还可以对一个batch的数据做精修。当然也可以在DataLoader之后再做修改batch内的数据。)

值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题..

通过以下方式可以检查一下输入后续模型的数据是否已经是想要的格式

for step, batch_data in enumerate(train_loader):
    if step < 1:
        print(batch_data)
    else:
        break

之后数据将数据放入gpu device, 一个batch的数据进入device端后就与内存上的数据不再互相干扰。之后数据就可以喂给模型了:

loss = model(**batch_data)

for key in batch_data.keys():
    batch_data[key] = batch_data[key].to(device)



猜您喜欢:

超110篇!CVPR 2021最全GAN论文汇总梳理!

超100篇!CVPR 2020最全GAN论文梳理汇总!

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成


附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享


《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》


浏览 28
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报