はじめに

Variational Autoencoder(VAE)と呼ばれるモデルが提案されたのは 2014 年の Auto-Encoding Variational Bayes です。ただ VAE の何がすごいのか、Stable Diffusion などの前処理でもいまだ VAE が使用されていますが、何が凄いのか、よく分からなかったので本稿でその背景に迫ってみたいと思います。

Autoencoder との比較記事などがあり「VAEは AutoencoderのSOTA モデルとして登場したのか?何が凄いのか?」と混乱したりしていました。 Autoencoder に非線形変換を加えたモデルが VAE です。そのため複雑な構造をもつデータのモデル化が可能になります。

本記事で EM アルゴリズムから出発して、VAE までを抑えてみたいと思います。

最尤推定とベイズ推定との違い

最尤推定とベイズ推定はどちらもモデルのパラメーター θ\theta に関する推定を行うための手法ですが、パラメータの取り扱いに違いがあります。

最尤推定

尤度関数 p(θD)p(\theta|D) を最大化する θ\theta を予測値として扱い、θ\theta の値が一意に決まる推定方法です。

ベイズ推定

パラメータに事前分布 p(θ)p(\theta) を与えて確率変数として取り扱い、データ DD を観測したときの事後確率 p(θD)p(\theta|D) を推定します。事後確率はベイズの定理から求めることができます1

p(θD)=p(Dθ)p(θ)p(D) p(\theta|D) = \frac{p(D|\theta)p(\theta)}{p(D)}

潜在変数をもつモデル

さて(唐突に)潜在変数を持つモデルのパラメータを推定する問題を、混合ガウス分布を題材にして考えてみたいと思います2

確率分布を線形結合して作る確率分布を混合分布(mixture distribution)と呼びますが、このときにガウス分布を使用したモデルが混合ガウス分布です。KK 個のガウス分布の重ね合わせ

p(x)=k=1KπkN(xμk,Σk) p(x) = \sum_{k=1}^K \pi_k N(\rm{x}|\mu_k, \Sigma_k)

で表されます。KK 個のガウス分布はそれぞれ個別に平均と共分散のパラメータを持っており、十分な数のガウス分布を用いることでほぼ任意の連続な密度関数を任意の精度で近似することができます。

問題

混合ガウス分布が有用であるということは(天下り的に)分かったのですが、ではそれらのパラメータをどのように求めれば良いのでしょうか?{x1,,xN}\{x_1,…,x_N\} のデータセットが観測されているとき、対数尤度関数は

lnp(Xπ,μ,Σ)=n=1Nln{k=1KπkN(xnμk,Σk)} \ln p(X|\pi, \mu, \Sigma) = \sum_{n=1}^N \ln \left\{\sum_{k=1}^K \pi_k N(x_n|\mu_k, \Sigma_k) \right\}

となるので、これを最適化すれば最尤推定で解けそうに思えます。

しかし、この対数尤度関数は実は解析的には計算できなくなってしまっています。対数の中に和 Σ\Sigma が入っているからだなどと説明されたりしますが、「まぁそんなもんか」と思って頂く(少なくとも自分はそうです)方が心穏やかに過ごせると思います。

ではどのようにするかですが、混合ガウス分布に潜在変数を導入して EM アルゴリズムと呼ばれる逐次的に最適化する枠組みを使用することでパラメータ推定が可能となります。

EM アルゴリズム

EMアルゴリズムを題材として、最尤推定とベイズ推定の違いについて更に深ぼって見ていきたいと思います。根本の違いはそれぞれの推定方法の差異であるモデルパラメータの取り扱いに起因します。

最尤法的な取り扱い (通常のEMアルゴリズム)

観測データ XX が(観測できない)潜在変数 ZZ から生成されているような状況を取り扱います。straight forward な推論方法であれば、対数尤度を最大化する θ\theta を求めることになります。

arg maxθ lnp(Xθ) \argmax_\theta~\ln p(X|\theta)

