TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】自動生成された定跡を強化学習に活用する

将棋AIの大会では、2年ほど前から定跡により勝敗が高い確率で決まるという状況になっている。

相手の準備していない定跡で嵌めたり、相手が準備した定跡に嵌らないようにするという盤外の戦術が必要になっており、AIの技術的な要素とは別の戦いになっている。
この状況に不満を感じているが、大会のルール上、定跡の対策なしには勝てなくなっているため、対策を強いられている。

手動で定跡を作成しているチームが多い中、「手動で作成するのはAIの大会ではない」という考えを持っているので、dlshogiは定跡も自動生成で行っている。

dlshogiの定跡作成の手法の概要は、第34回世界コンピュータ将棋選手権の詳細アピール文書に記載している。

定跡の強化学習への活用

自動定跡作成は、上記のように大会への対策として着手したが、モデルの学習においても活用できることに気付き、強化学習にも自動生成した定跡を活用している。

具体的には、以下の箇所に定跡を活用できる。
1. 定跡に登録されている局面を自己対局の開始局面とする
2. モデルにより推論した方策と価値に、定跡を反映する
3. 定跡により評価が誤っていたことがわかった局面を重点的に学習する
4. 大量の互角局面集を作成して棋力測定に使用する

定跡に登録されている局面を自己対局の開始局面とする

大会では、序盤は定跡を使用するため、モデルを使用するのは定跡を抜けた後の局面からになる。
機械学習するモデルは、ノーフリーランチ定理で主張されるようにあらゆる局面に対応できる万能のモデルは作成できない。
つまり、大会での勝率を高めるには、定跡を抜けた後の局面を重点的に学習することが必要になる。
そのため、定跡に登録されている局面を自己対局の開始局面とすることが効果的である。

モデルにより推論した方策と価値に、定跡を反映する

定跡自動生成により、モデルの評価が誤っていた局面は、その先まで調べて局面により評価が訂正される。
それを、自己対局時に反映することで、モデルの誤りを積極的に訂正することができる。

dlshogiでは、Use_Book_PolicyというUSIオプションがあり、モデルの推論結果の方策と価値に、定跡の出現回数を元にした方策と価値を反映することができる。

事前確率に定跡の遷移確率も使用する(use_book_policyオプション追加) · TadaoYamaoka/DeepLearningShogi@a059c29 · GitHub


これは、棋神アナリティクスでも採用されており、オプションを有効にすることで無効の場合と比較して勝率は57%になる。
第34回世界コンピュータ将棋選手権準優勝! dlshogiアップデート!!|棋神アナリティクス開発

定跡により評価が誤っていたことがわかった局面を重点的に学習する

定跡自動作成時に大会と同じくらいの時間をかけて探索を行っている。
その後、定跡内の局面をαβ法で探索することで、定跡をリファインする。

リファイン前後の差分から、評価を誤っていた局面が抽出できる。

その誤っていた局面を開始局面として、NNUE系を含めた連続対局を行うことで、評価の誤りを訂正するための棋譜を生成できる。

大量の互角局面集を作成して棋力測定に使用する

自動生成した局面から互角局面を抽出することで、floodgateの棋譜から作成した互角局面よりも多くの精度の高い互角局面集を作成できる。
定跡から作成した互角局面集を使うことで、より多様な局面での棋力測定ができる。

また、中終盤の互角局面など条件を指定した互角局面集を作成することもできる。

dlshogi定跡の統計

現状のdlshogiの自動生成した定跡は、約285万局面、1000万手登録されている。

手数の統計は、以下の通り。

count 2848786
mean  50.4030
std   29.8483
min   0
25%   29.0000
50%   45.0000
75%   67.0000
max   404

出現頻度の高い手順ほど深くまで掘られており、最大で404まで登録されている。平均で50手ある。

先手の評価値の統計は、以下の通り。

count 1400665
mean  425.3173
std   1658.8552
min   -30000
25%   -17.0000
50%   247.0000
75%   647.0000
max   30000

