什么是Diffusion模型?

作者:Ray Wang,原文:什么是Diffusion模型?

Diffusion过程

扩散(Diffusion)在热力学中指细小颗粒从高密度区域扩散至低密度区域,在统计领域,扩散则指将复杂的分布转换为一个简单的分布的过程。Diffusion模型定义了一个概率分布转换模型T\mathcal{T},能将原始数据x0x_0构成的复杂分布pcomplexp_{\mathrm{complex}}转换为一个简单的已知参数的先验分布ppriorp_{\mathrm{prior}}

x0pcomplex    T(x0)pprior\mathbf{x}_0 \sim p_{\mathrm{complex}} \implies \mathcal{T}(\mathbf{x}_0) \sim p_{\mathrm{prior}}

受到物理领域的热动力学相关知识启发,Diffusion模型提出可以用马尔科夫链(Markov Chain)来构造T\mathcal{T},即定义一系列条件概率分布q(xtxt1),t{1,2,3...T}q(\mathbf{x}_t \vert \mathbf{x}_{t-1}),\quad t\in\{1,2,3...T\},将x0\mathbf{x}_0依次转换为x1\mathbf{x}_1x2\mathbf{x}_2,…,xT\mathbf{x}_T,希望当TinfT \rightarrow \inf时,xTpprior\mathbf{x}_T \sim p_{\mathrm{prior}}

能满足xpprior\mathbf{x}_{\infty} \sim p_{\mathrm{prior}}这个期望的{q,pprior}\{q, p_{\mathrm{prior}}\}组合选择有很多,最简洁有效的选择就是正态分布,即:

(1)q(xtxt1)=N(xt;1βtxt1,βtI)\tag{1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathrm{I})

q(xT)=pprior(xT)=N(xT;0,I)where Tinfq(\mathbf{x}_T) = p_{\mathrm{prior}}(\mathbf{x}_T) = \mathcal{N}(\mathbf{x}_T; \mathbf{0}, \mathrm{I}) \quad where\ T \rightarrow \inf

即已知xt1\textbf{x}_{t-1}时,xt\textbf{x}_t的概率分布是一个平均值为1βtxt1\sqrt{1-\beta_t}\textbf{x}_{t-1},协方差为βtI\beta_t\textbf{I}的正态分布。

根据重参数化技巧可得:

(2)xt=1βtxt1+βtzt1where zt1N(0,I)\tag{2} \textbf{x}_t=\sqrt{1-\beta_t}\textbf{x}_{t-1}+\sqrt{\beta_t}\textbf{z}_{t-1} \quad where\ \textbf{z}_{t-1}\in\mathcal{N}(0, \textbf{I})

这一过程可以视作xt1\textbf{x}_{t-1}与标准正态分布噪声z\textbf{z}混合,扩散率系数βt\beta_t控制融合xt1\textbf{x}_{t-1}分布和标准正态分布的混合比例。从原始数据分布x0\textbf{x}_{0}xT\textbf{x}_{T},这一过程可以视作是在重复地给原始数据分布添加噪声,直到变为一个简单固定的分布为止。

扩散率βt\beta_t到底是什么呢?数据分布混合噪声分布时的比例为什么要设计成1βt\sqrt{1-\beta_t}βt\sqrt{\beta_t}呢?

假设αt=1βt\alpha_t = 1 - \beta_tαˉt=i=1Tαi\bar{\alpha}_t = \prod_{i=1}^T \alpha_i,那么:

(3)xt=αtxt1+1αtzt1 ;where zt1,zt2,N(0,I)=αtαt1xt2+αt(1αt1)zt2+1αtzt1=αtαt1xt2+1αtαt1zˉt2 ;where zˉt2,zˉt3,N(0,I)==αˉtx0+1αˉtz\tag{3}\begin{aligned} \mathbf{x}_t &= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1 - \alpha_t}\mathbf{z}_{t-1} & \text{ ;where } \mathbf{z}_{t-1}, \mathbf{z}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ &= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{\alpha_{t}(1 - \alpha_{t-1})} \mathbf{z}_{t-2}+ \sqrt{1 - \alpha_t}\mathbf{z}_{t-1} \\ &= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \bar{\mathbf{z}}_{t-2} & \text{ ;where } \bar{\mathbf{z}}_{t-2}, \bar{\mathbf{z}}_{t-3}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ &= \dots \\ &= \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z} \end{aligned}

