機械学習の理論 PR

【Python】Annealed Importance Sampling(焼きなまし重点サンプリング)の解説と実装

記事内に商品プロモーションを含む場合があります

 

本記事では、Annealed Importance Sampling(焼きなまし重点サンプリング)の解説と実装を紹介します。

 

Annealed Importance Sampling(焼きなまし重点サンプリング)

 

Importance Sampling(重点サンプリング)は、目的の確率分布\(p(x)\)とは別にサンプリングが容易な提案分布\(q(x)\)を用いて、確率分布\(p(x)\)の期待値や規格化定数を推定する方法です。

もし、Importance Sampling(重点サンプリング)についてご存知でない方は先に下記の記事を一読してから本記事を読んでいただけると幸いです。

【Python】重点サンプリング(Importance Sampling)の理論と実装重点サンプリングの理論とPythonによる実装を紹介しました。重点サンプリングを用いて規格化定数を評価する方法も紹介しています。ぜひ、参考にしてみてください。...

重点サンプリングでは、提案分布\(q(x)\)が目的の確率分布\(p(x)\)と似たような分布であることが重要であり、似ていない分布を用いると推定値の分散が大きくなり、推定精度が悪化します。

しかし、サンプリングが用意で確率分布\(p(x)\)と似たような分布\(q(x)\)を用意することは容易ではありません。

その問題を解決するのが、本記事で紹介するAnnealed Importance Sampling(焼きなまし重点サンプリング)です。

まずは、問題設定を簡単に説明した後にアルゴリズムの内容を説明をします。

問題設定

 

規格化定数が未知の確率分布\(p_{K}(x)\)の規格化定数やその期待値(統計量)を知ることを目的とします(Kは後の表記を簡単にするためのインデックスです。この段階ではあまり気にしないでください)

以降、その確率分布\(p_{K}(x)\)を規格化定数\(Z_{K}\)とそれ以外の部分に分けて次のように記述することにします。

$$p(x) = \frac{1}{Z_{K}} f_{K}(x)$$

ここで、規格化定数は、以下の条件を満たすように設定しいます。

$$Z_{K} = \int f_{K}(x) dx$$

今回は、連続確率分布を扱っていますが、離散確率分布でも以降の導出は容易です。

Annealed Importance Sampling(AIS)側の確率分布

 

まずは、Importance samplingと同様に、規格化定数が既知(または計算が容易な規格化定数)で、サンプリングすることが容易な近似分布\(p_{0}(x)\)を準備します。

$$p_{0}(x) = \frac{1}{Z_{0}} f_{0}(x)$$

再度強調しますが、この規格化定数\(Z_{0}\)は、既知または計算が容易な分布を使用しましょう。また、\(p(x)\)と\(p_{0}(x)\)は必ずしも類似した分布である必要はありません。

さらに、AISでは以下のような中間分布を準備します。

$$p_{k}(x; \beta_{k}) = \frac{1}{Z_{k}} f_{0}(x)^{\beta_{k}} f(x)^{1-\beta_{k}}$$

中間分布と呼んでいる理由は、パラメータが\(\beta_{k} = 0\)のときは近似分布\(p_{0}\)になり、パラメータが\(\beta_{k} =1\)のときにはターゲットの分布\(p_{K}\)となることに起因します。

中間分布のパラメータは、ユーザーが決定する量で次の性質を満たすように設定します。

$$0 = \beta_{0} < \beta_{1} < \cdots < \beta_{K-1} < \beta_{K} =1$$

 

規格化定数と期待値の計算方法

 

AISは規格化定数の比\(Z_{k+1}/Z_{k}\)が次のように表せることを利用します。

\begin{align} \frac{Z_{k+1}}{Z_{k}} &= \frac{1}{Z_{k}} \sum_{x}f_{k+1}(x) \\ &= \mathbb{E}_{p_{k}}\left[\frac{f_{k+1}(x)}{f_{k}(x)}  \right] \\ &= \lim_{M \to \infty} \frac{1}{M} \sum_{m=1}^{M} \frac{f_{k+1}(x_{k}^{m})}{f_{k}(x_{k}^{m})} \end{align}

ここで、\(x_{k}^{(m)}\)は、確率分布\(p_{k}\)に従うサンプル列の\(m\)番目のサンプルを表すことにします。

この性質を使用すると以下の等式を得ることができます。

