TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

ディリクレ分布の可視化

AlphaZeroのMCTSのルートノードではディリクレノイズを加えることで、全ての手をランダムで選ばれやすくしている。
P(x,a)=(1-\epsilon)p_a + \epsilon \eta_a, \eta \sim Dir(0.03)

以前の記事で、2次元のディリクレ分布を可視化したが、3次元の場合の可視化ができないか調べていたら、以下のページを見つけたので試してみた。
Visualizing Dirichlet Distributions with Matplotlib
x, y, zを2次元の正3角形に投影して、確率密度を色で表示している。
本当は、x, y, zを3角形の平面に投影して、確率密度を高さの軸にして3Dのグラフにしたかったが、やり方がわからなかった。

上記のページは、python2用のコードになっていたので、python3で動かすには、

import functools

を追加して、
reduceをfunctools.reduceに置換する。

[0.03, 0.03, 0.03]のディリクレ分布を可視化するコードは以下の通り。

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri

corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

# Mid-points of triangle sides opposite of each corner
midpoints = [(corners[(i + 1) % 3] + corners[(i + 2) % 3]) / 2.0 \
             for i in range(3)]
def xy2bc(xy, tol=1.e-3):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    s = [(corners[i] - midpoints[i]).dot(xy - midpoints[i]) / 0.75 \
         for i in range(3)]
    return np.clip(s, tol, 1.0 - tol)

class Dirichlet(object):
    def __init__(self, alpha):
        from math import gamma
        from operator import mul
        self._alpha = np.array(alpha)
        self._coef = gamma(np.sum(self._alpha)) / \
                     functools.reduce(mul, [gamma(a) for a in self._alpha])
    def pdf(self, x):
        '''Returns pdf value for `x`.'''
        from operator import mul
        return self._coef * functools.reduce(mul, [xx ** (aa - 1)
                                         for (xx, aa)in zip(x, self._alpha)])

def draw_pdf_contours(dist, nlevels=200, subdiv=8, **kwargs):
    import math

    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]

    plt.tricontourf(trimesh, pvals, nlevels, **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')

draw_pdf_contours(Dirichlet([0.03, 0.03, 0.03]))
実行結果

f:id:TadaoYamaoka:20171209223116p:plain

よく見ないとわからないが、角の部分の色が赤色(=高い値)になっている。
それぞれの角は(1, 0, 0)、(0, 1, 0)、(0, 0, 1)を意味する。
つまり、いずれかの手が選ばれやすくなる。

[0.03, 0.03, 0.03]だと見にくいので、[0.9, 0.9, 0.9]とすると、

draw_pdf_contours(Dirichlet([0.9, 0.9, 0.9]))

f:id:TadaoYamaoka:20171209223400p:plain
角の値の色が変わっているのが分かりやすくなった。

alphaが正の値の場合、

draw_pdf_contours(Dirichlet([5, 5, 5]))

f:id:TadaoYamaoka:20171209223735p:plain
中央の値が選ばれやすくなる。

乱数生成

Pythonでディリクレ分布に従って乱数を生成するには、numpy.random.dirichletを使用する。

import numpy as np
[np.random.dirichlet([0.03, 0.03, 0.03]) for _ in range(10)]
[array([  9.30062319e-13,   2.57427053e-04,   9.99742573e-01]),
 array([  1.00000000e+00,   2.17381455e-19,   6.85158127e-35]),
 array([  1.00000000e+00,   1.16336290e-15,   3.83232435e-19]),
 array([  8.76967918e-01,   1.15399381e-36,   1.23032082e-01]),
 array([  1.94448447e-45,   5.91371213e-01,   4.08628787e-01]),
 array([  5.51018131e-56,   1.00000000e+00,   1.32521447e-17]),
 array([  2.68526987e-21,   5.73712429e-01,   4.26287571e-01]),
 array([  1.00000000e+00,   7.11978680e-14,   4.22528386e-24]),
 array([  3.84802928e-58,   5.40350337e-20,   1.00000000e+00]),
 array([  1.56164809e-01,   8.43835191e-01,   2.39362055e-11])]

どれか一つが選ばれやすくなっているのが分かる。