公式中第二行到第三行的转换利用了一个性质:两个正态分布N(0,σ12I)\mathcal{N}(\mathbf{0}, \sigma_1^2\mathbf{I})N(0,σ22I)\mathcal{N}(\mathbf{0}, \sigma_2^2\mathbf{I})相加,新分布为N(0,(σ12+σ22)I)\mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2)\mathbf{I})

将公式3写为条件概率形式,可以得到:

(4)q(xtx0)=N(xt;αˉtx0,(1αˉt)I)\tag{4} q(\mathbf{x}_t \vert \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I})

由于βt(0,1)\beta_t\in(0,1),那么αt(0,1)\alpha_t\in(0,1)tinft \rightarrow \inf时,αˉt0\bar{\alpha}_t \rightarrow 0。可以看出,1βt\sqrt{1-\beta_t}βt\sqrt{\beta_t}作为系数保证了当 TinfT \rightarrow \inf时,q(xT)=pprior(xT)=N(0,I)q(\mathbf{x}_T) = p_{\mathrm{prior}}(\mathbf{x}_T)=\mathcal{N}(0,\textbf{I})。实际上,只要TT取一个足够大的值,不需要无限次迭代,得到的分布就已经很接近于标准正态分布了。

βtR\beta_t \in \mathbb{R}具体的取值可以预先定义。原论文使用从0.0001到0.02的线性插值作为所有β\beta的取值。

以上就是原数据分布到简单先验噪声分布的转换过程T\mathcal{T}的描述。值得注意的是,当βt\beta_t预定义时,上述整个扩散过程没有出现一个可学习的参数,就可以将任意原始复杂的分布转换为简单先验分布(标准正态分布)。

下面的示意图展示了一维数据分布pcomplexp_{\mathrm{complex}}中的两个样本(分别标识为蓝色和红色),经过多次加噪,最终被转换为ppriorp_{\mathrm{prior}}中的两个样本的过程。

通过Diffusion模型的前向过程,复杂的分布pcomplexp_{\mathrm{complex}}被转换为了一个标准正态分布ppriorp_{\mathrm{prior}}

逆转Diffusion过程

和GAN类似,Diffusion模型的最终目标是从ppriorp_{\mathrm{prior}}中采样一个样本,将其转换为原始数据分布中的一个样本。显然,如果逆转上一节提到的Diffusion过程,依次从q(xt1xt),t{T,T1,T2...0}q(\mathbf{x}_{t-1} \vert \mathbf{x}_t),\quad t\in\{T,T-1,T-2...0\}中采样,Diffusion模型就可以实现从xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})到数据分布pcomplexp_{\mathrm{complex}}的转换。

现在问题来了,q(xt1xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)到底是什么样的分布呢?Feller等人在1949年: On the Theory of Stochastic Processes, with Particular Reference to Applications证明连续扩散过程的逆转具有与正向过程相同的分布形式。即当扩散率βt\beta_t足够小,扩散次数足够多时,离散扩散过程接近于连续扩散过程,q(xt1xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)的分布形式同q(xtxt1)q(\mathbf{x}_{t} \vert \mathbf{x}_{t-1})一致,同样是高斯分布。 但是很难直接写出q(xt1xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)的分布参数。为此,可以用分布pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)来近似q(xt1xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)

(5)pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))\tag{5} p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))

其中μθ\boldsymbol{\mu}_\thetaΣθ\boldsymbol{\Sigma}_\theta都是要学习的函数,接受xt,t\mathbf{x}_t, t作为参数。

这样,连续迭代多次后,可以得到近似的真实数据分布pθ(x0)p_\theta(\mathbf{x}_{0})为:

