AlphaZeroのMCTSのルートノードではディリクレノイズを加えることで、全ての手をランダムで選ばれやすくしている。
以前の記事で、2次元のディリクレ分布を可視化したが、3次元の場合の可視化ができないか調べていたら、以下のページを見つけたので試してみた。
Visualizing Dirichlet Distributions with Matplotlib
x, y, zを2次元の正3角形に投影して、確率密度を色で表示している。
本当は、x, y, zを3角形の平面に投影して、確率密度を高さの軸にして3Dのグラフにしたかったが、やり方がわからなかった。
上記のページは、python2用のコードになっていたので、python3で動かすには、
import functools
を追加して、
reduceをfunctools.reduceに置換する。
[0.03, 0.03, 0.03]のディリクレ分布を可視化するコードは以下の通り。
%matplotlib inline import numpy as np import matplotlib.pyplot as plt import matplotlib.tri as tri corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]]) triangle = tri.Triangulation(corners[:, 0], corners[:, 1]) # Mid-points of triangle sides opposite of each corner midpoints = [(corners[(i + 1) % 3] + corners[(i + 2) % 3]) / 2.0 \ for i in range(3)] def xy2bc(xy, tol=1.e-3): '''Converts 2D Cartesian coordinates to barycentric.''' s = [(corners[i] - midpoints[i]).dot(xy - midpoints[i]) / 0.75 \ for i in range(3)] return np.clip(s, tol, 1.0 - tol) class Dirichlet(object): def __init__(self, alpha): from math import gamma from operator import mul self._alpha = np.array(alpha) self._coef = gamma(np.sum(self._alpha)) / \ functools.reduce(mul, [gamma(a) for a in self._alpha]) def pdf(self, x): '''Returns pdf value for `x`.''' from operator import mul return self._coef * functools.reduce(mul, [xx ** (aa - 1) for (xx, aa)in zip(x, self._alpha)]) def draw_pdf_contours(dist, nlevels=200, subdiv=8, **kwargs): import math refiner = tri.UniformTriRefiner(triangle) trimesh = refiner.refine_triangulation(subdiv=subdiv) pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)] plt.tricontourf(trimesh, pvals, nlevels, **kwargs) plt.axis('equal') plt.xlim(0, 1) plt.ylim(0, 0.75**0.5) plt.axis('off') draw_pdf_contours(Dirichlet([0.03, 0.03, 0.03]))
実行結果
よく見ないとわからないが、角の部分の色が赤色(=高い値)になっている。
それぞれの角は(1, 0, 0)、(0, 1, 0)、(0, 0, 1)を意味する。
つまり、いずれかの手が選ばれやすくなる。
[0.03, 0.03, 0.03]だと見にくいので、[0.9, 0.9, 0.9]とすると、
draw_pdf_contours(Dirichlet([0.9, 0.9, 0.9]))
角の値の色が変わっているのが分かりやすくなった。
alphaが正の値の場合、
draw_pdf_contours(Dirichlet([5, 5, 5]))
中央の値が選ばれやすくなる。
乱数生成
Pythonでディリクレ分布に従って乱数を生成するには、numpy.random.dirichletを使用する。
import numpy as np [np.random.dirichlet([0.03, 0.03, 0.03]) for _ in range(10)]
[array([ 9.30062319e-13, 2.57427053e-04, 9.99742573e-01]), array([ 1.00000000e+00, 2.17381455e-19, 6.85158127e-35]), array([ 1.00000000e+00, 1.16336290e-15, 3.83232435e-19]), array([ 8.76967918e-01, 1.15399381e-36, 1.23032082e-01]), array([ 1.94448447e-45, 5.91371213e-01, 4.08628787e-01]), array([ 5.51018131e-56, 1.00000000e+00, 1.32521447e-17]), array([ 2.68526987e-21, 5.73712429e-01, 4.26287571e-01]), array([ 1.00000000e+00, 7.11978680e-14, 4.22528386e-24]), array([ 3.84802928e-58, 5.40350337e-20, 1.00000000e+00]), array([ 1.56164809e-01, 8.43835191e-01, 2.39362055e-11])]
どれか一つが選ばれやすくなっているのが分かる。