TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

将棋AIの進捗 その38(SWA)

dlshogiの学習にSWA(Stochastic Weight Averaging)を実装して、測定した。

SWA

SWA(Stochastic Weight Averaging)は、一定間隔での重みを平均化することで、ニューラルネットワークのテスト精度を改善するテクニックである。
一般的なアンサンブルの手法では予測の結果を平均化するが、SWAでは重みを平均化することで実現する。

SWAの実装

SWAの実装は、PyTorchの実装を参考にした。
この実装では、学習開始時点からの平均を計算するようになっているが、強化学習に適用する場合は、古いステップの重みは忘れていった方が良い。
そこで、平均の代わりに、指数移動平均を使用するように変更した。
https://github.com/TadaoYamaoka/DeepLearningShogi/blob/master/dlshogi/swa.py

Leela Chess Zeroや、KataGoでも、指数移動平均を使用している。

重みを平均化する間隔はKataGoの設定を参考にして250バッチ(バッチサイズ1024)間隔、区間は適当に10とした。

効果測定

dlshogiの強化学習で生成した12サイクル分(1サイクル250万局面、1回の学習で過去10サイクル分を使用)のデータを学習して、188サイクル学習したモデルから追加で学習して効果を測定した。
SWAとdropoutは、どちらもアンサンブルの効果を狙ったもので、目的が重複しているためdropoutなしの場合も比較した。
テストデータには、floodgateの棋譜を使用した。

グラフ凡例
SWA ドロップアウト
master なし あり
nodropout なし あり
swa あり なし
swadropout あり あり
方策のテスト損失

f:id:TadaoYamaoka:20200118134510p:plain

価値のテスト損失

f:id:TadaoYamaoka:20200118134555p:plain

Q値(評価値)のテスト損失

f:id:TadaoYamaoka:20200118134624p:plain

テスト損失の合計

f:id:TadaoYamaoka:20200118135607p:plain

方策の正解率

f:id:TadaoYamaoka:20200118134640p:plain

価値の正解率

f:id:TadaoYamaoka:20200118134658p:plain

考察

方策のテスト損失を見ると、ドロップアウトありなしで、損失の値に開きがあるがこれは、学習時にドロップアウトにより素子が無効化されている影響である。
ドロップアウトの有無の条件が同じもの同士では、SWAを使った方が値が安定しており、最終的にわずかに損失が小さくなっている。

価値のテスト損失も同じ傾向である。
Q値のテスト損失は、ドロップアウトなしの場合SWAなしが最終的な損失が小さくなっているが、損失の合計ではSWAありの方が良い。


方策の正解率は、ドロップアウトなしの方が良く、ドロップアウトの有無が同じ条件では、SWAありの方が良い。
価値の正解率は、ドロップアウトなしの方が良く、ドロップアウトなしの場合SWAなしの方が最終的にわずかに高くなっているが各サイクルで逆転が起きている。


全体的に、SWAがある方が学習が安定して、テスト精度はわずかに高くなることが分かった。
また、ドロップアウトがない方が、テスト正解率が高くなる。

学習時間

SWAを使用した場合の学習時間を比較した。
各サイクルの学習時間平均は以下の通りとなった。

master 2:18
nodropout 2:19
swa 3:02
swadropout 3:03

SWAを使用すると学習時間が、1.3倍になる。
SWAは、学習終了時に重みを平均化したモデルでBatchNormalizationの統計を計算しなおす必要があるため、訓練データを使って順伝播を計算し直す必要がある。
そのため、学習時間が長くなっている。

まとめ

SWAを使用することで、floodgateの棋譜に対するテスト精度が少しだけ上がることが確認できた。
学習時間が長くなるが、学習中も強化学習で局面を生成しているので、それほど問題にはならないのでSWAを採用することにする。