(6)pθ(x0)=pθ(x0:T)dx1:T\tag{6} p_\theta(\mathbf{x}_{0})=\int p_\theta(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}

其中pθ(x0:T)p_\theta(\mathbf{x}_{0:T})x0,x1,,xT\mathbf{x}_{0},\mathbf{x}_{1},\cdots,\mathbf{x}_{T}的联合概率分布。借助条件概率公式:

(7)pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)\tag{7} p_\theta(\mathbf{x}_{0:T})=p(\mathbf{x}_T) \prod^T_{t=1}p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)

在有了μθ\boldsymbol{\mu}_\thetaΣθ\boldsymbol{\Sigma}_\theta后,pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)就被确认了下来。就可以完成逆转Diffusion过程了。如下图所示:

首先从N(0,I)\mathcal{N}(\mathbf{0}, \mathbf{I})中采样得到xT\mathbf{x}_T,然后在以μθ(xT,T)\boldsymbol{\mu}_\theta(\mathbf{x}_T, T)为均值,Σθ(xT,T)\boldsymbol{\Sigma}_\theta(\mathbf{x}_T, T)为方差的正态分布中采样得到xT1\mathbf{x}_{T-1}。依次重复这个过程,直到得到最终结果x0\mathbf{x}_0。由于q(xt1xt)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)未知,所以在逆转Diffusion过程中,用学习到的pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)代替它。

以二维平面上的一个数据集为例,原数据集中的二维数据构成了类似于字母e的一个图案,在Diffusion前向过程中,经过两次迭代,原始数据分布就被转换为了第一行第三列这样丧失了所有结构信息的分布,接近于高斯噪声分布。而逆转Diffusion过程则在t=Tt=T时刻开始,先从高斯噪声中采样,然后依次得到t=T2t=\frac{T}{2}t=0t=0时刻的数据分布。重建得到的分布很接近于原始数据分布,得到了一个非常不错的生成模型。

上图中还有一行值得注意,即第三行的漂移项。从xt\mathbf{x}_{t}xt1\mathbf{x}_{t-1},实际是一个采样过程。逆扩散过程中,第tt时刻的一个数据点xtx_t,对应于第t1t-1时刻xt1\mathbf{x}_{t-1}的一个高斯分布。这听起来有些奇怪,期望逆Diffusion过程能将非结构化的噪声分布转化为结构化的数据分布,中间每一个步骤应当更“结构化”才对,怎么tt时刻的一个数据点变成了t1t-1时刻的一个高斯分布了呢?点到分布,似乎更“乱”,更“非结构化”了。实际上,对应分布的方差Σθ(xt,t)=σt2I\mathbf{\Sigma}_{\theta}(\mathbf{x}_t, t) = \sigma_t^2\textbf{I}σt2\sigma_t^2的取值很接近于βt\beta_{t},即方差很小,而平均值μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)是被网络预测(可以视作一个去噪过程)得到。只要μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)预测的准,能准确的去除xt\mathbf{x}_{t}的噪声,就消除了分布中的“非结构化”信息。如第三列所示,μθ(xt,t)xt\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)-\mathbf{x}_t在噪声比较大的地方(即远离原数据分布的点),值也大;而噪声小的地方,接近于0。

总结地说,Diffusion过程可以被视作在逐渐加噪声,而逆Diffusion过程则是在逐渐去噪声。学习的网络需要建模估计输入图片中的噪声。

训练目标

现在只剩下最后一个问题,究竟怎么优化得到理想的μθ\boldsymbol{\mu}_\thetaΣθ\boldsymbol{\Sigma}_\theta?类似于其它生成模型,可以最小化在真实数据期望下,模型预测分布的负对数似然,即最小化预测pdata=q(x0)p_{\mathrm{data}}=q({\mathbf{x}_0})pθ(x0)p_{\theta}(\mathbf{x}_0)的交叉熵:

(8)L=Ex0q(x0)[logpθ(x0)]\tag{8} \mathcal{L} = \mathbb{E}_{\mathbf{x}_0 \sim q({\mathbf{x}_0})}\big[ - \log p_{\theta}(\mathbf{x}_0) \big]