後手の評価値の統計は、以下の通り。

count 1448121
mean  -138.5972
std   1111.3596
min   -30000
25%   -388.0000
50%   -89.0000
75%   116.0000
max   30000

先手も後手も手順によっては詰みまで探索されている。
先手は平均で評価値はプラスで、後手はマイナスになっており、AIによると、いかに後手が不利なゲームであるかがわかる。

まとめ

定跡の自動生成がモデルの学習にも活用できることを紹介した。
自己対局で現れる終盤の局面は、理論的に収束する強化学習アルゴリズムであっても、ほとんどの場合終盤局面は二度と現れない局面であり、一度誤って学習された局面が訂正されることは現実的には期待できない。
自動生成した定跡を活用することで評価の誤っていた局面を認識して、それを積極的に訂正することができる。
これにより、通常の強化学習では見過ごされがちな終盤局面の評価ミスを効果的に補正でき、モデルの精度向上に寄与することが期待できる。

定跡自動生成を強化学習のサイクルに組み込んだ模式図である。

【dlshogi】入玉特徴量 その3

前回入玉宣言に関する特徴量を加えたモデルの強さを測定したが、入玉宣言勝ちした棋譜がなく入玉の精度が測定できていなかった。

NNUE系を相手にして連続対局を行い測定しなおした。

強さ

NNUE系を相手に互角局面から連続対局を行った結果は以下の通り。
NNUE系は持ち時間を調整しているので、参考値。

   # PLAYER          :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
   1 nnue-4th        :    29.2   20.9   333.0     600    56      53  313   40  247     7
   2 resnet          :    27.5   40.1    99.5     200    50      92   92   15   93     8
   3 nyugyoku20      :   -22.0   40.7    85.5     200    43      64   79   13  108     6
   4 nyugyoku10      :   -34.6   38.7    82.0     200    41     ---   76   12  112     6

White advantage = 11.90 +/- 13.71
Draw rate (equal opponents) = 6.74 % +/- 1.05

前回と同様に、入玉特徴量なしの通常のResNetが一番強い(信頼度 92%)という結果になった。
入玉特徴量を加えると、詰みによって勝つ将棋には負の影響があることがわかった。

入玉宣言勝ちの精度

入玉宣言勝ちの数をカウントした結果は以下の通り。

モデル 入玉宣言勝ち
入玉特徴量なし 18
入玉特徴量あり(残り点数10点未満) 16
入玉特徴量あり(残り点数20点未満) 24

入玉特徴量あり(残り点数20点未満)の場合が、入玉宣言勝ちの数が最も多かった。

カイ二乗検定で、適合度検定を行うと、帰無仮説=すべてのモデルで入玉宣言勝ちの割合は同じに対して、p値は0.48となり、帰無仮説は棄却されない。
よって、有意差はない。

import numpy as np
import scipy.stats as stats

observed = np.array([18, 16, 24])
total_samples = 200 * 3
expected_prob = sum(observed) / total_samples
expected = np.array([200 * expected_prob] * 3)
chi2_stat, p_value = stats.chisquare(f_obs=observed, f_exp=expected)
print(f"p値: {p_value}")
p値: 0.4079740440452

200回の対局では測定できない程度の違いであり、入玉特徴量は実際の対局では効果はなさそうである。

まとめ

入玉特徴量を加えたモデルで、NNUE系と連続対局を行い強さを比較した。
結果、入玉特徴量がない通常のResNetが一番強いという結果になった。
また、入玉宣言勝ちした数を比較したが、有意差はなかった。
入玉特徴量が効果を発揮するケースは少なく、通常の対局が弱くなるため、今回実験した条件では入玉特徴量を加えることは効果的でないと言える。
ただし、入玉の訓練データを増やした場合など条件を変えれば効果がある可能性はある。

【dlshogi】入玉特徴量 その2