\begin{align} \frac{Z_{K}}{Z_{0}} &= \frac{Z_{1}}{Z_{0}} \frac{Z_{2}}{Z_{1}} \times \cdots \times \frac{Z_{K}}{Z_{K-1}} \\ &= \prod_{k=1}^{K} \mathbb{E}_{p_{k}}\left[\frac{f_{k+1}(x)}{f_{k}(x)}  \right]  \\ &= \frac{1}{M} \sum_{m=1}^{M} \prod_{k}^{K-1} \frac{f_{k+1}(x^{(m)}_{k})}{f_{k}(x^{(m)}_{k})} \\ &= \frac{1}{M} \sum_{m=1}^{M}  \frac{f_{1}(x_{0}^{(m)})}{f_{0}(x_{0}^{(m)})}\frac{f_{2}(x_{1}^{(m)})}{f_{1}(x_{1}^{(m)})} \cdots \frac{f_{K}(x_{K-1}^{(m)})}{f_{K-1}(x_{K-1}^{(m)})} \\ &\equiv \frac{1}{M} \sum_{m=1}^{M} \omega^{(m)} \end{align}

ここで、\(\omega^{(m)}\)は重みと呼ばれ、以下のように定義しました。

$$\omega^{(m)} =   \frac{f_{1}(x_{0}^{(m)})}{f_{0}(x_{0}^{(m)})}\frac{f_{2}(x_{1}^{(m)})}{f_{1}(x_{1}^{(m)})} \cdots \frac{f(x_{K-1}^{(m)})}{f_{K-1}(x_{K-1}^{(m)})}$$

 

この重みを使うことでターゲットとなる確率分布\(p_{K}\)に関する\(A(x)\)の期待値を以下のように計算することができます。

$$\mathbb{E}_{p_{K}}\left[ A(x) \right]  = \frac{\sum_{m=1}^{M} A(x^{m}_{0}) \omega^{(m)}}{\sum_{m=1}^{M} \omega^{(m)}}$$

また、規格化定数は次のように評価することができます。

$$Z_{K} = \frac{Z_{0}}{M} \sum_{m=1}^{M} w^{(m)}$$

AISのアルゴリズム

 

規格化定数や期待値の計算方法はわかりましたが、定式化の中には、中間分布から生成されるサンプル列\(\{x_{0}^{(m)}, x_{1}^{(m)}, \ldots, x_{K-1}^{(m)}\}_{m=1}^{M}\)が必要となります。

その部分は、適当な遷移確率を利用して、マルコフ連鎖モンテカルロ法によって生成します。

以下に具体的なアルゴリズムをまとめます。

AISのアルゴリズム

  • 入力 : パラメータ列 \(\{\beta_{k}\}_{k=0}^{K}\)
  • 以下を\(1 \sim M\)繰り返す
    • \(p_{0}\)から\(x_{0}^{(m)}\)をサンプル
    • 以下を\(1 \sim K\)回繰り返す
      • 初期値を\(x_{k-1}^{(m)}\)として、マルコフ連鎖モンテカルロ法により\(p_{k}(x)\)から\(x_{k}^{(m)}\)をサンプル
    • 重みを計算
      $$\omega^{(m)} \leftarrow \frac{f_{1}(x_{0}^{(m)})}{f_{0}(x_{0}^{(m)})}\frac{f_{2}(x_{1}^{(m)})}{f_{1}(x_{1}^{(m)})} \cdots \frac{f(x_{K-1}^{(m)})}{f_{K-1}(x_{K-1}^{(m)})}$$
  • 規格化定数と\(A(x)\)を以下のように計算

\begin{align} Z_{K} &= \frac{Z_{0}}{M} \sum_{m=1}^{M} \omega^{(m)} \\ \mathbb{E}_{p_{k}} \left[A(x)  \right] &= \frac{\sum_{m=1}^{M} A(x^{m}_{0}) \omega^{(m)}}{\sum_{m=1}^{M} \omega^{(m)}} \end{align} 

 

Annealed Importance Sampling(焼きなまし重点サンプリング)の実装例

 

必要なライブラリをインポート

 

import numpy as np
from matplotlib import pyplot as plt
import scipy.stats as stats
import seaborn as sns
sns.set()

# SEEDを固定
np.random.seed(0)

 

 

問題設定

 

問題設定を以下にまとめます。

  • 目的の確率分布

$$0.5 \times \mathcal{N}(0, 1) + 0.5 \times \mathcal{N}(4, 1)$$

  • 既知の情報(目的の確率分布の規格化定数は未知)

$$f_{K}(x) = 0.5 \exp\left( – \frac{(x-4)^{2}}{2} \right) + 0.5 \exp\left(- \frac{x^{2}}{2}  \right)$$

  • 提案分布

$$f_{0}(x) = \exp \left(- \frac{(x- \mu)^{2}}{2 \sigma^{2}} \right)$$

$$Z_{0} = \sqrt{2 \pi \sigma^{2}}$$

 

まずは、この問題設定に合うように\(f_{K}(x)\), \(f_{0}(x)\), \(f_{k}(x)\)を実装しましょう。

# target dist
def f_K(x, mu1=0, mu2=4.0, pi=0.5):
    pi1 =pi
    pi2 =1-pi1
    a1 = - ((x - mu1)**(2)) / 2
    a2 = - ((x - mu2)**(2)) / 2 
    return pi1*np.exp(a1)+pi2*np.exp(a2)

