dlshogiの学習では、SWA(Stochastic Weight Averaging)を導入している。
今までは、1世代学習するごとに、平均化した重みを出力して、次の世代ではその重みを使用して学習していた。
しかし、SWAは通常複数エポックに渡って平均化してから、最後に平均化した重みを出力を行う。
また、Leela Chess Zeroでは、複数世代にわたって重みを平均化しているようなので、dlshogiもそのように変更した。
実装方法の変更
dlshogiにSWAを実装したときは、まだPyTorchに正式にSWAが組み込まれていなかったため、contribのソースをコピーして使用していたが、このソースはSWAモデルの保存/読み込みに対応していなかったので、PyTorch 1.7以降で使えるようになったAveragedModelを使用するように変更した。
forwardの複数引数対応
AveragedModelは、モデルのforwardの引数が一つの場合しか想定されていない。
dlshogiでは、2つの引数を使用するため、そのままでは利用できなかった。
そこで、AveragedModelのインスタンスのforwardメソッドを一時的に書き換える処理を行った。
以下のように2つの引数を1つの辞書で受け取って、それを展開して元のメソッドに渡すようにした。
forward_ = swa_model.forward swa_model.forward = lambda x : forward_(**x) update_bn(hcpe_loader(train_data, args.batchsize), swa_model) del swa_model.forward
データローダでは、2つの入力を辞書で渡すようにした。
def hcpe_loader(data, batchsize): for x1, x2, t1, t2, value in Hcpe3DataLoader(data, batchsize, device): yield { 'x1':x1, 'x2':x2 }
精度比較
今までの1世代ごとに平均化する方法と、複数世代に渡って平均化する方法で精度を比較した。
訓練データには、dlshogiの強化学習で生成した3世代分のデータ(各世代1億局面)を使用した。
1世代目はSWAなしで学習し、2世代目からSWAありで学習した。
今までの方法では、2世代目平均化した重みを出力して、それを読み込んで3世代目を学習した。
新しい方法では、2世代目と3世代目で続けて重みを平均化した。
4回測定し、テスト損失の平均で評価した。
テストデータには、floodgateのレート3500以上の対局の棋譜からサンプリングした856,923局面を使用した。
テスト方策損失 | テスト価値損失 | |
---|---|---|
今までの方法 | 1.782588468 | 0.522283693 |
新しい方法 | 1.780505845 | 0.521579308 |
SWAなし | 1.84344225 | 0.545063128 |
新しい方法では、今までの方法より、方策、価値ともに損失が低下しており精度が向上している。
今までの方法でも、SWAなしよりは精度が高い。
使用方法
新しい方法は、dlshogiのtrain_hcpe3.pyに実装している。
複数世代でSWAの重みを引き継げるように、モデルと状態の保存形式を、PyTorchのcheckpoint形式に変更した。
今まではChainerからの互換性を引きずっていたので、モデルと状態は別々のファイルになっていた。
--checkpointと--resume
「--checkpoint」で、保存ファイル名を指定する。
PyTorchのチュートリアルによると拡張子に*.pthをつけるようだ。
読み込みは「--resume」に「--checkpoint」で保存したファイル名を指定する。
--initmodelと--stateの廃止
今まで使用していた「--initmodel(-m)」は廃止する(互換性のため残している)。
互換性のため「--resume」で今までの「--state」で保存した状態も読めるようにしている。
「--state」は廃止した。
--model
「--model」で、SWAの重みを反映したモデルを出力する。
このファイルは、これまでのモデルファイルと互換性がある。
その際、Batch Normalizationの再計算を行う。
ここが時間がかかるので、途中でSWAを反映したモデル出力が必要なければ、「--checkpoint」だけで学習は継続できる。
なお、Batch Normalizationの再計算は全訓練データの順伝播を行うため時間がかかるが、使用する訓練データの量が精度にかなり影響するため、減らさない方がよい。
今回の変更は、train_hcpe3.pyにしか反映していない(train_hcpe3.pyは、train_hcpe.pyと同等の学習もできる)。
train_hcpe.pyは互換性のために変更していない。
まとめ
複数世代に渡って重みを平均化できるように、SWAの実装を変更した。
新しい方法の方が精度が上がることが確かめられた。