前回は、入玉宣言に関する入力特徴量を追加して、入玉宣言した棋譜に対する評価精度が向上することを確認した。

今回は、入力特徴量の宣言までの残り点数を前回10点未満にしたところを、20点未満にした場合で比較した。
また、入力特徴量を増やしたことでNPSがどれくらい低下するか、互角局面から対局した場合の強さについて比較した。

残り点数

前回の入力特徴量は以下の通り。

入玉特徴量:

  • 入玉しているか
  • 敵陣の玉を除く枚数(10枚までの残り枚数。9枚以下をワンホットで与える。)
  • 残り点数(先手は28点、後手は27点までの残り点数。9点以下をワンホットで与える。)

残り点数を、9点以下をワンホットで与えていたが、今回は、19点以下を与えるように変更した。

精度

残り点数を増やした場合の精度は以下の通り。

評価データ:floodgateからサンプリングしたデータ
入玉特徴量有無 方策損失 価値損失 方策正解率 価値正解率
入玉特徴量なし 1.4413 0.4682 0.5224 0.7580
残り点数10点未満 1.4422 0.4677 0.5222 0.7586
残り点数20点未満 1.4435 0.4678 0.5220 0.7587
評価データ:入玉宣言勝ちした棋譜
入玉特徴量有無 方策損失 価値損失 方策正解率 価値正解率
入玉特徴量なし 2.3377 0.1030 0.3884 0.9539
残り点数10点未満 2.3371 0.0883 0.3874 0.9622
残り点数20点未満 2.3314 0.0820 0.3874 0.9654

floodgateからサンプリングした評価データでは、方策、価値ともにほぼ違いがない。

入玉宣言勝ちした棋譜では、価値の損失・正解率がわずかに改善している。

NPS

floodgateの棋譜からサンプリングした100局面でA100で1秒探索した際のNPS(4回測定)は以下の通り。

入玉特徴量なし 残り点数10点未満 残り点数20点未満
平均値 45485 45058 45003
中央値 47950 47449 47468
最大値 49150 48682 48676
最小値 25530 24493 25058

平均値は、入力特徴量なしのResNetと比較して、残り点数10点未満、残り点数20点未満はそれぞれ、99.06%、98.94%に低下している。
低下の割合は小さい。

ワンホットの特徴量は、値が0の特徴量が多くなるため、TensorRTがうまく計算を省略してくれることで影響が小さくなっていると考える。
また、ワンホットの特徴量は、1ビットにしてGPUに転送しているので、特徴量が増えてもほとんど影響を受けない。

強さ

入力特徴量なし(resnet)、残り点数10点未満(nyugyoku10)、残り点数20点未満(nyugyoku20)で、互角局面からリーグ戦で持ち時間3分1手2秒加算で連続対局した結果は以下の通り。

   # PLAYER        :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
   1 resnet        :    20.4   18.1   326.0     600    54      87  297   58  245    10
   2 nyugyoku20    :     2.7   17.8   303.5     600    51      95  279   49  272     8
   3 nyugyoku10    :   -23.1   18.2   270.5     600    45     ---  244   53  303     9

White advantage = 20.43 +/- 10.98
Draw rate (equal opponents) = 8.94 % +/- 0.98

入力特徴量なしが最も強い(信頼度 87%)という結果になった。
残り点数は、20点未満の方が強いという結果になった。

各モデルがどれくらい入玉宣言勝ちしているか確認したところ、入玉宣言勝ちは1局もなかった。
今回の測定では入玉宣言の精度は測定できていなかった。
対局相手にNNUE系も含めて測定してみたい。

まとめ

入玉宣言に関する入力特徴量のうち宣言までの残り点数を、前回は10点未満のところ、今回は20点未満まで含めるようにして比較した。
結果、入玉宣言勝ちした棋譜に対する価値の精度がわずかに向上することが確認できた。
また、NPSは、入玉宣言に関する入力特徴量を増やしても、98.94%くらいにしか低下しないことがわかった。
強さの比較では、入玉宣言に関する入力特徴量がない方が強いという結果になった。
ただし、入玉宣言勝ちした棋譜が1つもなかったため、入玉の強さは測定はできていない。条件を変えて測定してみたい。