さきほど導入した通り lnp(X)\ln p(X) は直接計算できないため、以下の様に q(Z)q(Z) という確率分布を導入し対数尤度を変形します。

lnp(Xθ)=lnp(X,Zθ)p(ZX,θ)=ln{p(X,Zθ)p(ZX,θ)q(Z)q(Z)}=ln{p(X,Zθ)q(Z)q(Z)p(ZX,θ)}=lnp(X,Zθ)q(Z)lnp(ZX,θ)q(Z) \begin{aligned} \ln p(\rm{X}|\theta) &= \ln \frac{p(X, Z|\theta)}{p(Z|X, \theta)} \\ &= \ln \left\{ \frac{p(\rm{X}, \rm{Z}|\theta)}{p(\rm{Z}|\rm{X}, \theta)}\frac{q(Z)}{q(Z)} \right\}\\ &= \ln \left\{ \frac{p(\rm{X}, \rm{Z}|\theta)}{q(Z)}\frac{q(Z)}{p(\rm{Z}|\rm{X}, \theta)} \right\}\\ &= \ln \frac{p(\rm{X}, \rm{Z}|\theta)}{q(Z)} - \ln \frac{p(\rm{Z}|\rm{X}, \theta)}{q(Z)} \\ \end{aligned}

1行目では確率の乗法定理を使用しています。両辺に q(Z)q(Z) を掛けて潜在変数に関する和を取ります。Zq(Z)=1\sum_Z q(Z)=1 であることを用いると

Zq(Z)lnp(Xθ)=Zq(Z)lnp(X,Zθ)q(Z)Zq(Z)lnp(ZX,θ)q(Z)lnp(Xθ)=Zq(Z)lnp(X,Zθ)q(Z)Zq(Z)lnp(ZX,θ)q(Z) \begin{aligned} \sum_{{\rm{Z}}}q({\rm{Z}}) \ln p(X|\theta) &= \sum_{{\rm{Z}}} q({\rm{Z}}) \ln \frac{p({\rm{X}}, \rm{Z}|\theta)}{q(Z)} - \sum_{\rm{Z} q(\rm{Z})} \ln \frac{p(\rm{Z}|\rm{X}, \theta)}{q(Z)} \\ \therefore \ln p(X|\theta) &= \sum_{\rm{Z}} q({\rm{Z}}) \ln \frac{p(\rm{X}, \rm{Z}|\theta)}{q(Z)} - \sum_{\rm{Z}} q({\rm{Z}}) \ln \frac{p(\rm{Z}|\rm{X}, \theta)}{q(Z)} \end{aligned}

と変形することができます。

L(q(Z),θ)=Zq(Z)lnp(X,Zθ)q(Z)DKL(qp)=Zq(Z)lnp(ZX,θ)q(Z) \begin{aligned} L(q(Z), \theta) &= \sum_{\rm{Z}} q({\rm{Z}}) \ln \frac{p(\rm{X}, \rm{Z}|\theta)}{q(Z)} \\ D_{KL}(q||p) &= - \sum_{\rm{Z}} q({\rm{Z}}) \ln \frac{p(\rm{Z}|\rm{X}, \theta)}{q(Z)} \end{aligned}

と置いて、対数尤度を以下のように表現しておきます。

lnp(Xθ)=L(q(Z),θ)+DKL(qp) \ln p(X|\theta) = L(q(Z), \theta) + D_{KL} (q||p)

ここで DKLD_{KL} は KL ダイバージェンスであり、q(Z)q(Z)p(ZX,θ)p(Z|X, \theta) の「距離」に関する指標です。DKL0D_{KL} \geq 0 であることを踏まえると、

lnp(Xθ)=L(q(Z),θ)+KL(qp)L(q(Z),θ) \ln p(X|\theta) = L(q(Z), \theta) + {\rm{KL}} (q||p) \geq L(q(Z), \theta)

