MiMo-V2-Flash: Hybrid Sliding Window Attention Architecture

论文链接:MiMo-V2-Flash Technical Report,模型开源地址:GitHub

MiMo-V2-Flash 是小米发布的 MoE 模型(309B 总参数 / 15B 激活参数),核心架构创新在于 Hybrid Sliding Window Attention + Learnable Attention Sink Bias。以下从标准全局注意力出发,逐步推导到 SWA 再到 Attention Sink 的完整公式体系。

符号约定

符号 含义
$x_i^{(l)}$ 第 $l$ 层、第 $i$ 个 token 的输入隐藏状态
$W_Q^{(l)}, W_K^{(l)}, W_V^{(l)}, W_O^{(l)}$ 第 $l$ 层的 Query/Key/Value/Output 投影矩阵
$d$ 每个注意力头的维度(MiMo-V2-Flash 中为 192)
$n$ 序列总长度
$W$ 滑动窗口大小(MiMo-V2-Flash 中为 128)

1. 全局 Attention

1.1 QKV 投影
$$ q_i^{(l)} = W_Q^{(l)} x_i^{(l)}, \quad k_i^{(l)} = W_K^{(l)} x_i^{(l)}, \quad v_i^{(l)} = W_V^{(l)} x_i^{(l)} $$
1.2 注意力得分
$$ a_{ij}^{(l)} = \frac{q_i^{(l)} \cdot {k_j^{(l)}}^\top}{\sqrt{d}}, \quad j \in \{1, 2, \dots, n\} $$
1.3 Softmax 归一化
$$ s_{ij}^{(l)} = \frac{\exp(a_{ij}^{(l)})}{\sum_{j'=1}^{n} \exp(a_{ij'}^{(l)})} $$

关键约束:$\sum_{j=1}^{n} s_{ij}^{(l)} = 1$,注意力必须 100% 分配给真实 token。

1.4 加权求和 + 输出投影
$$ o_i^{(l)} = \sum_{j=1}^{n} s_{ij}^{(l)} v_j^{(l)}, \quad h_i^{(l)} = W_O^{(l)} o_i^{(l)} $$

全局注意力的计算复杂度:每个 token 对 $n$ 个 token 求和 → $O(n^2)$。

1.5 层间传递
$$ \hat{x}_i^{(l)} = x_i^{(l)} + h_i^{(l)} $$
$$ x_i^{(l+1)} = \hat{x}_i^{(l)} + \text{FFN}^{(l)}\big(\text{LayerNorm}(\hat{x}_i^{(l)})\big) $$

2. Sliding Window Attention (SWA)

SWA 最早由 Longformer (Beltagy et al., 2020) 系统性提出并验证,核心思想来自自然语言的局部性先验——理解当前 token 通常只需参考附近少量 token。更早的局部注意力雏形见 Sparse Transformer (Child et al., 2019)

SWA 与全局 Attention 的唯一区别:$j$ 的求和范围从全序列缩小到局部窗口

2.1 定义窗口范围
$$ \mathcal{W}(i) = \{\, j \mid \max(1,\, i - W + 1) \le j \le i \,\} $$

token $i$ 只能看到自己左边 $W-1$ 个 token 加自己,总共 $W$ 个 token(因果遮码)。

2.2 QKV 投影——不变
$$ q_i^{(l)} = W_Q^{(l)} x_i^{(l)}, \quad k_i^{(l)} = W_K^{(l)} x_i^{(l)}, \quad v_i^{(l)} = W_V^{(l)} x_i^{(l)} $$
2.3 注意力得分——$j$ 被限制在窗口内
$$ a_{ij}^{(l)} = \frac{q_i^{(l)} \cdot {k_j^{(l)}}^\top}{\sqrt{d}}, \quad j \in \mathcal{W}(i) $$
2.4 Softmax 归一化——求和范围变为窗口
$$ s_{ij}^{(l)} = \frac{\exp(a_{ij}^{(l)})}{\sum_{j' \in \mathcal{W}(i)} \exp(a_{ij'}^{(l)})}, \quad j \in \mathcal{W}(i) $$