【dlshogi】入玉特徴量

現在のdlshogiの入力特徴量には、入玉宣言に関連する特徴量を含んでいない。
入玉宣言の精度を上げるため、入玉宣言に関する特徴量を加えることを検討する。

現在の入力特徴量

現在の入力特徴量:

  • 盤上の駒
  • 駒の種類ごとの効き
  • 効き数
  • 持ち駒
  • 王手

持ち駒は、歩については、8枚以上は区別しない。
また、先手、後手の区別はなく、後手は盤面を180度回転して入力する。

課題

この入力特徴量では、入玉宣言勝ちの条件を正しく認識できない問題がある。

入玉宣言勝ちは、正確な持ち駒の枚数が必要なため、8枚以上の歩も認識する必要がある。
また、先手、後手で宣言に必要な点数が異なるため、先手、後手を区別する必要がある。

また、機械学習モデルでは、点数の合計など演算が必要な特徴量は、演算結果を特徴量として入力する方が精度が上がる場合がある。

入玉特徴量

先手、後手の区別について、入玉宣言に関わらない局面では、先手、後手の区別は不要であるため、単に先手、後手を特徴量とするよりも、入玉宣言に必要な残り点数を特徴量とすることで、間接的に先手、後手の区別することにする。
こうすることで、局面を先手、後手で等価に扱える。

点数すべてを特徴量にすると、計算コストが増えるため、残り点数10点未満を特徴量とする。

入玉特徴量:

  • 入玉しているか
  • 敵陣の玉を除く枚数(10枚までの残り枚数。9枚以下をワンホットで与える。)
  • 残り点数(先手は28点、後手は27点までの残り点数。9点以下をワンホットで与える。)

実験

入玉特徴量の有無で、精度が変わるか比較した。

訓練条件:

  • ResNet 20ブロック256フィルタのモデル
  • 訓練データ約3.9億局面
  • バッチサイズ4096
  • Momentum SGD
  • 学習率0.04からエポックごとに半減
  • 8エポック

評価データは、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面の他に、NNUE系の10M~80Mノードで対局した棋譜から入玉宣言勝ちした棋譜を抽出した20,544,764局面を用いた。
精度はばらつくため4回測定して平均する。

実験結果

評価データ:floodgateからサンプリングしたデータ
入玉特徴量有無 方策損失 価値損失 方策正解率 価値正解率
入玉特徴量なし 1.4413 0.4682 0.5224 0.7580
入玉特徴量あり 1.4422 0.4677 0.5222 0.7586
評価データ:入玉宣言勝ちした棋譜
入玉特徴量有無 方策損失 価値損失 方策正解率 価値正解率
入玉特徴量なし 2.3377 0.1030 0.3884 0.9539
入玉特徴量あり 2.3371 0.0883 0.3874 0.9622

考察

floodgateからサンプリングした評価データでは、方策の精度はほぼ変わらず、価値の損失もほぼ変わらない。

入玉宣言勝ちした棋譜から抽出した評価データでは、方策の精度はほぼ変わらず、価値の損失が入力特徴量がある場合少しだけ下がっている。

入玉特徴量を加えた方が、入玉宣言勝ちの局面評価の精度は少しだけ上がると言えそうである。

まとめ

dlshogiに入玉特徴量を加えて入玉宣言の精度が上がるか検証した。
結果、入玉宣言勝ちした棋譜の局面の評価精度が少し向上することが確認できた。

入玉特徴量を加えることで、モデルの入玉宣言に関する認知負荷が下がることで、全体の精度が上がることを期待していたが、そのようなことはなかった。
今回は残り10点未満を特徴量にしたが、20点未満などの条件でも検証してみたい。

第5回世界将棋AI電竜戦 結果報告

