TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

エントロピー正則化項の微分

以前に方策が決定論的にならないようにするために、損失にエントロピー正則化項を加えることを書いたが、誤差逆伝播する際の微分の式が誤っていたので訂正する。

方策がソフトマックス関数の場合のエントロピー微分

エントロピーは以下の式で与えられる。
\displaystyle
H(p) = -\sum_j p_j \log p_j   \tag{1}

pがソフトマックス関数の場合、ロジットをzとすると、エントロピー偏微分は、以下の通りシンプルな式になる。
\displaystyle
\frac{\partial H(p)}{\partial z_i} = p_i(-\log p_i - H(p))   \tag{2}

この式は、以下の論文に記載されていた。
[1701.06548] Regularizing Neural Networks by Penalizing Confident Output Distributions

式の導出

以下、この式の導出について記載する。

上記の式は、Σと偏微分がうまく消えていてすっきりした形になっている。
論文には式の展開が記載されていないため、すぐにはどのように導出したのかわからなかった。
しかし、微分の連鎖則、積の微分公式、ソフトマックスの微分エントロピーの定義と確率の合計が1になることを使うと導出できる(ネットを検索してもこの式について書かれている情報がなかったので自分で導出したが、思いつくのに1日くらいかかった)。

まず、微分の連鎖則を使うと偏微分の式は以下の通りに書ける。
\displaystyle
\begin{eqnarray}
\frac{\partial H(p)}{\partial z_i} &=& \frac{\partial}{\partial z_i} (-\sum_j p_j \log p_j) \\
&=& -\sum_j ( \frac{\partial}{\partial p_j} (p_j  \log p_j) \cdot \frac{\partial p_j}{\partial z_i} )  \tag{3}
\end{eqnarray}

p_j  \log p_j微分は、積の微分公式を使うと、1 + \log p_jとなるので、
\displaystyle
\frac{\partial H(p)}{\partial z_i} = -\sum_j ( (1 + \log p_j) \cdot \frac{\partial p_j}{\partial z_i} ) \tag{4}
となる。

\frac{\partial p_j}{\partial z_i}は、pがソフトマックス関数なので、ソフトマックス関数の微分から、
\displaystyle
\frac{\partial p_j}{\partial z_i} = \left\{ \begin{array}{ll}
    p_i(1-p_i) & (i = j) \\
    -p_i p_j & (i \ne j)
  \end{array} \right.   \tag{5}
となる。

したがって、式(4)のΣの中は、
\displaystyle
\begin{eqnarray}
(1 + \log p_j) \cdot \frac{\partial p_j}{\partial z_i} &=& \left\{ \begin{array}{ll}
    (1 + \log p_i) p_i(1-p_i) & (i = j) \\
    (1 + \log p_j) ( -p_i p_j ) & (i \ne j)
  \end{array} \right. \\

&=& \left\{ \begin{array}{ll}
    p_i (1 - p_i + \log p_i - p_i \log p_i ) & (i = j) \\
    p_i (-p_j - p_j \log p_j) & (i \ne j)
  \end{array} \right.
\tag{6}
\end{eqnarray}
となる。

これを式(4)のΣに戻すと、i = ji \ne jをまとめられるため、
\displaystyle
\frac{\partial H(p)}{\partial z_i} = - p_i (1 -\sum_j p_j + \log p_i -\sum_j (p_j \log p_j)) \tag{7}
となる。

ここで、確率の合計は1になるため\sum_j p_j = 1となり、エントロピーの定義から、-\sum_j (p_j \log p_j) = H(p)となるため、
\displaystyle
\frac{\partial H(p)}{\partial z_i} = p_i (- \log p_i -H(p)) \tag{8}
となり、式(2)が導出できた。


この式を使うと、エントロピー正則化項の誤差逆伝播が効率よく計算できる。
ただ、この式を使わなくても、ディープラーニングフレームワークの計算グラフを使えば、

loss = policy_loss + beta * F.mean(F.sum(F.softmax(z) * F.log_softmax(z), axis=1))

のように記述すれば、多少効率は落ちるがフレームワークがよろしくやってくれるので中身の式は気にする必要はない。