一、概述
- 介绍
概率模型有时既包含观测变量(observed variable),又包含隐变量(latent variable)。当概率模型只包含观测变量时,那么给定观测数据,就可以直接使用极大似然估计法或者贝叶斯估计法进行模型参数的求解。然而如果模型包含隐变量,就不能直接使用这些简单的方法了。EM算法就是用来解决这种含有隐变量的概率模型参数的极大似然参数估计法。这里只讨论极大似然估计,极大后验估计与其类似。
- 算法
EM算法的输入如下:
XX :观测数据
ZZ : 末观测数据 (隐变量)
p(x,z∣θ)p(x, z mid theta) : 联合分布
p(z∣x,θ)p(z mid x, theta) :后验分布
θtheta :parameter
在算法运行开始时需要选择模型的初始化参数 θ(0)theta^{(0)} 。EM算法是一种迭代更新的算法,其计算公式为:
θt+1=argmaxEz∣x,θt[logp(x,z∣θ)]θ=argmaxθ∫zlogp(x,z∣θ)⋅p(z∣x,θt)dzbegin{gathered}
theta^{t+1}=underset{theta}{operatorname{argmax} E_{z mid x, theta^t}[log p(x, z mid theta)]} \
=underset{theta}{operatorname{argmax}} int_z log p(x, z mid theta) cdot pleft(z mid x, theta^tright) mathrm{d} z
end{gathered}
这个公式包含了迭代的两步:
-
①E step: 计算 p(x,z∣θ)p(x, z mid theta) 在概率分布 p(z∣x,θt)pleft(z mid x, theta^tright) 下的期望;
-
②M step: 计算使这个期望最大化的参数得到下一个EM步骤的输入。
总结来说,EM算法包含以下步骤:
- ①选择初始化参数θ(0)theta ^{(0)};
- ②E step;
- ③M step;
- ④重复②③步直至收敛。
二、EM算法的收敛性
现在要证明迭代求得的 θttheta^t 序列会使得对应的 p(x∣θt)pleft(x mid theta^tright) 是单调递增的 (如果 p(x∣θt)pleft(x mid theta^tright) 是单调递 增的,那么训练数据的似然就是单调递增的),也就是说要证明 p(x∣θt)≤p(x∣θt+1)pleft(x mid theta^tright) leq pleft(x mid theta^{t+1}right) 。首先我们有:
logp(x∣θ)=logp(x,z∣θ)−logp(z∣x,θ)log p(x mid theta)=log p(x, z mid theta)-log p(z mid x, theta)
接下来等式两边同时求关于 p(z∣x,θt)pleft(z mid x, theta^tright) 的期望:
左边 =∫zp(z∣x,θt)⋅logp(x∣θ)dz=logp(x∣θ)∫zp(z∣x,θt)dz=logp(x∣θ) 右边 =∫zp(z∣x,θt)⋅logp(x,z∣θ)dz⏟记作 Q(θ,θt)−∫zp(z∣x,θt)⋅logp(z∣x,θ)dz⏟记作 H(θ,θt)begin{gathered}
text { 左边 }=int_z pleft(z mid x, theta^tright) cdot log p(x mid theta) mathrm{d} z \
=log p(x mid theta) int_z pleft(z mid x, theta^tright) mathrm{d} z \
=log p(x mid theta) \
text { 右边 }=underbrace{int_z pleft(z mid x, theta^tright) cdot log p(x, z mid theta) mathrm{d} z}_{text {记作 } Qleft(theta, theta^tright)}-underbrace{int_z pleft(z mid x, theta^tright) cdot log p(z mid x, theta) mathrm{d} z}_{text {记作 } Hleft(theta, theta^tright)}
end{gathered}
因此有:
logp(x∣θ)=∫zp(z∣x,θt)⋅p(x,z∣θ)dz−∫zp(z∣x,θt)⋅logp(z∣x,θ)dzlog p(x mid theta)=int_z pleft(z mid x, theta^tright) cdot p(x, z mid theta) mathrm{d} z-int_z pleft(z mid x, theta^tright) cdot log p(z mid x, theta) mathrm{d} z
这里定义了 Q(θ,θt)Qleft(theta, theta^tright) ,称为 Qmathrm{Q} 函数 ( Qmathrm{Q} function),这个函数也就是上面的概述中迭代公式里 用到的函数,因此满足 Q(θt+1,θt)≥Q(θt,θt)Qleft(theta^{t+1}, theta^tright) geq Qleft(theta^t, theta^tright) 。
接下来将上面的等式两边 θtheta 分别取 θt+1theta^{t+1} 和 θttheta^t 并相减:
logp(x∣θt+1)−logp(x∣θt)=[Q(θt+1,θt)−Q(θt,θt)]−[H(θt+1,θt)−H(θt,θt)]log pleft(x mid theta^{t+1}right)-log pleft(x mid theta^tright)=left[Qleft(theta^{t+1}, theta^tright)-Qleft(theta^t, theta^tright)right]-left[Hleft(theta^{t+1}, theta^tright)-Hleft(theta^t, theta^tright)right]
我们需要证明 logp(x∣θt+1)−logp(x∣θt)≥0log pleft(x mid theta^{t+1}right)-log pleft(x mid theta^tright) geq 0 ,同时已知Q(θt+1,θt)−Q(θt,θt)≥0Qleft(theta^{t+1}, theta^tright)-Qleft(theta^t, theta^tright) geq 0,现在来观察H(θt+1,θt)−H(θt,θt) : Hleft(theta^{t+1}, theta^tright)-Hleft(theta^t, theta^tright) text { : }
H(θt+1,θt)−H(θt,θt)=∫zp(z∣x,θt)⋅log p(z∣x,θt+1)dz−∫zp(z∣x,θt)⋅log p(z∣x,θt)dz=∫zp(z∣x,θt)⋅logp(z∣x,θt+1)p(z∣x,θt)dz≤log∫zp(z∣x,θt)p(z∣x,θt+1)p(z∣x,θt)dz=log∫zp(z∣x,θt+1)dz=log 1=0H(theta ^{t+1},theta ^{t})-H(theta ^{t},theta ^{t})\ =int _{z}p(z|x,theta ^{t})cdot log; p(z|x,theta ^{t+1})mathrm{d}z-int _{z}p(z|x,theta ^{t})cdot log; p(z|x,theta ^{t})mathrm{d}z\ =int _{z}p(z|x,theta ^{t})cdot logfrac{p(z|x,theta ^{t+1})}{p(z|x,theta ^{t})}mathrm{d}z\ leq logint _{z}p(z|x,theta ^{t})frac{p(z|x,theta ^{t+1})}{p(z|x,theta ^{t})}mathrm{d}z\ =logint _{z}p(z|x,theta ^{t+1})mathrm{d}z\ =log; 1\ =0
这里的不等号应用了Jensen不等式:
log∑jλjyj≥∑jλjlog yj,其中λj≥0,∑jλj=1logsum _{j}lambda _{j}y_{j}geq sum _{j}lambda _{j}log; y_{j},其中lambda _{j}geq 0,sum _{j}lambda _{j}=1
也可以使用KL散度来证明 ∫zp(z∣x,θt)⋅logp(z∣x,θt+1)p(z∣x,θt)dz≤0int_z pleft(z mid x, theta^tright) cdot log frac{pleft(z mid x, theta^{t+1}right)}{pleft(z mid x, theta^tright)} mathrm{d} z leq 0 ,两个概率分布 P(x)P(x) 和 Q(x)Q(x) 的KL散度是恒 ≥0geq 0 的,定义为:
DKL(P∥Q)=Ex∼P[logP(x)Q(x)]D_{K L}(P | Q)=E_{x sim P}left[log frac{P(x)}{Q(x)}right]
因此有:
∫zp(z∣x,θt)⋅logp(z∣x,θt+1)p(z∣x,θt)dz=−KL(p(z∣x,θt)∣∣p(z∣x,θt+1))≤0int_z pleft(z mid x, theta^tright) cdot log frac{pleft(z mid x, theta^{t+1}right)}{pleft(z mid x, theta^tright)} mathrm{d} z=-K Lleft(pleft(z mid x, theta^tright)|| pleft(z mid x, theta^{t+1}right)right) leq 0
因此得证 logp(x∣θt+1)−logp(x∣θt)≥0log pleft(x mid theta^{t+1}right)-log pleft(x mid theta^tright) geq 0 。这说明使用EM算法迭代更新参数可以使得 logp(x∣θ)log p(x mid theta) 逐步增大。
另外还有其他定理保证了EM的算法收敛性。首先对于 θi(i=1,2,⋯ )theta^i(i=1,2, cdots) 序列和其对应的对数似然序列 L(θt)=logp(x∣θt)(t=1,2,⋯ )Lleft(theta^tright)=log pleft(x mid theta^tright)(t=1,2, cdots) 有如下定理:
-
①如果 p(x∣θ)p(x mid theta) 有上界,则 L(θt)=logp(x∣θt)Lleft(theta^tright)=log pleft(x mid theta^tright) 收敛到某一值 L∗L^* ;
-
②在函数 Q(θ,θ′)Qleft(theta, theta^{prime}right) 与 L(θ)L(theta) 满足一定条件下,由EM算法得到的参数估计序列 θttheta^t 的收敛值 θ∗theta^* 是 L(θ)L(theta) 的稳定点。
三、EM算法的导出
- ELBO+KL散度的方法
对于前面用过的式子,首先引入一个新的概率分布q(z)q(z):
log p(x∣θ)=log p(x,z∣θ)−log p(z∣x,θ)=log p(x,z∣θ)q(z)−log p(z∣x,θ)q(z) q(z)≠0log; p(x|theta )=log; p(x,z|theta )-log; p(z|x,theta )\ =log; frac {p(x,z|theta )}{q(z)}-log; frac{p(z|x,theta )}{q(z)}; ; q(z)neq 0
以上引入一个关于zz的概率分布q(z)q(z),然后式子两边同时求对q(z)q(z)的期望:
左边=∫zq(z)⋅log p(x∣θ)dz=log p(x∣θ)∫zq(z)dz=log p(x∣θ)右边=∫zq(z)log p(x,z∣θ)q(z)dz⏟ELBO(evidence lower bound)−∫zq(z)log p(z∣x,θ)q(z)dz⏟KL(q(z)∣∣p(z∣x,θ))左边=int _{z}q(z)cdot log; p(x|theta )mathrm{d}z=log; p(x|theta )int _{z}q(z)mathrm{d}z=log; p(x|theta )\ 右边=underset{ELBO(evidence; lower; bound)}{underbrace{int _{z}q(z)log; frac{p(x,z|theta )}{q(z)}mathrm{d}z}}underset{KL(q(z)||p(z|x,theta ))}{underbrace{-int _{z}q(z)log; frac{p(z|x,theta )}{q(z)}mathrm{d}z}}
因此我们得出 logp(x∣θ)=ELBO+KL(q∥p)log p(x mid theta)=E L B O+K L(q | p) ,由于KL散度恒 ≥0geq 0 ,因此logp(x∣θ)≥ELBOlog p(x mid theta) geq E L B O ,则 ELBOE L B O 就是似然函数 logp(x∣θ)log p(x mid theta) 的下界。使得logp(x∣θ)=ELBOlog p(x mid theta)=E L B O 时,就必须有 KL(q∥p)=0K L(q | p)=0 ,也就是 q(z)=p(z∣x,θ)q(z)=p(z mid x, theta) 时。在
每次迭代中我们取 q(z)=p(z∣x,θt)q(z)=pleft(z mid x, theta^tright) ,就可以保证 logp(x∣θt)log pleft(x mid theta^tright) 与 ELBOE L B O 相等,也就是:
log p(x∣θ)=∫zp(z∣x,θt)log p(x,z∣θ)p(z∣x,θt)dz⏟ELBO−∫zp(z∣x,θt)log p(z∣x,θ)p(z∣x,θt)dz⏟KL(p(z∣x,θt)∣∣p(z∣x,θ))log; p(x|theta )=underset{ELBO}{underbrace{int _{z}p(z|x,theta ^{t})log; frac {p(x,z|theta )}{p(z|x,theta ^{t})}mathrm{d}z}}underset{KL(p(z|x,theta ^{t})||p(z|x,theta ))}{underbrace{-int _{z}p(z|x,theta ^{t})log; frac{p(z|x,theta )}{p(z|x,theta ^{t})}mathrm{d}z}}
当 θ=θttheta=theta^t 时, logp(x∣θt)log pleft(x mid theta^tright) 取ELBO,即:
log p(x∣θt)=∫zp(z∣x,θt)log p(x,z∣θt)p(z∣x,θt)dz⏟ELBO−∫zp(z∣x,θt)log p(z∣x,θt)p(z∣x,θt)dz⏟=0=ELBOlog; p(x|theta ^{t})=underset{ELBO}{underbrace{int _{z}p(z|x,theta ^{t})log; frac{p(x,z|theta ^{t})}{p(z|x,theta ^{t})}mathrm{d}z}}underset{=0}{underbrace{-int _{z}p(z|x,theta ^{t})log; frac{p(z|x,theta ^{t})}{p(z|x,theta ^{t})}mathrm{d}z}}=ELBO
也就是说 logp(x∣θ)log p(x mid theta) 与 ELBOE L B O 都是关于 θtheta 的函数,且满足 logp(x∣θ)≥ELBOlog p(x mid theta) geq E L B O ,也就 是说 logp(x∣θ)log p(x mid theta) 的图像总是在 ELBOE L B O 的图像的上面。
对于 q(z)q(z) ,我们取q(z)=p(z∣x,θt)q(z)=pleft(z mid x, theta^tright) ,这也就保证了只有在 θ=θttheta=theta^t 时 logp(x∣θ)log p(x mid theta) 与 ELBOE L B O 才会相等,因 此使 ELBOE L B O 取极大值的 θt+1theta^{t+1} 一定能使得 logp(x∣θt+1)≥logp(x∣θt)log pleft(x mid theta^{t+1}right) geq log pleft(x mid theta^tright) 。该过程如下图 所示:
然后我们观察一下ELBOELBO取极大值的过程:
θt+1=argmaxθELBO=argmaxθ∫zp(z∣x,θt)log p(x,z∣θ)p(z∣x,θt)dz=argmaxθ∫zp(z∣x,θt)log p(x,z∣θ)dz−argmaxθ∫zp(z∣x,θt)p(z∣x,θt)dz⏟与θ无关=argmaxθ∫zp(z∣x,θt)log p(x,z∣θ)dz=argmaxθEz∣x,θt[log p(x,z∣θ)]theta ^{t+1}=underset{theta }{argmax}ELBO \ =underset{theta }{argmax}int _{z}p(z|x,theta ^{t})log; frac{p(x,z|theta )}{p(z|x,theta ^{t})}mathrm{d}z\ =underset{theta }{argmax}int _{z}p(z|x,theta ^{t})log; p(x,z|theta )mathrm{d}z-underset{与theta 无关}{underbrace{underset{theta }{argmax}int _{z}p(z|x,theta ^{t})p(z|x,theta ^{t})mathrm{d}z}}\ {color{Red}{=underset{theta }{argmax}int _{z}p(z|x,theta ^{t})log; p(x,z|theta )mathrm{d}z}} \ {color{Red}{=underset{theta }{argmax}E_{z|x,theta ^{t}}[log; p(x,z|theta )]}}
由此我们就导出了EM算法的迭代公式。
- ELBO+Jensen不等式的方法
首先要具体介绍一下Jensen不等式:对于一个凹函数 f(x)f(x)(国内外对凹凸函数的定义恰好相反,这里的凹函数指的是国外定义的凹函数),我们查看其图像如下:
t∈[0,1]c=ta+(1−t)bϕ=tf(a)+(1−t)f(b)tin [0,1]\ c=ta+(1-t)b\ phi =tf(a)+(1-t)f(b)
凹函数恒有 f(c)≥ϕ\mathrm ,也就是 f(ta+(1−t)b)≥tf(a)+(1−t)f(b)f(c) geq phi \mathrm{~ , 也 就 是 ~} f(t a+(1-t) b) geq t f(a)+(1-t) f(b) ,当 t=12t=frac{1}{2} 时有 f(a2+b2)≥f(a)2+f(b)2fleft(frac{a}{2}+frac{b}{2}right) geq frac{f(a)}{2}+frac{f(b)}{2} ,可以理解为对于凹函数来说 先求期望再求函数值 恒 ≥geq 先求函数值再求期望,即 f(E)≥E[f]f(E) geq E[f] 。
上面的说明只是对Jensen不等式的一个形象的描述,而非严谨的证明。接下来应用Jensen不等式来导出EM算法:
logp(x∣θ)=log∫zp(x,z∣θ)dz=log∫zp(x,z∣θ)q(z)⋅q(z)dz=logEq(z)[p(x,z∣θ)q(z)]≥Eq(z)[logp(x,z∣θ)q(z)]⏟ELBObegin{gathered}
log p(x mid theta)=log int_z p(x, z mid theta) mathrm{d} z \
=log int_z frac{p(x, z mid theta)}{q(z)} cdot q(z) mathrm{d} z \
=log E_{q(z)}left[frac{p(x, z mid theta)}{q(z)}right] \
geq underbrace{E_{q(z)}left[log frac{p(x, z mid theta)}{q(z)}right]}_{E L B O}
end{gathered}
这里应用了Jensen不等式得到了上面出现过的 ELBOE L B O ,这里的 f(x)f(x) 函数也就是 loglog 函数, 显然这是一个凹函数。当 logP(x,z∣θ)q(z)log frac{P(x, z mid theta)}{q(z)} 这个函数是一个常数时会取得等号,利用这一点我们 也同样可以得到 q(z)=p(z∣x,θ)q(z)=p(z mid x, theta) 时能够使得 logp(x∣θ)=ELBOlog p(x mid theta)=E L B O 的结论:
p(x,z∣θ)q(z)=C⇒q(z)=p(x,z∣θ)C⇒∫zq(z)dz=∫z1Cp(x,z∣θ)dz⇒1=1C∫zp(x,z∣θ)dz⇒C=p(x∣θ)将C代入q(z)=p(x,z∣θ)C得q(z)=p(x,z∣θ)p(x∣θ)=p(z∣x,θ)frac{p(x,z|theta )}{q(z)}=C\ Rightarrow q(z)=frac{p(x,z|theta )}{C}\ Rightarrow int _{z}q(z)mathrm{d}z=int _{z}frac{1}{C}p(x,z|theta )mathrm{d}z\ Rightarrow 1=frac{1}{C}int _{z}p(x,z|theta )mathrm{d}z\ Rightarrow C=p(x|theta )\ 将C代入q(z)=frac{p(x,z|theta )}{C}得\ {color{Red}{q(z)=frac{p(x,z|theta )}{p(x|theta )}=p(z|x,theta )}}
这种方法到这里就和上面的方法一样了,总结来说就是:
log p(x∣θ)≥Eq(z)[logp(x,z∣θ)q(z)]⏟ELBOlog; p(x|theta )geq underset{ELBO}{underbrace{E_{q(z)}[logfrac{p(x,z|theta )}{q(z)}]}}
上面的不等式在q(z)=p(z∣x∣θ)q(z)=p(z|x|theta )时取等号,因此在迭代更新过程中取q(z)=p(z∣x,θt)q(z)=p(z|x,theta ^{t})接下来的推导过程就和第1种方法一样了。
四、广义EM算法
上面介绍的EM算法属于狭义的EM算法,它是广义EM的一个特例。在上面介绍的EM算法的E步中我们假定q(z)=p(z∣x,θt)q(z)=p(z|x,theta ^{t}),但是如果这个后验p(z∣x,θt)p(z|x,theta ^{t})无法求解,那么必须使⽤采样(MCMC)或者变分推断等⽅法来近似推断这个后验。前面我们得出了以下关系:
logp(x∣θ)=∫zq(z)logp(x,z∣θ)q(z)dz−∫zq(z)logp(z∣x,θ)q(z)dz=ELBO+KL(q∥p)log p(x mid theta)=int_z q(z) log frac{p(x, z mid theta)}{q(z)} mathrm{d} z-int_z q(z) log frac{p(z mid x, theta)}{q(z)} mathrm{d} z=E L B O+K L(q | p)
当我们对于固定的 θtheta ,我们希望 KL(q∥p)K L(q | p) 越小越好,这样就能使得 ELBOE L B O 更大:
固定θ,q^=argminqKL(q∥p)=argmaxqELBO固定 theta, hat{q}=underset{q}{operatorname{argmin}} K L(q | p)=underset{q}{operatorname{argmax}} E L B O
ELBOE L B O 是关于 qq 和 θtheta 的函数,写作 L(q,θ)L(q, theta) 。以下是广义EM算法的基本思路:
E step: qt+1=argmaxL(q,θt)q^{t+1}=operatorname{argmax} Lleft(q, theta^tright)
M step: θt+1=argmaxqL(qt+1,θ)theta^{t+1}=underset{q}{operatorname{argmax}} Lleft(q^{t+1}, thetaright)
再次观察一下 ELBOE L B O :
ELBO=L(q,θ)=Eq[log p(x,z)−log q]=Eq[log p(x,z)]−Eq[log q]⏟H[q]ELBO=L(q,theta )=E_{q}[log; p(x,z)-log; q]\ =E_{q}[log; p(x,z)]underset{H[q]}{underbrace{-E_{q}[log; q]}}
因此,我们看到,⼴义 EM 相当于在原来的式⼦中加⼊熵H[q]H[q]这⼀项。
五、EM的变种
EM 算法类似于坐标上升法,固定部分坐标,优化其他坐标,再⼀遍⼀遍的迭代。如果在 EM 框架中,⽆法求解zz后验概率,那么需要采⽤⼀些变种的 EM 来估算这个后验:
①基于平均场的变分推断,VBEM/VEM
②基于蒙特卡洛的EM,MCEM
“开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 8 天,点击查看活动详情”