揭开loss看梯度 -- DiffusionRL的实际优化目标

  • 发布于 2026-01-20
  • 4 次阅读

作者:王峰

https://zhuanlan.zhihu.com/p/1995493539912623296

这段时间一直在搞Diffusion+RL,各种调参下来也算有了一些心得,但没什么创新点,感觉不值得发一篇paper,上次为了写CPS那篇论文又是研究底层理论、又是做对比实验,还要跟审稿人对线,对于我这种已经工作的人来说实在太浪费时间,于是还是写一篇博客吧。

本文所使用的分析方法,其实也是一众做loss function论文中常用的:一个loss对神经网络起作用的是它的梯度,而loss本身是什么形式并没有那么重要,所以我现在看到一个loss就习惯性地会对它求导,看看它的梯度,推导一下,再积分回去,就可以得到另一种loss形式,有时就会有一些新的体会。

本文准备介绍一下Flow-GRPOAWMDiffusionNFT这三篇比较有代表性的文章,分别求一下它们的梯度和等效loss,通过这些分析,可以对DiffusionRL究竟在做什么有更深的理解。

阅读之前,建议还是先对这三篇文章有所了解,本文并不会过多介绍它们的基本原理,感兴趣的直接在知乎上搜对应文章即可。

一、Flow-GRPO

Flow-GRPO这篇文章之前也说过,提供了一个相当不错的代码基础和benchmark,过去一年有很多DiffusionRL的工作都是在此之上进行的。关于这篇文章的分析也挺多的,例如Flow Matching RL(一):Flow-GRPO在学什么? 和 DiffusionNFT的附录,我的结论与他们类似,看过的话可以跳过这一节。

Flow-GRPO使用SDE来在采样中间过程引入随机性,在每一步新加一个高斯噪声 ​\epsilon

d{x}_t = [{v}_{\theta}({x}_t, t) + \frac{\sigma_t^2}{2t}({x}_t+(1-t)\hat{{v}}_{\theta}({x}_t, t))]dt + \sigma_t \sqrt{dt} {\epsilon}

在使用SDE获得了一系列样本(图像) ​x_0 之后,输入reward model即可得到其reward ​R(x_0) ,之后计算GRPO Advantage:

A^i = \frac{R({x}_0^i) - \text{mean}(\{R({x}^i_0)\}_{i=1}^G)}{\text{std}(\{R({x}_0^i)\}_{i=1}^G)}

其最终的loss就是GRPO的loss(在此忽略KL项):

\mathcal{L}(\theta) = \mathbb{E}_{{x}^i\sim\pi_{\theta_{\text{old}}}}\frac{1}{G} \sum \limits_{i=1}^{G} \frac{1}{T} \sum \limits_{t=0}^{T-1}\mathop{min}\Big(r_t^i(\theta)A^i, \text{clip}(r_t^i(\theta), 1-\epsilon, 1+\epsilon)A^i)\Big)

其中

r_t^i(\theta) = \frac{p_{\theta}({x}^i_{t-1}|{x}_t^i)}{p_{\theta_{\text{old}}}({x}^i_{t-1}|{x}_t^i)}

​p_{\theta}({x}^i_{t-1}|{x}_t^i) 服从高斯分布,其对数概率被定义为:

\log p_{\theta}({x}^i_{t-1}|{x}_t^i) = -\frac{\|{x}_{t-\Delta t} - \mu_{old}({x}_t, t)\|^2}{2\sigma_t^2} - \log\sigma_t - \log\sqrt{2\pi}

其中

\mu_{old}({x}_t, t)={x}_t + [{v}_{old}({x}_t, t) + \frac{\sigma_t^2}{2t}({x}_t+(1-t)\hat{{v}}_{old}({x}_t, t))]dt

代表高斯分布的均值,在每一步, ​- \log\sigma_t - \log\sqrt{2\pi} 是一个常数,在

r_t^i(\theta) = \frac{p_{\theta}({x}^i_{t-1}|{x}_t^i)}{p_{\theta_{\text{old}}}({x}^i_{t-1}|{x}_t^i)}

中可以被约掉,所以对数概率可以化简为:

\begin{align} \log p_{\theta}({x}^i_{t-1}|{x}_t^i) &= -\frac{\|{x}_{t-\Delta t} - \mu_{old}({x}_t, t)\|^2}{2\sigma_t^2\Delta t}\\ &= -\frac{\|(\Delta t+\frac{(1-t)\sigma_t^2\Delta t}{2t})({v}_\theta - {v}_{old}) + \sigma_t\sqrt{\Delta t} {\epsilon}\|^2}{2\sigma_t^2 \Delta t}\\ & = -\frac{\|C_1({v}_\theta - {v}_{old}) + C_2 {\epsilon}\|^2}{2C_2^2}. \end{align}

这里将两个常数分别用 ​C_1​C_2 替换以减少公式复杂度。

