10条PyTorch避坑指南
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
本文转载自:机器之心 | 作者:Eugene Khvedchenya
高性能 PyTorch 的训练管道是什么样的?是产生最高准确率的模型?是最快的运行速度?是易于理解和扩展?还是容易并行化?答案是,包括以上提到的所有。
建议 0:了解你代码中的瓶颈在哪里
建议 1:如果可能的话,将数据的全部或部分移至 RAM。
class RAMDataset(Dataset):
def __init__(image_fnames, targets):
self.targets = targets
self.images = []
for fname in tqdm(image_fnames, desc="Loading files in RAM"):
with open(fname, "rb") as f:
self.images.append(f.read())
def __len__(self):
return len(self.targets)
def __getitem__(self, index):
target = self.targets[index]
image, retval = cv2.imdecode(self.images[index], cv2.IMREAD_COLOR)
return image, target
建议 2:解析、度量、比较。每次你在管道中提出任何改变,要深入地评估它全面的影响。
# Profile CPU bottlenecks
python -m cProfile training_script.py --profiling
# Profile GPU bottlenecks
nvprof --print-gpu-trace python train_mnist.py
# Profile system calls bottlenecks
strace -fcT python training_script.py -e trace=open,close,read
Advice 3: *Preprocess everything offline*
建议 3:离线预处理所有内容
建议 4:调整 DataLoader 的工作程序
假设我们为 Cityscapes 训练图像分割模型,其批处理大小为 32,RGB 图像大小是 512x512x3(高、宽、通道)。我们在 CPU 端进行图像标准化(稍后我将会解释为什么这一点比较重要)。在这种情况下,我们最终的图像 tensor 将会是 512 * 512 * 3 * sizeof(float32) = 3,145,728 字节。与批处理大小相乘,结果是 100,663,296 字节,大约 100Mb;
除了图像之外,我们还需要提供 ground-truth 掩膜。它们各自的大小为(默认情况下,掩膜的类型是 long,8 个字节)——512 * 512 * 1 * 8 * 32 = 67,108,864 或者大约 67Mb;
因此一批数据所需要的总内存是 167Mb。假设有 8 个工作程序,内存的总需求量将是 167 Mb * 8 = 1,336 Mb。
将 RGB 图像保持在每个通道深度 8 位。可以轻松地在 GPU 上将图像转换为浮点形式或者标准化。
在数据集中用 uint8 或 uint16 数据类型代替 long。
class MySegmentationDataset(Dataset):
...
def __getitem__(self, index):
image = cv2.imread(self.images[index])
target = cv2.imread(self.masks[index])
# No data normalization and type casting here
return torch.from_numpy(image).permute(2,0,1).contiguous(),
torch.from_numpy(target).permute(2,0,1).contiguous()
class Normalize(nn.Module):
# https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.py
def __init__(self, mean, std):
super().__init__()
self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())
self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input.to(self.mean.type) - self.mean) * self.std
class MySegmentationModel(nn.Module):
def __init__(self):
self.normalize = Normalize([0.221 * 255], [0.242 * 255])
self.loss = nn.CrossEntropyLoss()
def forward(self, image, target):
image = self.normalize(image)
output = self.backbone(image)
if target is not None:
loss = self.loss(output, target.long())
return loss
return output
model = nn.DataParallel(model) # Runs model on all available GPUs
GPU 负载不平衡;
在主 GPU 上聚合需要额外的视频内存
在训练期间继续在前向推导内使用 nn.DataParallel 计算损耗。在这种情况下。za 不会将密集的预测掩码返回给主 GPU,而只会返回单个标量损失;
使用分布式训练,也称为 nn.DistributedDataParallel。借助分布式训练的另一个好处是可以看到 GPU 实现 100% 负载。
https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255
https://medium.com/@theaccelerators/learn-pytorch-multi-gpu-properly-3eb976c030ee
https://towardsdatascience.com/how-to-scale-training-on-multiple-gpus-dae1041f49d2
建议 5: 如果你拥有两个及以上的 GPU
def test_loss_profiling():
loss = nn.BCEWithLogitsLoss()
with torch.autograd.profiler.profile(use_cuda=True) as prof:
input = torch.randn((8, 1, 128, 128)).cuda()
input.requires_grad = True
target = torch.randint(1, (8, 1, 128, 128)).cuda().float()
for i in range(10):
l = loss(input, target)
l.backward()
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
建议 9: 如果设计自定义模块和损失——配置并测试他们
通过硬件升级可以更轻松地解决某些瓶颈。
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲 在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲 在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~