TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

将棋AIの実験ノート:活性化関数Swishを試す

画像認識で高い精度を達成しているEfficientNetで使われている活性化関数Swishが将棋AIでも効果があるか試してみた。

EfficientNetでは、残差ブロックにMBConvというアーキテクチャが使用されており、その中の活性化関数にSwishが使用されている。
MBConvについても別途検証を行ったが、今回はSwishの効果について記す。

Swishについて

Swishは以下の式で定義される活性化関数である。
\displaystyle
x \cdot \sigma(\beta x)
ここで、\sigma(z)シグモイド関数\sigma(z)=(1+exp(-z))^{-1}である。

グラフの形は以下のようになる。
f:id:TadaoYamaoka:20200814131227p:plain

どのような場合でもReLUに代わるものではないが、多くのタスクでReLUよりも精度が向上することが報告されている。

PyTorchのSwish実装

GitHubで公開されているPyTorchのEfficientNetの実装からSwishのコードを流用した。

この実装では、微分の計算が効率化されている。

ただし、Onnxに出力する際はエラーになるようなので、オリジナル実装に切り替える必要がある。

dlshogiのネットワークへの適用

dlshogiのpolicy_value_network.pyのreluをswishに置き換えたソースがこちら。
DeepLearningShogi/policy_value_network_swish.py at feature/experimental_network · TadaoYamaoka/DeepLearningShogi · GitHub

比較実験

測定条件

dlshogiの強化学習で生成した教師局面を使用して、教師あり学習を行い、floodgateからサンプリングした856,923局面に対する損失と一致率で評価する。
強化学習で生成した局面は、1サイクルあたり約270万局面生成しており、それを過去10サイクル分を1回のデータセットとして学習する。
8サイクル分のデータを学習した時点で比較した。

学習率:0.001(最初の1サイクルのみ0.01)
バッチサイズ:1024
最適化:Momentum SGD(係数0.9)
学習則:dlshogiと同じ

学習には、ampを使用した。

測定結果
条件 訓練損失平均 テスト損失 policy一致率 value一致率
Swishなし 0.85036430 1.62555117 0.43175799 0.70151462
Swishあり 0.86290113 1.60377155 0.43336603 0.71022165

f:id:TadaoYamaoka:20200814143424p:plain
f:id:TadaoYamaoka:20200814143705p:plain
f:id:TadaoYamaoka:20200814143714p:plain

テストデータに対する、損失、policyの一致率、valueの一致率すべてでSwishを使用した方が良い結果になった。

まとめ

dlshogiのResnetを使用したモデルの学習にSwishを使用すると精度が改善されることが確認できた。

今回は初期値からの学習を行ったが、現在学習済みのモデルに対して、活性化関数を変更した場合にどうなるかについても確認したい。
途中から変更しても効果があるなら、現在継続している学習に変更を取り入れる予定である。

Swishは深いモデルでうまく機能する傾向があるそうだが、もしかしたらNNUEの学習でも効果があるのかもしれないので(私は試さないので)興味のある方は試してみてほしい。