740 TFLOPS!迄今最强 FlashAttention 来了。
随着大型语言模型(LLM)加速落地,扩展模型上下文窗口变得越来越重要。然而,Transformer 架构的核心 —— 注意力层的时间复杂度和空间复杂度与输入序列长度的平方成正比。这使得扩展模型上下文窗口存在挑战。
2022 年,一种快速、内存高效的注意力算法 ——FlashAttention 问世,该算法无需任何近似即可加速注意力并减少内存占用。
FlashAttention 对注意力计算进行重新排序的算法,并利用 tiling 和重计算来显著加快计算速度,将内存使用量从序列长度的二次减少到线性。
2023 年,研究团队宣布推出 FlashAttention-2,在算法、并行化和工作分区等方面有了显著改进。
现在,来自 Meta、英伟达、Together AI 等机构的研究者宣布推出 FlashAttention-3,它采用了加速 Hopper GPU 注意力的三种主要技术:
FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 利用率为 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。
FlashAttention-3 的改进将带来:
论文标题:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
论文地址:https://tridao.me/publications/flash3/flash3.pdf
论文作者之一 、FlashAttention1-3 版本的参与者 Tri Dao 表示:FlashAttention 被广泛用于加速 Transformers,已经使注意力速度提高了 4-8 倍,但尚未利用现代 GPU。因而他们发布了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高达 740 TFLOPS(75% 实用性),FP8 接近 1.2 PFLOPS!
Hopper GPU 硬件特性:WGMMA、TMA、FP8
虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以实现 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能来最大限度地提高性能。接下来文章描述了一些新的 Hopper 特定功能,以及它们为何如此重要。
首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),该功能利用了 Hopper 架构上新的张量内核,比 Ampere 架构具有更高的吞吐量。
然后是 TMA(Tensor Memory Accelerator),这是一个特殊的硬件单元,可以加速全局内存和共享内存之间的数据传输,用于处理所有索引计算和边界外预测。这样一来寄存器就释放了,寄存器是增加 tile 大小和效率的宝贵资源。
低精度 FP8,让 Tensor Core 吞吐量翻了一倍。
FlashAttention-3 充分利用了 Hopper 架构的所有这些新功能。
异步:GEMM 和 Softmax 重叠
注意力机制主要有两个操作,GEMM 和 softmax。为什么要将它们重叠?
问题在于在现代加速器上,非矩阵乘法(matmul)运算比矩阵乘法运算慢。特殊函数如指数运算(如 softmax 函数)的吞吐量甚至低于浮点乘加操作;这些运算是由多功能单元处理的,这是一个与浮点乘加或矩阵乘加不同的单元。
理想情况下,研究者希望矩阵乘法和 softmax 能够并行操作。当 Tensor Cores 忙于矩阵乘法时,多功能单元应当在计算指数运算!
Inter-warpgroup 重叠
重叠 GEMM 和 softmax 最简单的方法是什么都不做,warp 调度程序会免费完成部分重叠。下图说明了 pingpong 调度,其中相同的颜色表示相同的迭代。
Intra-warpgroup 重叠
即使在一个 warpgroup 中,研究者也可以在运行该 warpgroup 的 GEMM 时运行 softmax 的某些部分。如图所示,相同的颜色表示相同的迭代。
这种 pipeline 流程可以将 FP16 注意力前向传播的吞吐量从大约 620 TFLOPS 提高到 640-660 TFLOPS,但代价是更高的寄存器压力,因而需要更多的寄存器来同时保存 GEMM 的累加器以及 Softmax 的输入 / 输出。
低精度:使用非相干处理减少量化误差
激活 LLM 可能存在一些极端值,导致量化困难,从而产生较大的量化误差。本文采用非相干处理(incoherent processing),该技术通过将查询和键与一个随机正交矩阵相乘来「分散(spread out)」极端值,从而减少量化误差。特别地,该研究使用了 Hadamard 变换,它可以在每个注意力头中以 O (d log d) 的时间复杂度完成,而不是 O (d^2),其中 d 是头部维度。
研究者发现非相干处理可以将量化误差减少很多,具体的数值误差比较见下表。
实验
文中展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行了比较(两者都已经使用了 Hopper GPU 的新硬件功能)。
在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。
对于 FP8,FlashAttention-3 接近 1.2 PFLOPS。
文章来源于“机器之心”