以前にdlshogiのモデルで活性化関数をReLUからSwishにした場合の比較を行った。
今回は、活性化関数Mishを試した。
Mish
Mishは、
で表される活性化関数である。
論文によると、6層CNNのCIFAR-10の訓練で、Swishの正解率を上回ると報告されている。
[1908.08681] Mish: A Self Regularized Non-Monotonic Activation Function
PyTorchの実装
Mishの論文が出た直後に、PyTorchにリクエストのissueが上がっていたが、コミュニティに広く受け入れられていないという理由でリジェクトされていた。
しかし、PyTorch 1.9で、標準機能として実装されていた。
Mish — PyTorch 1.10 documentation
PyTorchに実装されていないため、dlshogiで試すのを保留していたが、標準で実装されたので試してみた。
比較条件
dlshogiのResNet15ブロック224フィルタのモデルの活性化関数を、ReLUとSwish、Mishにした場合で比較を行った。
訓練データには、dlshogiの最新モデルの自己対局で生成した50000157局面(同一局面の平均化後は41895205局面)を使用した。
dlshogi.trainのオプションは、「--use_average」、「--use_evalfix」、「--use_amp」、「--use_swa」を有効にした。
バッチサイズ4096で、学習率0.04から始めて、エポックごとに半減、2エポックを学習して、SWAモデルのテストデータに対する精度を比較した。
それぞれ4回測定し、平均で比較した。
テストデータには、floodgateのレート3500以上の対局の棋譜からサンプリングした856,923局面を使用した。
比較結果
方策損失 | 価値損失 | 方策正解率 | 価値正解率 | 方策エントロピー | |
---|---|---|---|---|---|
ReLU | 1.708107425 | 0.543135325 | 0.45590225 | 0.712226825 | 1.67308225 |
Swish | 1.66914235 | 0.5330944 | 0.4644194 | 0.720270625 | 1.631262975 |
Mish | 1.678887075 | 0.53691395 | 0.462033475 | 0.7182632 | 1.6450056 |
訓練時間
時間/epoch | |
---|---|
ReLU | 2:08:18 |
Swish | 2:12:37 |
Mish | 2:15:17 |
考察
精度は、方策、価値ともに、Swish > Mish > ReLU の順になった。
訓練時間は、ReLU < Swish < Mishの順になった。
MishはSwishと比べると、精度は低くなり、訓練時間も長くなることがわかった。
ReLUと比べると、精度は高くなっているが、訓練時間が長くなる。
まとめ
活性化関数MishがPyTorchの標準に取り込まれたので、dlshogiのモデルで精度が上がるか試した。
結果、Swishの方が精度が高く、訓練時間も短いという結果であった。
dlshogiで、Mishを採用するメリットはないと言える。