PyTorch Dataloader读取时,如何在进程之间传输数据?

共 5943字,需浏览 12分钟

 ·

2021-10-12 18:58

点击上方机器学习与生成对抗网络”,关注星标

获取有趣、好玩的前沿干货!

来源|知乎  作者|Envy
链接|https://zhuanlan.zhihu.com/p/409629586
编辑|人工智能前沿讲习
最近我在做PyTorch的Dataloader相关的开发,有一个问题让我比较在意:PyTorch的Dataloader在启动多个进程读取样本的时候,这些数据是怎么在进程之间进行传输的?会不会引入多余的内存拷贝?

01

Dataloader使用multiprocess.Queue来传输数据
先简单介绍一下Dataloader的多进程模式,Dataloader在构造的时候,若num_workers不为0,就会启动num_workers个worker进程,然后主进程会向worker进程分发读取任务,worker进程读到数据之后,再把数据放到队列中供主进程取用。Worker进程所执行的代码片段如下:
while watchdog.is_alive():    try:        r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)    except queue.Empty:        continue    if isinstance(r, _ResumeIteration):        # Acknowledge the main process        data_queue.put((r, None))        iteration_end = False        # Recreate the fetcher for worker-reuse policy        fetcher = _DatasetKind.create_fetcher(            dataset_kind, dataset, auto_collation, collate_fn, drop_last)        continue    elif r is None:        # Received the final signal        assert done_event.is_set() or iteration_end        break    elif done_event.is_set() or iteration_end:        # `done_event` is set. But I haven't received the final signal        # (None) yet. I will keep continuing until get it, and skip the        # processing steps.        continue    idx, index = r    data: Union[_IterableDatasetStopIteration, ExceptionWrapper]    if init_exception is not None:        data = init_exception        init_exception = None    else:        try:            data = fetcher.fetch(index)        except Exception as e:            if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:                data = _IterableDatasetStopIteration(worker_id)                # Set `iteration_end`                #   (1) to save future `next(...)` calls, and                #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.                iteration_end = True            else:                # It is important that we don't store exc_info in a variable.                # `ExceptionWrapper` does the correct thing.                # See NOTE [ Python Traceback Reference Cycle Problem ]                data = ExceptionWrapper(                    where="in DataLoader worker process {}".format(worker_id))    data_queue.put((idx, data))    del data, idx, index, r  # save memory
其中,data_queue通常是一个torch.multiprocessing.Queue的实例。
看到这里,似乎Dataloader只是平平无奇地使用了torch.multiprocessing.Queue的接口,难道torch.multiprocessing.Queue有一些高级的技巧?
继续看torch.multiprocessing.Queue的代码,发现它只是简单地把multiprocessing.Queue包了一下。众所周知, multiprocessing.Queue在Linux使用socket来实现,那难道读上来的数据需要在socket之间传来传去吗,效率也太低了吧?!不对,PyTorch一定还有其他骚操作。

02

Tensor(CPU Tensor)在multiprocessing.Queue中的序列化和反序列化
在通常的用法里,Dataloader从Dataset里读出来的数据都会被collate_fn转成CPU Tensor,那我们就继续看看,Tensor是怎么在队列中序列化和反序列化的。
可以看到,torch.Tensor重载了__reduce_ex__()函数,序列化的时候只会用到 Tensor.storage, Tensor.storage_offset, size, stride, requires_grad和backward_hooks;而反序列化的torch._utils._rebuild_tensor_v2()也只会用到以上信息。
multiprocessing.Queue是使用pickle来做序列化和反序列化的,而重载__reduce_ex__()正是自定义序列化反序列化方式的方法之一。
那看起来,CPU Tensor在进程中传输时,是在接收进程中把Tensor重新构建了一遍,而构建Tensor时候用到的信息, Tensor.storage_offset, size, stride, requires_grad和backward_hooks,都只是用于描述Tensor的meta信息,实际和数据相关的,就只有Tensor.storage了。

03