下面重点来了,对prob ratio ​r(\theta) 求导:

\begin{align} \nabla_\theta r_i(\theta) &= \nabla_\theta e^{\log p_{\theta}({x}^i_{t-1}|{x}_t^i) - \log p_{old}({x}^i_{t-1}|{x}_t^i)}\notag\\ &= sg(e^{\log p_{\theta}({x}^i_{t-1}|{x}_t^i) - \log p_{old}({x}^i_{t-1}|{x}_t^i)}) \nabla_\theta \log p_{\theta}({x}^i_{t-1}|{x}_t^i)\notag\\ & \approx \nabla_\theta \log p_{\theta}({x}^i_{t-1}|{x}_t^i). \end{align}

这里 ​sg(\cdot) 是梯度截止符,不改变梯度方向,而且 ​p_\theta​p_{old} 一般来说非常接近,在刚采样完成的几个iteration这两者甚至完全一样,所以这里直接约等于了。

注意到loss中还有一堆min和clip,这个是PPO clip,起到稳定训练的作用,我们在分析时可以将其省略。所以loss的梯度为:

\begin{align} \nabla_{v_\theta}\mathcal{L} &\approx A \nabla_{v_\theta} \log p_{\theta}({x}^i_{t-1}|{x}_t^i) \notag\\ & = -A \frac{2(C_1({v}_\theta - {v}_{old})+C_2\epsilon)C_1}{2C_2^2} \notag\\ &\approx -\frac{C_1}{C_2}A{\epsilon} \end{align}

这里第一个约等于是因为上一个公式有一个约等号,第二个约等于则是假设了 ​v_\theta \approx v_{old} ,这个假设跟上边假设 ​p_\theta \approx p_{old} 类似,但上边那个假设即使不成立也并不影响梯度方向,而两个速度相减直接影响了梯度方向,所以这个假设有一点牵强。但不进行这个假设的话,这个公式又有点太复杂了不利于理解。

好在后来,Flow-GRPO原作者组里又出了一篇文章叫GRPO-Guard,对这个问题进行了修正,他们修改了log-prob的定义,最终起到的效果之一就是在梯度中去掉了 ​{v}_\theta - {v}_{old} 那一项,修正之后,后一个约等号就变成等号了,具体的改动可以看一看GRPO-Guard这篇文章,里面也有对梯度的分析,与本文的结论一致。

所以说,Flow-GRPO的policy gradient实际上是使用advantage加权的噪声,这一点其实还是有一点奇怪的,速度应该朝另一个速度的方向移动,沿噪声方向移动有什么意义?而这可能也是其优化效率不如后两个算法的原因。

速度 ​v_\theta 沿着梯度方向进行梯度下降,我们也可以求得其对应的目标位置: ​v_\theta + \frac{C_1}{C_2}A \epsilon ,从而构成另一个loss:

​\mathcal{L} = \|v_\theta - sg(v_\theta + \frac{C_1}{C_2}A\epsilon)\|^2

可以一眼看出来,这个损失函数的梯度跟上边是一样的。

如果只在意梯度方向而不在意幅度的话,则另一个对应目标位置为: ​v_\theta + A \epsilon ,实际上梯度下降也不可能真的到达 ​v_\theta + \frac{C_1}{C_2}A \epsilon ,只能说向这个目标移动,射线上的每个点都可以作为目标,所以说 ​v_\theta + A \epsilon 是优化目标也没什么错误。

二、AWM

AWM这个算法相当的简单,但效果却很不错,具体可以看作者的介绍:优势加权匹配(AWM):让扩散模型的强化学习与预训练对齐 。原文是从方差的角度来解释的,但本文从梯度角度也可以解释为什么它的优化效率如此的高。

AWM的训练过程与Diffusion本身的训练过程很类似,在采样到样本 ​x_0 之后,对样本重新进行加噪来训练,加噪方式和loss都与Diffusion基本一致,只是在最终loss的前边套上了advantage作为loss weight:

\begin{align} x_t^i &= (1-t)x_0^i + t{\epsilon} \\ v_{gt} &= {\epsilon} - {x}_0^i\\ \mathcal{L}_{AWM} &= A^i\|{v}_\theta(x_t^i) - v_{gt}\|^2 \end{align}

这几个公式熟悉flow matching的人都应该能看出来,其实就是标准的FM训练过程加了个weight,所以这篇文章的名字就叫advantage weighted matching,大道至简,返璞归真。

下面对这个loss求一下梯度,也非常简单:

\nabla_{v_\theta}\mathcal{L} = 2A^i(v_\theta - v_{gt})

其对应的优化目标(之一)为: ​v_\theta + A(v_{gt}-v_\theta) ,这个优化目标就比Flow-GRPO的好理解多了,起码v应该是向另一个v移动:如果当前样本很好时,比如 ​A=1 ,那么优化目标就是 ​v_{gt} ;当样本很差例如 ​A=-1 时,要将其沿 ​v_{gt}-v_\theta 推向远离 ​v_{gt} 的方向,此时优化目标为 ​2v_\theta - v_{gt} ,也就是向外插1倍的位置,非常的直观。

