年末に新しいCPUが届いたので、正月はPCを組んでいた。
同時にフルタワーケースを買ったのだが、GPU3枚だと熱対策をしないと安定動作しなかったので、ドリルで加工してファンを増設したりと正月から働いてしまったorz
安定動作するようになったので、前回記事にしたUSIエンジンをリーグに加えた強化学習を走らせている。
10サイクルほど学習したら結果を記事にする予定である。
SWA(Stochastic Weight Averaging)
Leela Chess Zeroの最新動向を調べていて、SWA(Stochastic Weight Averaging)という手法がdlshogiでも効果がありそうなので、試してみることにした。
SWAは、訓練中にネットワークの重みを一定間隔置きに平均化することで、局所最適ではなくグローバル最適に収束させる手法である。
SWAの実装は、ChainerではGitHubで実装例が公開されている。
PyTorchでは、オプティマイザの拡張機能として提供されており、PyTorchの方が簡単に実装できる。
Stochastic Weight Averaging in PyTorch | PyTorch
dlshogiの学習部にはChainerを使用していたが、Chainerは開発が終了したため、Chainerで実装するよりも、今後も考えてPyTorchに移行してから、SWAを試すことにした。
PyTorchに移行
ChainerからPyTorchへの移行は、APIが類似しており、ほぼAPIが1対1で対応しているため、比較的容易に移行できる。
以下では、単純な変換で対応できなかった点を記す。
モデルの保存形式
dlshogiでは対局プログラムはcuDNNのAPIを直接使用しており、Chainerで保存したモデルをC++で読み込んでいる。
PyTorchの標準的な方法では、モデルはPythonのPickle形式に、パラメータのstate_dictを格納して保存される。
C++のプログラム側でこれを読み込むのは苦労するので、モデルの保存形式は、Python側でChainerと互換性を持たせて保存することにした。
Chainerの保存形式は、Numpy標準のnpz形式なので、state_dictに格納されているtorch.TensorをNumpyに変換してからnp.savez()で保存するればよい。
また、畳み込み層や、全結合層のパラメータとバイアスの名前が違うために、互換性を持たすには変換が必要だった。
Chainer | PyTorch |
---|---|
W | weight |
b | bias |
また、BatchNormのパラメータの名前は、PyTorchでは、畳み込みと全結合と同じweightとbiasが使用されているため、保存時に層の種類を判定する必要がある。
処理を簡易にするため、層の名前にnormもしくはbnを含むかで区別するようにした。
BatchNormのパラメータの対応は以下のようになる。
Chainer | PyTorch |
---|---|
gamma | weight |
beta | bias |
avg_mean | running_mean |
avg_var | running_var |
N | num_batches_tracked |
これで、C++側のプログラムの修正なしに、モデルが読み込めるようになった。
交差エントロピー誤差
Chainerでは、2値分類の交差エントロピー誤差の正解データには、0か1の整数を与える必要がある。
dlshogiでは、価値関数をブートストラップするために、自己対局中の探索結果の価値と、ゲーム結果の両方を損失として使用している。
前者の損失には、2確率変数の交差エントロピーを計算する必要があるため、自作の損失関数を使用していた。
PyTorchでは、BCEWithLogitsLossが正解データに確率変数を与えることができるため、自作の処理が不要になった。
後者の損失でも、引き分けを学習するため、同様に自作の損失関数を使用していたが、BCEWithLogitsLossを使用することができた。
正解率の計算
Chainerには、正解率の計算をするaccuracy関数とbinary_accuracy関数があるが、PyTorchでは用意されていない。
そのため、以下のような自作の関数を作成した。
def accuracy(y, t): return (torch.max(y, 1)[1] == t).sum().item() / len(t) def binary_accuracy(y, t): pred = y >= 0 truth = t >= 0.5 return pred.eq(truth).sum().item() / len(t)
これは標準で用意されていてもよいと思うのが。
まとめ
ChainerからPyTorchへの移行は比較的容易にできた。
今まで自作でがんばっていたがPyTorchで標準できるようになった箇所もあるが、逆にChainerにあった便利な機能がない箇所もあった。
ChainerからPyTorchへの移行できたので、次はSWAを試す予定である。