事实上没法写出pθ(x0)p_{\theta}(\mathbf{x}_0)的表达式,直接计算上面的交叉熵难度很大。目前已知的仅有公式6,7以及pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t)q(xtxt1)q(\mathbf{x}_t \vert \mathbf{x}_{t-1})的表达式。为此,可以做一些数学推导,将公式8中的pθ(x0)p_{\theta}(\mathbf{x}_0)转换为已知的东西:

(9)L=Eq(x0)logpθ(x0)=Eq(x0)log(pθ(x0:T)dx1:T)=Eq(x0)log(q(x1:Tx0)pθ(x0:T)q(x1:Tx0)dx1:T)=Eq(x0)log(Eq(x1:Tx0)pθ(x0:T)q(x1:Tx0))Eq(x0:T)logpθ(x0:T)q(x1:Tx0)=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]\tag{9} \begin{aligned} \mathcal{L} &= - \mathbb{E}_{q(\mathbf{x}_0)} \log p_\theta(\mathbf{x}_0) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T} \Big) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} d\mathbf{x}_{1:T} \Big) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_0)} \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \Big) \\ &\leq - \mathbb{E}_{q(\mathbf{x}_{0:T})} \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \\ &= \mathbb{E}_{q(\mathbf{x}_{0:T})}\Big[\log \frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_\theta(\mathbf{x}_{0:T})} \Big] \end{aligned}

其中q(x1:Tx0)=t=1Tq(xtxt1)q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1}),等式变不等式那一步利用了Jensen不等式。根据公式9,为了最小化L\mathcal{L},我们可以转而去最小化其上界LVLBL_{VLB}

(10)LVLB=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=Eq[logt=1Tq(xtxt1)pθ(xT)t=1Tpθ(xt1xt)]=Eq[logpθ(xT)+t=1Tlogq(xtxt1)pθ(xt1xt)]=Eq[logpθ(xT)+t=2Tlogq(xtxt1)pθ(xt1xt)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlog(q(xt1xt,x0)pθ(xt1xt)q(xtx0)q(xt1x0))+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+t=2Tlogq(xtx0)q(xt1x0)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+logq(xTx0)q(x1x0)+logq(x1x0)pθ(x0x1)]=Eq[logq(xTx0)pθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)logpθ(x0x1)]=Eq[logpθ(x0x1)L0]+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))Lt1+DKL(q(xTx0)pθ(xT))LT\tag{10} \begin{aligned} L_\text{VLB} &= \mathbb{E}_{q(\mathbf{x}_{0:T})} \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \\ &= \mathbb{E}_q \Big[ \log\frac{\prod_{t=1}^T q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{ p_\theta(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t) } \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{\color{blue}{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \Big( \frac{\color{blue}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\cdot \frac{\color{blue}{q(\mathbf{x}_t \vert \mathbf{x}_0)}}{\color{blue}{q(\mathbf{x}_{t-1}\vert\mathbf{x}_0)}} \Big) + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{q(\mathbf{x}_1 \vert \mathbf{x}_0)} + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big]\\ &= \mathbb{E}_q \Big[ \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1) \Big] \\ &= \mathbb{E}_q [ \underbrace{- \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)}_{L_0} ] + \sum_{t=2}^T \underbrace{D_\text{KL}(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t))}_{L_{t-1}} + \underbrace{D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T))}_{L_T} \end{aligned}

上式中蓝色部分直接的变换实际上利用了贝叶斯公式:

(11)q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)\tag{11} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) }

注意由马尔科夫链的性质,有q(xtxt1,x0)=q(xtxt1)q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0)=q(\mathbf{x}_t \vert \mathbf{x}_{t-1})

再重新看公式10的最后一行,可以看出LVLBL_{VLB}实际上由一个熵(L0L_0),以及多个KL散度(Lt,t{1,2,3,,T}L_{t},t \in \{1,2,3,\cdots,T\})构成。其中LTL_TxT\mathbf{x}_Tx0\mathbf{x}_0一个是先验分布,一个是数据分布,都是固定的,故LTL_T是一个常数,最小化LVLBL_{VLB}时可以忽略。可以只去研究L0L_0Lt,t{1,2,3,,T1}L_{t},t \in \{1,2,3,\cdots,T-1\}

