TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

PyTorchでPPOを実装する

麻雀AIの準備として、PyTorchでPPOアルゴリズムをスクラッチで実装した。

はじめ、最近リリースされたTorchRLで実装しようと思って試していたが、連続環境でのチュートリアルはあるが、いろいろ試したが離散環境に対応することができず断念した。

Stable Baselines3のようなPPOが実装された強化学習ライブラリがあるが、OpenAI Gymのインターフェースが前提になっており、麻雀AIは多人数でプレイするため、マルチエージェントで学習するには扱いにくい。
そこで、スクラッチで実装することにした。

PPO

Proximal Policy Optimization(PPO)は、オンポリシーのアルゴリズムで、現在の方策でエピソードを収集した後、以下の目標を最大化するように方策を学習する。方策勾配法よりも学習が安定し、エピソードを複数エポック学習することができサンプル効率がよい。

アドバンテージには、Generalized Advantage Estimation (GAE)を使用することが一般的である。

ハイパーパラメータ\gamma\lambdaにより、バイアスと分散のトレードオフを調整できる。
GAEの論文では、強化学習では、分散よりバイアスが有害であることが指摘されている。
分散は、データ量を増やすことで抑えることができる。

実装

アルゴリズムの実装にバグがないことを確認できるように、AtariのBreakoutをPPOで学習するコードを実装した。
実装は、Stable Baselines3の実装を参考にした。
GAEの計算がシンプルに実装されており参考になった。

結果

学習はバッチサイズ、学習率に敏感で、何度か失敗することがあった。
バッチサイズは64のときにうまく学習できた。

160kステップ学習した際の平均エピソード長、平均報酬、各訓練損失は以下の通り。

160kステップで1時間20分くらいかかった。

まとめ

麻雀AIの準備として、PyTorchでPPOアルゴリズムをスクラッチで実装した。
実行したコードでBreakoutの学習ができることが確認できた。
これでOpenAIGymインターフェースに縛られず、マルチエージェントで麻雀AIを学習する準備ができた。