という関係式を導くことができ、対数尤度の下界が L(q(Z),θ)L(q(Z), \theta) で表現されることが分かりました。 では LL の計算は簡単なのかというと、q(Z)q(Z)θ\theta という変数を扱う必要があり、特に q(Z)q(Z) に関しては確率分布であるということ以外、その関数の形や特性が全くの不明であるものです。そのため LL も直接的に計算はできないのですが、そこで登場するのが EM アルゴリズムです。

イェンセンの不等式を使うパターン

以上の式変形はイェンセンの不等式を用いることで、より直接的に下界を求めることができます。

lnp(Xθ)=lnZp(X,Zθ)=lnZq(Z)p(X,Zθ)q(Z)Zq(Z)lnp(X,Zθ)q(Z)L(q(Z),θ) \begin{aligned} \ln p(X|\theta) &= \ln \sum_Z p(X,Z|\theta) \\ &= \ln \sum_Z q(Z)\frac{p(X,Z|\theta)}{q(Z)} \\ &\geq \sum_Z q(Z) \ln \frac{p(X,Z|\theta)}{q(Z)} \equiv L(q(Z), \theta)\\ \end{aligned}

以上より対数尤度の下界が求まったので、LL を最大化するようなパラメータ θ\theta を求めればよいことになります。ただしイェンセンの不等式を用いた議論でも、結局はKLダイバージェンスを用いた議論が必要になるので先程求めた関係式に戻ることになります。

パラメータ推定方法

対数尤度 lnp(Xθ)\ln p(X|\theta) を最大化する代わりに、LL を最大化するようなパラメータ θ\theta を求める方針をとれば良いことが分かりました。 しかしここで今一度、LL についての表式を確認してみると、

L(q(Z),θ)=Zq(Z)lnp(X,Zθ)q(Z) L(q(Z), \theta) = \sum_{Z} q(Z) \ln \frac{p(X, Z|\theta)}{q(Z)}

LL はパラメータ θ\theta の関数であると同時に、q(Z)q(Z) の汎関数であることが分かります。つまり LL を最大化するには θ\thetaq(Z)q(Z) の組み合わせを見つける必要があるのですが、q(Z)q(Z) は関数であるので変分法などを用いる必要が出てきます。もう少し簡単に解くために、EMアルゴリズムと呼ばれる枠組みを導入します。

EMアルゴリズムでは θ\thetaq(Z)q(Z) をそれぞれ逐次的に更新することで対数尤度の最大化を行います。徐々に LL を大きくしていくようなイメージです。

L(q0,θ0)L(q1,θ1)..argmaxq, θ L(q,θ)lnp(Xθ) L(q_0, \theta_0) \leq L(q_1, \theta_1) \leq .. \leq \underset{q,~\theta}{\operatorname{argmax}}~L(q, \theta) \leq \ln p(X|\theta)

具体的には以下の二段階の逐次的最適化手法を採ることで、最尤推定値を見つけます。各ステップは以下のように:

それぞれ qqθ\theta に対して L(q,θ)L(q, \theta) を最大化していきます。

Eステップ

まず

を考えます。単純な発想だと汎関数 L(q,θ)L(q, \theta) の変分問題を解く必要がありそうなのですが、対数尤度とKLダイバージェンスとの関係式を用いることでその煩雑さを回避することができるというのが肝です。

いま θ\theta を固定して、q(Z)q(Z) に関して L(q(Z),θ)L(q(Z),\theta) を最大化します。対数尤度関数の関係式に立ち戻ると3

lnp(Xθold)=L(q,θold)+DKL(qp)L(q,θold)=DKL(qp)+lnp(Xθold) \begin{aligned} \ln p(X|\theta^{old}) &= L(q, \theta^{old}) + D_{KL} (q||p) \\ \Leftrightarrow L(q, \theta^{old}) &= - D_{KL} (q||p) + \ln p(X|\theta^{old}) \end{aligned}

いま q(Z)q(Z) のみが変数であるので

L(q,θold)=DKL(qp)+const. L(q, \theta^{old}) = - D_{KL} (q||p) + const.