11/30、12/1に開催された文部科学大臣杯第5回世界将棋AI電竜戦にdlshogiとして参加しました。

結果

予選を1位で通過し、決勝リーグでは、1位と0.8勝差で準優勝という結果になりました。

dlshogiは、第2回から4年連続準優勝です。

今大会の感想

優勝した氷彗がNNUE系では飛躍的に棋力を伸ばしていたのが印象的です。
ここ数年NNUE系の棋力は頭打ち感があり、どのチームも定跡に力を入れて勝っていた印象がありましたが、氷彗は定跡だけでなく探索精度も向上していました。

dlshogiも昨年までは定跡に力を入れていましたが、今大会ではモデル精度を高めました。
精度を高めるためにモデルサイズを大きくしたことで、NPSが落ちることで終盤読み負けを心配していましが、結果としては終盤ミスがあったのは1局のみで、負けたのは主に後手で定跡負けした対局で、終盤も比較的安定していました。


以下は、今大会の工夫点についてです。

2.2億パラメータのモデル

今大会では、Resnet+Transformerの2.2億パラメータのモデルを学習しました。
モデルの詳細は、以前の記事を参照してください。

前大会では、Resnetの30ブロック384フィルタのモデルで、パラメータ数は約0.8億です。
約2.7倍のパラメータ数になっています。

このサイズのモデルでは、80GのGPUメモリでもバッチサイズ4096では学習できず、8GPUを使用してバッチサイズ512で分散学習を行いました。
分散学習できるように、dlshogiのPyTorch Lightning対応を行いました。
8GPU使用して、収束するまで学習するのに約1か月かかっています。

モデル精度

今回のモデル(pre55)と前大会のモデル(pre44)の精度を比較した結果を示します。テストデータには、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面(重複なし)を使用しています。

モデル 方策損失 価値損失 方策正解率 価値正解率
pre55 1.255327 0.433520 0.571877 0.781784
pre44 1.305959 0.440115 0.558544 0.776513

方策で1.3%、価値で0.5%正解率が向上しています。

強さ

方策のみで強さを比較した結果では、前大会のモデルからR+77.5だけ強くなっています。
同一持ち時間の対局では、R+28.6だけ強くなっています。

NPS

NPSは、前大会のモデルと比較して45.6%に低下しています。

定跡

第34回世界コンピュータ将棋選手権から続けて定跡の自動生成を行いました。

戦型

定跡の自動生成は、開始局面から自動で行っていますが、大会中に相手によって戦型を切り替えられるように、特定戦型からも作成しています。
しかし、先手番は角換わりの勝率が高いため、今大会ではすべて角換わりとしました。

dlshogiは、終盤でNNUE系に読み負けることがあるため、終盤で読みが重要になる角換わりは少しリスクがあると思っていましたが、角換わりは手が狭いため先手はかなり深くまで準備できていました。

決勝リーグのRyfamateとの対局では、先手角換わりで、137手(時間調整の4手除く)まで定跡で、定跡を抜けた時点で評価値4824でした。

懸念した終盤で読み負けしたのは、決勝リーグのDaigorillaとの対局でした。終盤の精度はまだ改善が必要だと思っています。

後手の対策

後手番は、先手が角換わりを選択した場合、角換わりを避ける有効な手は見つかっていないため、角換わりを受けることにしました。
第34回世界コンピュータ将棋選手権では、勝率が少し下がっても角換わりを避けるため、6手目に1四歩を指す作戦を取りましたが、結果は負けと千日手が多かったため、成功とは言えませんでした。
今大会では、後手番の角換わりも広く定跡を準備したため、相手の定跡を外して互角に持ち込める自信はありました。
しかし、決勝リーグで2局は定跡負けしました。
決勝リーグのポン太との対局では、dlshogiは60手で定跡を抜けた後もポン太は75手目まで定跡で、定跡を抜けた後、先手勝勢となっていました。