虽然AWM的论文中没有加入PPO clip,但我实际实验下来,还是需要套上的,否则优化一会就会剧烈抖动甚至直接归0。

三、DiffusionNFT

DiffusionNFT与AWM的主要区别是它是一个off-policy的方法,它会用滑动平均维护一个old model:

\theta^{old} \leftarrow \eta_i\theta^{old}+(1-\eta_i)\theta

采样和梯度计算都在这个old model上进行,这样的好处是训练会更加稳定,但缺点是在old model上计算的梯度方向终归与new model不那么匹配,所以优化速度会慢。

DiffusionNFT的motivation和推导过程请参见原论文,这里只列出其最终的损失函数:

\begin{align} \mathcal{L}(\theta) = \mathbb{E}_{c, \pi^{old}(x_0|c), t} &\left[ r \|v_{\theta}^{+}(x_t, c, t) - v_{gt}\|_2^2 + (1-r) \|v_{\theta}^{-}(x_t, c, t) - v_{gt}\|_2^2 \right], \notag \\ v_{\theta}^{+} &= (1-\beta)v_{old} + \beta v_{\theta},\notag\\ v_{\theta}^{-} &= (1+\beta)v_{old} - \beta v_{\theta}, \end{align}

这里的 ​\beta 在实际工程中一般取1即可, ​r 其实跟advantage很类似,只是它被归一化到了 ​[0, 1] 范围,而advantage一般是零均值的,在实际应用中它们之间的关系可以表达为: ​A = 2r-1

为了简化推导,我们定义old和gt之间的差距为 ​\Delta = v_{old} - v_{gt} ,定义new和old之间的差距为 ​\delta_{\theta} = v_{\theta} - v_{old} ,这样上边的loss可以化简为:

\mathcal{L} = r \|\Delta + \beta \delta_{\theta}\|_2^2 + (1-r) \|\Delta - \beta \delta_{\theta}\|_2^2

下面求一下梯度(注意到 ​\nabla_{v_\theta} \delta_\theta = I ,而 ​\Delta 里不含 ​\theta ):

\begin{align} \nabla_{v_{\theta}} \mathcal{L} &= r \cdot 2(\Delta + \beta \delta_{\theta}) \cdot \beta + (1-r) \cdot 2(\Delta - \beta \delta_{\theta}) \cdot (-\beta) \\ &= 2\beta \left[ r(\Delta + \beta \delta_{\theta}) - (1-r)(\Delta - \beta \delta_{\theta}) \right] \\ &= 2\beta \left[ (2r - 1)\Delta + \beta \delta_{\theta} \right] \\ &= 2\beta\underbrace{A(v_{old} - v_{gt})}_{\text{Optimize Direction}} + \underbrace{2\beta^2 (v_{\theta} - v_{old})}_{\text{Trust Region Regularization}} \end{align}

一番推导下来,可以看到这里有两项loss,其中第一项跟AWM的形式很类似,起到指定优化方向的作用,而第二项起到了trust region的作用,熟悉RL历史的朋友会一下子想起来TRPO,而PPO clip实际上是TRPO的一个改进。也是因为这里有一个类似trust region的约束项,所以DiffusionNFT并不需要使用PPO clip即可让优化稳定进行下去。

进一步推导:

\nabla_{v_{\theta}} \mathcal{L} = 2\beta^2\left(v_\theta - (v_{old} + \frac{A}{\beta}(v_{gt} - v_{old}))\right)

所以DiffusionNFT的等效loss可以写成:

\mathcal{L} = 2\beta^2\|v_\theta - (v_{old} + \frac{A}{\beta}(v_{gt} - v_{old}))\|^2

也就是说,它的优化目标为 ​v_{old} + \frac{A}{\beta}(v_{gt} - v_{old}) ,与AWM对比,其实就是将 ​v_\theta 更换为 ​v_{old} ,即使用off-policy的old model来计算优化目标,old model因为是moving average更新的,所以它更加稳定。

四、后记

本文通过对loss求导,再反推一个新的loss形式的技巧,对三种DiffusionRL方法进行了解析,这里我列了一个表格进行总结:

Forward表示像Diffusion训练时那样加噪,而Backward表示直接使用采样过程中迭代求得的xt

可以看到三个方法的优化目标其实都还是挺容易理解的,尤其是AWM和DiffusionNFT的区别几乎只在于on-policy或是off-policy。

目前就我自己的实验来看,AWM优化速度快、效果好但没那么稳定,需要搭配PPO clip并且需要一定的调参技巧,DiffusionNFT略慢但优化稳定,迁移到其他算法时,基本上只需要调一调ema的参数。实际使用中,根据自己的需求进行选择即可。