窗口外:$s_{ij}^{(l)} = 0$。仍满足 $\sum_{j \in \mathcal{W}(i)} s_{ij} = 1$。

2.5 加权求和
$$ o_i^{(l)} = \sum_{j \in \mathcal{W}(i)} s_{ij}^{(l)} v_j^{(l)} $$

计算复杂度从 $O(n^2)$ 降至 $O(nW)$。当 $W=128, n=256000$ 时降低约 2000 倍。

2.6 多层 SWA 的感受野

单层 SWA,token $i$ 的直接感受野为 $[i-W+1, i]$。多层叠加后,第 $L$ 层的理论感受野:

$$ \mathcal{R}^{(L)}(i) = [\max(1, i - LW + L), \; i] $$

有效感受野大小 $\approx L \times W$。但信息经过多层间接传递会衰减(类似电话游戏),远不如全局注意力的直接访问精确。

3. Hybrid Attention:混合架构

纯 SWA 长距离信息衰减严重,纯全局 $O(n^2)$ 太贵。MiMo-V2-Flash 的折中方案——按层交替使用:

$$ \text{Attn}^{(l)} = \begin{cases} \text{Global Attention}, & l \bmod (N+1) = 0 \\ \text{SWA}(W=128), & \text{otherwise} \end{cases} $$

其中 $N=5$,即每 5 层 SWA 后跟 1 层 GA。总共 39 层 SWA + 9 层 GA(第 1 层特殊使用 GA + Dense FFN 稳定早期训练)。GA 层每隔 5 层做一次全局信息刷新,修正 SWA 接力传递中的信息损失。

4. Attention Sink:从问题到公式

4.1 问题:SWA 中 Softmax 的"被迫分配"

回顾 SWA 的 Softmax:

$$ \sum_{j \in \mathcal{W}(i)} s_{ij} = 1 $$

即使窗口内所有 token 都不相关,注意力仍必须 100% 分配给它们。在全局注意力($n$ 很大)中,多余注意力分散到大量 token 上,单个 token 被污染程度 $\sim O(1/n)$,影响较小。但在 $W=128$ 的小窗口中,污染集中于少数 token,对输出 $o_i = \sum s_{ij} v_j$ 的干扰显著放大。

早期研究(Xiao et al., 2023)发现模型会自发将多余注意力集中到序列首 token(如 <BOS>),该 token 成为"注意力水槽"(Attention Sink)。但在 SWA 中,远处 token 看不到序列首 token——窗口外不存在这个"水槽"。

4.2 解法:在 Softmax 分母引入可学习标量

引入一个可学习标量 $\text{sink} \in \mathbb{R}$(每个注意力头各一个),作为 Softmax 分母中的额外虚拟项。

标准 SWA 的 Softmax(分母只有真实 token):

$$ s_{ij} = \frac{\exp(a_{ij})}{\sum_{j' \in \mathcal{W}(i)} \exp(a_{ij'})} $$

带 Sink 的 Softmax(分母多一个虚拟项):

$$ s_{ij} = \frac{\exp(a_{ij})}{\exp(\text{sink}) + \sum_{j' \in \mathcal{W}(i)} \exp(a_{ij'})} $$

此时注意力权重总和:

$$ \sum_{j \in \mathcal{W}(i)} s_{ij} = \frac{\sum_{j} \exp(a_{ij})}{\exp(\text{sink}) + \sum_{j} \exp(a_{ij})} < 1 $$

缺失的 $1 - \sum_j s_{ij}$ 即被虚拟 sink 吸收的部分。sink 不对应任何 $v$ 向量,吸收的注意力不产生输出,不污染 $o_i$。