について考えることになります。KLダイバージェンスが 0\geq 0 の量であることを考えると、LL

DKL(qp)=0 D_{KL}(q||p) = 0

を満たす場合に最大になり、

q(Z)=p(ZX,θold) q(Z) = p(Z|X, \theta^{old})

であることが分かります。

Mステップ

次に

について考えます。qq を先ほど求めた値に固定して θ\theta に関して LL を最大化します。

L(q,θ)=Zq(Z)lnp(X,Zθ)q(Z)=Zp(ZX,θold)lnp(X,Zθ)p(ZX,θold)=Zp(ZX,θold)lnp(X,Zθ)Zp(ZX,θold)lnp(ZX,θold) \begin{aligned} L(q, \theta) &= \sum_{Z} q(Z) \ln \frac{p(X, Z|\theta)}{q(Z)} \\ &= \sum_{Z} p(Z|X, \theta^{old}) \ln \frac{p(X, Z|\theta)}{p(Z|X, \theta^{old})} \\ &= \sum_{Z} p(Z|X, \theta^{old}) \ln p(X, Z|\theta) - \sum_{Z} p(Z|X, \theta^{old}) \ln p(Z|X, \theta^{old}) \\ \end{aligned}

最後に

以上の E ステップと M ステップとを交互に繰り返すことで最適化問題を解くというものが EM アルゴリズムです。EM アルゴリズムの計算結果としてパラメータ θ\theta が決定論的に求まるということは改めて抑えておきましょう。

変分推論 (変分ベイズ)

ここまでで潜在変数を持ったモデルの最尤推定に関する手法を議論してきました。さて次に、潜在変数を持ったモデルに関するベイズ的な取り扱いについて議論してきます。この議論こそが変分ベイズの導入部分です。

全てのパラメータに対して事前分布(prior distrubution)が与えられている完全なベイズモデルを考えます。決定論的なパラメータ θ\theta や潜在変数 ZZ はまとめて ZZ として表記します。観測データ XX に対して潜在変数などのパラメータ ZZ を持つ確率モデルを扱う際の主たる目標は

p(ZX) p(Z|X)

で表される事後確率(posterior distribution)を求めることです。

対数尤度への再訪

もちろん直接的に事後確率 p(ZX)p(Z|X) が求まるのであればそれで終わりなのですが、 現実的には潜在変数が高次元であることや閉形式の解がなかったりなどと、一般的に事後確率を直接評価することは困難です。そのため近似手法を用いて推論するという流れになっていきます。(1)決定論的な近似手法と、(2)確率的な(サンプリングを用いた)近似手法の2パターンが存在します。変分ベイズは手法1に相当し、マルコフ連鎖モンテカルロ法などは手法2に相当します4

対数尤度はこれまでと同様に

lnp(X)=L(q)+DKL(qp)L(q) \ln p(X) = L(q) + D_{KL} (q || p) \geq L(q)

として表され、各項は

L(q)=q(Z)ln{p(X,Z)q(Z)}dZDKL(qp)=q(Z)ln{q(Z)p(ZX)}dZ \begin{aligned} L(q) &= \int q(Z) \ln \left\{ \frac{p(X,Z)}{q(Z)} \right\} dZ \\ D_{KL}(q||p) &= \int q(Z) \ln \left\{ \frac{q(Z)}{p(Z|X)} \right \} dZ \\ \end{aligned}

です。(通常の)EMアルゴリズムと同様の手法を採ると、下界である L(q)L(q)qq に対して最大化するにはKLダイバージェンスを最小化すればよく、

q(Z)=p(ZX) q(Z) = p(Z|X)

の場合であることが分かります。事後確率 p(ZX)p(Z|X) が計算困難であるという前提に立っているため、この解を使用することはできません。そのため、シンプルな変分問題

arg maxq(Z)L(q(Z)) \argmax_{q(Z)} L(q(Z))

を解く必要が生じます。

変分法について

