老黄又赢麻了!英伟达亲自下场推出 FlashAttention-3:H100利用率飙升至75%!

机器学习算法工程师

共 3732字,需浏览 8分钟

 ·

2024-07-12 14:11

点蓝色字关注“机器学习算法工程师

设为星标,干货直达!

740 TFLOPS!迄今最强 FlashAttention 来了。

随着大型语言模型(LLM)加速落地,扩展模型上下文窗口变得越来越重要。然而,Transformer 架构的核心 —— 注意力层的时间复杂度和空间复杂度与输入序列长度的平方成正比。这使得扩展模型上下文窗口存在挑战。


2022 年,一种快速、内存高效的注意力算法 ——FlashAttention 问世,该算法无需任何近似即可加速注意力并减少内存占用。


FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算速度,将内存使用量从序列长度的二次减少到线性。



2023 年,研究团队宣布推出 FlashAttention-2,在算法、并行化和工作分区等方面有了显著改进。


现在,来自 Meta、英伟达、Together AI 等机构的研究者宣布推出 FlashAttention-3,它采用了加速 Hopper GPU 注意力的三种主要技术:


  • 通过 warp-specialization 重叠整体计算和数据移动;

  • 交错分块 matmul 和 softmax 运算;

  • 利用硬件支持 FP8 低精度的不连贯处理。


FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 利用率为 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。


FlashAttention-3 的改进将带来:


  • 更高效的 GPU 利用率:H100 理论最大 FLOPS 利用率为 75%,而之前仅为 35%。这使得 LLM 的训练和运行速度比以前的版本快得多。

  • 较低精度下更好的性能:FlashAttention-3 可以在保持精度的同时使用较低精度的数字 (FP8)。这可以实现更快的处理速度并可能降低内存使用量,从而为运行大规模人工智能操作的客户节省成本并提高效率。

  • 能够在 LLM 中使用更长的上下文:通过加速注意力机制,FlashAttention-3 使 AI 模型能够更有效地处理更长的文本片段。这使得应用程序能够理解并生成更长、更复杂的内容而不会减慢速度。



论文标题:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

论文地址:https://tridao.me/publications/flash3/flash3.pdf


论文作者之一 、FlashAttention1-3 版本的参与者 Tri Dao 表示:FlashAttention 被广泛用于加速 Transformers,已经使注意力速度提高了 4-8 倍,但尚未利用现代 GPU。因而他们发布了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高达 740 TFLOPS(75% 实用性),FP8 接近 1.2 PFLOPS!



Hopper GPU 硬件特性:WGMMA、TMA、FP8


虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以实现 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能来最大限度地提高性能。接下来文章描述了一些新的 Hopper 特定功能,以及它们为何如此重要。


首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),该功能利用了 Hopper 架构上新的张量内核,比 Ampere 架构具有更高的吞吐量。



然后是 TMA(Tensor Memory Accelerator),这是一个特殊的硬件单元,可以加速全局内存和共享内存之间的数据传输,用于处理所有索引计算和边界外预测。这样一来寄存器就释放了,寄存器是增加 tile 大小和效率的宝贵资源。



低精度 FP8,让 Tensor Core 吞吐量翻了一倍。



FlashAttention-3 充分利用了 Hopper 架构的所有这些新功能。


异步:GEMM 和 Softmax 重叠


注意力机制主要有两个操作,GEMM 和 softmax。为什么要将它们重叠?


问题在于在现代加速器上,非矩阵乘法(matmul)运算比矩阵乘法运算慢。特殊函数如指数运算(如 softmax 函数)的吞吐量甚至低于浮点乘加操作;这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。


理想情况下,研究者希望矩阵乘法和 softmax 能够并行操作。当 Tensor Cores 忙于矩阵乘法时,多功能单元应当在计算指数运算! 


Inter-warpgroup 重叠


重叠 GEMM 和 softmax 最简单的方法是什么都不做,warp 调度程序会免费完成部分重叠。下图说明了 pingpong 调度,其中相同的颜色表示相同的迭代。



Intra-warpgroup 重叠


即使在一个 warpgroup 中,研究者也可以在运行该 warpgroup 的 GEMM 时运行 softmax 的某些部分。如图所示,相同的颜色表示相同的迭代。



这种 pipeline 流程可以将 FP16 注意力前向传播的吞吐量从大约 620 TFLOPS 提高到 640-660 TFLOPS,但代价是更高的寄存器压力,因而需要更多的寄存器来同时保存 GEMM 的累加器以及 Softmax 的输入 / 输出。


低精度:使用非相干处理减少量化误差


激活 LLM 可能存在一些极端值,导致量化困难,从而产生较大的量化误差。本文采用非相干处理(incoherent processing),该技术通过将查询和键与一个随机正交矩阵相乘来「分散(spread out)」极端值,从而减少量化误差。特别地,该研究使用了 Hadamard 变换,它可以在每个注意力头中以 O (d log d) 的时间复杂度完成,而不是 O (d^2),其中 d 是头部维度。


研究者发现非相干处理可以将量化误差减少很多,具体的数值误差比较见下表。



实验


文中展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行了比较(两者都已经使用了 Hopper GPU 的新硬件功能)。


在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。



对于 FP8,FlashAttention-3 接近 1.2 PFLOPS。



扩展阅读:


斯坦福提出新型Attention算法!提速2-4倍,BERT单节点训练最快 

比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了


参考链接:

https://tridao.me/blog/2024/flash3/  转自机器之心



推荐阅读

使用PyTorch 2.0加速Transformer:训练推理均拿下!

硬核解读Stable Diffusion(系列三)

硬核解读Stable Diffusion(系列二)

硬核解读Stable Diffusion(系列一)

带你入门扩散模型:DDPM


机器学习算法工程师


                                    一个用心的公众号


浏览 176
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报