使用 PyTorch 进行分布式训练
AI算法与图像处理
共 6607字,需浏览 14分钟
· 2021-07-10
点击下方“AI算法与图像处理”,一起进步!
重磅干货,第一时间送达
size:进行训练的 GPU 设备的数量
rank:对GPU设备有一个序列的id号
# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
download=True,
train=True)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True)
# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
download=True,
train=True,
transform=transform)
# Create distributed sampler pinned to rank
sampler = DistributedSampler(train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True, # May be True
seed=42)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=False, # Must be False!
num_workers=4,
sampler=sampler,
pin_memory=True)
def create_model():
model = nn.Sequential(
nn.Linear(28*28, 128), # MNIST images are 28x28 pixels
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10, bias=False) # 10 classes to predict
)
return model
# Initialize the model
model = create_model()
# Initialize the model
model = create_model()
# Create CUDA device
device = torch.device(f'cuda:{rank}')
# Send model parameters to the device
model = model.to(device)
# Wrap the model in DDP wrapper
model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)
for i in range(epochs):
for x, y in train_loader:
# do the training
...
for i in range(epochs):
train_loader.sampler.set_epoch(i)
for x, y in train_loader:
# do the training
...
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
rank = args.local_rank
if rank == 0:
torch.save(model.module.state_dict(), 'model.pt')
python -m torch.distributed.launch --nproc_per_node=4
ddp_tutorial_multi_gpu.py
努力分享优质的计算机视觉相关内容,欢迎关注:
个人微信(如果没有备注不拉群!) 请注明:地区+学校/企业+研究方向+昵称
下载1:何恺明顶会分享
在「AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析
下载2:终身受益的编程指南:Google编程风格指南
在「AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!
下载3 CVPR2021
在「AI算法与图像处理」公众号后台回复:CVPR,即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文
点亮 ,告诉大家你也在看
评论
金融研究 | 使用Python测量关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
词向量(更新) | 使用MD&A2001-2022语料训练Word2Vec模型
buTips: 公众号推送后内容只能更改一次,且只能改20字符。 如果内容出问题,或者想更新内容, 只能重复推送。 为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-03-24-load-w2v-and-expand-your-concpe
大邓和他的Python
0
金融研究(更新) | 使用Python构建关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
科普:深度学习训练,不同预算GPU选购指南
以下文章来源于微信公众号:DeepHub IMBA作者:Mike Clayton本文仅用于学术分享,如有侵权,请联系后台作删文处理导读购买显卡第一个要考虑的问题是什么?当然是预算。本文提供了不同预算的显卡选购指南,希望能对各位读者有所帮助。在进行机器学习项目时,特别是在处理深度学习和神经网络时,最好
机器学习初学者
0
管理世界2024 | 使用管理层讨论与分析测量「企业人工智能指标」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址 https://textdata.cn/blog/2024-04-19-ai-improve-firm-productivity/
大邓和他的Python
0
GPT的风也吹到了CV,详解自回归视觉模型的先驱! ImageGPT:使用图像序列训练图像 GPT模型
作者丨科技猛兽编辑丨极市平台导读 在 CIFAR-10 上,iGPT 使用 linear probing 实现了 96.3% 的精度,优于有监督的 Wide ResNet,并通过完全微调实现了 99.0% 的精度,匹配顶级监督预训练模型。本文目录1 自回归视觉模型的先驱 ImageGPT:
机器学习初学者
0
代码 | 使用 MD&A文本测量「企业不确定性感知FEPU」
Tips: 为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2024-04-25-firm-economic-policy-uncertainty/本文使用的缩写EPU 经济政策不确定性(Economic Policy Uncerta
大邓和他的Python
0
面试官:在原生input上面使用v-model和组件上面使用有什么区别?
前言面试官:vue3的v-model都用过吧,来讲讲。粉丝:v-model其实就是一个语法糖,在编译时v-model会被编译成:modelValue属性和@update:modelValue事件。一般在子组件中定义一个名为modelValue的props来接收父组件v-model传递的值,然后当子组
高级前端进阶
0