KL-Based Divergences
给定两个离散分布$P(\mathcal C)$和$Q(\mathcal C)$,它们的KL散度定义为:
$$ \mathcal D_{KL}(P\Vert Q)=\sum_{c\in\mathcal C}P(c)\log\frac{P(c)}{Q(c)} $$
由于KL散度的不对称性:$\mathcal D_{KL}(P\Vert Q)\neq \mathcal D_{KL}(Q\Vert P)$,定义前向KL散度(forward KL)为$\mathcal D_{KL}(P\Vert Q)$,定义反向KL散度(reverse KL)为$\mathcal D_{KL}(Q\Vert P)$。
KL散度是无界的,一个常用的衡量概率分布有界散度为:$JSD$(Jensen-Shannon divergence)。$JSD(\beta)$结合了前向和后向KL散度(其中$0<\beta< 1$)
$$ \mathcal D_{JSD(\beta)}(P\Vert Q)=\beta\mathcal D_{KL}(P\Vert \beta P+(1-\beta)Q)+(1-\beta)\mathcal D_{KL}(Q\Vert \beta P+(1-\beta)Q) $$
经过证明$\lim_{\beta\rightarrow 0}\mathcal D_{JSD(\beta)}(P\Vert Q)/\beta=\mathcal D_{KL}(P\Vert Q)$,因此,$JSD(\beta)$的梯度当$\beta$接近0和1时分别与前向KL和反向KL的梯度相近。
Distillation For Auto-Regressive Sequence Models
定义学生policy和教师policy为$p_S$和$p_T$,定义学生策略有可学习参数$\theta$。基于给定的数据集$(X, Y)$,其中$Y$可以是预先准备的也可以是教师模型$p_T$基于$X$生成的。基于一个散度函数$\mathcal D$,定义$p_T$和$p_S$之间的token-level分布差异为:
$$ \mathcal D(p_T\Vert p_S^\theta)(y\vert x):=\frac{1}{L_y}\sum_{n=1}^{L_y}\mathcal D(p_T(\cdot\vert y_{< n},x)\Vert p_S^\theta(\cdot\vert y_{< n}, x)) $$
其中$x$为一条样本的input,$y$为一条样本的output,$L_y$为output的长度。
Supervised FT
没有教师policy,只有$(X,Y)$,可以使用最小负对数似然优化学生policy
$$ L_{SFT}(\theta)=\mathbb E_{(x,y)\sim(X,Y)}[-\log p_S^\theta(y\vert x)] $$
Sequence-Level KD
当有学生policy和教师policy,且最大化教师policy生成的sequence的似然,可以视作使用教师policy生成的output上$(X, Y_T)$,对学生policy做Supervieed FT
$$ L_{SeqKD}(\theta)=\mathbb E_{(x,y)\sim(X,Y_T)}[-\log p_S^\theta(y\vert x)] $$
Supervised KD
基于token-level的优化:
$$ L_{SD}(\theta)=\mathbb E_{(x,y)\sim(X,Y)}[-\mathcal D_{KL}(p_T\Vert p_S^\theta)(y\vert x)] $$
Generalized Knowledge Distillation
基于固定的$(X, Y)$或者$(X, Y_T)$训练的学生policy,对于训练分布外的数据,存在一定泛化问题,因此Agarwal et al.提出On-policy Distillation,简而言之,基于学生最新的policy生成的$(X, Y^{\theta_{new}}_S)$来做Supervised KD。
$$ L_{OD}(\theta)=\mathbb E_{x\sim X}[\mathbb E_{y\sim p_S(\cdot\vert x)}[-\mathcal D_{KL}(p_T\Vert p_S^\theta)(y\vert x)]] $$
此外,作者结合On-Policy策略和常规Supervised KD策略,给出Generalized KD(GKD)。
$$ L_{GKD}(\theta)=(1-\lambda)\mathbb E_{(x,y)\sim(X,Y)}[-\mathcal D(p_T\Vert p_S^\theta)(y\vert x)]+\lambda\mathbb E_{x\sim X}[\mathbb E_{y\sim p_S(\cdot \vert x)}[-\mathcal D(p_T\Vert p_S^\theta)(y\vert x)]] $$
RLT + On-Policy GKD
作者提出在RL阶段,可以引入GKD,使得学生policy在RL训练基于奖励$r$优化的过程中,不会偏离固定的教师policy。
$$ \mathbb E_{x\sim X}[(1-\alpha)\mathbb E_{y\sim p_S^\theta(\cdot\vert x)}r(y)-\alpha\mathbb E_{y\sim p_S(\cdot\vert x)}[\mathcal D(p_T\Vert p_S^\theta)(y\vert x)]] $$
其中$\alpha\in[0, 1]$。此外作者建议在RL训练中使用逆向KL或者使用$JSD(0.9)$。
On-Policy Distillation
此外,最近Thinking Machines的博客“On-Policy Distillation”也提出了类似的方法,基于on-policy的蒸馏方式。这篇博客关于on-policy的motivation比较合理:对于LLM中的on-policy训练,其奖励信号通常比较稀疏,比方说训练一些math或者coding的数据,其reward只能给出这道题目是否正确的奖励信号,但如果做错了,模型并不知道是哪里错了。而对于监督学习(SFT),模型能够获得token级别的信号,但SFT只能让模型学习老师的路径,但这些路径可能并不会在学生模型真实推理中遇到,相反,当学生模型SFT训练后碰到一些训练中没见过的路径,可能会与越来越偏离正确答案。
事实上,这里关于On-Policy训练奖励稀疏这块,作者提到比如一道数学题,On-Policy的奖励只能让模型知道其结果正确还是错误,如果错了并不知道是哪里错了。 基于GRPO算法,最终的advantage会公平地作用在每个generated sentence的所有token上,所以这里理解起来,模型也并不知道一道数学题做错了得到0奖励,是因为最终的答案错了还是中间的过程错了,因为每个rollout中的token得到一致的advantage。
关于On-Policy Distillation,作者推荐reverse KL:
$$ KL(\pi_\theta\Vert \pi_{\text{teacher}}) = \mathbb E_{x\sim\pi_\theta}[\log\pi_\theta(x_{t+1}\vert x_{1..t}-\log\pi_{\text{teacher}}(x_{t+1}\vert x_{1..t}))] $$
References
[1] Agarwal et al. “ON-POLICY DISTILLATION OF LANGUAGE MODELS: LEARNING FROM SELF-GENERATED MISTAKES ” arXiv preprint axXiv:2306.13649 (2023).
[2] Kevin Lu et al. “On-Policy Distillation” Thinking Machines (2025).