分布q(xt1xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)和分布pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)之间的KL散度

根据公式5,分布pθ(xt1xt)p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)是一个高斯分布,其平均值和方差由Diffusion模型网络预测产生。

而分布q(xt1xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)可以根据贝叶斯定律,即公式11继续推下去得到:

(12)q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)=q(xtxt1)q(xt1x0)q(xtx0)exp(12((xtαtxt1)2βt+(xt1αˉt1x0)21αˉt1(xtαˉtx0)21αˉt))=exp(12((αtβt+11αˉt1)xt12(2αtβtxt+2αˉt11αˉt1x0)xt1+C(xt,x0)))\tag{12} \begin{aligned} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) &= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) } \\ &= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) } \\ &\propto \exp \Big(-\frac{1}{2} \big(\frac{(\mathbf{x}_t - \sqrt{\alpha_t} \mathbf{x}_{t-1})^2}{\beta_t} + \frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp\Big( -\frac{1}{2} \big( \color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})} \mathbf{x}_{t-1}^2 - \color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{2\sqrt{\bar{\alpha}_{t-1}}}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)} \mathbf{x}_{t-1} + C(\mathbf{x}_t, \mathbf{x}_0) \big) \Big) \end{aligned}

继续推导下去,可以发现q(xt1xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)同样是一个高斯分布。假设:

(13)q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)\tag{13} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \color{blue}{\tilde{\boldsymbol{\mu}}}(\mathbf{x}_t, \mathbf{x}_0), \color{red}{\tilde{\beta}_t} \mathbf{I})

那么,由公式12,公式13中的两个新变量:

(14)β~t=1/(αtβt+11αˉt1)=1αˉt11αˉtβtμ~t(xt,x0)=(αtβtxt+αˉt11αˉt1x0)/(αtβt+11αˉt1)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0\tag{14} \begin{aligned} \tilde{\beta}_t &= 1/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t \\ \tilde{\boldsymbol{\mu}}_t (\mathbf{x}_t, \mathbf{x}_0) &= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0\\ \end{aligned}

因此,最小化LtL_t这个KL损失实际上目标就是拉近下面这两个高斯分布的距离:

(15)q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))\tag{15} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), {\tilde{\beta}_t} \mathbf{I}) \longleftrightarrow p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))

多元正态分布之间的KL散度可以直接根据分布参数计算出来。

(16)Lt=Eq[12Σθ(xt,t)22μ~t(xt,x0)μθ(xt,t)2]+C\tag{16} L_t = \mathbb{E}_{q} \Big[\frac{1}{2 \| \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) \|^2_2} \|{\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)} - {\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \Big] + C

上式中CC是一个不依赖于θ\theta的常数。为了模型简单,可以令Σθ(xt,t)=σt2I\mathbf{\Sigma}_{\theta}(\mathbf{x}_t, t) = \sigma_t^2\textbf{I},其中σt2\sigma_t^2可以设置为βt\beta_tβ~t\tilde{\beta}_t,论文说这两个选择效果差不多。实际上,当σt2=β~t\sigma_t^2=\tilde{\beta}_t时,公式15中的两个分布的方差就一样了。这一选择是为了简化计算,并不是唯一的。

从公式16来看,只需要定义一个网络μθ(xt,t)\mu_\theta(\mathbf{x}_t, t),使用L2损失约束其预测值同μ~t(xt,x0)\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)一致即可。具体来说,可以定义一个接受xtx_ttt作为参数的网络,从原数据分布中采样一个数据x0x_0,通过公式3计算得到xtx_t,然后利用公式14计算得到μ~t(xt,x0)\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0),将xtx_ttt送入网络得到μθ(xt,t)\mu_\theta(\mathbf{x}_t, t)。使用L2损失约束两个样本一致,并优化网络。

