基于 Transformer 结构的视觉语言大模型(VLM)在各种下游的视觉语言任务上取得了巨大成功,但由于其较长的输入序列和较多的参数,导致其相应的计算开销地提升,阻碍了在实际环境中进一步部署。为了追求更为高效的推理速度,前人提出了一些针对 VLM 的加速方法,包括剪枝和蒸馏等,但是现有的这些方法大都采用静态架构,其针对不同输入实例采用同样的计算图进行推理,忽略了不同实例之间具有不同计算复杂性的事实:针对复杂的跨模态交互实例,自然需要更多计算才能完全理解图像和相关问题的复杂细节;相反,简单的实例则可以用更少的计算量解决。这也导致较高加速比下的 VLM 的性能严重下降。
为了解决上述这些问题,哈工大联合度小满推出针对多模态模型的自适应剪枝算法 SmartTrim,论文已被自然语言处理顶级会议 COLING 24 接收。
前期探究和研究动机
本文首先针对 VLM 中每一层的 token 表示和 attention head 的冗余情况进行分析,如下图所示。我们有了以下发现:(1)无论是哪种模态的 token 或者 head,层内相似性始终很高,说明模型是存在显着冗余。(2)Token 的冗余度随着深度而逐渐增加。(3)不同实例之间的冗余程度差异较大,进一步说明依赖于输入的自适应剪枝对于 VLM 加速的重要性。
在基于 VQA 微调的 METER 的跨模态编码器中,层内不同 token(上)和 attention head(下)表示的相似性。
基于上述发现,本文提出针对 VLM 的自适应剪枝框架:SmartTrim,从 token 和 attention head 两方面同时对模型冗余部分进行剪枝。
SmartTrim 框架结构图
跨模态感知的 Token 修剪器:
文本和图像各自的 Token 序列首先经过各自编码器进行编码,对于得到的序列表示,经过基于 MLP 结构的跨模态感知 Token 修剪器识别对于当前层不重要的 Token:在识别过程中模型不仅考虑 token 在当前模态序列的重要性,同时还要引入其在跨模态交互中的重要性。最终 token 的重要性分数转化成一个 0/1 的二值 mask 用来去除冗余 token。
模态自适应的注意力头修剪器:
VLM 分别通过 MSA(multi-head self-attention module) 和 MCA (multi-head cross-attention module)捕获模态内和模态间交互。正如前文分析,注意力部分计算开销根据输入的复杂性而变化,导致注意力模块出现的冗余会产生较大的开销。为此,我们将模态自适应注意力头修剪器集成到注意力模块中。该修剪器用以衡量各个注意力头的显著性,根据此对冗余的注意力头做修剪。
模型训练
在模型的训练过程中,我们在优化任务相关的训练目标的同时,还引入了计算开销相关的训练目标,
让模型在训练过程中对性能和效率进行权衡。针对上述修剪器生成的二值 mask(M)在训练中不可导的问题,我们采用了基于重参数化的技巧从而进行端到端的训练:
自蒸馏与课程训练策略:
我们还引入一种自蒸馏的训练策略来提高通过自适应剪枝得到的小模型:通过对齐剪枝后的小模型和全容量模型之间输出,使得剪枝模型的输出与全容量模型更为一致,进一步提高小模型的能力。另外我们利用课程学习的训练方式指导模型的训练,使模型稀疏度逐步减低到目标比例,从而保证了优化过程的稳定性。
最终的模型训练目标为:
实验结果
我们基于 METER 和 BLIP 这两个 VLM 作为原始模型并在一系列下游 VL 任务上评估 SmartTrim 以及其他方法的性能和效率,如下表所示:我们的方法将原始模型加速了 2-3 倍,同时性能下降最小。
具有不同加速比下的 VLM 加速方法结果。
与前人方法相比,SmartTrim 不需要额外的预训练,而且还通过 token 和 head 两个方面提供了更细粒度地控制模型的计算开销,以更好地探索效率与性能之间的权衡,下面的帕累托图显示我们的方法在 1.5x 的加速比下甚至相比原始模型性能有所提升,而在高加速比下的相比其他加速方法具有显著优势。
不同 VLM 加速方法在 NLVR2 上的效率与性能权衡的帕累托前沿。
我们进一步展示了一些随着深度增加 SmartTrim 逐步裁剪不同模态的冗余 token 的例子:
Token 的逐步裁剪修剪过程。
上图 (a)-(c) 是由我们提出的跨模态感知 Token 修剪器获得的,可以看到针对不同的问题我们的修剪器网络可以合适地选择更为相关的 patch。(d) 为去掉跨模态信息指导的基线模型地输出,我们也可以观察到其只保留了图片的主体部分但与问题并不相关的 patch token,并最终产生错误的答案。
我们还统计了在 vqa 数据的测试集上我们的 SmartTrim 为不同实例分配的计算量情况,如下图所示。可以发现 SmartTrim 可以自适应地根据跨模态交互的复杂性分配不同的计算开销,为简单实例(图左)分配更少的计算,为困难实例(图右)分配更多计算。
VQA 上 SmartTrim 的 FLOPs 直方图。
更多详细内容可以参考论文原文。论文提出的方法未来将结合到度小满轩辕大模型中,大模型项目地址:https://github.com/Duxiaoman-DI/XuanYuan,欢迎大家访问!
文章来自微信公众号 “ 机器之心 ”