与Transformer相比,RNN模型的一大优势是应对长序列的能力。
比如Mamba,内部状态大小始终保持不变,计算随序列长度线性增长,吃得多,消化快。
理论虽如此,但实际情况却是,目前的这些RNN模型在长上下文中的有效性并不能令人满意。
为啥会这样?空有效率但实际上能力不行?
近日,来自清华的研究团队对此进行了深入的实验研究:
论文地址:https://arxiv.org/pdf/2410.07145v1
文章表明,Mamba这类RNN模型在长上下文中主要面临两个问题:
一是无法推断比训练长度更长的输入,原因是较短的训练数据导致了循环状态过拟合;
二是内存容量的上限,由于模型无法有效遗忘很久以前的信息,导致新的信息存不进来了。
——这俩问题明显不是RNN的锅。
而经过研究人员的对症下药,Mamba-2(370M)在256K上下文长度上达到了近乎完美的密钥检索精度。
所以结论就是,Mamba yes!「RNN神教」前景一片光明!
对此,Mamba的作者Albert Gu点赞转发,并发表了相当详细的见解:
「这是一篇很棒的论文(名字也很棒)—— 关于状态空间模型(SSM)的状态容量和长上下文能力的巧妙实验。」
令人惊讶的是,对于每个状态大小 M,当训练上下文长度达到或超过某个临界值 K 时,都会出现一个转折点,在这个点上 SSM 就能够稳健地实现长度泛化。 这是因为当上下文长度小于 K 时,循环状态没有被充分利用,导致模型在训练期间会「过拟合」。但一旦通过足够长序列的训练使模型的状态容量得到充分利用,它就会自动获得泛化能力。 值得注意的是,K 与 M 竟然呈线性关系!—— 这表明每个 token 可能存在某种固有的信息含量(即存在一个值 B,使得上下文中的每个 token 对应 B 字节的循环状态)。这个 B 值可能是由模型架构决定的?
「反过来说,过分担心循环模型的长度泛化问题可能是一个误区。我们无需设计新机制或特殊的缓解措施:只需要在更长的序列上训练(因为是线性时间复杂度,所以不会增加计算开销!),就能获得更好的泛化效果。」
最后,Albert Gu用一句话总结:要让你的Mamba吃得饱饱的,它就能发挥出最佳状态!
喂饱你的Mamba
先来复习一下基础知识。
本文以Mamba2作为主要研究对象,内部的计算表示为下图中的并行结构:
整体的输入输出遵循SSM(也即RNN)的形式:
而把上图中模块内部所有的计算写出来,就是下面这一坨公式:
之前提到的两个问题,核心在于模型的内部状态,也就是ht的表现。
所以下面在探索问题和解决方案时,咱们可以重点关注这些公式中,与ht计算相关的参数。
之前有研究表明,当上下文长度超过其训练长度时,Mamba-1和RWKV-4的性能会严重下降。
顺着这个思路,研究人员在两个方向上进行了实验分析:状态崩溃(STATE COLLAPSE)和容量上限(STATE CAPACITY)。