但DDPM并没有停止于此,继续分析化简公式16。μ~t(xt,x0)\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)的输入有xt,x0\mathbf{x}_t, \mathbf{x}_0,而μθ(xt,t)\mu_\theta(\mathbf{x}_t, t)xtx_t作为输入。借助公式3,可以得到x0=1αˉt(xt1αˉtzt)\mathbf{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t)。将其代入,有:

(17)μ~t=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt1αˉt(xt1αˉtzt)=1αt(xtβt1αˉtzt)\tag{17} \begin{aligned} \tilde{\boldsymbol{\mu}}_t &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t) \\ &= \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big) \end{aligned}

根据公式16和公式17,μθ(xt,t)\mu_\theta(\mathbf{x}_t, t)在给定xt\mathbf{x}_t的情况下,需要预测出1αt(xtβt1αˉtzt)\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big)。为了降低学习的难度,可以直接定义:

(18)μθ(xt,t)=1αt(xtβt1αˉtzθ(xt,t))\tag{18} \boldsymbol{\mu}_\theta(\mathbf{x}_t, t) = {\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_\theta(\mathbf{x}_t, t) \Big)}

这样,公式16可以继续简化:

(19)LtC=Ex0,zt[12σt2μ~t(xt,x0)μθ(xt,t)2]=Ex0,zt[12σt21αt(xtβt1αˉtzt)1αt(xtβt1αˉtzθ(xt,t))2]=Ex0,zt[βt22αt(1αˉt)σt2ztzθ(xt,t)2]=Ex0,zt[βt22αt(1αˉt)σt2ztzθ(αˉtx0+1αˉtzt,t)2]\tag{19} \begin{aligned} L_t - C &= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}_t} \Big[\frac{1}{2\sigma_t^2} \| \color{blue}{\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)} - \color{green}{\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}_t} \Big[\frac{1}{2\sigma_t^2} \| \color{blue}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big)} - \color{green}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\mathbf{z}}_\theta(\mathbf{x}_t, t) \Big)} \|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}_t} \Big[\frac{ \beta_t^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \sigma_t^2} \|\mathbf{z}_t - \mathbf{z}_\theta(\mathbf{x}_t, t)\|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}_t} \Big[\frac{ \beta_t^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \sigma_t^2} \|\mathbf{z}_t - \mathbf{z}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t, t)\|^2 \Big] \end{aligned}

公式19表示在优化时,采样x0pdata\mathbf{x}_0 \sim \mathbf{p}_{data}ztN(0,I)\mathbf{z}_t \in \mathcal{N}(0, \mathbf{I}),后计算αˉtx0+1αˉtzt\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t,然后联合时间tt,送入zθ\mathbf{z}_\theta,得到预测值,约束其与zt\mathbf{z}_t一致。

计算L0L_0

已知L0=Ex0,x1logpθ(x0x1)L_0=-\mathbb{E}_{\mathbf{x}_0, \mathbf{x}_1}\log p_\theta(\mathbf{x}_{0} \vert \mathbf{x}_1),而pθ(x0x1)=N(μθ(x1,1),σ12I)p_\theta(\mathbf{x}_{0} \vert \mathbf{x}_1) = \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I})。所以L0L_0实际上是一个多元高斯分布的负对数似然的期望,即其熵。多元高斯分布的熵仅与其协方差有关,即L0L_0仅与σ12I\sigma_1^2\mathbf{I}有关,L0L_0是一个常数。

然而,论文DDPM指出,一般而言,x0\mathbf{x}_0的分布实际上是离散的,而不是连续的。比如图片数据,像素值取值必须是整数,归一化到[1,1][-1,1]后,依然是离散的点。Diffusion前向的第一步实际上是为离散数据添加噪声。那么,逆Diffusion的最后一步,即从x1\mathbf{x}_1x0\mathbf{x}_0,也不能被简单地看作从N(μθ(x1,1),σ12I)\mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I})中采样,而是在从N(μθ(x1,1),σ12I)\mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I})采样的基础上再加上离散化操作。L0L_0也不再是一个常数,而是一个与μθ(x1,1)\mu_\theta(\mathbf{x}_1, 1)相关的积分,其具体表达式可以参考DDPM的Sec3.3。在忽略σ12\sigma_1^2和边缘效应后,L0L_0的取值可以被N(μθ(x1,1),σ12I)\mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I})的密度函数与离散时的分块大小(bin width)相乘所拟合。