Tensor.Storage的序列化与反序列化
Tensor.Storage同样重载了pickle的序列化与反序列化过程,在torch/multiprocessing/reduction.py中,给Tensor.Storage 注册了reduce函数reduce_storage.
这里为什么使用copyreg库而不是重载__reduce__(), 在copyreg的注释里说copyreg是专用于C extension的.按照这个说法,在reductions.py里为Tensor注册的reduce function应该是没有起效的。
def reduce_storage(storage):    from . import get_sharing_strategy    if storage.is_cuda:        raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")    elif get_sharing_strategy() == 'file_system':        metadata = storage._share_filename_()        cache_key = metadata[1]        rebuild = rebuild_storage_filename        storage._shared_incref()    elif storage.size() == 0:        # This is special cased because Empty tensors        # (with size 0) cannot be mmapped.        return (rebuild_storage_empty, (type(storage),))    else:        fd, size = storage._share_fd_()        df = multiprocessing.reduction.DupFd(fd)        cache_key = fd_id(fd)        metadata = (df, size)        rebuild = rebuild_storage_fd  # type: ignore[assignment]
shared_cache[cache_key] = StorageWeakRef(storage) return (rebuild, (type(storage),) + metadata)
这个函数首先根据环境中的sharing strategy来决定共享内存的使用方式,然后若storage原本不在共享内存中的话,就把它拷到共享内存中去,比如_share_fd_()的实现。


04

小结
看到这里,我们可以大概得出一个结论了。
Worker进程从Dataset中读出来的Tensor本身是普通的CPU Tensor,但当把它放到multiprocessing.Queue中去的时候,这个Tensor的数据会被拷到共享内存中,Queue只会发送这个Tensor所具有的meta信息,主进程接到这些meta信息之后,就可以从共享内存中的数据重新构建Tensor。
这里有一点值得注意,如果你想要验证这个过程,在发送进程调用multiprocessing.Queue.put()之后,立即调用Tensor.is_shared()并不会返回True,因为put()是非阻塞的,只有当Tensor被QueueFeedThread序列化完成之后再调用is_shared(),才会得到预期中的结果。

05

Dataloader的小心机
在default_collate中,有这样一段小代码:
if isinstance(elem, torch.Tensor):    out = None    if torch.utils.data.get_worker_info() is not None:        # If we're in a background process, concatenate directly into a        # shared memory tensor to avoid an extra copy        numel = sum([x.numel() for x in batch])        storage = elem.storage()._new_shared(numel)        out = elem.new(storage)    return torch.stack(batch, 0, out=out)
含义是,如果batch中的数据已经是Tensor了,那么,如果这是一个Worker进程,就开一段共享内存把这个batch放进去。因为collate的时候无论如何都会有一次内存拷贝(除非底层的Dataset有其他保证),那么这个操作就省掉了之后放进队列中的那一次内存拷贝。
不过我随便找了几个自定义了collate_fn的模型看了一下他们写的collate过程,是没有把这一点考虑进去的。这也算是Dataloader的一个小心机吧,有缘人就用得上。


猜您喜欢:

等你着陆!【GAN生成对抗网络】知识星球!

CVPR 2021专题1:GAN的改进

CVPR 2021 | GAN的说话人驱动、3D人脸论文汇总

CVPR 2021 | 图像转换 今如何?几篇GAN论文

【CVPR 2021】通过GAN提升人脸识别的遗留难题

CVPR 2021生成对抗网络GAN部分论文汇总

经典GAN不得不读:StyleGAN

最新最全20篇!基于 StyleGAN 改进或应用相关论文

超100篇!CVPR 2020最全GAN论文梳理汇总!

附下载 | 《Python进阶》中文版

附下载 | 经典《Think Python》中文版

附下载 | 《Pytorch模型训练实用教程》

附下载 | 最新2020李沐《动手学深度学习》

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 | 超100篇!CVPR 2020最全GAN论文梳理汇总!

附下载 |《计算机视觉中的数学方法》分享

浏览 28
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报