後で調べると、67手まではfloodgateに前例があり、先手が優勢であることは事前に調べておけばわかっていました。
自動生成にこだわっていましたが、floodgateの棋譜も事前に調査しておく必要がありそうです。

まとめ

第5回世界将棋AI電竜戦に参加しました。
今大会では優勝した氷彗の躍進が印象的でした。
dlshogiもモデルのパラメータ数を増やすことで着実に強くなっており、今大会でも準優勝という成績を残すことができました。

最後に、大会を運営してくださった主催者様、対局してくださった他のチーム、そして応援してくださった皆様にお礼申し上げます。

Gumbel AlphaZeroの論文を読む その4

前回の続き

探索の内部処理

探索の処理は、searchに書かれている。

引数
  • params: ルートおよび再帰関数に渡されるパラメータ。
  • rng_key: 乱数生成器の状態。
  • root: ルートノードの初期状態で、事前確率、価値、埋め込みを含む。
  • recurrent_fn: 葉ノードおよび未訪問アクションに対して呼び出される関数。
  • root_action_selection_fn: ルートでアクションを選択するための関数。
  • interior_action_selection_fn: シミュレーション中にアクションを選択するための関数。
  • num_simulations: 実行するシミュレーションの数。
  • max_depth: 探索木の最大深度。
  • invalid_actions: ルートでの無効なアクションのマスク。
  • extra_data: ツリーに渡される追加データ。
  • loop_fn: シミュレーションを実行するための関数。
行動選択関数の切り替え

ルートノードと中間ノードで行動選択関数を切り替える。
行動選択関数については、別途解説。

  action_selection_fn = action_selection.switching_action_selection_wrapper(
      root_action_selection_fn=root_action_selection_fn,
      interior_action_selection_fn=interior_action_selection_fn
  )
def switching_action_selection_wrapper(
    root_action_selection_fn: base.RootActionSelectionFn,
    interior_action_selection_fn: base.InteriorActionSelectionFn
) -> base.InteriorActionSelectionFn:
  """Wraps root and interior action selection fns in a conditional statement."""

  def switching_action_selection_fn(
      rng_key: chex.PRNGKey,
      tree: tree_lib.Tree,
      node_index: base.NodeIndices,
      depth: base.Depth) -> chex.Array:
    return jax.lax.cond(
        depth == 0,
        lambda x: root_action_selection_fn(*x[:3]),
        lambda x: interior_action_selection_fn(*x),
        (rng_key, tree, node_index, depth))

  return switching_action_selection_fn
バッチサイズとバッチインデックス範囲の設定

バッチサイズとバッチインデックス範囲を設定する。
バッチサイズはルートの価値の1番目の次元から取得する。
max_depthとinvalid_actionsが提供されていない場合、デフォルト値を設定する。

  # Do simulation, expansion, and backward steps.
  batch_size = root.value.shape[0]
  batch_range = jnp.arange(batch_size)
  if max_depth is None:
    max_depth = num_simulations
  if invalid_actions is None:
    invalid_actions = jnp.zeros_like(root.prior_logits)
ループの本体関数

シミュレーションを実行するループの本体を定義する。
以下のステップを実行する。

  • RNGキーを分割。
  • シミュレーションを実行し、行動を選択する。
  • ノードが未訪問の場合、新しいノードを追加して木を拡張する。
  • バックアップを行い木を更新する。

展開とバックアップは別途解説。

def body_fun(sim, loop_state):
  rng_key, tree = loop_state
  rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
  simulate_keys = jax.random.split(simulate_key, batch_size)
  parent_index, action = simulate(simulate_keys, tree, action_selection_fn, max_depth)
  next_node_index = tree.children_index[batch_range, parent_index, action]
  next_node_index = jnp.where(next_node_index == Tree.UNVISITED, sim + 1, next_node_index)
  tree = expand(params, expand_key, tree, recurrent_fn, parent_index, action, next_node_index)
  tree = backward(tree, next_node_index)
  loop_state = rng_key, tree
  return loop_state