另外值得一提的是,逆Diffusion的最后一步,DDPM直接取μθ(x1,1)\mu_\theta(\mathbf{x}_1, 1)作为x0\mathbf{x}_0

简化训练目标

上文已经分别描述了Lt,t{0,1,2,3,,T1}L_{t},t \in \{0,1,2,3,\cdots,T-1\}的计算过程,最终可以按照公式10,最小化L0+t=1T1LtL_0+\sum_{t=1}^{T-1} L_{t}来优化网络。论文DDPM发现,去除LtL_{t}中的加权系数βt22αt(1αˉt)σt2\frac{ \beta_t^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \sigma_t^2},得到简化的训练目标如下:

(20)Lsimple(θ):=Et,x0,ϵt[ϵtzθ(αˉtx0+1αˉtϵt,t)2]\tag{20} L_\text{simple}(\theta) := \mathbb{E}_{t,\mathbf{x}_0, \mathbf{\epsilon}_t} \Big[\|\mathbf{\epsilon}_t - \mathbf{z}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{\epsilon}_t, t)\|^2 \Big]

公式中tt{1,2,,T}\{1,2,\cdots,T\}中均匀采样。t=1t=1时对应于L0L_0的一个近似,t>1t>1时对应于去除了加权系数的公式19。

相对于直接计算LVLBL_{VLB}LsimpleL_\text{simple}实现起来更加简单,tt较小时的LtL_t权重被减少,tt较大时的权重被增加。这样网络能更专注于tt较大,图片中噪声更多时,更难更复杂的噪声预测任务。

训练采样流程

可以将上文描述的Diffusion模型的训练采样过程分别总结如下:

训练时,分别从q(x0)q(\mathbf{x}_0)Uniform(1,,T)Uniform({1,\cdots,T})N(0,I)\mathcal{N}(\mathbf{0},\textbf{I})中采样得到x0x_0ttϵ\epsilon,利用公式3计算得到xtx_t,将xtx_ttt送入网络,预测得到一个噪声。最小化预测噪声和真实采样的ϵ\epsilon之间的距离。重复这一过程直到网络收敛。

Diffusion模型的逆转采样每个时刻主要包含以下三步:

  1. xtx_ttt送入网络,预测得到噪声ϵ\epsilon
  2. 利用估计的噪声ϵ\epsilonxtx_t,计算μθ=1αt(xtβt1αˉtϵ)\mu_\theta= \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon \Big)
  3. 如果t>1t>1,需要从N(μθ,σt2I)\mathcal{N}(\mu_\theta, \sigma_t^2\mathbf{I})中采样得到xt1x_{t-1},利用重参数化技巧,可以将采样过程转换为首先采样zN(0,I)z\in\mathcal{N}(\mathbf{0},\textbf{I}),然后计算xt1=μθ+σtzx_{t-1}=\mu_\theta+\sigma_tz。如果t=1t=1,直接令x0=μθx_0=\mu_\theta

总结

Diffusion模型的每一步推导都有严密的数学基础,调整其细节时,必须仔细思考背后的数学基础。如果它火起来,成为生成模型的主流,简直是不给我这种调参侠活路!

参考文献

写本篇博客时,我主要参考了下述论文和博客文章。

相关论文:

  1. Sohl-Dickstein, J., Weiss, E.A., Maheswaranathan, N., & Ganguli, S. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. ArXiv, abs/1503.03585.
  2. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. ArXiv, abs/2006.11239.

网页链接:

  1. What are Diffusion Models? | Lil’Log (lilianweng.github.io)
  2. diffusion_models/Diffusion_models.ipynb at main · InFoCusp/diffusion_models (github.com)
  3. Ayan Das · An introduction to Diffusion Probabilistic Models