TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

将棋AIの実験ノート:活性化関数Mishを試す

以前にdlshogiのモデルで活性化関数をReLUからSwishにした場合の比較を行った。

今回は、活性化関数Mishを試した。

Mish

Mishは、
\displaystyle
f(x)=x \tanh(softplus(x))
で表される活性化関数である。

論文によると、6層CNNのCIFAR-10の訓練で、Swishの正解率を上回ると報告されている。
[1908.08681] Mish: A Self Regularized Non-Monotonic Activation Function

PyTorchの実装

Mishの論文が出た直後に、PyTorchにリクエストのissueが上がっていたが、コミュニティに広く受け入れられていないという理由でリジェクトされていた。

しかし、PyTorch 1.9で、標準機能として実装されていた。
Mish — PyTorch 1.9.1 documentation

PyTorchに実装されていないため、dlshogiで試すのを保留していたが、標準で実装されたので試してみた。

比較条件

dlshogiのResNet15ブロック224フィルタのモデルの活性化関数を、ReLUとSwish、Mishにした場合で比較を行った。

訓練データには、dlshogiの最新モデルの自己対局で生成した50000157局面(同一局面の平均化後は41895205局面)を使用した。
dlshogi.trainのオプションは、「--use_average」、「--use_evalfix」、「--use_amp」、「--use_swa」を有効にした。

バッチサイズ4096で、学習率0.04から始めて、エポックごとに半減、2エポックを学習して、SWAモデルのテストデータに対する精度を比較した。

それぞれ4回測定し、平均で比較した。

テストデータには、floodgateのレート3500以上の対局の棋譜からサンプリングした856,923局面を使用した。

比較結果

方策損失 価値損失 方策正解率 価値正解率 方策エントロピー
ReLU 1.708107425 0.543135325 0.45590225 0.712226825 1.67308225
Swish 1.66914235 0.5330944 0.4644194 0.720270625 1.631262975
Mish 1.678887075 0.53691395 0.462033475 0.7182632 1.6450056
正解率比較グラフ


訓練時間
時間/epoch
ReLU 2:08:18
Swish 2:12:37
Mish 2:15:17

考察

精度は、方策、価値ともに、Swish > Mish > ReLU の順になった。
訓練時間は、ReLU < Swish < Mishの順になった。

MishはSwishと比べると、精度は低くなり、訓練時間も長くなることがわかった。
ReLUと比べると、精度は高くなっているが、訓練時間が長くなる。

まとめ

活性化関数MishがPyTorchの標準に取り込まれたので、dlshogiのモデルで精度が上がるか試した。
結果、Swishの方が精度が高く、訓練時間も短いという結果であった。
dlshogiで、Mishを採用するメリットはないと言える。