木の初期化

ルートノードから木を初期化する。

  # Allocate all necessary storage.
  tree = instantiate_tree_from_root(root, num_simulations,
                                    root_invalid_actions=invalid_actions,
                                    extra_data=extra_data)
def instantiate_tree_from_root(
    root: base.RootFnOutput,
    num_simulations: int,
    root_invalid_actions: chex.Array,
    extra_data: Any) -> Tree:
  """Initializes tree state at search root."""
  chex.assert_rank(root.prior_logits, 2)
  batch_size, num_actions = root.prior_logits.shape
  chex.assert_shape(root.value, [batch_size])
  num_nodes = num_simulations + 1
  data_dtype = root.value.dtype
  batch_node = (batch_size, num_nodes)
  batch_node_action = (batch_size, num_nodes, num_actions)

  def _zeros(x):
    return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype)

  # Create a new empty tree state and fill its root.
  tree = Tree(
      node_visits=jnp.zeros(batch_node, dtype=jnp.int32),
      raw_values=jnp.zeros(batch_node, dtype=data_dtype),
      node_values=jnp.zeros(batch_node, dtype=data_dtype),
      parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32),
      action_from_parent=jnp.full(
          batch_node, Tree.NO_PARENT, dtype=jnp.int32),
      children_index=jnp.full(
          batch_node_action, Tree.UNVISITED, dtype=jnp.int32),
      children_prior_logits=jnp.zeros(
          batch_node_action, dtype=root.prior_logits.dtype),
      children_values=jnp.zeros(batch_node_action, dtype=data_dtype),
      children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
      children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
      children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
      embeddings=jax.tree.map(_zeros, root.embedding),
      root_invalid_actions=root_invalid_actions,
      extra_data=extra_data)

  root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
  tree = update_tree_node(
      tree, root_index, root.prior_logits, root.value, root.embedding)
  return tree
シミュレーションの実行

ループ関数を使用してnum_simulations回シミュレーションを実行する。
ループ関数はデフォルトで、jax.lax.fori_loopを使用する。
結果の木を返す。

  # Allocate all necessary storage.
  tree = instantiate_tree_from_root(root, num_simulations,
                                    root_invalid_actions=invalid_actions,
                                    extra_data=extra_data)
  _, tree = loop_fn(
      0, num_simulations, body_fun, (rng_key, tree))

  return tree

まとめ

探索の内部処理を解説した。
次回は、探索の内部処理で呼ばれている各関数の詳細を解説する。

Gumbel AlphaZeroの論文を読む その3

前回に続き、examples/visualization_demo.py のソースを解説する。

探索

探索の処理は、gumbel_muzero_policyに書かれている。

引数は、以下の通り。

  • params: ルートおよび再帰関数に渡されるパラメータ。
  • rng_key: 乱数生成器の状態。
  • root: (prior_logits, value, embedding)の形式のRootFnOutput。prior_logitsは方策ネットワークからのもので、形状はそれぞれ([B, num_actions], [B], [B, ...])。
  • recurrent_fn: シミュレーションステップで取得された葉ノードおよび未訪問アクションに対して呼び出される関数。引数として(params, rng_key, action, embedding)を取り、RecurrentFnOutputと新しい状態埋め込みを返す。
  • num_simulations: シミュレーションの数。
  • invalid_actions: 無効なアクションのマスク。無効な行動は1、有効な行動は0のマスク。形状は[B, num_actions]。
  • max_depth: シミュレーション中に許可される最大探索木の深さ。
  • loop_fn: シミュレーションを実行するために使用される関数。Haikuモジュール内でこの関数を使用する場合、hk.fori_loopを渡す必要があるかもしれない。
  • qtransform: ノードの完成したQ値を取得するための関数。
  • max_num_considered_actions: ルートノードで展開される最大行動数。有効な行動の数が少ない場合は、より少ない行動が展開される。
  • gumbel_scale: Gumbelノイズのスケール。完全情報ゲームの評価ではgumbel_scale=0.0を使用できる。


