ガンベル最大トリックを解説して実装してみる

こんにちは、ぐぐりら(@guglilac)です。
久しぶりの投稿ですね。
ゆるーくやってます笑

今回はガンベル最大トリックについて解説して、試しに実装してみるとこまでやってみます。

実際に使うことはあんまなさそうだけれど、参考書読んでて気になったので。
自分用の備忘録も兼ねて。

ガンベル最大トリックとは


とはいうものの、ガンベル最大トリックってなんぞ?という感じだと思います。
僕も参考書読んで初めて知りました。ぐぐってもそんな出てこないので普及した名前ではないのかも。別名があるのかなあ。

ガンベル最大トリックは、サンプリング手法の一種です。
深層学習の出力層で使うそうです。

深層学習の文脈では、活性化関数のsoftmax関数をかけて各クラスに属する確率分布を計算することがあります。
softmax関数に入力する前のデータをスコアと呼び、一般にいろんな大きさの値をとります。

softmax関数はこのスコアを受け取り、正規化することで成分の和が1の確率ベクトルを得ます。
スコアを$u$、確率ベクトルを$p$とすると、softmax関数はこのように計算します。

\[p_i=\frac{\exp u_i}{\sum_i u_i}\]

この各クラスに属する確率のうち、一番高いクラスを出力する決定論的な出力を行うのが一般的ではありますが、この確率に応じて無作為にサンプリングしたクラスを出力するという構成を持つニューラルネットワークモデルもあります。

この際に用いられるサンプリング手法が、ガンベル最大トリックです。
ガンベル最大トリックを用いると、softmax関数をかけて確率を計算することなく、スコアからそのままサンプリングできます。

アルゴリズム

ガンベル最大トリックの疑似コードを示します。


for i 1 to N:
    $r$~U(0,1)
    $g=-\log(−\log r)$
    $z_i=u_i+g$

return argmax z 
U(0,1)は0から1の間の一様分布で、rはその分布に従う確率変数です。
uは先ほど述べたスコアです。
このようにして各スコアに対して$z_i$を計算し、$z$の成分の中で最大の要素を持つインデックスを返します。

このように計算したインデックスが、ちょうどsoftmax関数によって得られる確率に従う確率変数のサンプリングになっています。

なにをやっているのか、どうしてこれで正しくサンプリングできるのかを解説します。

解説と証明

疑似コードを日本語訳すると、
  1. 一様分布に従う確率変数$r$を一つ得る
  2. $r$を用いてガンベル分布に従う確率変数$g$を一つ得る
  3. gをスコアに足して$z_i$を得る
以上をスコアの全成分にそれぞれ行い、一番大きい$z_i$を持つ$i$を返しています。

一様分布はいいと思いますが、ガンベル分布は耳慣れないと思います。
ガンベル分布の累積分布関数$F(x)$、確率密度関数$f(x)$はそれぞれ次のようになります。

 \[F(x)=\exp(-\exp(-x))\]
 \[f(x)=\exp(x)F(x)\]

このガンベル分布に従う確率変数をサンプリングする際に$g=-\log(−\log r)$のように計算しています。

これは逆関数法というテクニックによるものです。
確率変数の欲しい累積分布関数の逆関数を陽に求めることができるときに使える便利な手法です。

確率変数の欲しい累積分布関数の逆関数に(0,1)の一様分布からサンプリングした値を入れると、返り値がその確率分布に従った確率変数になります。

今回は、ガンベル分布の累積分布関数$F(x)$の逆関数が$F^{-1}(x)=-\log(−\log x)$と求めることができるため、この手法を用いることができています。

逆関数が計算できれば他の分布でも使えるので、覚えとくといいかもしれません。

逆関数法の証明

なぜ逆関数法でサンプリングできるのかを証明します。
$r$を(0,1)一様分布に従う確率変数とすると、定義から
\[P(r\le u)=u\]

累積分布関数は単調増加関数なので、逆関数も単調増加。よって

\[P(F^{-1}(r)\le F^{-1}(u))=u\]

$x=F^{-1}(u)$とおくと、$u=F(x)$なので

\[P(F^{-1}(r)\le x)=F(x)\]

この通り、$F^{-1}(r)$の累積分布関数が$F(x)$になっています。

ガンベル最大トリックの正当性証明

