Diffusion Models
Introduction
Diffusion models 包含前向和反向两个过程,前向过程是从一个真实分布开始,对其采样\(x_0\sim q(x)\),然后不断添加噪声获得一系列加噪样本\(x_1,\dots,x_T\),加噪的步长大小通过一个方差函数\(\beta_t\in (0,1)\)控制,随加噪次数增加,样本逐渐接近高斯分布。
\[\begin{aligned} q(x_t\vert x_{t-1}) &= \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1},\beta_tI)\\ q(x_{1:T}\vert x_0) &= \prod_{t=1}^T q(x_t\vert x_{t-1}) \end{aligned}\]我们可以用重参数技巧推断\(x_0\)和\(x_t\)的关系,令\(\alpha_t = 1-\beta_t\),则有
\[\begin{aligned} x_t &= \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1}&\epsilon\sim\mathcal{N}(0,I)\\ &= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2})+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}_{t-2} \\ &=\dots\\ &= \sqrt{\prod_{s=1}^t\alpha_s}x_0 + \sqrt{1-\prod_{s=1}^t\alpha_s}\bar{\epsilon} \end{aligned}\]两个均值为0的高斯噪声\(\mathcal{N}(0,\sigma_1^2I)\)和\(\mathcal{N}(0,\sigma_2^2I)\)的和服从\(\mathcal{N}(0,\sigma_1^2+\sigma_2^2)\)。这样我们就得到了\(x_t\)和\(x_0\)的关系,可以通过\(x_0\)推断\(x_t\)。记\(\bar{\alpha}_t = \prod_{s=1}^t\alpha_s\),则有
\[q(x_t\vert x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, 1-\bar{\alpha}_t I)\]通常我们让\(\beta\)逐渐变大,即\(\beta_1 < \beta_2 < \dots < \beta_T\),则\(\bar{\alpha}_t = \prod_{s=1}^t(1-\beta_s)\)逐渐变小,\(\bar{\alpha}_1 > \bar{\alpha}_2 > \dots > \bar{\alpha}_T\)。
反向过程是从一个噪声分布\(x_T\sim \mathcal{N}(0,I)\)开始的去噪过程\(q(x_{t-1}\vert x_t)\),所以要学习一个去噪模型\(p_\theta(x_{t-1}\vert x_t)=\mathcal{N}(x_{t-1};\mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)。事实上以\(x_0\)为条件去噪是可行的,即推断\(q(x_{t-1}\vert x_t, x_0)=\mathcal{N}(x_{t-1};\tilde{\mu}(x_t, x_0),\tilde{\beta}_t I)\)。这里我们根据贝叶斯公式直接给出结论
\[\begin{aligned} \tilde{\beta}_t &= \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t\\ \tilde{\mu}_t &= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 \end{aligned}\]再将前向过程中\(x_0\)和\(x_t\)的关系代入,我们可以得到
\[\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t\right)\]所以如果我们用一个参数为\(\theta\)的神经网络来学习\(\mu_\theta(x_t, t)\),我们实际上可以是在学习一个噪声模型\(\epsilon_\theta(x_t, t)\)。那么
\[x_{t-1} = \mathcal{N}(x_{t-1};\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right), \Sigma_\theta(x_t, t))\\ \text{where }x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon_t\]损失函数可以是最小化\(\tilde{\mu}_t\)和\(\mu_\theta(x_t, t)\)的差距
\[\begin{aligned} L &= \mathbb{E}_{x_0, \epsilon}\left[\frac{1}{2\Vert \Sigma_\theta(x_t, t)\Vert^2}\Vert \tilde{\mu}_t-\mu_\theta(x_t, t)\Vert^2\right]\\ &= \mathbb{E}_{x_0, \epsilon}\left[\frac{1}{2\Vert \Sigma_\theta(x_t, t)\Vert^2}\Vert \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t\right)-\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right)\Vert^2\right]\\ &= \mathbb{E}_{x_0, \epsilon}\left[\frac{1}{2\Vert \Sigma_\theta(x_t, t)\Vert^2}\Vert \frac{1-\alpha_t}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}}(\epsilon_t-\epsilon_\theta(x_t, t))\Vert^2\right]\\ &= \mathbb{E}_{x_0, \epsilon}\left[\frac{(1-\alpha_t)^2}{2\alpha_t(1-\bar{\alpha}_t)\Vert \Sigma_\theta(x_t, t)\Vert^2}\Vert \epsilon_t-\epsilon_\theta(x_t, t)\Vert^2\right]\\ &= \mathbb{E}_{x_0, \epsilon}\left[\frac{(1-\alpha_t)^2}{2\alpha_t(1-\bar{\alpha}_t)\Vert \Sigma_\theta(x_t, t)\Vert^2}\Vert \epsilon_t-\epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon_t, t)\Vert^2\right] \end{aligned}\]简化的损失函数用重参数化技巧直接预测\(\epsilon_\theta(x_t, t)\)代替用变分下界中的负对数损失进行优化,直接预测噪声而不需要对\(\mu_\theta\)显式建模。简单总结下
- 在正向过程,只需要做采样和\(\beta\)的设计,不需要学习
- repeat
- \[x_0\sim q(x)\]
- \[t\sim \text{Uniform}(1, T)\]
- \[\epsilon_t\sim \mathcal{N}(0, I)\]
- \[\nabla_\theta L \Vert \epsilon_t-\epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon_t, t)\Vert^2\]
- repeat
- 反向过程需要对噪声模型\(\epsilon_\theta(x_t, t)\)进行学习,也可以对协方差矩阵\(\Sigma_\theta(x_t, t)\)进行学习。固定的\(\Sigma\)用正向过程的\(\beta\)来替代即可,即\(\Sigma_\theta(x_t, t)=\beta_t I\),更准确来说是\(\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t I\)。
- \[x_T\sim \mathcal{N}(0, I)\]
- for \(t=T, T-1, \dots, 1\)
- \[z\sim \mathcal{N}(0, I)\]
- \[x_{t-1}= \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right) + \sigma_t z\]
- return \(x_0\)
需要说明的是我们学习的噪声\(\epsilon\)是描述\(x_0\)和\(x_t\)之间的关系的噪声,这个噪声可以用来对\(x_t\)和\(x_{t-1}\)之间的关系进行建模,但它并不是\(x_t = \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1}\)里的\(\epsilon_{t-1}\)。所以我们在反向推理时,除了预测噪声\(\epsilon_\theta(x_t, t)\),还需要用一个额外的噪声\(z\)来模拟\(x_{t-1}\)相对于\(\tilde{\mu}_t\)的偏差。
RL for Diffusion Models
扩散模型的逆向过程是一个天然的MDP,因此可以对标RL来进行建模
- \(s\triangleq (x_t, t, c)\),\(c\) for condition
- \[a\triangleq x_{t-1}\]
- \[\pi(a\vert s)\triangleq p_\theta(x_{t-1}\vert x_t, t, c)\]
- \(r\)的设计比较主观,一般我们要定义一个对\(x_0\)的评估函数\(r(x_0,c)\),提供稀疏的奖励信号。
这样就可以套用策略梯度方法来学习,比如REINFORCE算法或者重要性采样算法。
\[\begin{aligned} \nabla_\theta J(\theta) =& \mathbb{E}\left[\sum_{t=1}^T\nabla_\theta\log p_\theta(x_{t-1}\vert x_t, t, c)R_t\right]\text{ or }\\ &\mathbb{E}\left[\sum_{t=1}^T\frac{p_{\theta}(x_{t-1}\vert x_t, t, c)}{p_{\theta_{\text{old}}}(x_{t-1}\vert x_t, t, c)}\nabla_\theta\log p_\theta(x_{t-1}\vert x_t, t, c)R_t\right] \text{ where }\\ p_\theta(x_{t-1}\vert x_t, t, c) =& \mathcal{N}(x_{t-1};\mu_\theta(x_t, t, c), \Sigma_\theta(x_t, t, c)) \text{ i.e. } \\ & \frac{1}{\sqrt{2\pi}\sigma_t}\exp\left(-\frac{1}{2\sigma_t^2}(x_{t-1}-\mu_\theta(x_t, t, c))^2\right) \end{aligned}\]在一般的策略梯度算法中,奖励\(R_t\)是时刻\(t\)之后的累计奖励\(\sum_{i=t}^T \gamma^{i-t}r_t\),但扩散模型是一个逆向过程,只有当\(t=1\)时提供非零奖励,所以我们可以直接用\(r(x_0,c)\)作为奖励,或者用\(r(x_0,c)\gamma^{t-1}\)作为奖励。
为了便于求导,我们假定\(\Sigma\)是固定的,那么策略梯度只与\(\mu\)有关,我们可以直接对\(\mu\)求导
\[\begin{aligned} \nabla_\theta \log p_\theta(x_{t-1}\vert x_t, t, c) &= \nabla_\theta \log \frac{1}{\sqrt{2\pi}\sigma_t}\exp\left(-\frac{1}{2\sigma_t^2}(x_{t-1}-\mu_\theta(x_t, t, c))^2\right)\\ &= -\frac{1}{2\sigma_t^2}\nabla_\theta(x_{t-1}-\mu_\theta(x_t, t, c))^2\\ &= \frac{1}{\sigma_t^2}(x_{t-1}-\mu_\theta(x_t, t, c))\nabla_\theta\mu_\theta(x_t, t, c) \\ &= \frac{1}{\sigma_t^2}(x_{t-1}-\mu_\theta(x_t, t, c))\nabla_\theta\left[\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t,c)\right)\right]\\ &= \frac{\beta_t}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}\sigma_t^2}(\mu_\theta(x_t, t, c)-x_{t-1})\nabla_\theta\epsilon_\theta(x_t, t,c)\\ &= \frac{\beta_tz}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}\sigma_t}\nabla_\theta\epsilon_\theta(x_t, t,c)\\ \end{aligned}\]其实上述公式还可以进一步化简,因为\(\sigma_t\)和\(\beta_t\)也是固定的,且有一定的关系,假设我们根据\(\tilde{\beta}_t=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t\)来设计\(\sigma\),那么\(\sigma_t=\sqrt{\tilde{\beta}_t}\),所以
\[\begin{aligned} \nabla_\theta \log p_\theta(x_{t-1}\vert x_t, t, c) &= \frac{\beta_tz}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}\sqrt{\tilde{\beta}_t}}\nabla_\theta\epsilon_\theta(x_t, t, c)\\ &= \frac{\beta_tz}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}}\sqrt{\frac{1-\bar{\alpha}_{t}}{(1-\bar{\alpha}_{t-1})\beta_t}}\nabla_\theta\epsilon_\theta(x_t, t, c)\\ &= \sqrt{\frac{\beta_t}{\alpha_t(1-\bar{\alpha}_{t-1})}}z\nabla_\theta\epsilon_\theta(x_t, t, c) \\ &= \sqrt{\frac{1-\alpha_t}{\alpha_t-\bar{\alpha}_t}}z\nabla_\theta\epsilon_\theta(x_t, t, c) \end{aligned}\]