処理の流れは以下の通り。

無効な行動をマスクする

無効な行動の事前確率のlogitをfloatの最小値にする。

  # Masking invalid actions.
  root = root.replace(
      prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))
def _mask_invalid_actions(logits, invalid_actions):
  """Returns logits with zero mass to invalid actions."""
  if invalid_actions is None:
    return logits
  chex.assert_equal_shape([logits, invalid_actions])
  logits = logits - jnp.max(logits, axis=-1, keepdims=True)
  # At the end of an episode, all actions can be invalid. A softmax would then
  # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
  # a finite `min_logit` for the invalid actions.
  min_logit = jnp.finfo(logits.dtype).min
  return jnp.where(invalid_actions, min_logit, logits)
Gumbelノイズの生成

Jaxでは同じキーからは同じ乱数が生成されるため、jax.random.splitで新しいキーを生成する。
jax.random.gumbelを使用して、Gumbel分布に従う乱数を生成する。
次元は、事前確率のlogitsの次元とする。

  # Generating Gumbel.
  rng_key, gumbel_rng = jax.random.split(rng_key)
  gumbel = gumbel_scale * jax.random.gumbel(
      gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)
探索の実行

探索処理を呼び出す。
探索処理の内容は別途解説。

  # Searching.
  extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
  search_tree = search.search(
      params=params,
      rng_key=rng_key,
      root=root,
      recurrent_fn=recurrent_fn,
      root_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_root_action_selection,
          num_simulations=num_simulations,
          max_num_considered_actions=max_num_considered_actions,
          qtransform=qtransform,
      ),
      interior_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_interior_action_selection,
          qtransform=qtransform,
      ),
      num_simulations=num_simulations,
      max_depth=max_depth,
      invalid_actions=invalid_actions,
      extra_data=extra_data,
      loop_fn=loop_fn)
  summary = search_tree.summary()
最適な行動の選択

最も訪問された行動数を計算(considered_visit)する。
qtransformで、未訪問の行動のQ値を補完する。
seq_halving.score_consideredで、行動ごとのスコアを計算する。
スコアが最大の行動を選択する。

  # Acting with the best action from the most visited actions.
  # The "best" action has the highest `gumbel + logits + q`.
  # Inside the minibatch, the considered_visit can be different on states with
  # a smaller number of valid actions.
  considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
  # The completed_qvalues include imputed values for unvisited actions.
  completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(  # pytype: disable=wrong-arg-types  # numpy-scalars  # pylint: disable=line-too-long
      search_tree, search_tree.ROOT_INDEX)
  to_argmax = seq_halving.score_considered(
      considered_visit, gumbel, root.prior_logits, completed_qvalues,
      summary.visit_counts)
  action = action_selection.masked_argmax(to_argmax, invalid_actions)

スコアの計算処理は、以下の通り。
訪問数が最大の行動にペナルティを課して選択されないようにする。
gumbel + logits(=root.prior_logits) + normalized_qvalues(=completed_qvalues)を計算する。
gumbelは乱数なので、確率的に行動を選択することになる。

def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
                     visit_counts):
  """Returns a score usable for an argmax."""
  # We allow to visit a child, if it is the only considered child.
  low_logit = -1e9
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  penalty = jnp.where(
      visit_counts == considered_visit,
      0, -jnp.inf)
  chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
  return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty
行動の重みの生成

方策の学習に使用する行動の重みを計算する。
root.prior_logits + completed_qvaluesに、softmaxを適用して新しい方策を求める。

# Producing action_weights usable to train the policy network.
  completed_search_logits = _mask_invalid_actions(
      root.prior_logits + completed_qvalues, invalid_actions)
  action_weights = jax.nn.softmax(completed_search_logits)

まとめ

探索処理(gumbel_muzero_policy)の流れを解説した。
次回は、探索の内部処理を解説する。