そもそも変分法とはベイズ推論とは独立した一つのトピックで、汎関数の最適化問題を解くための手法です。ここまでで潜在変数を含んだモデルの対数尤度 lnp(X)\ln p(X) は、その下界 L(q)L(q) を最大化することでパラメータの推論を行うことを見てきました。このときに下界は q(Z)q(Z) に関する汎関数であり、そのため変分推論と呼ばれています。

余談

EM アルゴリズム周りを勉強しているときに何をしているか良く分からなくなる理由の一つが、lnp(X)\ln p(X) の式変形の導出過程かなと思っています。書籍によって

  1. 先に lnp(X)=L(q)+KL(qp)\ln p(X) = L(q) + {\rm{KL}} (q || p) の変形をして、L(q)L(q) の最大化の条件を議論する
  2. 先に p(ZX)q(Z)p(Z|X) \simeq q(Z) としたい気持ちを表明し、実は lnp(X)=L(q)+KL(qp)\ln p(X) = L(q) + {\rm{KL}} (q || p) ですと説明する

という方針に大きく別れていて、議論の運びがまちまちであると感じています^[方針1 の議論では、いきなり q(Z)q(Z) が天下り的にでてきて「???」となっているうちに、KL ダイバージェンスが出てきます。なぜKLダイバージェンスをわざわざ作ったのかが見えにくいです。また方針2の議論では、近似のために q(Z)q(Z) を導入すると思っていたら、何かよくわからない L(q)L(q) が出てきたと感じます。どちらも誤ってはいなくニワトリ卵のような議論でもあると感じるので、]。改めて以下の流れで、ELBO について整理しておきます。

  1. lnpL(q)\ln p \geq L(q) となるような下限で議論したい
  2. 事後分布の直接的な議論は難しいので、p(ZX)q(Z)p(Z|X) \simeq q(Z) となる近似分布を導入する 2-1. KL(q(Z)p(XZ))\rightarrow {\rm{KL}}(q(Z)||p(X|Z)) が指標になる
  3. lnp(X)\ln p(X)KL(q(Z)p(XZ)){\rm{KL}}(q(Z)||p(X|Z)) を含むように変形してみる 3-1. lnp(X)=F(q)+KL(qp)\ln p(X) = F(q) + {\rm{KL}} (q || p) の変形 3-2. KL ダイバージェンスの 0\geq 0 の性質から、F(q)F(q) は lower bound L(q)L(q) として解釈できる

平均場近似 (mean field theory)

事後分布 p(ZX)p(Z|X) の近似分布である q(Z)q(Z) の作り方は様々あるようですが、

q(Z)=i=1Mqi(Zi) q(Z) = \prod_{i=1}^M q_i(Z_i)

と因数分解した形式を用いる手法を平均場近似と呼びよく用いられている手法の一つです。分解できるという仮定以外は特別な条件は何も課してません。

さてここで、L(q)L(q) を各 qj(Zj)q_j(Z_j) について最適化する問題を考えます。

L(q)=qlnp(X,Z)qdZ=iqi{lnp(X,Z)lniqi}dZ={iqilnp(X,Z)iqiilnqi}dZ \begin{aligned} L(q) &= \int q \ln \frac{p(X, Z)}{q} dZ \\ &= \int \prod_i q_i \left\{ \ln p(X, Z) - \ln \prod_i q_i \right\} dZ \\ &= \int \left\{ \prod_i q_i \ln p(X, Z) - \prod_i q_i \sum_i \ln q_i \right\} dZ \\ \end{aligned}

いま

q=iqi=q1(Z1)q2(Z1)qj(Zj) q = \prod_i q_i = q_1(Z_1)q_2(Z_1)…q_j(Z_j)…

のように分解しているうちで、L(q)L(q)qjq_j に関して最適化しようとしているので qj(Zj)q_j(Z_j) に依存する項のみを抽出します。第一項目は

