dlshogiにGumbel AlphaZeroの強化学習を導入したいと思って、Gumbel AlphaZeroの論文を読んでいるが、理論がなかなか難しいため公式実装を確認しながら理解する。
Gumbel AlphaZero
AlphaZeroの強化学習は、自己対局でUCTで探索した際の訪問回数を目標の確率分布として方策の改善を行う。
しかし、シミュレーション回数が少ない場合、方策を改善できる保証がないという問題がある。
Gumbel AlphaZeroは、より少ないシミュレーション回数で、方策が理論的に改善されることを保証する。
Gumbel AlphaZeroは、ルートノードでPUCBの代わりに、Gumbel-Top-kトリックを使用して確率的にサンプリングを行う。
探索中の中間ノードでは、 Sequential Halving アルゴリズムで、決定論的に行動を選択する。
公式実装
公式実装は、JaxのJITを使ってPythonでも高速に並列探索できるように実装されている。
そのため処理の流れの順になっておらず、コードの理解が難しい。
JITを行うと、デバッガで処理が追えなくなるので、JITを無効にして、ステップしながら理解できるようにした。
環境構築
WSL2のUbuntu 22.04上に、Minicondaをインストールして、conda仮想環境に構築した。
sudo apt-get install graphviz graphviz-dev pip install "jax[cuda12]" chex pygraphviz
JIT無効化
環境変数 JAX_DISABLE_JIT = true を設定する。
VS Codeのlauch.jsonのenvに設定するとよい。
"configurations": [ { "name": "visualization_demo", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/examples/visualization_demo.py", "console": "integratedTerminal", "env": { "JAX_DISABLE_JIT": "true" } } ]
まとめ
Gumbel AlphaZeroの論文を理解するために、公式実装のデバッグ環境を構築した。
次回からステップ実行しながらコードを理解する。