TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

PythonでAlphaZero Shogiを実装する その4

AlphaZeroでは、訓練と自己対局は並列で行われ、チェックポイントで自己対局で使用するネットワークが最新のネットワークに更新される。
チェックポイントは、ミニバッチサイズ4,096で、1,000ステップ間隔だが、チェックポイントの間に何ゲーム行われるかを論文から推測してみた。

チェックポイントあたりのゲーム数

  • 70万ステップを12時間で完了
  • チェックポイントは、ミニバッチサイズ4,096で、1,000ステップ間隔
  • 自己対局は、5,000アクター(※疑似コードから)が並列で対局

以上が、AlphaZeroの論文に記載されている条件である。

これだけでは、推測できないので、自己対局で増える局面数と学習に使える新規の局面数の増加速度が同じで、学習に使う1ゲームあたり平均150手という仮定をおくと、1,000ステップを5,000並列で実行し、1ステップが4096手分なので、チェックポイントあたりのゲーム数は、
1000ステップ / 5000 * 4096手 / 150手 = 5.46 ゲーム
と算出できる。

なお、論文では、AlphaZero Shogiでは、70万ステップを12時間で完了しているため、1手あたりの時間は、
12時間/70万ステップ/4096*5000 = 0.075秒
と計算できる。
1回の対局が平均150手と仮定すると、1ゲームあたりは、11.3秒となる。

実装への反映

dlshogi-zeroは、訓練と自己対局を別々のプログラムで作成しているので、訓練と自己対局のバランスを、AlphaZero Shogiと同じにしようとすると、頻繁にプログラムを切り替えることになる。
プログラムの開始時のモデルのコンパイルなどを毎回行うと効率が悪いので、効率の良い仕組みを考えたい。

GPU1枚でも実行できるようにしようとすると、GPUは排他的に利用されるため、同時に実行はできないので、交互に実行することになる。
そうすると、1つのプログラムで、モデルを共有して訓練と自己対局を行うようにした方がよさそうだ。

1つのプログラムで訓練と自己対局を交互に行うとすると、以下のような流れになる。

  1. チェックポイントを一定の局面数として、並列で実行しているエージェントのどれかのゲームが終了するたびに、チェックポイントに達しているかチェックする。
  2. チェックポイントに達していたら、すべてのエージェントの実行を止めて、訓練を行う。
  3. 訓練が完了したらモデルが更新される。
  4. 最新のモデルを使って自己対局を再開する。


できれば、この実装に変更して、技術書典6の本に反映したい(間にあうだろうか・・・)。

追記(別の方法で検証)

自己対局で増える局面数と学習に使える新規の局面数の増加速度が同じと仮定したが、実際は自己対局で増える局面数の方が多いかもしれない。
ここからは誤差の大きい推測になるが、初代AlphaGoでは方策ネットワークの推論に3msかかっているため、ネットワークの規模の比率から1手あたりの時間を推測して、算出してみる。
初代AlphaGoの方策ネットワークは13層、192フィルタ、画像サイズ19×19で、
AlphaZero Shogiの方策ネットワークは83層(Resnet40ブロック+入力1+出力2)、256フィルタ、画像サイズ9×9で、
推論の計算量は、層数L、フィルタ数K、画像の幅Wとすると、計算量は、O(L \times K^2 \times W^2)(あっているか自信なし)となるので、
計算量は、1.29倍となる。
したがって、1手800シミュレーションなので、1手あたり、6.11秒かかる。
1,000ステップあたりのゲーム数は、
12時間/70万ステップ*1000/6.11/150*5000=336.56ゲーム
となる。

自己対局で増える局面数と学習に使える新規の局面数の増加速度が同じという仮定は外れていそうだ。

こちらの推測を元にすると、別々のプログラムでもよいかもしれない。

2019/3/29追記

コメントで、AlphaZero論文のTable S3に訓練ゲーム数が記載されているとご指摘いただいたので、その値で計算し直してみました。
Table S3を完全に見落としていました。

Table S3には、

Mini-batches 700k
Traing Time 12h
Traning Games 24 million

と記載されています。

したがって、チェックポイントごとのゲーム数は、
24,000,000ゲーム*1000ステップ/70万ステップ=34285.71ゲーム
となります。
予測よりもずっと大きかったです。

あまりにも予測と差が大きいので、予測の方の何かの前提が間違っていそうです。
Table S3にThinking Timeが800sims~80msと記載されているので、1手80msのようです。

1つのアクターが並列化なしで800シミュレーションを80msで実行するのはおそらく不可能なので、一つ一つのアクターについてもなんらかの並列化を行っていそうです。