作者:Ray Wang,原文:什么是Diffusion模型?
Diffusion过程
扩散(Diffusion)在热力学中指细小颗粒从高密度区域扩散至低密度区域,在统计领域,扩散则指将复杂的分布转换为一个简单的分布的过程。Diffusion模型定义了一个概率分布转换模型T,能将原始数据x0构成的复杂分布pcomplex转换为一个简单的已知参数的先验分布pprior:
x0∼pcomplex⟹T(x0)∼pprior
受到物理领域的热动力学相关知识启发,Diffusion模型提出可以用马尔科夫链(Markov Chain)来构造T,即定义一系列条件概率分布q(xt∣xt−1),t∈{1,2,3...T},将x0依次转换为x1,x2,…,xT,希望当T→inf时,xT∼pprior。
能满足x∞∼pprior这个期望的{q,pprior}组合选择有很多,最简洁有效的选择就是正态分布,即:
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)(1)
q(xT)=pprior(xT)=N(xT;0,I)where T→inf
即已知xt−1时,xt的概率分布是一个平均值为1−βtxt−1,协方差为βtI的正态分布。
根据重参数化技巧可得:
xt=1−βtxt−1+βtzt−1where zt−1∈N(0,I)(2)
这一过程可以视作xt−1与标准正态分布噪声z混合,扩散率系数βt控制融合xt−1分布和标准正态分布的混合比例。从原始数据分布x0到xT,这一过程可以视作是在重复地给原始数据分布添加噪声,直到变为一个简单固定的分布为止。
扩散率βt到底是什么呢?数据分布混合噪声分布时的比例为什么要设计成1−βt和βt呢?
假设αt=1−βt,αˉt=∏i=1Tαi,那么:
xt=αtxt−1+1−αtzt−1=αtαt−1xt−2+αt(1−αt−1)zt−2+1−αtzt−1=αtαt−1xt−2+1−αtαt−1zˉt−2=…=αˉtx0+1−αˉtz ;where zt−1,zt−2,⋯∼N(0,I) ;where zˉt−2,zˉt−3,⋯∼N(0,I)(3)
公式中第二行到第三行的转换利用了一个性质:两个正态分布N(0,σ12I)和N(0,σ22I)相加,新分布为N(0,(σ12+σ22)I)。
将公式3写为条件概率形式,可以得到:
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)(4)
由于βt∈(0,1),那么αt∈(0,1)。t→inf时,αˉt→0。可以看出,1−βt和βt作为系数保证了当 T→inf时,q(xT)=pprior(xT)=N(0,I)。实际上,只要T取一个足够大的值,不需要无限次迭代,得到的分布就已经很接近于标准正态分布了。
βt∈R具体的取值可以预先定义。原论文使用从0.0001到0.02的线性插值作为所有β的取值。
以上就是原数据分布到简单先验噪声分布的转换过程T的描述。值得注意的是,当βt预定义时,上述整个扩散过程没有出现一个可学习的参数,就可以将任意原始复杂的分布转换为简单先验分布(标准正态分布)。
下面的示意图展示了一维数据分布pcomplex中的两个样本(分别标识为蓝色和红色),经过多次加噪,最终被转换为pprior中的两个样本的过程。
通过Diffusion模型的前向过程,复杂的分布pcomplex被转换为了一个标准正态分布pprior。
逆转Diffusion过程
和GAN类似,Diffusion模型的最终目标是从pprior中采样一个样本,将其转换为原始数据分布中的一个样本。显然,如果逆转上一节提到的Diffusion过程,依次从q(xt−1∣xt),t∈{T,T−1,T−2...0}中采样,Diffusion模型就可以实现从xT∼N(0,I)到数据分布pcomplex的转换。
现在问题来了,q(xt−1∣xt)到底是什么样的分布呢?Feller等人在1949年: On the Theory of Stochastic Processes, with Particular Reference to Applications证明连续扩散过程的逆转具有与正向过程相同的分布形式。即当扩散率βt足够小,扩散次数足够多时,离散扩散过程接近于连续扩散过程,q(xt−1∣xt)的分布形式同q(xt∣xt−1)一致,同样是高斯分布。 但是很难直接写出q(xt−1∣xt)的分布参数。为此,可以用分布pθ(xt−1∣xt)来近似q(xt−1∣xt):
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))(5)
其中μθ和Σθ都是要学习的函数,接受xt,t作为参数。
这样,连续迭代多次后,可以得到近似的真实数据分布pθ(x0)为:
pθ(x0)=∫pθ(x0:T)dx1:T(6)
其中pθ(x0:T)为x0,x1,⋯,xT的联合概率分布。借助条件概率公式:
pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)(7)
在有了μθ和Σθ后,pθ(xt−1∣xt)就被确认了下来。就可以完成逆转Diffusion过程了。如下图所示:
首先从N(0,I)中采样得到xT,然后在以μθ(xT,T)为均值,Σθ(xT,T)为方差的正态分布中采样得到xT−1。依次重复这个过程,直到得到最终结果x0。由于q(xt−1∣xt)未知,所以在逆转Diffusion过程中,用学习到的pθ(xt−1∣xt)代替它。
以二维平面上的一个数据集为例,原数据集中的二维数据构成了类似于字母e的一个图案,在Diffusion前向过程中,经过两次迭代,原始数据分布就被转换为了第一行第三列这样丧失了所有结构信息的分布,接近于高斯噪声分布。而逆转Diffusion过程则在t=T时刻开始,先从高斯噪声中采样,然后依次得到t=2T和t=0时刻的数据分布。重建得到的分布很接近于原始数据分布,得到了一个非常不错的生成模型。
上图中还有一行值得注意,即第三行的漂移项。从xt到xt−1,实际是一个采样过程。逆扩散过程中,第t时刻的一个数据点xt,对应于第t−1时刻xt−1的一个高斯分布。这听起来有些奇怪,期望逆Diffusion过程能将非结构化的噪声分布转化为结构化的数据分布,中间每一个步骤应当更“结构化”才对,怎么t时刻的一个数据点变成了t−1时刻的一个高斯分布了呢?点到分布,似乎更“乱”,更“非结构化”了。实际上,对应分布的方差Σθ(xt,t)=σt2I,σt2的取值很接近于βt,即方差很小,而平均值μθ(xt,t)是被网络预测(可以视作一个去噪过程)得到。只要μθ(xt,t)预测的准,能准确的去除xt的噪声,就消除了分布中的“非结构化”信息。如第三列所示,μθ(xt,t)−xt在噪声比较大的地方(即远离原数据分布的点),值也大;而噪声小的地方,接近于0。
总结地说,Diffusion过程可以被视作在逐渐加噪声,而逆Diffusion过程则是在逐渐去噪声。学习的网络需要建模估计输入图片中的噪声。
训练目标
现在只剩下最后一个问题,究竟怎么优化得到理想的μθ和Σθ?类似于其它生成模型,可以最小化在真实数据期望下,模型预测分布的负对数似然,即最小化预测pdata=q(x0)和pθ(x0)的交叉熵:
L=Ex0∼q(x0)[−logpθ(x0)](8)
事实上没法写出pθ(x0)的表达式,直接计算上面的交叉熵难度很大。目前已知的仅有公式6,7以及pθ(xt−1∣xt)和q(xt∣xt−1)的表达式。为此,可以做一些数学推导,将公式8中的pθ(x0)转换为已知的东西:
L=−Eq(x0)logpθ(x0)=−Eq(x0)log(∫pθ(x0:T)dx1:T)=−Eq(x0)log(∫q(x1:T∣x0)q(x1:T∣x0)pθ(x0:T)dx1:T)=−Eq(x0)log(Eq(x1:T∣x0)q(x1:T∣x0)pθ(x0:T))≤−Eq(x0:T)logq(x1:T∣x0)pθ(x0:T)=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)](9)
其中q(x1:T∣x0)=∏t=1Tq(xt∣xt−1),等式变不等式那一步利用了Jensen不等式。根据公式9,为了最小化L,我们可以转而去最小化其上界LVLB。
LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]=Eq[logpθ(xT)∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1)]=Eq[−logpθ(xT)+t=1∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0)⋅q(xt−1∣x0)q(xt∣x0))+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+t=2∑Tlogq(xt−1∣x0)q(xt∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+logq(x1∣x0)q(xT∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[logpθ(xT)q(xT∣x0)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)−logpθ(x0∣x1)]=Eq[L0−logpθ(x0∣x1)]+t=2∑TLt−1DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))+LTDKL(q(xT∣x0)∥pθ(xT))(10)
上式中蓝色部分直接的变换实际上利用了贝叶斯公式:
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)(11)
注意由马尔科夫链的性质,有q(xt∣xt−1,x0)=q(xt∣xt−1)。
再重新看公式10的最后一行,可以看出LVLB实际上由一个熵(L0),以及多个KL散度(Lt,t∈{1,2,3,⋯,T})构成。其中LT中xT和x0一个是先验分布,一个是数据分布,都是固定的,故LT是一个常数,最小化LVLB时可以忽略。可以只去研究L0和Lt,t∈{1,2,3,⋯,T−1}。
分布q(xt−1∣xt,x0)和分布pθ(xt−1∣xt)之间的KL散度
根据公式5,分布pθ(xt−1∣xt)是一个高斯分布,其平均值和方差由Diffusion模型网络预测产生。
而分布q(xt−1∣xt,x0)可以根据贝叶斯定律,即公式11继续推下去得到:
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)∝exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))(12)
继续推导下去,可以发现q(xt−1∣xt,x0)同样是一个高斯分布。假设:
q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)(13)
那么,由公式12,公式13中的两个新变量:
β~tμ~t(xt,x0)=1/(βtαt+1−αˉt−11)=1−αˉt1−αˉt−1⋅βt=(βtαtxt+1−αˉt−1αˉt−1x0)/(βtαt+1−αˉt−11)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0(14)
因此,最小化Lt这个KL损失实际上目标就是拉近下面这两个高斯分布的距离:
q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)⟷pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))(15)
多元正态分布之间的KL散度可以直接根据分布参数计算出来。
Lt=Eq[2∥Σθ(xt,t)∥221∥μ~t(xt,x0)−μθ(xt,t)∥2]+C(16)
上式中C是一个不依赖于θ的常数。为了模型简单,可以令Σθ(xt,t)=σt2I,其中σt2可以设置为βt或β~t,论文说这两个选择效果差不多。实际上,当σt2=β~t时,公式15中的两个分布的方差就一样了。这一选择是为了简化计算,并不是唯一的。
从公式16来看,只需要定义一个网络μθ(xt,t),使用L2损失约束其预测值同μ~t(xt,x0)一致即可。具体来说,可以定义一个接受xt和t作为参数的网络,从原数据分布中采样一个数据x0,通过公式3计算得到xt,然后利用公式14计算得到μ~t(xt,x0),将xt和t送入网络得到μθ(xt,t)。使用L2损失约束两个样本一致,并优化网络。
但DDPM并没有停止于此,继续分析化简公式16。μ~t(xt,x0)的输入有xt,x0,而μθ(xt,t)以xt作为输入。借助公式3,可以得到x0=αˉt1(xt−1−αˉtzt)。将其代入,有:
μ~t=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtαˉt1(xt−1−αˉtzt)=αt1(xt−1−αˉtβtzt)(17)
根据公式16和公式17,μθ(xt,t)在给定xt的情况下,需要预测出αt1(xt−1−αˉtβtzt)。为了降低学习的难度,可以直接定义:
μθ(xt,t)=αt1(xt−1−αˉtβtzθ(xt,t))(18)
这样,公式16可以继续简化:
Lt−C=Ex0,zt[2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2]=Ex0,zt[2σt21∥αt1(xt−1−αˉtβtzt)−αt1(xt−1−αˉtβtzθ(xt,t))∥2]=Ex0,zt[2αt(1−αˉt)σt2βt2∥zt−zθ(xt,t)∥2]=Ex0,zt[2αt(1−αˉt)σt2βt2∥zt−zθ(αˉtx0+1−αˉtzt,t)∥2](19)
公式19表示在优化时,采样x0∼pdata和zt∈N(0,I),后计算αˉtx0+1−αˉtzt,然后联合时间t,送入zθ,得到预测值,约束其与zt一致。
计算L0
已知L0=−Ex0,x1logpθ(x0∣x1),而pθ(x0∣x1)=N(μθ(x1,1),σ12I)。所以L0实际上是一个多元高斯分布的负对数似然的期望,即其熵。多元高斯分布的熵仅与其协方差有关,即L0仅与σ12I有关,L0是一个常数。
然而,论文DDPM指出,一般而言,x0的分布实际上是离散的,而不是连续的。比如图片数据,像素值取值必须是整数,归一化到[−1,1]后,依然是离散的点。Diffusion前向的第一步实际上是为离散数据添加噪声。那么,逆Diffusion的最后一步,即从x1到x0,也不能被简单地看作从N(μθ(x1,1),σ12I)中采样,而是在从N(μθ(x1,1),σ12I)采样的基础上再加上离散化操作。L0也不再是一个常数,而是一个与μθ(x1,1)相关的积分,其具体表达式可以参考DDPM的Sec3.3。在忽略σ12和边缘效应后,L0的取值可以被N(μθ(x1,1),σ12I)的密度函数与离散时的分块大小(bin width)相乘所拟合。
另外值得一提的是,逆Diffusion的最后一步,DDPM直接取μθ(x1,1)作为x0。
简化训练目标
上文已经分别描述了Lt,t∈{0,1,2,3,⋯,T−1}的计算过程,最终可以按照公式10,最小化L0+∑t=1T−1Lt来优化网络。论文DDPM发现,去除Lt中的加权系数2αt(1−αˉt)σt2βt2,得到简化的训练目标如下:
Lsimple(θ):=Et,x0,ϵt[∥ϵt−zθ(αˉtx0+1−αˉtϵt,t)∥2](20)
公式中t从{1,2,⋯,T}中均匀采样。t=1时对应于L0的一个近似,t>1时对应于去除了加权系数的公式19。
相对于直接计算LVLB,Lsimple实现起来更加简单,t较小时的Lt权重被减少,t较大时的权重被增加。这样网络能更专注于t较大,图片中噪声更多时,更难更复杂的噪声预测任务。
训练采样流程
可以将上文描述的Diffusion模型的训练采样过程分别总结如下:
训练时,分别从q(x0)、Uniform(1,⋯,T)、N(0,I)中采样得到x0,t和ϵ,利用公式3计算得到xt,将xt和t送入网络,预测得到一个噪声。最小化预测噪声和真实采样的ϵ之间的距离。重复这一过程直到网络收敛。
Diffusion模型的逆转采样每个时刻主要包含以下三步:
- 将xt和t送入网络,预测得到噪声ϵ
- 利用估计的噪声ϵ和xt,计算μθ=αt1(xt−1−αˉtβtϵ)
- 如果t>1,需要从N(μθ,σt2I)中采样得到xt−1,利用重参数化技巧,可以将采样过程转换为首先采样z∈N(0,I),然后计算xt−1=μθ+σtz。如果t=1,直接令x0=μθ
总结
Diffusion模型的每一步推导都有严密的数学基础,调整其细节时,必须仔细思考背后的数学基础。如果它火起来,成为生成模型的主流,简直是不给我这种调参侠活路!
参考文献
写本篇博客时,我主要参考了下述论文和博客文章。
相关论文:
- Sohl-Dickstein, J., Weiss, E.A., Maheswaranathan, N., & Ganguli, S. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. ArXiv, abs/1503.03585.
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. ArXiv, abs/2006.11239.
网页链接:
- What are Diffusion Models? | Lil’Log (lilianweng.github.io)
- diffusion_models/Diffusion_models.ipynb at main · InFoCusp/diffusion_models (github.com)
- Ayan Das · An introduction to Diffusion Probabilistic Models