TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

将棋AIの進捗 その57(SWAの修正)

dlshogiの学習では、SWA(Stochastic Weight Averaging)を導入している。

今までは、1世代学習するごとに、平均化した重みを出力して、次の世代ではその重みを使用して学習していた。
しかし、SWAは通常複数エポックに渡って平均化してから、最後に平均化した重みを出力を行う。
また、Leela Chess Zeroでは、複数世代にわたって重みを平均化しているようなので、dlshogiもそのように変更した。

実装方法の変更

dlshogiにSWAを実装したときは、まだPyTorchに正式にSWAが組み込まれていなかったため、contribのソースをコピーして使用していたが、このソースはSWAモデルの保存/読み込みに対応していなかったので、PyTorch 1.7以降で使えるようになったAveragedModelを使用するように変更した。

forwardの複数引数対応

AveragedModelは、モデルのforwardの引数が一つの場合しか想定されていない。
dlshogiでは、2つの引数を使用するため、そのままでは利用できなかった。

そこで、AveragedModelのインスタンスのforwardメソッドを一時的に書き換える処理を行った。

以下のように2つの引数を1つの辞書で受け取って、それを展開して元のメソッドに渡すようにした。

        forward_ = swa_model.forward
        swa_model.forward = lambda x : forward_(**x)
        update_bn(hcpe_loader(train_data, args.batchsize), swa_model)
        del swa_model.forward

データローダでは、2つの入力を辞書で渡すようにした。

def hcpe_loader(data, batchsize):
    for x1, x2, t1, t2, value in Hcpe3DataLoader(data, batchsize, device):
        yield { 'x1':x1, 'x2':x2 }

精度比較

今までの1世代ごとに平均化する方法と、複数世代に渡って平均化する方法で精度を比較した。

訓練データには、dlshogiの強化学習で生成した3世代分のデータ(各世代1億局面)を使用した。
1世代目はSWAなしで学習し、2世代目からSWAありで学習した。
今までの方法では、2世代目平均化した重みを出力して、それを読み込んで3世代目を学習した。
新しい方法では、2世代目と3世代目で続けて重みを平均化した。
4回測定し、テスト損失の平均で評価した。
テストデータには、floodgateのレート3500以上の対局の棋譜からサンプリングした856,923局面を使用した。

テスト方策損失 テスト価値損失
今までの方法 1.782588468 0.522283693
新しい方法 1.780505845 0.521579308
SWAなし 1.84344225 0.545063128

新しい方法では、今までの方法より、方策、価値ともに損失が低下しており精度が向上している。
今までの方法でも、SWAなしよりは精度が高い。

使用方法

新しい方法は、dlshogiのtrain_hcpe3.pyに実装している。
複数世代でSWAの重みを引き継げるように、モデルと状態の保存形式を、PyTorchのcheckpoint形式に変更した。
今まではChainerからの互換性を引きずっていたので、モデルと状態は別々のファイルになっていた。

--checkpointと--resume

「--checkpoint」で、保存ファイル名を指定する。
PyTorchのチュートリアルによると拡張子に*.pthをつけるようだ。
読み込みは「--resume」に「--checkpoint」で保存したファイル名を指定する。

--initmodelと--stateの廃止

今まで使用していた「--initmodel(-m)」は廃止する(互換性のため残している)。
互換性のため「--resume」で今までの「--state」で保存した状態も読めるようにしている。
「--state」は廃止した。

--model

「--model」で、SWAの重みを反映したモデルを出力する。
このファイルは、これまでのモデルファイルと互換性がある。

その際、Batch Normalizationの再計算を行う。
ここが時間がかかるので、途中でSWAを反映したモデル出力が必要なければ、「--checkpoint」だけで学習は継続できる。

なお、Batch Normalizationの再計算は全訓練データの順伝播を行うため時間がかかるが、使用する訓練データの量が精度にかなり影響するため、減らさない方がよい。


今回の変更は、train_hcpe3.pyにしか反映していない(train_hcpe3.pyは、train_hcpe.pyと同等の学習もできる)。
train_hcpe.pyは互換性のために変更していない。

まとめ

複数世代に渡って重みを平均化できるように、SWAの実装を変更した。
新しい方法の方が精度が上がることが確かめられた。