$z_k$が最大になる確率がsoftmax関数をかけて得られる$p_k$に等しいことを示します。
「$z_k$が最大になる」事象をAと表記します。

$g_k$が与えられた条件下での事象Aが起こる確率は
$i \neq k$の全ての$i$で$z_k \ge z_i$となる確率です。
各$i$についてその確率を計算すると、以下のようになります。

\begin{eqnarray*}P(z_k \ge z_i |g_k)&=&P(u_k+g_k\ge u_i+g_i|g_k)\\&=&P(u_k+g_k-u_i\ge g_i|g_k)\\ &=&F(u_k+g_k-u_i)\end{eqnarray*}

最後は$g_i$がガンベル分布に従うことを用いています。

各$i$について独立なので、単純に$i \neq k$で積を取れば良く、$g_k$が与えられた条件下での事象Aが起こる確率は

\[P(A|g_k)=\prod_{i \neq k}F(u_k-u_i+g_k)\]

となります。

あとは$g_k$について周辺化すれば終わりです。
softmax関数の定義

\[p_k=\frac{\exp u_k}{\sum_i u_i}\]

に注意して計算を進めると以下のようになります。

\begin{eqnarray*}\int_{-\infty}^{\infty} f(g_k)\prod_{i \neq k}F(u_k-u_i+g_k)dg_k&=&\int_{-\infty}^{\infty}\exp(-g_k)F(g_k)\prod_{i \neq k}F(u_k-u_i+g_k)dg_k\\&=&\int_{-\infty}^{\infty}\exp(-g_k)\prod_{i}F(u_k-u_i+g_k)dg_k\\&=&\int_{-\infty}^{\infty}\exp(-g_k)\exp(-\frac{\exp(-g_k)}{p_k})dg_k\\&=&p_k\end{eqnarray*}

$z_k$が最大になる確率がsoftmax関数をかけて得られる$p_k$に等しいことが示せました。
$i=k$の部分を総乗の中に含めることができるのがポイントですね。
最後の積分の過程は省略しました。置換してほげほげする高校数学です。

シミュレーション

解説だけしてもあれなので、実際にサンプリングしてみました。
Python3で実装しています。
スコアは$[1,2,3,4,5]$で、10000回サンプリングして各インデックスが選ばれる確率を算出しました。

ほぼ理論値と一致してますね。

ソースコード

単純なアルゴリズムなので必要ないかもですが一応ソースコードも載せておきます。
あ、このスコアだと大丈夫ですが、スコアの値によっては$\exp$の計算のときにオーバーフローしたりするのでその辺は注意したほうがいいかもです。

 
import numpy as np
import math
import matplotlib.pyplot as plt
u = list(range(1, 6))
N = 10000


def gumbel(u):
    classes = len(u)
    r = np.random.rand(classes)
    z = []
    for i in range(classes):
        g = -math.log(-math.log(r[i]))
        z.append(u[i] + g)
    return np.argmax(np.array(z))


def softmax(u):
    classes = len(u)
    denom = sum([math.exp(u[i]) for i in range(classes)])
    return [math.exp(u[i]) / denom for i in range(classes)]


def simulate(u, N):
    classes = len(u)
    result = np.zeros(classes)
    for i in range(N):
        result[gumbel(u)] += 1
    return result / N


def plot(result, p, u):
    w = 0.4
    classes = len(u)
    left1 = list(range(1, classes + 1))
    left2 = [left1[i] + w for i in range(len(left1))]
    labels = [str(u[i]) for i in range(classes)]
    plt.bar(left1, result, align='center',
            color='#FFA0A0', width=w, label='Result')
    plt.bar(left2, p, align='center', color='#A0A0FF',
            width=w, label='True Value')
    plt.legend()

    label_left = [left1[i] + w / 2 for i in range(len(left1))]
    plt.xticks(label_left, labels)
    plt.show()
    return


result = simulate(u, N)
p = softmax(u)
print("Result{}".format(result))
print("True Value:{}".format(p))

plot(result, p, u)

まとめ

ガンベル最大トリックによるサンプリングを解説して、証明して、実装してみました。
実際に使うことはあまりないかもしれませんが、うまくいくとやっぱり気持ち良いもんですね。

個人的には、初めてブログで数式載せたのが大変でした笑
また一歩技術ブログ的に進歩したので良かったです。

ありがとうございました。

コメント