# proposal dist
def f_0(x, mu=0, scale=1):
    return np.exp(-((x-mu)**2) / 2)

# intermediate dist
def f_k(x, beta):
    return (f_0(x)**(1-beta))*(f_K(x)**beta)

 

簡単に中間分布・提案分布・目的の分布の関係を図示しておきます。

x_test = np.linspace(-5, 10, 100)
fig, ax = plt.subplots()

for i in range(11):
    ax.plot(x_test, f_k(x_test, i/10))

plt.show()

 

<output>

中間分布の可視化

 

マルコフ連鎖モンテカルロ法を実装

 

中間分布からのサンプリングにマルコフ連鎖モンテカルロ法が必要になります。

今回は、メトロポリス法を使用します。

def metropolis(xs_init, beta, steps=10):
    # inital condition
    xs = xs_init

    for i in range(len(xs_init)):
        for j in range(steps):
            prop_x = xs[i] + np.random.randn()
            accept_prob = f_k(prop_x, beta)/f_k(xs[i], beta)
            if np.random.rand() <= accept_prob:
                xs[i] = prop_x
    return xs

 

一応、\(\beta_{k}=0.75\)のときにしっかりサンプルが得られているかを確認しておきます。

x_0 = np.random.randn(1000)

xs = metropolis(x_0, 0.75, steps=300)

sns.distplot(xs, bins=50, kde='True')
plt.show()

 

<output>

メトロポリス法の実行例

 

Annealed Importance Sampling(焼きなまし重点サンプリング)の実装

 

これまでのコードを使用して、Annealed Importance Sampling(焼きなまし重点サンプリン)の実装を行います。

def AIS(n_samples=1000, n_inter=100, n_steps=30):
    '''
    Args:
    n_samples : サンプル列の長さ(上述のMに対応)
    n_inter : 中間分布の個数
    n_steps : MCMCのステップ数
    
    Return:
    norm_term : 目的の確率分布の規格化定数の推定値
    mean : 目的の確率分布の平均の推定値
    '''

    # 中間分布のパタメータ
    betas = np.linspace(0, 1, n_inter)
    xs_k = np.random.randn(n_samples)

    # 対数の値で保存(overflowを避けるため)
    log_weights = np.zeros(n_samples)
    
    for i in range(1, len(betas)):
        log_weights += np.log(f_k(xs_k, betas[i])) - np.log(f_k(xs_k, betas[i-1]))
        xs_k = metropolis(xs_k, betas[i], steps=n_steps)
    
    weights = np.exp(log_weights)

    # 規格化定数
    norm_term = (np.sum(weights)/n_samples)*(np.sqrt(2*np.pi))    
    # 平均
    mean = np.sum(xs_k*weights)/np.sum(weights)
    
    return norm_term, mean

 

実際に実行してみましょう。

AIS(n_samples=1000, n_inter=300, n_steps=30)

 

<output>

(2.5075079910242226, 2.060047935888935)

 

目的の確率分布の規格化定数は、\(\sqrt{2 \pi} = 2.5066282\cdots\)であり、平均値は\(2.0\)であるため、良い推定値が得られていることがわかります。

 

規格化定数や統計量の計算には、指数関数の和の対数を取る計算が含まれ、頻繁にオーバーフローが生じます。

その場合は、以下のlogsumexp法を使用することで、解決できるケースが多いです。

logsumexpとは|Numpy・PyTorchによる実装例も解説!本記事では、指数関数の和の対数を計算する際に生じるオーバフローを回避する方法の一つであるlogsumexp法について説明しました。また、本記事の後半では、NumpyとPyTorchによるlogsumexp法の実装例を示しました。...

 

まとめ

 

本記事では、Annealed Importance Samplingの理論と実装を紹介しました。

より複雑な確率分布に使用して、遊んでみてください!

ABOUT ME
努力のガリレオ
【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点
Python学習を効率化させるサービス

 

Pythonを学習するのに効率的なサービスを紹介していきます。

まず最初におすすめするのは、Udemyです。

Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。

また、セール中は1500円定義で利用することができ、コスパも最強です。

下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。

 

【最新】UdemyでおすすめのPythonコース|東大生が厳選!10万を超える講座があるUdemyの中で、Pythonに関係する講座を厳選しました。また、本記事では、Udemyを使用しながらPythonをどのような順番で勉強するべきかを紹介しました。ぜひ参考にしてください。...

 

他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。

 

【最新】Pythonに強いプログラミングスクール7選|東大生が厳選Pythonの流行と共に、Pythonに強いプログラミングスクールが増えてきました。本記事では、特にPythonを効率的に学ぶことができるプログラミングスクールを経験をもとに厳選して、内容を詳しく解説しています。...

 

自分の学習スタイルに合わせて最適なものを選びましょう。

また、私がPythonを学ぶ際に使用した本を全て暴露しているので参考にしてください。