本文是LLM后训练的第三篇,主要介绍近端策略优化(Proximal Policy Optimization,简称PPO)算法。PPO算法在LLM后训练中起到举足轻重的作用。PPO算法更像是一个强化学习技巧的集合体,其设计重点在于保证训练的稳定性。因此在介绍PPO算法之前,需要先了解其背后的一些基础背景。
广义优势估计 (GAE)
广义优势估计(Generalized Advantage Estimation, 简称GAE)是一种用于强化学习中的优势函数估计方法。在LLM后训练(二)中,我们提到了TD误差实际上是优势函数的一个估计。然而,这种估计往往是有偏的,因为对状态值函数的估计可能不准确。与之相对,如果使用蒙特卡洛采样实际样本去估计优势函数,虽然可以做到无偏,但是方差会非常大。GAE方法通过引入一个衰减因子$\lambda$,在TD误差和蒙特卡洛采样之间进行权衡,从而得到一个既有较低偏差又有较低方差的优势函数估计。 下面给出TD(0)版本的优势函数估计:
$$ \begin{align} \hat{A}_t &= \delta_{t} \nonumber \\ &= r_t + \gamma V(s_{t+1}) - V(s_t) \nonumber \end{align} $$这里只使用了一个时间步的奖励和状态函数估计,进一步我们可以写出TD(1)版本的优势函数估计:
$$ \begin{align} \hat{A}_t &= r_t + \gamma r_{t+1} + \gamma^2V(s_{t+2}) - V(s_t) \nonumber \\ &= r_t + \gamma V(s_{t+1}) - V(s_t) + \gamma(r_{t+1} + \gamma V(s_{t+2}) - V(s_{t+1})) \nonumber \\ &= \delta_t + \gamma \delta_{t+1} \nonumber \end{align} $$其中第二个等号使用了逆向递归展开,最终利用不同时间步的TD(0)误差表示出了TD(1)版本的优势函数估计。 以此类推,我们可以得到TD(k)版本的优势函数估计:
$$ \begin{align} \hat{A}_t &= \sum_{l=0}^{k} \gamma^l r_{t+l} + \gamma^{k+1} V(s_{t+k}) - V(s_t) \nonumber \\&= \sum_{l=0}^{k} \gamma^l \delta_{t+l} \nonumber \end{align} $$这里省略了中间的递归展开步骤,直接给出了TD(k)版本的优势函数估计。可以看到,TD(k)版本的优势函数估计是TD(0)版本误差的一个加权和,其中权重是$\gamma^l$。
需要注意的是,尽管TD(k)最终可以表示为各时间步TD(0)误差的加权和,但每一项TD(0)误差都包含一步真实奖励的采样,因此奖励项必不可少。随着展开时间步数的增加,TD(k)版本的优势函数估计偏差会逐渐减小,但方差会逐渐增大。为了权衡偏差和方差,GAE方法引入了一个衰减因子$\lambda$,将TD(k)版本的优势函数估计进行加权平均:
$$ \begin{align} \hat{A}_t^{GAE} &= (1 - \lambda) \sum_{k=0}^{\infty} \lambda^k \hat{A}_t^{TD(k)} \nonumber \\&= (1 - \lambda) \Big[ \delta_t + \lambda(\delta_t + \gamma \delta_{t+1}) + \lambda^2(\delta_t + \gamma \delta_{t+1} + \gamma^2 \delta_{t+2}) + \cdots \Big] \nonumber \\&= (1 - \lambda) \Big[ (1+\lambda + \lambda^2 + \cdots) \delta_t + (\lambda + \lambda^2 + \cdots) \gamma \delta_{t+1} + (\lambda^2 + \cdots) \gamma^2 \delta_{t+2} + \cdots \Big] \nonumber \\&= \delta_t + \lambda \gamma \delta_{t+1} + \lambda^2 \gamma^2 \delta_{t+2} + \cdots \nonumber \\&= \sum_{l=0}^{\infty} (\lambda \gamma)^l \delta_{t+l} \nonumber \end{align} $$其中,最后一个等号利用了无穷级数求和公式$\sum_{k=0}^{\infty} \lambda^k = \frac{1}{1-\lambda}$。上述推导实际上是基于无穷项级数的求和来得到GAE版本的优势函数估计。在实际问题中,求和通常会被截断到一个有限的时间步数$T$,但工程上仍然复用上述结果(理论严谨性有待商榷)。从末尾向前递归,可得到所有时间步的GAE(忽略关于$\lambda$的归一化系数):
$$ \begin{align} A_T^{GAE} &\leftarrow \delta_T \nonumber \\ \hat{A}_t^{GAE} &\leftarrow \delta_t + \lambda \gamma \hat{A}_{t+1}^{GAE} \nonumber \end{align} $$信赖域策略优化(TRPO)
信赖域策略优化(Trust Region Policy Optimization, 简称TRPO)是一种用于强化学习中的策略优化算法。TRPO的核心思想是通过限制每次更新的策略变化范围来保证训练的稳定性和收敛性。
首先回顾策略梯度方法的目标函数对参数的梯度:
$$ \nabla_\theta J(\theta) = \mathbb{E}_{\pi} \Big[ A_t \nabla_\theta log\pi_{\theta}(a_t|s_t) \Big] \nonumber $$然而,直接使用上述目标函数有如下几个问题:
- 上述式子是对策略进行采样,反过来更新策略本身。策略每次更新后,为了让采样更准确,都需要重新采样数据,这会导致训练效率非常低。
- 即使每次更新后重新采样数据,直接使用上述目标函数进行优化也可能导致训练不稳定,并不能保证每次更新都能提升策略的性能。
因此,TRPO算法引入了一个新的目标函数,称为重要性采样(Importance Sampling)目标函数:
$$ \begin{align} \nabla_\theta J(\theta) &= \mathbb{E}_{\pi_{\theta_{old}}} \Big[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t \nabla_\theta log\pi_{\theta}(a_t|s_t) \Big] \nonumber \end{align} $$它使用了重要性采样定理(由简单的积分变换即可推导),在旧策略$\pi_{\theta_{old}}$上采样数据,并以重要性权重$\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$来修正目标函数的梯度。这种方法允许我们在不重新采样数据的情况下更新策略,从而提高了训练效率。对应的目标函数形式(去掉梯度算子)为:
$$ \begin{align} J(\theta) &= \mathbb{E}_{\pi_{\theta_{old}}} \Big[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t \Big] \nonumber \end{align} $$然而,直接优化上述目标函数仍然可能导致训练不稳定,因为策略的更新可能会过大,从而导致性能下降。因为重要性采样定理在实际使用中最好保证$\pi_{\theta}$和$\pi_{\theta_{old}}$之间的差异较小,否则需要很多采样点才能得到一个准确的估计。因此,需要引入一个约束条件来限制每次更新的策略变化范围。TRPO算法使用KL散度(Kullback-Leibler Divergence)作为约束条件,确保新策略$\pi_{\theta}$与旧策略$\pi_{\theta_{old}}$之间的KL散度不超过一个预设的阈值$\delta$,除此之外,还使用广义优势估计(GAE)来估计优势函数,以进一步提高训练的稳定性和效率。最终的TRPO算法的优化问题可以表示为:
$$ \begin{align} \max_\theta & \quad \mathbb{E}_{\pi_{\theta_{old}}} \Big[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t^{GAE} \Big] \nonumber \\ \text{subject to} & \quad \mathbb{E}_{s_t \sim d^{\pi_{\theta_{old}}}} \Big[ D_{KL}(\pi_{\theta_{old}}(\cdot|s_t) || \pi_{\theta}(\cdot|s_t)) \Big] \leq \delta \nonumber \end{align} $$这里只介绍TRPO的核心思想,不再展开求解方法,感兴趣的读者可以参考原论文 Trust Region Policy Optimization。
近端策略优化(PPO)
在介绍完GAE和TRPO算法之后,就可以引入最终的PPO算法了。PPO算法是TRPO算法的一个改进版本。TRPO除优化目标外还引入了约束条件来限制每次更新的策略变化幅度,这使得其实现和求解都相对复杂。PPO算法通过在优化目标中引入一个截断函数(Clipping Function)来隐式地约束策略变化范围,从而省去了TRPO中的显式约束,使得实现和求解都更加简洁高效。PPO算法综合了GAE、Actor-Critic方法、重要性采样等多种强化学习技术,成为目前LLM后训练中最常用的算法之一。
Actor部分的优化目标
PPO算法的优化目标可以表示为:
$$ \begin{align} J(\theta) &= \mathbb{E}_{\pi_{\theta_{old}}} \Big[ \min \Big( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t^{GAE}, clip\Big( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon \Big) A_t^{GAE} \Big) \Big] \nonumber \end{align} $$这个公式引入了一个$\min$函数和一个$\text{clip}$函数。要理解PPO的优化目标,需要对这两个函数进行联合分析:
- 假设$A_t^{GAE} > 0$,即当前动作的优势函数估计为正,则优化目标的第一项$\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t^{GAE}$会鼓励策略增大选择动作$a_t$的概率。然而,如果新旧策略的概率比已经非常大,超出了$1+\epsilon$的范围,$\text{clip}$函数便会触发,此时$\min$会选择截断项作为优化目标,而截断项对概率比不传导梯度,因此等价于放弃对该动作的进一步优化。
- 假设$A_t^{GAE} < 0$,即当前动作的优势函数估计为负,则优化目标的第一项$\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t^{GAE}$会鼓励策略减小选择动作$a_t$的概率。然而,如果新旧策略的概率比已经非常小,低于$1-\epsilon$的范围,$\text{clip}$函数便会触发,此时$\min$会选择截断项作为优化目标,同样不传导梯度,等价于放弃对该动作的进一步优化。
PPO算法还有一个变体叫做PPO-penalty, 其优化目标如下:
$$ \begin{align} J(\theta) &= \mathbb{E}_{\pi_{\theta_{old}}} \Big[ \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t^{GAE} - \beta D_{KL}(\pi_{\theta_{old}}(\cdot|s_t) || \pi_{\theta}(\cdot|s_t)) \Big] \nonumber \end{align} $$PPO-penalty引入了一个KL散度惩罚项来限制每次更新的策略变化范围,$\beta$是控制惩罚项权重的超参数。这个变体不常用,因此不再赘述。
Critic部分的优化目标
在LLM后训练(二)中,我们提到Critic部分的优化目标是最小化TD误差:
$$ \mathcal{L}(\phi) = \mathbb{E}\left[ (r_t + \gamma V_{\phi}(s_{t+1}) - V_{\phi}(s_t))^2 \right] \nonumber $$在PPO算法中,Critic部分采用类似的优化目标,但使用了GAE版本的优势函数估计来替代TD误差:
$$ \begin{align} V_{target} &= A_t^{GAE} + V_{\phi_{old}}(s_t) \nonumber \\ \mathcal{L}(\phi) &= \mathbb{E}\left[ (V_{target} - V_{\phi}(s_t))^2 \right] \nonumber \end{align} $$上式为Critic部分的优化目标的原始形式,为了和之前的TD误差版本进行对比,我们可以将$V_{target}$展开:
$$ \begin{align} V_{target} &= A_t^{GAE} + V_{\phi_{old}}(s_t) \nonumber \\&= \delta_t + \lambda \gamma A_{t+1}^{GAE} + V_{\phi_{old}}(s_t) \nonumber \\&= r_t + \gamma V_{\phi_{old}}(s_{t+1}) - V_{\phi_{old}}(s_t) + \lambda \gamma A_{t+1}^{GAE} + V_{\phi_{old}}(s_t) \nonumber \\&= r_t + \gamma V_{\phi_{old}}(s_{t+1}) + \lambda \gamma A_{t+1}^{GAE} \nonumber \end{align} $$
相比于原版TD误差的优化目标,PPO算法的Critic优化目标额外引入了一个GAE折扣项$\lambda \gamma A_{t+1}^{GAE}$。需要注意,$V_{target}$中的量均基于旧的Critic网络计算,旧网络的参数$\phi_{old}$不参与梯度回传,从而能够更稳定地更新Critic网络的参数$\phi$。
PPO算法伪代码如下:

在某些训练框架中(例如GPT),Critic部分的优化目标也会引入一个剪切函数来限制每次更新的值函数变化范围,类似于Actor部分的优化目标:
$$ \begin{align} V_{clip} &= clip(V_{\phi}(s_t), V_{\phi_{old}}(s_t) - \epsilon, V_{\phi_{old}}(s_t) + \epsilon) \nonumber \\ \mathcal{L}(\phi) &= \mathbb{E}\left[ \max \Big( (V_{target} - V_{\phi}(s_t))^2, (V_{target} - V_{clip})^2 \right] \nonumber \end{align} $$其含义与Actor部分的截断函数类似,具体如下:
- 假设$V_{target} - V_{\phi}(s_t) > 0$,即当前值函数估计偏低,则第一项$(V_{target} - V_{\phi}(s_t))^2$会驱动Critic提高值函数的估计。然而,如果当前值函数相比于旧网络已经大幅上升,超出了$V_{\phi_{old}}(s_t) + \epsilon$的范围,$\text{clip}$函数便会触发,此时$\max$会选择$(V_{target} - V_{clip})^2$作为优化目标,而该截断项不传导梯度,等价于放弃对该状态的进一步优化。
- 假设$V_{target} - V_{\phi}(s_t) < 0$,即当前值函数估计偏高,则第一项$(V_{target} - V_{\phi}(s_t))^2$会驱动Critic降低值函数的估计。然而,如果当前值函数相比于旧网络已经大幅下降,低于$V_{\phi_{old}}(s_t) - \epsilon$的范围,$\text{clip}$函数便会触发,同样选择截断项作为优化目标,不传导梯度,等价于放弃对该状态的进一步优化。
PPO使用一批rollout数据对Actor和Critic进行多次交替更新,再采样新的rollout数据进行下一批次更新。因此它介于on-policy和off-policy之间,不过主流观点仍倾向于认为PPO是一种on-policy算法。
至此,LLM所需要的强化学习前置知识就介绍完了,下一篇将介绍LLM强化学习中如何应用这些算法进行后训练。

说些什么吧!