4.3 sink 值的效果
$\text{sink}$ 大小 $\exp(\text{sink})$ 效果
$\to -\infty$ $\approx 0$ 退化为标准 Softmax
适中 中等 吸收部分多余注意力
很大 很大 大量注意力被吸收,token 对窗口内所有人给低权重

$\text{sink}$ 可学习,模型自动为每个注意力头学到最合适的吸收量。

4.4 加上数值稳定性(论文完整公式)

为防止 $\exp$ 溢出,减去最大值 $m_i$(不改变数学结果,分子分母同乘常数):

$$ m_i = \max\!\Big(\max_{j \in \mathcal{W}(i)} a_{ij},\; \text{sink}\Big) $$
$$ s_{ij} = \frac{\exp(a_{ij} - m_i)}{\exp(\text{sink} - m_i) + \sum_{j' \in \mathcal{W}(i)} \exp(a_{ij'} - m_i)} $$

最终输出:

$$ o_i = \sum_{j \in \mathcal{W}(i)} s_{ij} v_j $$

5. 完整单层前向传播汇总

$$ \begin{align} & \textbf{1. QKV 投影:} \quad q_i = W_Q x_i, \quad k_i = W_K x_i, \quad v_i = W_V x_i \\\\ & \textbf{2. 注意力范围:} \quad \mathcal{J}_i = \begin{cases} \{1,\dots,n\}, & \text{GA 层} \\ \mathcal{W}(i), & \text{SWA 层} \end{cases} \\\\ & \textbf{3. 注意力得分:} \quad a_{ij} = \frac{q_i k_j^\top}{\sqrt{d}}, \quad j \in \mathcal{J}_i \\\\ & \textbf{4. Softmax:} \quad s_{ij} = \frac{\exp(a_{ij} - m_i)}{[\exp(\text{sink} - m_i)]_{\text{SWA}} + \sum_{j'} \exp(a_{ij'} - m_i)} \\\\ & \textbf{5. 输出:} \quad o_i = \sum_{j \in \mathcal{J}_i} s_{ij} v_j \\\\ & \textbf{6. 残差+FFN:} \quad x_i^{(l+1)} = x_i^{(l)} + W_O o_i + \text{FFN}(\cdot) \end{align} $$

其中 $[\cdot]_{\text{SWA}}$ 表示该项仅在 SWA 层存在,GA 层无此项。

6. 三种 Softmax 对比

全局 Attention SWA SWA + Sink
$j$ 的范围 ${1,\dots,n}$ $\mathcal{W}(i)$ $\mathcal{W}(i)$
分母 $\sum_{j’=1}^{n} \exp(a_{ij’})$ $\sum_{j’ \in \mathcal{W}} \exp(a_{ij’})$ $\exp(\text{sink}) + \sum_{j’ \in \mathcal{W}} \exp(a_{ij’})$
$\sum_j s_{ij}$ $= 1$ $= 1$ $< 1$
计算复杂度 $O(n^2)$ $O(nW)$ $O(nW)$

7. 实验验证

在 32B dense 模型上的消融实验结果(预训练 250B tokens):

配置 MMLU BBH GSM8K MATH
All GA 57.3 54.7 34.2 9.5
Hybrid SWA ($W$=128, 无 sink) 54.9 52.4 36.9 8.9
Hybrid SWA ($W$=128, 有 sink) 58.3 56.1 36.9 10.3
Hybrid SWA ($W$=512, 有 sink) 58.3 54.9 37.9 10.0

关键发现:

  1. 无 sink 的 SWA 全面劣于 All GA——验证了"被迫分配"问题在小窗口中的严重性。
  2. 加入 sink 后反超 All GA——sink 解决了 Softmax 强制归一化的问题。
  3. $W$=128 优于 $W$=512——更小窗口迫使 SWA 只处理局部信息,长距离依赖完全交给 GA 层,形成更清晰的分工。$W$=512 的 SWA 会"越界"部分处理长距离信息,模糊了局部/全局的边界,反而导致性能下降。