PyTorch Dataloader读取时,如何在进程之间传输数据?
点击上方“机器学习与生成对抗网络”,关注星标
获取有趣、好玩的前沿干货!
01
while watchdog.is_alive():try:r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)except queue.Empty:continueif isinstance(r, _ResumeIteration):# Acknowledge the main processdata_queue.put((r, None))iteration_end = False# Recreate the fetcher for worker-reuse policyfetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)continueelif r is None:# Received the final signalassert done_event.is_set() or iteration_endbreakelif done_event.is_set() or iteration_end:# `done_event` is set. But I haven't received the final signal# (None) yet. I will keep continuing until get it, and skip the# processing steps.continueidx, index = rdata: Union[_IterableDatasetStopIteration, ExceptionWrapper]if init_exception is not None:data = init_exceptioninit_exception = Noneelse:try:data = fetcher.fetch(index)except Exception as e:if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:data = _IterableDatasetStopIteration(worker_id)# Set `iteration_end`# (1) to save future `next(...)` calls, and# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.iteration_end = Trueelse:# It is important that we don't store exc_info in a variable.# `ExceptionWrapper` does the correct thing.# See NOTE [ Python Traceback Reference Cycle Problem ]data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))data_queue.put((idx, data))del data, idx, index, r # save memory
02
03
def reduce_storage(storage):from . import get_sharing_strategyif storage.is_cuda:raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")elif get_sharing_strategy() == 'file_system':metadata = storage._share_filename_()cache_key = metadata[1]rebuild = rebuild_storage_filenamestorage._shared_incref()elif storage.size() == 0:# This is special cased because Empty tensors# (with size 0) cannot be mmapped.return (rebuild_storage_empty, (type(storage),))else:fd, size = storage._share_fd_()df = multiprocessing.reduction.DupFd(fd)cache_key = fd_id(fd)metadata = (df, size)rebuild = rebuild_storage_fd # type: ignore[assignment]shared_cache[cache_key] = StorageWeakRef(storage)return (rebuild, (type(storage),) + metadata)
04
05
if isinstance(elem, torch.Tensor):out = Noneif torch.utils.data.get_worker_info() is not None:# If we're in a background process, concatenate directly into a# shared memory tensor to avoid an extra copynumel = sum([x.numel() for x in batch])storage = elem.storage()._new_shared(numel)out = elem.new(storage)return torch.stack(batch, 0, out=out)
猜您喜欢:
CVPR 2021 | GAN的说话人驱动、3D人脸论文汇总
附下载 |《TensorFlow 2.0 深度学习算法实战》
评论
