[Decoding优化]原理&图解FlashDecoding/FlashDecoding++
共 4970字,需浏览 10分钟
·
2024-06-02 23:45
0x00 前言
FlashDecoding和FlashDecoding++单独摘出来,准备整理一篇Decoding优化的文章,后续会补充更多细节。上一篇Attention优化的文章,已经详细讲解了FlashAttention-1和FlashAttention-2算法中各自的优化点、FlashAttention IO复杂度分析以及适用场景、FlashAttention在分布式训推中的应用;并且,通过图解的方式通俗易懂地讲解了FlashAttention种关于MQA/GQA以及Causal Mask的处理。最后,还梳理了Memory-Efficient Attention。推荐先阅读完上一篇,再来阅读本篇:
0x01 FlashDecoding[1]
一般情况下FlashAttention forward pass在Q的seqlen维度以及batch_size维度做并行。可以看到,对于当前的Q的分块Queries,forward pass会在thread block中,逐个遍历所有的K, V分块,计算逐个分块的局部Attention输出。每个局部的Attention输出,会在thread block内部遍历的过程中,随着每一次迭代,根据当前次迭代的值进行scale,一直到沿着K,V的迭代完成后,就获得了最终正确的Output。
这种方式,对于训练的forward是work的,因为训练时,seqlen或bs会比较大,GPU资源能够被有效地利用。但是在推理的Generation阶段,是逐token生成,在利用KV Cache的情况下,每次推理实际的queries token数为1,已经无法通过queries进行并行了,GPU资源无法得到有效的利用,特别是如果bs还比较小,那GPU资源浪费将会更加严重。于是针对这种情况,FlashAttention作者开发了FlashDecoding,对推理阶段的forward进行优化。基本的思路其实也很直观:既然,Q和BS无法进一步并行了,那么对K,V进行并行是不是就可以了呢?没错,这就是FlashDecoding的思路。
FlashDecoding的做法如下:
1. 首先,将K/V切分成更小的块,比如5块;
2. 然后在这些K/V块上,使用标准FlashAttention进行计算,得到所有小块的局部结果
3. 最后,使用一个额外的kernel做全局的reduce,得到正确输出
在128K context的情况下,FlashDecoding比标准FlashAttention快50倍。
除了FlashAttention repo本身,目前像TRT-LLM和vLLM都在generation阶段,针对小bs*headnum使用了FlashDecoding的思路进行优化,TRT-LLM中提供了multi_block_mode选项进行控制,而在vLLM中则是实现了PagedAttention V2来支持。而在prompt阶段vLLM则通过xformers的flash-attn后端进行推理。
0x02 FlashDecoding++[2](非官方)
FlashDecoding++最主要的创新点,在于提出了基于统一max值的异步softmax。我们知道,safe-softmax的计算公式中,需要先求每行x的最大值,然后减去这个max(x)之后,再做softmax以防止数值溢出。
FlashDecoding++认为,这个max值,不一定需要online计算max(x),而是可以是一个合理的先验值。我们对上边的公式分子分母提取公因式,可以得到:
可以发现,使用先验值与直接计算max(x),最终softmax的结果,在数学上是等价的。问题在于如何确定这个先验值以防止数值异常,比如对于一个很小的x,这时如果使用一个非常大的先验值,就可能导致概率值异常。FlashDecoding++认为一个合理的先验值,可以直接从数据集中进行统计获得。对于不同的模型,这个先验值也是不一样的。
在工程实现上,FlashDecoding++采用了Fallback的做法,因为就算是从数据集中统计得到的先验值,依然无法覆盖所有的corner case,还是可能会导致overflow。因此,当出现数值溢出时,FlashDecoding++就是Fallback到FlashDecoding的计算。
结合一些工程上对GEMV/GEMM Tensor Cores padding和Kernel调度优化,FlashDecoding++对比FlashDecoding大概有37%的性能提升,性能提升还是很明显的。不过FlashDecoding++代码并没有开源,具体实现暂时无法探究。另外对于论文中提到异步softmax目前我也有些疑惑,因为FlashDecoding++虽然提出了统一先验值,解决了求max(x)值的问题,但是依然没有解决softmax依赖求和项的问题,所以,按照这个逻辑理解,softmax的计算应该还是无法真正意义上并行的。但是由于分子的计算不需要再在每次iteration中执行rescale计算,每个thread block的内层循环针对K,V,只需要负责当前块中的softmax分子计算以及累计求和项即可,确实能节省非matmul计算量。这点优化和FlashAttention-2中的减少非matmul计算的逻辑是异曲同工的。
这里首先感谢
(2)在(1)中的每个thread block得到局部结果后,再进行一次整体的softmax计算。
画一下FlashDecoding++和FlashDecoding的计算流程对比,如下。优化点在于,FlashDecoding++在Step[1],计算量比FlashDecoding直接使用的FA2要少。
FlashDecoding++对应的forward pass,估计大概长这样:(修改自FA2 forward pass)
对比一下原来FlashAttention2的forward pass:
可以看到FlashDecoding++的forward pass在step[1]中,内循环的每个迭代步,计算是可以完全并行,无需进行额外rescale。而FA2,由于需要rescale,KV内循环的每次迭代不是独立的,当前次迭代需要对上一次迭代的结果进行rescale。因此,对于FlashDecoding++,可以在K,V维度切成多个chunk,分给不同的thread block并行计算,最后再进行一次校正即可。 我们知道,FlashDecoding也在KV维度切成了多个chunk,只是每个chunk内的使用FlashAttention2计算,FlashAttention2还有针对KV的循环的micro chunk,在micro chunk这个循环中,需要每次迭代都进行rescale。
参考
^Flash-Decoding for long-context inference. https://crfm.stanford.edu/2023/10/12/flashdecoding.html
^FLASHDECODING++: FASTER LARGE LANGUAGE MODEL INFERENCE ON GPUS. https://arxiv.org/pdf/2311.01282.pdf
- The End -
长按二维码关注我们
本公众号专注:
1. 技术分享;
2. 学术交流;
3. 资料共享。
欢迎关注我们,一起成长!