使用 PyTorch 进行分布式训练
点击下方“AI算法与图像处理”,一起进步!
重磅干货,第一时间送达
size:进行训练的 GPU 设备的数量
rank:对GPU设备有一个序列的id号
# Download and initialize MNIST train datasettrain_dataset = datasets.MNIST('./mnist_data',download=True,train=True)# Wrap train dataset into DataLoadertrain_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True)
# Download and initialize MNIST train datasettrain_dataset = datasets.MNIST('./mnist_data',download=True,train=True,transform=transform)# Create distributed sampler pinned to ranksampler = DistributedSampler(train_dataset,num_replicas=world_size,rank=rank,shuffle=True, # May be Trueseed=42)# Wrap train dataset into DataLoadertrain_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False, # Must be False!num_workers=4,sampler=sampler,pin_memory=True)
def create_model():model = nn.Sequential(nn.Linear(28*28, 128), # MNIST images are 28x28 pixelsnn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 10, bias=False) # 10 classes to predict)return model# Initialize the modelmodel = create_model()
# Initialize the modelmodel = create_model()# Create CUDA devicedevice = torch.device(f'cuda:{rank}')# Send model parameters to the devicemodel = model.to(device)# Wrap the model in DDP wrappermodel = DistributedDataParallel(model, device_ids=[rank], output_device=rank)
for i in range(epochs):for x, y in train_loader:# do the training...
for i in range(epochs):train_loader.sampler.set_epoch(i)for x, y in train_loader:# do the training...
parser = argparse.ArgumentParser()parser.add_argument("--local_rank", type=int)args = parser.parse_args()rank = args.local_rank
if rank == 0:torch.save(model.module.state_dict(), 'model.pt')
python -m torch.distributed.launch --nproc_per_node=4ddp_tutorial_multi_gpu.py
努力分享优质的计算机视觉相关内容,欢迎关注:
个人微信(如果没有备注不拉群!) 请注明:地区+学校/企业+研究方向+昵称
下载1:何恺明顶会分享
在「AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析
下载2:终身受益的编程指南:Google编程风格指南
在「AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!
下载3 CVPR2021
在「AI算法与图像处理」公众号后台回复:CVPR,即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文
点亮
,告诉大家你也在看
评论
