今年 2 月,月之暗面提出了一种名为 MoBA 的注意力机制,即 Mixture of Block Attention,可以直译为「块注意力混合」。
据介绍,MoBA 是「一种将混合专家(MoE)原理应用于注意力机制的创新方法。」该方法遵循「更少结构」原则,并不会引入预定义的偏见,而是让模型自主决定关注哪些位置。
MoBA 在处理长上下文时表现出极强的潜力,它允许 Query 只稀疏地关注少量 Key-Value 块,从而大幅降低计算成本。
然而,目前业界对 MoBA 性能背后的设计原则仍缺乏深入理解,同时也缺少高效的 GPU 实现,这限制了其实际应用。
在这篇论文中,来自 MIT、NVIDIA 机构的研究者首先建立了一个统计模型,用于分析 MoBA 的内部机制。模型显示,其性能关键取决于路由器是否能够基于 Query-Key 的相似度,准确区分相关块与无关块。研究者进一步推导出一个信噪比,将架构参数与检索准确率建立起形式化联系。
基于这一分析,本文识别出两条主要的改进路径:一是采用更小的块大小,二是在 Key 上应用短卷积,使语义相关信号在块内聚集,从而提升路由准确性。
然而,尽管小块尺寸在理论上更优,但在现有的 GPU 实现中,小块会导致严重的内存访问碎片化和低并行度,速度甚至慢于稠密注意力。
为解决这一矛盾,研究者进一步提出了 FlashMoBA,一种硬件友好的 CUDA kernel,可在小块配置下仍然高效地执行 MoBA。
结果显示优化后的 MoBA 在性能上可与密集注意力基线相匹敌。对于小块场景,FlashMoBA 相比 FlashAttention-2 可实现最高 14.7 倍加速。

理论模型表明,较小的块尺寸能带来显著的质量提升,但朴素的 GPU 实现效率低下。由月之暗面发布的原始 MoBA 实现,在配置小块尺寸时会遭遇性能瓶颈,这些瓶颈抵消了稀疏性带来的计算节省,导致执行速度比稠密注意力更慢。
研究者推出了 FlashMoBA,这是一种硬件感知的 CUDA 内核,旨在使小块 MoBA 变得实用且高效。
小块尺寸引入了几个关键的性能挑战,要在实际部署中应用必须解决这些问题。
首先,在为每个查询收集稀疏、不连续的键值块时,会出现低效的内存访问,导致从 HBM 读取数据时出现非合并内存读取。

最后,由于每个块的工作量减少以及启动大量独立内核的开销,导致 GPU 占用率低,进而造成并行度差和硬件利用率低。
为了克服这些挑战,FlashMoBA 采用了三个融合内核,以最大限度地减少 HBM 往返次数,并使计算与 GPU 架构相对齐,如图 1 所示。
Top-k 选择过程是原始 MoBA 实现中的主要瓶颈,该实现显式生成了完整的分数矩阵并串行处理批次序列。研究者将其替换为 Flash TopK(图 1 中的步骤 1),这是一个由融合内核组成的高度优化的三阶段流水线。



最后,一个高效的后处理步骤将以查询为中心的索引重新格式化为以键块为中心的变长布局,以便进行主注意力传递。整个流水线在批次和注意力头之间完全并行化,消除了原始的性能瓶颈。
为了处理 MoBA 的不规则稀疏性,前向内核使用了一种基于两级分块机制的「收集并致密化」策略,详见算法 1。

要区分两种类型的块:

这种两级方法是关键所在,因为在 SRAM 中缓存查询允许在逻辑键块的所有物理图块之间复用数据,从而通过高效的稠密 GEMM(通用矩阵乘法)分摊昂贵的不规则内存访问成本。
反向传播利用了 FlashAttention-2 的内存高效设计,并实现为三个内核的序列(算法 5)。

主内核在键维度上并行化计算,每个线程块处理一个键块。为了处理稀疏性,它镜像了前向传播的「收集并致密化」策略,使用变长索引收集查询子集并将梯度输出到片上图块中。

这种设计确保了反向传播在序列长度上保持线性复杂度,这是相对于标准注意力的二次复杂度的一个关键改进。由于反向传播通常构成优化注意力实现的主要性能瓶颈(通常比前向传播慢 2-3 倍),因此我们需要反向内核的高效率对于实现长序列的实际训练至关重要。
本文从零开始预训练模型,并进行可控实验来验证 MoBA 的设计原则。实验共训练了两个模型,所有实验均在 8× H100 80GB GPU 上完成:
本文在语言建模、长上下文检索以及真实任务上对 MoBA 的表现进行了评估。实验结果表明,改进后的模型在多种基准测试中提高了性能。


这一趋势在所有基准和不同模型规模上都保持一致。对 340M 模型来说,将块大小从 512 缩小到原来的 1/4 到 128,可带来如下提升:



总体来看,小块尺寸对于 MoBA 达到与密集注意力相当的性能是必要的。
Key Convolution 。Key Convolution 在不同任务中都能带来性能提升,而且具有任务偏好特性。对于 340M 模型:
对于 1B 模型:



注:卷积核宽度 W∈{3,5},分别记作 kconv3 和 kconv5。
稀疏匹配密集注意力机制。在多个基准测试和规模下,MoBA 的表现与密集注意力机制相当甚至更胜一筹。

虽然理论上小块尺寸能够带来更高的模型质量,但此前由于 GPU 利用率低下,小块一直难以在实际中使用。FlashMoBA 的出现让这些配置真正变得可行。
端到端性能。图 3 对比了不同序列长度(8K 至 512K token)下的延迟和内存占用。FlashMoBA 在两项指标上都显著优于原始实现。
在 N=64K 且 B=128 的配置下:FlashMoBA 比原始 MoBA 快 7.4 倍,内存占用减少 6.1 倍,原始 MoBA 在 128K 序列就会 OOM(内存溢出),而 FlashMoBA 能扩展到 512K。
随着序列越长、块越小,优势更明显,因为 FlashMoBA 消除了全局 reindex 的开销,在长序列条件下可实现最高 14.7× 快于 FlashAttention-2 的速度。

为了理解 FlashMoBA 的提速来源,图 4 展示了在 N=64K 下前向传播的耗时分布。
原始 MoBA 包含 5 个阶段:(1)计算质心并执行 top-k、(2)全局 reindex、(3)在路由后的索引上执行注意力、(4)局部因果注意力以及(5)合并结果。
其中步骤 (1)、(2)、(5) 占据了超过 70% 的执行时间。
FlashMoBA 则使用两个融合 kernel,这种融合设计将 64K 序列下的前向传播时间降至 49 ms,而 FlashAttention-2 在相同设置下为 99 ms。

文章来自于“机器之心”,作者 “机器之心编辑部”。