【深度学习】干货!小显存如何训练大模型

机器学习初学者

共 3425字,需浏览 7分钟

 ·

2022-03-03 01:45

之前Kaggle有一个Jigsaw多语言毒舌评论分类[1]比赛,当时我只有一张11G显存的1080Ti,根本没法训练SOTA的Roberta-XLM-large模型,只能遗憾躺平。在这篇文章中,我将分享一些关于如何减少训练时显存使用的技巧,以便你可以用现有的GPU训练更大的网络。

混合精度训练

第一个可能已经普及的技巧是使用混合精度(mixed-precision)训练。当训练一个模型时,一般来说所有的参数都会存储在显存VRAM中。很简单,总的VRAM使用量等于存储的参数数量乘以单个参数的VRAM使用量。一个更大的模型不仅意味着更好的性能,而且也会使用更多的VRAM。由于性能相当重要,比如在Kaggle比赛中,我们不希望减小模型的规模。因此减少显存使用的唯一方法是减少每个变量的内存使用。默认情况下变量是32位浮点格式,这样一个变量就会消耗4个字节。幸运的是,人们发现可以在某些变量上使用16位浮点,而不会损失太多的精度。这意味着我们可以减少一半的内存消耗! 此外,使用低精度还可以提高训练速度,特别是在支持Tensor Core的GPU上。

在1.5版本之后,pytorch开始支持自动混合精度(AMP)训练。该框架可以识别需要全精度的模块,并对其使用32位浮点数,对其他模块使用16位浮点数。下面是Pytorch官方文档[2]中的一个示例代码。

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

梯度积累

第二个技巧是使用梯度积累。梯度累积的想法很简单:在优化器更新参数之前,用相同的模型参数进行几次前后向传播。在每次反向传播时计算的梯度被累积(加总)。如果你的实际batch size是N,而你积累了M步的梯度,你的等效批处理量是N*M。然而,训练结果不会是严格意义上的相等,因为有些参数,如Batch Normalization,不能完全累积。

关于梯度累积,有一些事情需要注意:

  1. 当你在混合精度训练中使用梯度累积时,scale应该为有效批次进行校准,scale更新应该以有效批次的粒度进行。
  2. 当你在分布式数据并行(DDP)训练中使用梯度累积时,使用no_sync()上下文管理器来禁用前M-1步的梯度全还原,这可以增加训练的速度。

具体的实现方法可以参考文档[3]

梯度检查点

最后一个,也是最重要的技巧是使用梯度检查点(Gradient Checkpoint)。Gradient Checkpoint的基本思想是只将一些节点的中间结果保存为checkpoint,在反向传播过程中对这些节点之间的其他部分进行重新计算。据Gradient Checkpoint的作者说[4],在这个技巧的帮助下,他们可以把10倍大的模型放到GPU上,而计算时间只增加20%。Pytorch从0.4.0版本开始正式支持这一功能,一些非常常用的库如Huggingface Transformers也支持这一功能,而且非常简单,只需要下面的两行代码:

bert = AutoModel.from_pretrained(pretrained_model_name)
bert.config.gradient_checkpointing=True

实验

在这篇文章的最后,我想分享之前我在惠普Z4工作站上做的一个简单的benchmark。该工作站配备了2个24G VRAM的RTX6000 GPU(去年底升级到2个48G的A6000了),在实验中我只用了一个GPU。我用不同的配置在Kaggle Jigsaw多语言毒舌评论分类比赛的训练集上训练了XLM-Roberta Base/Large,并观察显存的使用量,结果如下。

ModelXLM-R BaseXLM-R Base 1XLM-R Base 2XLM-R LargeXLM-R Large 1XLM-R Large 2
Batch size/GPU8816888
Mixed-precisionoffononoffonon
gradient checkpointingoffoffoffoffoffon
VRAM usage12.28G10.95G16.96OOM23.5G11.8G
one epoch70min50min40min-100min110min

我们可以看到,混合精度训练不仅减少了内存消耗,而且还带来了显著的速度提升。梯度检查点的功能也非常强大。它将VRAM的使用量从23.5G减少到11.8G!

以上就是所有内容,希望对大家有帮助🙂

参考资料

[1]

Jigsaw多语言毒舌评论分类: https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification

[2]

Pytorch官方文档: https://pytorch.org/docs/1.8.1/notes/amp_examples.html

[3]

gradient-accumulation文档: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation

[4]

据Gradient Checkpoint的作者说: https://github.com/cybertronai/gradient-checkpointing

往期精彩回顾




浏览 88
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报