状态崩溃
状态崩溃(SC)指的是,RNN模型在输入上表现出异常行为的时间比训练期间看到的时间更长的现象
上图展示了Mamba-2和RWKV-6在训练长度之外的语言建模损失。为了可控性和合成任意长度的提示,这个损失是在仅由「\n」字符组成的提示上计算的(称为「newlines」提示)。
结果表明,当上下文长度远大于其训练长度时,两个RNN的性能都会严重下降,最后就跟瞎猜差不多了。
语言建模可能无法反映下游能力,上图给出了Mamba-2(在8K上下文窗口上训练)在密钥检索任务上的评估结果。
我们可以发现,Mamba-2在8K上下文中具有近乎完美的检索准确性,但在序列长度超过16K后就没法看了,无论模型参数量大小。
从上面的公式来看,这种结果可能出人意料,因为内部状态ht的更新应该具有稳定的指数内存衰减,即对于最后k个token具有良好的检索准确性。
问题出在哪里?
由于递归状态的维度不会随时间而变化,因此状态崩溃期间行为的急剧变化一定是状态值变化的结果。
作者对Mamba-2 370M中每一层的递归状态进行了统计,发现当上下文长度超过训练长度时,一些头部的平均值和方差会急剧变化:
图5显示了模型第38层第2个头的状态,在t=20K时方差爆炸。从中可以发现这种方差爆炸在很大程度上可以归因于少数异常通道,其余大多数通道则相对稳定。
分析一下公式,与ht计算有关的∆t、Bt和xt:
如上图所示,虽然三者都是输入的函数,但xt相对稳定,而Bt比∆t更早发生爆炸,进一步探索还能发现生成∆t和Bt的卷积权重明显更大。
作者认为,产生SC的原因是,对于训练长度来说,状态容量过大,模型能够实现强大的语言建模性能,而无需学习如何忘记。
上图显示了第一个token在不同时间步的内存强度,作者发现爆炸的头(第38层的第2、4、7个头)强烈倾向于在训练长度内保留所有信息,在t=8K时内存强度超过0.8。
解决方案
为了缓解SC,使模型沿序列长度更好地泛化,作者提出了3种解决方案,总的思想是修改状态的update规则来避免其溢出。
Method 1: Forget More and Remember Less
通过增加状态衰减量(忘记更多)或减少输入信息的数量(记住更少)来减少SC,作者选择干预Bt和αt(分别控制输入强度和内存衰减强度)。
Method 2: State Normalization
在每次更新后对状态进行归一化,以确保状态的范数始终低于阈值:
PS:这种方式会将模型转换为非线性RNN,无法以与原始模型相同的方式并行化,预填充速度要慢得多。
Method 3: Sliding Window by State Difference
利用状态ht可以写为加权和的形式,来模拟滑动窗口机制,无需在每一步都从窗口的开头重新处理。
此方法适用于所有可以写成加权和的RNN,包括RWKV 5和6、RetNet、GLA等。尽管会使生成的计算和内存成本翻倍,但仍然是一个可以接受的权衡,因为RNN的生成成本比Transformer低很多。
以上3个是不需要训练的方案,而基于SC是由状态参数过拟合引起的假设,我们也可以尝试使用超过状态容量的序列长度来训练模型。
容量上限
根据以上的讨论,当且仅当训练长度包含的信息少于状态容量时,才会发生SC,所以我们可以通过实验间接估计模型的状态容量。
研究人员训练了多个具有不同状态大小和训练长度的Mamba-2,并将SC未发生的最小训练长度视为状态容量。
实验数据选择RedPajama-V2,一个从CommonCrawl中提取的30T token的开放数据集,进行去重以确保数据质量。
在评估过程中,对长度超过16K token的文档进行抽样,如果不够长,则对其进行拼接。
研究人员试验了具有不同状态大小的模型配置,包括来自Mamba-2官方checkpoint的三个预训练模型,大小分别为130M、370M和780M,另外3个模型(36M、47M、85M)则从头开始训练。
实验结果
上图展示了在Mamba-2 780M上无训练长度泛化方法的结果。我们可以看到,虽然LongMamba大大提高了模型的长度泛化性(3倍以上),但它在较短的序列上会导致明显更大的困惑度,并且仍然不可避免地表现出SC。
相比之下,本文的所有的方法都成功地抑制了SC,使模型能够泛化到超过64K个token。
三种方案中,状态归一化在较短序列上的性能大大低于其他方法,这可能是因为归一化折叠状态会改变heads之间的规范比率,破坏了学习机制。
上图显示了Mamba-2在语言建模和密钥检索方面的状态容量。两个图中最右边的数据点对应于Mamba-2 370M。