Two Level KL
关于LLM强化学习中的KL散度,假设策略模型为$\pi_\theta$,参考模型为$\pi_{ref}$,两个模型的KL散度定义为
$$ D_{KL}(\pi_\theta\Vert\pi_{ref})=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}=\sum_{y\in\mathcal Y}\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)} $$
除此之外,对于已经采样的一条文本$y$,可以计算该条文本平均每个token关于这两个模型的KL散度
$$ \begin{align} D_{KL}^y(\pi_\theta\Vert\pi_{ref})&=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}D_{KL}(\pi_\theta(\cdot\vert y_{< t})\Vert\pi_{ref}(\cdot\vert y_{< t})) \\ &=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\sum_{v_t\in\mathcal V}\pi_\theta(v_t\vert y_{< t})\log\frac{\pi_\theta(v_t\vert y_{< t})}{\pi_{ref}(v_t\vert y_{< t})} \end{align} $$
后者常用于知识蒸馏场景,但对于强化学习场景则一般使用前者。
RKL & FKL
对于两个模型$\pi_\theta$和$\pi_{ref}$,其中$\pi_\theta$是待优化的模型分布,那么通常定义:
- $D_{KL}(\pi_{ref}\Vert\pi_\theta)=\mathbb E_{y\sim\pi_{ref}}\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}$为前向KL散度(Forward KL, FKL)
- $D_{KL}(\pi_\theta\Vert\pi_{ref})=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}=$为反向KL散度(Reverse KL, RKL)
对于FKL,最小化FKL:$\min\limits_\theta\sum_{y\in\mathcal Y}\pi_{ref}(y)\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}$等价于$\max\limits_\theta\sum_{y\in\mathcal Y}\pi_{ref}(y)\log\pi_\theta(y)$,也就是$\max\limits_\theta\mathbb E_{y\sim\pi_{ref}}[\log\pi_\theta(y)]$,这是一个最大似然估计(Maximum Likelihood Estimation, MLE),像自回归LLM的pretrain或者SFT使用的就是这个优化。FKL的优化特性包含:均值寻求(Mean Seeking)、零回避(Zero-Avoiding)、覆盖性(Inclusive)。观察$\pi_{ref}(y)\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}$,如果某个$y$使得$\pi_{ref}(y)> 0$,而$\pi_\theta(y)\rightarrow 0$,那么$\log\frac{\pi_{ref}}{\pi_\theta}\rightarrow\infty$,导致KL爆炸,所以,模型$\pi_\theta$不会在$\pi_{ref}$有概率的地方概率为0,而会被拉伸自己去覆盖$\pi_{ref}$的所有高概率区域。
对于RKL,其优化目标为$\min\limits_\theta\sum_{y\in\mathcal Y}\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}$,其优化特性包含:众数寻求(Mode-Seeking)、零强制(Zero-Forcing)、排他性(Exclusive)。观察公式$\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}$,如果$\pi_{ref}(y)\approx 0$,为了让整体KL小,$\pi_\theta(y)$必须也趋近于0,因此模型$\pi_\theta$会极力避免在$\pi_{ref}$概率低的地方有概率。详细分析可见 Kristiadi.的博客。
KL Estimator
对于Two Level KL中的第一种定义,由于无法采样每一种可能的$y$,因此需要通过估计来近似。Schulman.这篇blog讨论了三种KL散度的近似,分别为:
- $k1=\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}$
- $k2=\frac{1}{2}(\log\frac{\pi_\theta(y)}{\pi_{ref}(y)})^2$
- $k3=\frac{\pi_{ref}(y)}{\pi_\theta(y)}-\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}-1$
$k1$估计是一个无偏估计,显然$\mathbb E_{y\sim\pi_\theta}k1=D_{KL}(\pi_\theta\Vert\pi_{ref})$,$k1$的方差$\text{Var}_{\pi_\theta}(k1)=\mathbb E_{y\sim\pi_\theta}\left[(\log\frac{\pi_\theta(y)}{\pi_{ref}(y)})^2\right]-(D_{KL}(\pi_\theta\Vert\pi_{ref}))^2$,由于第一个平方项的均值会存在较大的波动,整体方差较大。$k2$显然是一个有偏估计,但方差较小($k2$恒正)。$k3$也是一个无偏估计,因为$\mathbb E_{y\sim\pi_\theta}(\frac{\pi_{ref}(y)}{\pi_\theta(y)}-1)=\sum_{y\in\mathcal Y}(\pi_{ref}(y)-\pi_\theta(y))=0$,且对任意$x> 0$,$x\ge1+\log x$,所以$k3\ge 0$,因此k3的方差也会相对$k1$的方差更小。
在LLM的一些后训练RL算法中,这些估计会有被用到,比如在PPO中的reward设计中,单个token的reward设计为$r_t=\begin{cases}-\beta\cdot\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)},&0\le t< \vert y\vert-1 \\R(y)-\beta\cdot\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)},&t=\vert y\vert-1\end{cases}$,这里使用的$k1$估计,且把每个token的概率偏差拆开了。又比如在GRPO中,直接把基于$k3$的估计项加在了损失函数中,但也略有不同,在GRPO中,对于其中一个rollout结果$o_i$,对应的KL惩罚在损失中的表达式为:$-\frac{1}{\vert o_i\vert}\sum_{t=1}^{\vert o_i\vert}\beta\cdot\left[\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})}-\log\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})}-1\right]$。仔细发现,这和$k3$估计还是有差别的,把求和算出来的到的结果是$(\sum_{t=1}^{\vert o_i\vert}\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})})-\log\frac{\pi_{ref}(o_i)}{\pi_\theta(o_i)}-\vert o_i\vert$,抛开常数项,显然$\frac{\pi_{ref}(o_i)}{\pi_\theta(o_i)}=\Pi_{t=1}^{\vert o_i\vert}\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})}\neq\sum_{t=1}^{\vert o_i\vert}\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})}$。这是由于直接计算$\Pi_{t=1}^{\vert o_i\vert}\frac{\pi_{ref}(o_{i,t})}{\pi_\theta(o_{i,t})}$对于生成超长$o_i$时容易数值爆炸。
Reward, Loss with k1 and k3 Estimator
最近Shah et al.对LLM的RL训练中,关于KL正则项是添加于reward中还是loss中,以及选用的RL估计是$k1$估计还是$k3$估计,一共四种情况,做了一个比较详细的分析。作者从标准KL梯度出发,逐一分析这四种情况的梯度是否有偏,并分别做实验验证。得出的结论是选择$k1$估计并且将KL正则项放置reward中不论在领域内还是领域外的测试,均优于其他三种情况。因为只有这种情况的KL正则项在真实损失函数中的期望梯度是无偏的。这里再重复推理一下这篇文章的思路。首先,这篇文章是从KL项的梯度是否有偏这个角度出发,这与Schulman博客的出发点不一致,后者是从KL项本身是否有偏出发,设计对应的估计项。
对于$\pi_\theta$以及$\pi_{ref}$,其KL散度的梯度为$\nabla_\theta\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y)$:
$$ \begin{align} \nabla_\theta\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}&=\nabla_\theta\sum_{y\in\mathcal Y}\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}=\sum_{y\in\mathcal Y}\nabla_\theta\left(\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\right) \\ &=\sum_{y\in\mathcal Y}\left(\nabla_\theta\pi_\theta(y)\right)\cdot\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}+\pi_\theta(y)\cdot\nabla_\theta\log\frac{\pi_\theta(y)}{\pi_{ref}(y)} \\ &=\sum_{y\in\mathcal Y}\pi_\theta(y)\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y)+\sum_{y\in\mathcal Y}\pi_\theta(y)\cdot\nabla_\theta\log\pi_\theta(y) \\ &=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y)+\sum_{y\in\mathcal Y}\nabla_\theta\pi_\theta(y) \\ &=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y) \end{align} $$
作者基于REINFORCE算法讨论KL正则项添加在reward和loss时的损失函数中KL部分的梯度。定义$KL_t$为对于单个token的KL估计项,对于$k1$估计来说,$KL_t=\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}$;对于$k_3$估计来说,$KL_t=\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-\log\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-1$。
$$ KL_t=\begin{cases} \log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)},&\text{k1 estimator} \\ \frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-\log\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-1,&\text{k3 estimator} \end{cases} $$
在REINFORCE算法中,如果KL项加在reward中,有$r_t=s_t-\beta\text{sg}(KL_t)$,$A_t=\sum_{t=1}^{\vert y\vert}r_t=R-\beta\sum_{t=1}^{\vert y\vert}\text{sg}(KL_t)-b$,与GRPO类似,$A_t$对于$y$的所有位置token值均一样,从而损失的梯度$\nabla_\theta J(\theta)=\mathbb E_{y\sim\pi_\theta}\left[\left(R-\beta\sum_{t=1}^{\vert y\vert}\text{sg}(KL_t)-b\right)\nabla_\theta\log\pi_\theta(y)\right]$,其中KL项贡献的梯度为$\left(\sum_{t=1}^{\vert y\vert}\text{sg}(KL_t)\right)\nabla_\theta\log\pi_\theta(y)$。如果KL项加在loss中,则$A_t=R-b$,$\nabla_\theta J(\theta)=\mathbb E_{y\sim\pi_\theta}\left[\left(R-b\right)\nabla_\theta\log\pi_\theta(y)-\beta\sum_{t=1}^{\vert y\vert}\nabla_\theta KL_t\right]$,其中KL项贡献的梯度为$\sum_{t=1}^{\vert y\vert}\nabla_\theta KL_t$。梯度贡献总结如下:
$$ \nabla_\theta KL=\begin{cases} \mathbb E_{y\sim\pi_\theta}\left[\left(\sum_{t=1}^{\vert y\vert}\text{sg}(KL_t)\right)\nabla_\theta\log\pi_\theta(y)\right],& \text{kl in reward} \\ \mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}\nabla_\theta KL_t\right],&\text{kl in loss} \end{cases} $$
K1 Estimator & Reward
$$ \begin{align} \nabla_\theta KL &= \mathbb E_{y\sim\pi_\theta} \left[ \left(\sum_{t=1}^{\vert y\vert}\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}\right)\cdot\nabla_\theta\log\pi_\theta(y)\right] \\ &=\mathbb E_{y\sim\pi_\theta}\left[\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot\nabla_\theta\log\pi_\theta(y)\right] \end{align} $$
这种情况KL梯度是无偏的
K1 Estimator & Loss
$$ \begin{align} \nabla_\theta KL&=\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}\nabla_\theta\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}\right]=\mathbb E_{y\sim\pi_\theta}\left[\nabla_\theta\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\right] \\ &=\mathbb E_{y\sim\pi_\theta}\nabla_\theta\log\pi_\theta(y)=\sum_{y\in\mathcal Y}\pi_\theta(y)\nabla_\theta\log\pi_\theta(y) \\ &=\sum_{y\in\mathcal Y}\nabla_\theta\pi_\theta(y)=0 \end{align} $$
很明显,KL梯度期望为0,存在明显偏差
K3 Estimator & Reward
$$ \begin{align} \nabla_\theta KL&=\mathbb E_{y\sim\pi_\theta}\left[\left(\sum_{t=1}^{\vert y\vert}\left(\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-\log\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-1\right)\right)\cdot\nabla_\theta\log\pi_\theta(y)\right] \\ &=\mathbb E_{y\sim\pi_\theta}\left[\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot\nabla_\theta\log\pi_\theta(y)\right] + \mathbb E_{y\sim\pi_\theta}\left[\left(\sum_{t=1}^{\vert y\vert}\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\right)\cdot\nabla_\theta\log\pi_\theta(y)\right] \end{align} $$
也是有偏的,偏差为$\mathbb E_{y\sim\pi_\theta}\left[\left(\sum_{t=1}^{\vert y\vert}\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\right)\cdot\nabla_\theta\log\pi_\theta(y)\right]$
K3 Estimator & Loss
$$ \begin{align} \nabla_\theta KL&=\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}\nabla_\theta\left(\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-\log\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-1\right)\right] \\ &=\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}\left(\nabla_\theta\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}+\nabla_\theta\log\pi_\theta(y_t)\right)\right] \\ &=\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}\nabla_\theta\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\right]=\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}-\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\nabla_\theta\log\pi_\theta(y_t)\right] \end{align} $$
观察这个梯度,会发现与前向KL $D_{KL}(\pi_{ref}\Vert\pi_\theta)$的梯度近似:
$$ \begin{align} \nabla_\theta\mathbb E_{y\sim\pi_{ref}}\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}&=\nabla_\theta\sum_{y\in\mathcal Y}\pi_{ref}(y)\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}\\ &=\sum_{y\sim\mathcal Y}\pi_{ref}(y)\cdot\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot -\frac{\pi_{ref}(y)}{\pi_\theta^2(y)}\cdot\nabla_\theta\pi_\theta(y) \\ &=\sum_{y\in\mathcal Y}-\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\pi_\theta(y)\cdot\nabla_\theta\log\pi_\theta(y)\\ &=\mathbb E_{y\sim\pi_\theta}-\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\nabla_\theta\log\pi_\theta(y) \end{align} $$
巧合的是,我们计算一下序列级别$k3$估计放在loss中的梯度(前面计算的是token级别$k3$估计放在loss的梯度) $$ \begin{align} \nabla_\theta(k3)&=\mathbb E_{y\sim\pi_\theta} \nabla_\theta(\frac{\pi_{ref}(y)}{\pi_\theta(y)}-\log\frac{\pi_{ref}(y)}{\pi_\theta(y)}-1) \\ &=\mathbb E_{y\sim\pi_\theta}(-\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\nabla_\theta\log\pi_\theta(y))-\mathbb E_{y\sim\pi_\theta}(\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot-\frac{\pi_{ref}(y)}{\pi_\theta^2(y)}\cdot\nabla_\theta\pi_\theta(y)) \\ &=\mathbb E_{y\sim\pi_\theta}(-\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\nabla_\theta\log\pi_\theta(y))+\mathbb E_{y\sim\pi_\theta}\nabla_\theta\log\pi_\theta(y) \\ &=\mathbb E_{y\sim\pi_\theta}(-\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\nabla_\theta\log\pi_\theta(y)) \\ &=\nabla_\theta\mathbb E_{y\sim\pi_{ref}}\log\frac{\pi_{ref}(y)}{\pi_\theta(y)} \end{align} $$
总结:k3估计(token级别)用在loss中也是有偏的,且其梯度近似序列级别的k3估计的梯度。此外发现序列级别的k3估计(非token级别的k3估计)的梯度恰好和$\pi_\theta$与$\pi_{ref}$的前向KL散度$D_{KL}(\pi_{ref}\Vert\pi_\theta)$的梯度一致,但k3估计本身又是负向KL散度$D_{KL}(\pi_\theta\Vert\pi_{ref})$的无偏估计。
K2 Estimator & Loss
$$ \begin{align} \nabla_\theta(k2)&=\mathbb E_{y\sim\pi_\theta}\nabla_\theta(\frac{1}{2}(\log\frac{\pi_\theta(y)}{\pi_{ref}(y)})^2) \\ &=\mathbb E_{y\sim\pi_\theta}\nabla_\theta(\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot\frac{\pi_{ref}(y)}{\pi_\theta(y)}\cdot\frac{\nabla_\theta\pi_\theta(y)}{\pi_{ref}(y)}) \\ &=\mathbb E_{y\sim\pi_\theta}\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y) \end{align} $$
我们另外计算了$k2$估计(sequence级)在loss中的梯度,发现该梯度与$D_{KL}(\pi_\theta\Vert\pi_{ref})$是一致的,对于token级别的$k2$估计,梯度为
$$ \mathbb E_{y\sim\pi_\theta}\sum_{t=1}^{\vert y\vert}\nabla_\theta(\frac{1}{2}(\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)})^2)=\mathbb E_{y\sim\pi_\theta}\sum_{t=1}^{\vert y\vert}\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}\nabla_\theta\log\pi_\theta(y_t) $$
这里表达式和$k3$在loss中sequence级与token级梯度的关系是一样的,但工程上都是选用token级(也就是选用token级的kl估计),主要有三个方面:1)使用sequence级计算的梯度,由于包含$\frac{\pi_\theta(y)}{\pi_{ref}(y)}$,会存在比较大的方差,导致收敛困难,但对于token级由于是把每个token的ratio加和,整体方差会小很多;2)观察$\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\nabla_\theta\log\pi_\theta(y)$实际上等于$\left(\sum_{t=1}^{\vert y\vert}\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}\right)\cdot\left(\sum_{t=1}^{\vert y\vert}\nabla_\theta\log\pi_\theta(y_t)\right)$,这个式子相比$\sum_{t=1}^{\vert y\vert}\log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}\nabla_\theta\log\pi_\theta(y_t)$多出了所有的交叉乘积项$\sum_{i\neq j}\log\frac{\pi_\theta(y_i)}{\pi_{ref}(y_i)}\nabla_\theta\log\pi_\theta(y_j)$,而在LLM推理中,每个token可以说是相对独立的,如果存在这些交叉项,也就意味着其他position的token的好坏会影响当前token所带来的梯度,这其实有点反直觉。而对于token级别的梯度公式,显而易见表达的是每个token与$\pi_{ref}$的偏离度$\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}$,作用在对应token所带来的梯度$\nabla_\theta\log\pi_\theta(y_t)$上,这也一定程度让模型知道每个token的好坏,使得梯度更新的方向更加细粒度。
Summary of K1 K2 K3
总结一下,发现$k1$ in reward和$k2$ in loss对应的梯度都是无偏的($k1$是真正序列级梯度无偏),$k3$ in loss发现优化的方向等价于优化前向KL散度。
Experiments
论文作者使用Qwen2.5-7B和Llama3-8B模型在Hendrycks MATH数据集上训练,测试集分为in-domain的MATH500和MATH^2以及out-of-domain的MMLU college physics、college chemistry、college biology。
on-policy setting
off-policy setting
using correct gradient estimators
最后,作者做了一个比较有趣的实验,前面讨论了on-policy设定下,$k3$ in loss的梯度是有偏的,$k1$ in reward梯度是无偏的。除此之外呢,在K1 Estimator & Loss已经计算得到$k1$ in loss的梯度是0,所以同时添加$k1$ in reward和$k1$ in loss的梯度应该还是无偏的;以及在K3 Estimator & Reward和K3 Estimator & Loss中分别计算出了$k3$ in reward的梯度为$\mathbb E_{y\sim\pi_\theta}\left[\log\frac{\pi_\theta(y)}{\pi_{ref}(y)}\cdot\nabla_\theta\log\pi_\theta(y)\right] + \mathbb E_{y\sim\pi_\theta}\left[\left(\sum_{t=1}^{\vert y\vert}\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\right)\cdot\nabla_\theta\log\pi_\theta(y)\right]$,以及$k3$ in loss的梯度为$\mathbb E_{y\sim\pi_\theta}\left[\sum_{t=1}^{\vert y\vert}-\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}\nabla_\theta\log\pi_\theta(y_t)\right]$,所以同时添加$k3$ in reward和$k3$ in loss的梯度把原始的偏差项主项(对角线)抵消掉了,还剩余$\mathbb E_{y\sim\pi_\theta}\left[\sum_{i\neq j}\frac{\pi_{ref}(y_i)}{\pi_\theta(y_i)}\cdot\nabla_\theta\log\pi_\theta(y_j)\right]$的偏差项(感觉可以理解这里是消除了主要偏差项,残余了部分次要偏差项)。最终作者比较了这3个设定($k1$ in reward, $k1$ in reward + $k1$ in loss, $k3$ in reward + $k3$ in loss,其中前两个设定是无偏的,最后一个设定仍然有偏但偏差部分是少于$k3$ in loss的)和梯度偏差最大的设定$k3$ in loss。结果发现前三个设定的结果都优于$k3$ in loss。特别是$k3$ in reward + $k3$ in loss,效果反而在大部份场景取得sota。这个主要证明减少kl estimator的梯度偏差能有带来模型性能上的提升。
KL in DeepSeek-V3.2
在DeepSeek-V3.2论文中,对于GRPO的KL项也做了关于梯度的修正,文章基于$k3$ in loss,但是对$k3$做了调整:
$$ KL_t=D_{KL}(\pi_\theta(y_t)\Vert\pi_{ref}(y_t))=\frac{\pi_\theta(y_t)}{\pi_{old}(y_t)}\left(\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-\log\frac{\pi_{ref}(y_t)}{\pi_\theta(y_t)}-1\right) $$
这个estimator在off-policy下梯度是无偏的,可以证明:
$$ \begin{align} \nabla_\theta\left(\frac{\pi_\theta}{\pi_{old}}\left(\frac{\pi_{ref}}{\pi_\theta}-\log\frac{\pi_{ref}}{\pi_\theta}-1\right)\right)&=\nabla_\theta\left(\frac{\pi_{ref}}{\pi_{old}}-\frac{\pi_\theta}{\pi_{old}}\log\frac{\pi_{ref}}{\pi_\theta}-\frac{\pi_\theta}{\pi_{old}}\right) \\ &=0-\frac{\nabla_\theta\pi_\theta}{\pi_{old}}\log\frac{\pi_{ref}}{\pi_\theta}+\frac{\nabla_\theta\pi_\theta}{\pi_{old}}-\frac{\nabla_\theta\pi_\theta}{\pi_{old}} \\ &=-\frac{\nabla_\theta\pi_\theta}{\pi_{old}}\log\frac{\pi_{ref}}{\pi_\theta}=\frac{\nabla_\theta\pi_\theta}{\pi_{old}}\log\frac{\pi_\theta}{\pi_{ref}} \end{align} $$
在off-policy设定下,该梯度的均值估计为:
$$ \mathbb E_{y\sim\pi_{old}}\frac{\nabla_\theta\pi_\theta}{\pi_{old}}\log\frac{\pi_\theta}{\pi_{ref}}=\sum_{y\in\mathcal Y}\pi_{old}\cdot\left(\frac{\nabla_\theta\pi_\theta}{\pi_{old}}\log\frac{\pi_\theta}{\pi_{ref}}\right)=\sum_{y\in\mathcal Y}\nabla_\theta\pi_\theta\log\frac{\pi_\theta}{\pi_{ref}} $$
理论上是无偏的。
KL in On-Policy Distillation
前面讨论的都是第一个Level的KL,即模型之间的KL,这里看看第二个Level的$D^y_{KL}(\pi_\theta\Vert\pi_{ref})$。前面介绍了这种KL散度一般用于知识蒸馏场景,而目前比较出名的关于知识蒸馏的方向是在线蒸馏(On-Policy Distillation),这里分析一下在线蒸馏的梯度优化是怎样的。
$$ \begin{align} \nabla_\theta D^y_{KL}(\pi_\theta\Vert\pi_{ref})&=\nabla_\theta\left[\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\sum_{v_t\in\mathcal V}\pi_\theta(v_t\vert y_{< t})\log\frac{\pi_\theta(v_t\vert y_{< t})}{\pi_{ref}(v_t\vert y_{< t})}\right] \\ &=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\sum_{v_t\in\mathcal V}\nabla_\theta\pi_\theta(v_t\vert y_{< t})\log\frac{\pi_\theta(v_t\vert y_{< t})}{\pi_{ref}(v_t\vert y_{< t})} \\ &=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\sum_{v_t\in\mathcal V}\pi_\theta(v_t\vert y_{< t})\log\frac{\pi_\theta(v_t\vert y_{< t})}{\pi_{ref}(v_t\vert y_{< t})}\nabla_\theta\log\pi_\theta(v_t\vert y_{< t}) \\ &=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\mathbb E_{v_t\sim\mathcal V}\log\frac{\pi_\theta(v_t\vert y_{< t})}{\pi_{ref}(v_t\vert y_{< t})}\nabla_\theta\log\pi_\theta(v_t\vert y_{< t}) \end{align} $$
之前小米的MOPD中基础的在线蒸馏损失函数长这样:
$$J(\theta)=-\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\text{sg}(\log\frac{\pi_{\text{Teacher}}(y_t)}{\pi_\theta(y_t)})\log\pi_\theta(y_t)$$
对其求梯度得到
$$\nabla_\theta J(\theta)=\frac{1}{\vert y\vert}\sum_{t=1}^{\vert y\vert}\log\frac{\pi_\theta(y_t)}{\pi_{\text{Teacher}}(y_t)}\nabla_\theta\log\pi_\theta(y_t)$$
发现其实和$\nabla_\theta D^y_{KL}(\pi_\theta\Vert\pi_{ref})$存在偏差,这里梯度默认$\pi_\theta(v_t\vert y_{< t})$是one-hot分布了。
References
[1] Shah et al. “A COMEDY OF ESTIMATORS: ON KL REGULARIZATION IN RL TRAINING OF LLMS” ICLR Openreview 2026.
[2] Schulman. “Approximating KL Divergence” joschu.net 2020.
[3] Kristiadi. “KL Divergence: Forward vs Reverse?” agustinus.kristia.de 2016.
[4] DeepSeek-AI. “DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models” DeepSeek 2025.