TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

将棋AIの進捗 その35(PyTorchに移行)

年末に新しいCPUが届いたので、正月はPCを組んでいた。
同時にフルタワーケースを買ったのだが、GPU3枚だと熱対策をしないと安定動作しなかったので、ドリルで加工してファンを増設したりと正月から働いてしまったorz
安定動作するようになったので、前回記事にしたUSIエンジンをリーグに加えた強化学習を走らせている。
10サイクルほど学習したら結果を記事にする予定である。

SWA(Stochastic Weight Averaging)

Leela Chess Zeroの最新動向を調べていて、SWA(Stochastic Weight Averaging)という手法がdlshogiでも効果がありそうなので、試してみることにした。

SWAは、訓練中にネットワークの重みを一定間隔置きに平均化することで、局所最適ではなくグローバル最適に収束させる手法である。
SWAの実装は、ChainerではGitHub実装例が公開されている。

PyTorchでは、オプティマイザの拡張機能として提供されており、PyTorchの方が簡単に実装できる。
Stochastic Weight Averaging in PyTorch | PyTorch

dlshogiの学習部にはChainerを使用していたが、Chainerは開発が終了したため、Chainerで実装するよりも、今後も考えてPyTorchに移行してから、SWAを試すことにした。

PyTorchに移行

ChainerからPyTorchへの移行は、APIが類似しており、ほぼAPIが1対1で対応しているため、比較的容易に移行できる。

以下では、単純な変換で対応できなかった点を記す。

モデルの保存形式

dlshogiでは対局プログラムはcuDNNのAPI直接使用しており、Chainerで保存したモデルをC++で読み込んでいる。

PyTorchの標準的な方法では、モデルはPythonのPickle形式に、パラメータのstate_dictを格納して保存される。
C++のプログラム側でこれを読み込むのは苦労するので、モデルの保存形式は、Python側でChainerと互換性を持たせて保存することにした。

Chainerの保存形式は、Numpy標準のnpz形式なので、state_dictに格納されているtorch.TensorをNumpyに変換してからnp.savez()で保存するればよい。
また、畳み込み層や、全結合層のパラメータとバイアスの名前が違うために、互換性を持たすには変換が必要だった。

Chainer PyTorch
W weight
b bias

また、BatchNormのパラメータの名前は、PyTorchでは、畳み込みと全結合と同じweightとbiasが使用されているため、保存時に層の種類を判定する必要がある。
処理を簡易にするため、層の名前にnormもしくはbnを含むかで区別するようにした。
BatchNormのパラメータの対応は以下のようになる。

Chainer PyTorch
gamma weight
beta bias
avg_mean running_mean
avg_var running_var
N num_batches_tracked

これで、C++側のプログラムの修正なしに、モデルが読み込めるようになった。

交差エントロピー誤差

Chainerでは、2値分類の交差エントロピー誤差の正解データには、0か1の整数を与える必要がある。
dlshogiでは、価値関数をブートストラップするために、自己対局中の探索結果の価値と、ゲーム結果の両方を損失として使用している。
前者の損失には、2確率変数の交差エントロピーを計算する必要があるため、自作の損失関数を使用していた。
PyTorchでは、BCEWithLogitsLossが正解データに確率変数を与えることができるため、自作の処理が不要になった。

後者の損失でも、引き分けを学習するため、同様に自作の損失関数を使用していたが、BCEWithLogitsLossを使用することができた。

正解率の計算

Chainerには、正解率の計算をするaccuracy関数とbinary_accuracy関数があるが、PyTorchでは用意されていない。
そのため、以下のような自作の関数を作成した。

def accuracy(y, t):
    return (torch.max(y, 1)[1] == t).sum().item() / len(t)

def binary_accuracy(y, t):
    pred = y >= 0
    truth = t >= 0.5
    return pred.eq(truth).sum().item() / len(t)

これは標準で用意されていてもよいと思うのが。

エポック数とイテレーション数のカウント

Chainerでは、オプティマイザが自動でイテレーション数のカウントして、エポック数をカウントアップするメソッドが用意されている。
PyTorchにはそれに相当する機能がないため、自分で変数をカウントアップし、オプティマイザの状態を保存する際に、それらも同時に保存する必要がある。

torch.save({
    'epoch': epoch,
    't': t,
    'optimizer_state_dict': optimizer.state_dict(),
    }, args.state)

これもPyTorchにも欲しい機能である。

まとめ

ChainerからPyTorchへの移行は比較的容易にできた。
今まで自作でがんばっていたがPyTorchで標準できるようになった箇所もあるが、逆にChainerにあった便利な機能がない箇所もあった。

ChainerからPyTorchへの移行できたので、次はSWAを試す予定である。