{iqilnp(X,Z)}dZ={iqilnp(X,Z)}dZ1dZ2=qj{ijqilnp(X,Z)dZi}dZj \begin{aligned} \int \left\{ \prod_i q_i \ln p(X, Z) \right\} dZ &= \int \int … \int \left\{ \prod_i q_i \ln p(X, Z) \right\} dZ_1dZ_2… \\ &= \int q_j \left\{ \prod_{i\neq j} \int q_i \ln p(X, Z) dZ_i \right\} dZ_j \end{aligned}

第二項目は

q(Z)dZ=1 \begin{aligned} \int q(Z) dZ = 1 \end{aligned}

であることを踏まえると、対数の中に qjq_j が残らない項は qjdZj=1\int q_j dZ_j = 1 として周辺化されて定数となるので、

{iqiilnqi}dZ=qjdZj{ijqiijlnqi}dZijijqidZijqjlnqjdZj=qjlnqjdZj+const. \begin{aligned} \int \left\{ \prod_i q_i \sum_i \ln q_i \right\} dZ &= - \int q_j dZ_j \left\{ \prod_{i\neq j} q_i \sum_{i\neq j} \ln q_i \right\} dZ_{i\neq j} \\ &- \int \prod_{i\neq j} q_i dZ_{i\neq j} \int q_j \ln q_j dZ_j \\ &= - \int q_j \ln q_j dZ_j + const. \end{aligned}

以上のように lnqj\ln q_j を含む項だけが残ります。従って

L(q)=qj{ijqilnp(X,Z)dZi}dZjqjlnqjdZj+const. \begin{aligned} L(q) = \int q_j \left\{ \prod_{i\neq j} \int q_i \ln p(X, Z) dZ_i \right\} dZ_j - \int q_j \ln q_j dZ_j + const. \end{aligned}

と変形することができます。また、

ijqilnp(X,Z)dZi=lnp~(X,Zj) \int \prod_{i\neq j} \int q_i \ln p(X, Z) dZ_i = \ln \tilde{p}(X, Z_j)

と置いて、

L(q)=qjqjlnp~(X,Zj)dZj+const. \begin{aligned} L(q) = - \int q_j \frac{q_j}{\ln \tilde{p}(X, Z_j)} dZ_j + const. \end{aligned}

と変形することができます。今、 qijq_{i\neq j} を固定したまま L(q)L(q) を最大化をしたいわけなのですが、求めた L(q)L(q) の式を眺めると、(負の)KLダイバージェンスの形になっていることが分かります。そのため L(q)L(q) を最大化するためには該当の KLダイバージェンスを最小化すれば良く^[負符号がついていて、かつ KL ダイバージェンスが 0\geq 0 であることを使っています]、

q(Zj)=lnp~(X,Zj) q(Z_j) = \ln\tilde{p}(X, Z_j)

であることが分かります。

まとめ

VAE までを理解する第一回目の記事として、EMアルゴリズムから変分ベイズの冒頭までを追ってみました。次回以降でさらに変分ベイズを深堀りして、VAE の理解につなげていきたいと思います。


  1. 現実的にはただ単に公式に当てはめて計算することは不可能なので、ベイズ推論を計算機で実現するためには様々な近似手法が考案されています。そのため「ベイズ推論よくわからない…」と思いがちですが、よく分からないのはその近似手法であって、ベイズ推論自体はシンプルな発想です。 ↩︎

  2. 潜在変数を持つモデルの尤度関数を計算するためにEMアルゴリズムが導入され、また完全なベイズ的取り扱いをしたときに変分推論が導入される、という流れになっています。 ↩︎

  3. EMアルゴリズムが初見で「??」となってしまう原因の一つがここにあるかと思っています。最大化しようとしている L(q,θ)L(q,\theta) が本当に最大化したい lnp(Xθ)\ln p(X|\theta) に依存しているように見えて、「何やってるんだこれ?」と感じたりします。 ↩︎

  4. ここでの「決定論的」とは、数式で関係式を記述できるといった程度の意味合いかと思います。 ↩︎

#ベイズ推論 #VAE