TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【論文】Bridging the Human–AI Knowledge Gap: Concept Discovery and Transfer in AlphaZeroを読む

DeepMindarXivで発表した、AlphaZeroからチェスの新しい概念を抽出して人間のパフォーマンス向上に使えるようにする方法について述べた「Bridging the Human-AI Knowledge Gap: Concept Discovery and Transfer in AlphaZero」を読んだ際のメモ。

概要

  • AlphaZeroから隠されたチェスの知識を概念として抽出する
  • 概念は、AlphaZeroの中間層から概念ベクトルとして抽出する
  • 人間が知らずAlphaZeroだけが知っていること発見する
  • 学習可能かつ新規な概念を抽出する
  • 人間のグランドマスターが概念を学習可能か検証した

結果

  • 概念と局面を提示することで、チェスのグランドマスターのテスト局面の正解率が向上した
  • AlphaZeroが学習した知識は、人間の理解を超えるものではなく学習可能な知識であることが示された

詳細

導入

  • 知識には、AIと人間の両方が知っていること(M∩H)、人間だけが知っていること(H-M)、AIだけが知っていること(M-H)がある
  • 既存研究は主にM∩Hに焦点をあてているが、M-Hに焦点をあてる
  • 正解が検証しやすいチェスを題材とする
  • 4人のグランドマスターに概念を教えることができることを示す
  • AlphaZeroの潜在表現を分析することでM-Hが存在することの証拠を発見する
  • AlphaZeroの方策ネットワークとMCTSツリーを使用する
  • 新規性と教育可能の観点でフィルタリングする


概念の定義

  • 概念を知識の単位として定義
  • 概念の2つの特性
    • 知識(役立つ情報)が含まれる
    • 最小性(無関係な情報が含まれない)
  • 別のエージェントに転送可能であれば、自己完結型で、タスクに役立つことを意味する
  • 概念はニューラル ネットワークの潜在空間に線形にエンコードされると仮定する
  • 計画を生み出す概念を発見することを目指す

概念の発見方法

  • 凸最適化を使用して AZ の概念を表すベクトルを発掘する
  • 教えやすさと新規性に基づいて概念をフィルタリングする
概念ベクトルの発掘
  • 概念発見を凸最適化問題として定式化
  • 最小性は、L1ノルムを通じてスパース性を促進することによって達成する


ここで、v_{c,l} \in \mathbb{R}^{d_l} は概念 c を表すために層 𝓁 の潜在空間に存在するベクトル、d_l は層 𝓁 の次元

  • 静的概念と動的概念で概念制約が異なる
  • 静的概念は単一の状態で見られるように定義される
  • 動的概念は一連の状態で見出されるように定義される
静的概念の概念制約
  • 静的概念は、単一の状態のみを含む概念として定義
  • 局面にラベル(概念cを含むか含まないか)付けされた教師ありデータを使用
  • 教師ありデータの概念は人間の知識をコード化しているので検証できる
  • 正例と負例を使用して潜在表現(ニューラルネットワーク内の活性化の後の表現)を生成
  • 概念を含む方が、潜在表現と概念ベクトルの内積が大きいという仮説を立てる
  • 以下の通り凸最適化問題を定式化


ここで、\mathbb{Z}^+_lは正例の層𝓁の潜在表現、\mathbb{Z}^-_lは正例の層𝓁の潜在表現

  • ホールド分割したテストセットで評価できる
動的概念の概念制約
  • 動的概念は、一連の状態に見られる概念として定義
  • モンテカルロ木探索 (MCTS) 統計を使用して、意味のある状態シーケンスの候補を見つける
  • 最適なロールアウトと途中の標準以下のロールアウトを対比する
  • MCTSの価値推定値または訪問数に従って最適でないパスと定義する
  • 概念cにより、最適なパスが選択されたと仮定
  • 3つの異なる説明
    • シナリオ1 積極的な計画: ロールアウトにより概念が増加させる
    • シナリオ2 予防計画: ロールアウトにより概念の存在が回避される
    • シナリオ3 両方のロールアウトに存在する概念
  • シナリオ1と2に関心がある
  • シナリオ1は、以下の通り凸最適化問題を定式化


ここで、Tはロールアウトの深さ

  • シナリオ2では不等式が反対になる

  • 最適なパスを、さまざまな MCTS 深度にわたる複数の標準以下のパスと対比
  • 複数の標準以下のパスを使用することで多意味ベクトルを学習する可能性を減らし、解空間をさらに縮小する


ここで、T_jは最適ではないロールアウトが見つかる最大の深さ

概念のフィルタリング

  • 抽出できた概念には、多くは既知の概念または一般化できない概念を含む
  • 有用(転移可能)で新規なものにフィルタリングする
教えやすさ
  • 概念を別のエージェント教えてタスクを解決できるか
  • アイディア
    • フェーズ1: ベースラインとなる概念を知らないエージェント(方策のトップ1の重複が0.2 未満)を見つける
    • フェーズ2: 概念プロトタイプ(概念を例示するサンプル)を使用してエージェントに概念を教える
    • フェーズ3: 概念に関連するタスクでパフォーマンスを評価する
  • 概念が教育可能であればフェーズ3でパフォーマンス向上が期待される
  • 人間とのアプローチを評価するときも、同様のプロセスを使用して評価
  • フェーズ2では、AlphaZeroを教師とする
  • プロトタイプの選択
    • 静的概念の場合、概念スコア(概念ベクトルと潜在表現の内積)の上位2.5%
    • 動的概念の場合、MCTSを実行して標準以下のロールアウトを見つける
  • 学習は、教師(AlphaZero)の方策と生徒の方策のKLダイバージェンスを最小化
  • 評価は、トップ1の手を選択する頻度で評価
  • AlphaZeroの対局からランダムにサンプリングして学習場合と比較する

  • 概念プロトタイプで学習すると(濃い線)、ランダムなプロトタイプで学習する(薄い線)より大幅にパフォーマンスが向上する
新規性
  • アプローチ
    • AlphaZeroの訓練の後期段階で学習される概念を見つける
    • 新規性の指標に基づいて概念をフィルタリングする
  • 概念は、AlphaZeroの訓練の後期段階で学習される
  • 最新と、75 Elo ポイントが異なるバージョンの AlphaZero を使用
  • 意見が一致しないチェスの局面を選択
  • 検証は、AlphaZeroのゲームで発生し、人間のゲームでは発生しないかどうかで判断
  • 概念をAlphaZeroのゲーム ベクトル空間と人間のゲーム ベクトル空間に回帰することによって新規性スコアを測定

  • ランクは、層ごとに異なる
  • 層19(ボトルネックの最終層)と層23(方策の層)で、AphaZeroゲームのスコアが人間ゲームのランクより高い
  • これらの層で新しい概念を見つける
  • AlphaZeroゲームで発生する基底ベクトルのセットを使用して概念ベクトルをどの程度うまく再構築できるか(特異値分解を使う)に基づいて新規性スコアを定義
  • 教育可能性ベースのフィルタリング後に残った概念のうち、新規性の指標を使用してさらに27.1%を削除

人間による評価

  • グランドマスターによる人間の評価は、教えやすさの測定方法と同様に3 つの段階に従う
    • フェーズ1: 概念に対応する一連のパズルで、ベースラインのパフォーマンスを測定
    • フェーズ2: フェーズ1と同じパズルを、AlphaZeroの推奨手と合わせて提示
    • フェーズ3: 概念に対応する未見のパズルで、最終パフォーマンスを測定
  • 4人のグランドマスターが評価に協力
  • 3~4つの概念に対する 4 つのパズル (概念ごとに)を提示
  • 一部重複があったが、異なる参加者には異なる概念が示された

グランドマスターのパフォーマンス

  • すべての参加者がフェーズ1 と 3 の間で顕著に改善
  • 向上の大きさは、チェスプレイヤーの強さ(Elo レーティング)とは相関しない
  • パフォーマンスに影響を与えた可能性のある要因
    • 難易度と品質のばらつき
    • 教えやすさのばらつき
    • 熟考の結果: 自由形式のコメントでAlphaZeroの戦略に言及しているが最終的に選択しなかった

概念の定性分析

  • 全体として、グランドマスターは概念を高く評価
  • イデアにはしばしば新しい要素が含まれていることを発見
  • 局面が非常に複雑であることに気づいた
  • 人間がラベル付けしたコンセプトを使用してグラフ分析することで概念の意味について洞察できる

  • エッジの重みが最も大きい2つの概念
    • 空間: 空間を増やすことが計画の重要な要素
    • 奪還: 駒を取り戻す/獲得する計画
学習に失敗した例
  • 駒損するにも関わらず、空間の優位性を獲得するプレイ
  • パズルの証拠からグランドマスターがこの概念を学ばなかったことを示唆している
  • 人間はこれらの概念に異なる優先順位を付ける傾向がある(できるだけ早く王を安全な場所に連れて行くことを優先するなど)
人間とAlphaZeroの違い
  • 定性的な例は、AlphaZeroがチェスの局面における概念の関連性に関して人間とは異なる事前条件を持っていることを示唆
  • 人間のチェスプレイヤーは、ヒューリスティックなチェスの原則を採用する
  • 一例は、オープニングの 3つの「黄金律」(センターをコントロールし、自分の駒を展開し、キングを安全な場所に連れて行く)
  • AlphaZeroは前提知識を持っていないため、より柔軟になることができる
  • AlphaZeroは特定のサイドに集中するのではなく、ボード全体でプレイする
  • 駒の価値はあまり重要視せず、空間と駒のアクティビティを優先する
  • 人間は他の人間とチェスをするため、情報の非対称性と不完全なプレイを想定する可能性がある
  • 人間は、対戦相手が間違いを犯すことを期待して、チェスの局面を複雑にしたり、最良の手があまり明確ではない継続を選択することがある
  • AphaZeroと人間では時間の役割が違う
  • 人間はできるだけ早く勝利を確実にし、リスクを最小限に抑えようとして、チェスの局面を単純化することがある
  • AlphaZeroは勝つことだけを目標に訓練しているため、戦略的に遅い勝利を選択することがある

結論

  • AlphaZeroの表現空間から知識を抽出するために、教師なしの方法で新しいチェスの概念を発見するフレームワークを開発
  • 概念が別のエージェントに教えられること、新規性があることを検証した
  • 4人の世界トップのグランドマスターが概念を学習した後、パフォーマンスが向上した
  • 線形な疎な概念ベクトルに限定したが、他の概念が非線形ベクトルの形で発見される可能性がある
  • インタラクティブにすることで、反事実のシナリオを調査できるようになり、なぜAlphaZeroがアプローチを選択しなかったのかをより深く理解できるようになる

感想

将棋の世界では、棋士の間でAIは評価値は教えてくれるが、なぜ最善かは教えてくれないと言われている。
AIで「なぜ」に答える方法について関心があり、この論文を見つけたので読んでみた。

人間に理解できる自然言語で「なぜ」に答えるのが理想だと思うが、この論文は、AlphaZeroが学習している概念をベクトルとして抽出して、そこから概念を表す局面を例示することで、人間が学習できることを示している。
だだし、例示で学習できる人間はチェスのトッププレイヤーに限られると思うので、普通の人間でも理解できように自然言語で答えるにはまだハードルがあると思う。
概念を、既知のチェスの概念と関連付けることで、分析できるように試みているので、この手法を発展されることで言語化は可能かもしれない。

この論文では人間が知らない新規の概念に焦点をあてているが、既知の概念を表す局面を抽出する手法として使えるので、初心者が既知の概念を練習するための問題集の作成方法としても使えそうである。

まとめ

AlphaZeroから、チェスの学習可能な新規の概念を抽出する方法に関する論文を読んだ。
概念をベクトルとして抽出して関連する局面を例示することで、人間に学習可能であることを示している。
チェスのグランドマスターが協力して、パフォーマンスが向上できることを検証している。

将棋でも、この手法でdlshogiの概念を抽出して、人間が学習可能か興味があるので、試してみたい。

【論文】Grandmaster-Level Chess Without Searchを読む

DeepmindarXivで発表したチェスで探索なしでグランドマスターレベルに達したTransformerベースモデルに関する論文「Grandmaster-Level Chess Without Search」を読んだ際のメモ。

概要

  • 標準的なTransformerのデコーダ教師あり学習することで探索なしでチェスのグランドマスターレベルに達した
  • Lichessの1,000万対局の棋譜を使用
  • 行動価値を学習する
  • 棋譜の各局面について全合法手を展開した局面に対してStockfish 16で50ミリ秒探索して、行動価値を学習するための訓練データを作成
  • 行動価値(勝率)を128個のビンに分け、クラス分類問題として学習する
  • 2億7,000 万のパラメーターのTransformerモデルを使用
結果
  • Lichessの早指し(blitz)戦で、対人間でグランドマスターレベル(2895 Elo)に達した
  • 対AIエンジンでは、2299 Eloと低い値になった(対AIでは1手のミスの影響が大きいためと分析している)

詳細

データ
  • Lichess (lichess.org) から2023年2月の1,000 万対局をダウンロード
  • 全合法手を展開した局面の状態価値を、Stockfish 16 を使用して50ミリ秒探索して推定
  • 50ミリ秒の探索は、Lichessのblitzで2713 Eloに相当する
  • 訓練に使用する1,532億の行動価値を推定するデータが得られる
  • テストデータとして、別の月(2023年3月)の1,000対局を使用
  • テストデータには、序盤は訓練データと重複する14.7%の局面が含まれるが、分布が歪むため除かない
ビン
  • 行動価値を128のビンに分割してクラス分類問題として学習する
モデル
  • デコーダのみのTransformerモデル
  • コンテキスト サイズは、行動価値の予測の場合79 (状態価値予測、動作クリーニングの場合78)
  • 出力サイズはビン数128 (動作クリーニングの場合は合法手の数1968)
  • コンテキストサイズが固定のため学習された位置エンコーディングを使用
  • パラメーターは、約2億7,000万
  • Llamaで提案されたpost-normalizationとSwiGLUを使用
トークン化
  • FEN文字列の各文字をトークン化
  • FENの空白は、空白の数の分だけ「.」で表現することで固定長にする
  • 77文字の固定長文字列に変換される
  • 先後で盤面は回転しない
  • 指し手は全合法手に付けたインデックスをトークンとする(1トークン)
訓練
  • 学習率 1e-4 のAdam を使用
  • 交差エントロピー損失を使用
  • バッチサイズ4096
  • 1,000万ステップ学習(2.67 エポックに相当)
  • Polyak averagingを使用
予測ターゲット
  • 行動価値の予測以外に、状態価値の予測、動作クリーニングをするモデルを学習して比較した

評価

以下の指標で評価する

  • 指し手の一致率(Action-accuracy)
  • テストセット全体の平均 Kendall 順位相関 (Kendall の 𝜏)
  • パズルの精度: Lichessの10,000 個のパズルを使用
  • 強さ(Elo)
    • (i) Lichess で人間のみまたはボットのみを相手に Blitz ゲームをプレイ
    • (ii) すべてのプレイヤー間で内部トーナメント(AlphaZeroの方策関数及び価値関数(400 MCTS、温度パラメータ使用)、Stockfish(時間制限))
    • AlphaZeroは履歴局面にアクセスできる点に注意
    • Transformerモデルは現局面のみのため千日手にならないように千日手になる手を勝率50%にする

結果

表1に結果が整理されている。

  • モデルが大きいほどスコアは一貫して向上
  • 対人間でグランドマスターレベル(2895 Elo)
  • Lichessでボットと対戦すると Elo が低下
パズル


  • 270MのモデルがStockfish 16に次いで精度が高い
  • AlphaZero、GPT-3.5-turbo-instructは精度が大幅に低い
  • パズルを解くには正しい手順が必要にも拘わらず貪欲なポリシーで正解している
スケーリング則


  • 訓練データのサイズが小さいと、大きいモデルでは過学習する
  • 訓練データのサイズが大きい場合、データセットのサイズに関係なくモデルサイズが大きいほど最終的な精度が向上する

比較実験


予測ターゲット

  • 3つの予測ターゲットを比較すると、行動価値のEloが高かった
  • 状態価値、動作クリーニングと比較して、全合法を展開して訓練データを作成しているためデータ数が不公平
  • データ数を公平にすると、状態価値のEloが最も高く、行動価値もほぼ同じ
  • 予測ターゲットよりも、データセットサイズに起因している
  • 動作クリーニングは最善手1手のみを学習している情報量が少なくなった結果と考えられる(確率分布にすることで改善する可能性あり)
層の深さ
  • Transformerの層の深さは、パラメータ数を同一とした場合、深くするほど精度があがるが、8で飽和
データ サンプラー
  • 訓練データで重複局面を削除した方が精度が大幅に向上
  • データの多様性が重要
ビニング
  • ビンの数が多いほどパフォーマンスが向上する
  • 指し手の一致率では128が一番精度が高い
損失関数
  • スカラー回帰問題として扱うより分類問題として学習した方が、指し手一致率、Kendallの𝜏で精度が高い
  • パズルはわずかに低くなった

考察

千日手に対する盲点
  • 履歴局面にアクセスできないため、千日手を指してしまう
  • 回避するために、千日手になる手を勝率50%とした
勝てる局面での優柔不断さ
  • 行動価値をビンで学習している最高の勝率のビンにある手はランダムで選択される
  • それにより勝てる局面を長引かせる傾向がある
対人間と対ボットでのElo
  • 対人間と比較して、対ボットでEloが低下する
  • 原因は、人間は諦める傾向があることと、Lichessで人間vsボットの対戦数の少なさに起因していること、対ボットでは1つのミスで大きなペナルティになることが考えられる
限界
  • Stockfish 16との差は完全に埋められない(一つの要因は千日手を認識しないこと)
  • 行動価値はすべての合法手を考慮するため1手探索ではないか(次の状態への遷移はしないため「探索なし」と主張する)
  • ポリシーの決定的な性質により搾取される可能性がある

定性的評価

  • 駒を頻繁に犠牲にする、攻撃的で進取的なスタイル
  • 対戦相手に難しい決断をさせるような動きを好む
  • 人間の分析者の感想:「通常のエンジンをプレイするよりも楽しい」、「絶望的に打ちのめされているだけではないように感じる」
  • 多くの局面でStockfish と一致する。不一致の局面ではStockfish が正しいことが観測された

結論

  • Stockfish 16の教師データから未知の局面に一般化されるモデルが学習できることが実証された
  • 複雑で洗練されたアルゴリズムがTransformerモデルに蒸留できることが示された
  • 大きなTransformerは「単なる」統計的パターン認識器ではなく、一般的なアルゴリズムを近似できる強力な技術である

感想

探索なしでグランドマスターレベルに達したという結果については、将棋AIでもdlshogiのポリシーのみでも十分な強さがあることからそれほど驚きはない。
それよりも、大規模な比較実験の結果からいろいろな示唆が得られる論文だと思った。

訓練データの重複データを削除すると大幅に精度が上がるという結果については、dlshogiの実験でも同一局面の出力を平均化するということを行っており、効果があることを確認していた。
チェスにおいてもデータの多様性が重要という結果が得られたことで、その裏付けとなったと思う。

指し手のみを学習するより、すべての合法手の行動価値を学習する方が精度が向上するという結果については、モデルを蒸留する際の示唆が得られた。
指し手を学習(最善手のみ)や蒸留元のモデルの出力を学習するよりも、全合法手を展開して状態価値から分布を学習する方が情報量が増えて精度が向上するということがわかる。
これは、モデル蒸留で方策を学習する際にも効果的な手法になりそうである。

モデルのパラメータサイズについて、学習データが十分であれば、学習データを増やさなくてもモデルサイズが大きいほどよいという結果は、モデルサイズに応じて学習データを増やす必要があるという認識だったので、少し意外であった。
大規模な実験をしないと分からない結果なので、参考になった。

Transformerモデルの入力のトークン化、層数、ポジショナルエンコーダの方法について、以前に将棋AIでTransformerモデルを学習した際に、それほど精度が上がらなかったので、Transformerモデルを使用する際の参考になる。

まとめ

チェスで探索なしでグランドマスターレベルに達したエージェントについての論文を読んだ。
結果自体は、それほど驚きはなかったが、大規模な比較実験の結果は、将棋AIの学習でも参考になる点が多かった。
dlshogiのモデルを蒸留して、CPUのみでも従来側のエージェントに匹敵するモデルを作りたいと思っており、蒸留の方法に応用できそうである。

また、LLMを応用することで、将棋AIの指し手の解説ができないかと考えているので、学習済みの言語モデル+将棋AIで新しいことができないか考えているので、将棋AIでTransformerモデルを扱う際のヒントにもなりそうである。

PokéLLMonの論文を読む

LLMを使用してポケモンバトルをプレイするPokéLLMon論文を読んだ際のメモ。

概要

  • LLMを使用して、ポケモンバトルをプレイするエージェントを作成する。
  • 現在の状態をテキストとして与え、行動を生成する
  • 以前のターンの結果をテキストとして、コンテキストに与える「インコンテキスト強化学習
  • 外部知識として、技や能力の効果、タイプ相性をコンテキストに与える
  • パニック スイッチング(強い相手のときに交代を繰り返す)を防ぐため、プロンプトエンジニアリング手法のSelf-Consistency(SC)を使用

行っていないこと

  • 探索は行わない
  • モデルのパラメータの学習は行わない

アルゴリズムの概要は、図4が分かりやすい。

結果

  • オンライン対戦(Pokemon showdown)において、ラダー戦で勝率49%、招待戦で勝率56%を達成
  • 平均的な人間レベル

課題

  • 長期計画に弱い
  • 短期的な利益を達成できる行動を取る傾向がある
  • 人間プレイヤーの消耗戦略に対して脆弱

プロンプト

論文では詳細なプロンプトについて記載がなかったので、GitHubのソースを実際に動かしてプロンプトを確認してみた。
環境構築については、後で記載する。

環境変数「OPENAI_LOG=debug」を設定することで、実際にOpenAI APIにPOSTされたJSONデータを確認した。
後で気づいたが、battle_log/pokellmon_vs_bot/output.jsonlにも出力される。

1ターン目のプロンプト

{
"messages": [
{
"role": "system",
"content": "You are a pokemon battler that targets to win the pokemon battle. You can choose to take a move or switch in another pokemon. Here are some battle tips: Use status-boosting moves like swordsdance, calmmind, dragondance, nastyplot strategically. The boosting will be reset when pokemon switch out. Set traps like stickyweb, spikes, toxicspikes, stealthrock strategically. When face to a opponent is boosting or has already boosted its attack/special attack/speed, knock it out as soon as possible, even sacrificing your pokemon. if choose to switch, you forfeit to take a move this turn and the opposing pokemon will definitely move first. Therefore, you should pay attention to speed, type-resistance and defense of your switch-in pokemon to bear the damage from the opposing pokemon. And If the switch-in pokemon has a slower speed then the opposing pokemon, the opposing pokemon will move twice continuously."
},
{
"role": "user",
"content": "Historical turns:\nBattle start: You sent out Qwilfish. Opponent sent out Tsareena. Qwilfish\'s ability: Intimidate. It decreased opposing Tsareena\'s atk 1 level.\nTurn 1: Current battle state:\nOpponent has 6 pokemons left.\nOpposing pokemon:tsareena,Type:Grass,HP:100%,Is dynamax:False,Attack:168(-1 stage boosted),Defense:213,Special attack:132,Special defense:213,Speed:169,Ability:Queenly Majesty(The Pokemon\'s majesty pressures opponents and makes them unable to use priority moves against the Pokemon or its allies.) Ice, Bug-type attack is super-effective (2x damage) to tsareena. Water, Ground-type attack is ineffective (0.5x damage) to tsareena.\ntsareena\'s all the possible attacks:[powerwhip,grass,Power:120],[rapidspin,normal,Power:50],[knockoff,dark,Power:65],[uturn,bug,Power:70],[highjumpkick,fighting,Power:130],[tripleaxel,ice,Power:20],\nYour current pokemon:qwilfish,Type:Water and Poison,HP:100%,Attack:210,Defense:193,Special attack:142,Special defense:142,Speed:193(faster than tsareena).Ability:Intimidate(When the Pokemon enters a battle, it intimidates opposing Pokemon and makes them cower, lowering their Attack stats.),Item:Focus Sash(An item to be held by a Pokemon. If it has full HP, the holder will endure one potential KO attack, leaving 1 HP.) Ice, Fighting, Bug-type attack is ineffective (0.5x damage) to qwilfish.\nYour qwilfish has 4 moves:\nMove:toxicspikes,Type:Poison,Status-move,Power:0,Acc:100%,Effect:Scatters poisoned spikes, poisoning opposing Pokemon that switch in.\nMove:spikes,Type:Ground,Status-move,Power:0,Acc:100%,Effect:Scatters Spikes, hurting opposing Pokemon that switch in.\nMove:taunt,Type:Dark,Status-move,Power:0,Acc:100%,Effect:For the next few turns, the target can only use damaging moves.\nMove:waterfall,Type:Water,Power:79,Acc:100%,Effect:Has a 20% chance to make the target flinch.(ineffective (0.5x damage) to tsareena)\nYou have 5 pokemons:\nPokemon:rhyperior,Type:Ground and Rock,HP:100%,Attack:273,Defense:257,Special attack:136,Special defense:136,Speed:111(slower than tsareena). Moves:[stoneedge,Rock,1x damage],[megahorn,Bug,2x damage],[earthquake,Ground,0.5x damage], Grass-type attack is extremely-effective (4x damage) to rhyperior. Ice, Fighting-type attack is super-effective (2x damage) to rhyperior. Normal-type attack is ineffective (0.5x damage) to rhyperior.\nPokemon:greedent,Type:Normal,HP:100%,Attack:208,Defense:208,Special attack:141,Special defense:174,Speed:82(slower than tsareena). Moves:[bodyslam,Normal,1x damage],[payback,Dark,1x damage],[earthquake,Ground,0.5x damage], Fighting-type attack is super-effective (2x damage) to greedent.\nPokemon:heracross,Type:Bug and Fighting,HP:100%,Attack:246,Defense:166,Special attack:110,Special defense:198,Speed:182(faster than tsareena). Moves:[megahorn,Bug,2x damage],[knockoff,Dark,1x damage],[closecombat,Fighting,1x damage],[facade,Normal,1x damage], Fighting, Grass, Dark, Bug-type attack is ineffective (0.5x damage) to heracross.\nPokemon:inteleon,Type:Water,HP:100%,Attack:182,Defense:150,Special attack:246,Special defense:150,Speed:238(faster than tsareena). Moves:[darkpulse,Dark,1x damage],[uturn,Bug,2x damage],[icebeam,Ice,2x damage],[hydropump,Water,0.5x damage], Grass-type attack is super-effective (2x damage) to inteleon. Ice-type attack is ineffective (0.5x damage) to inteleon.\nPokemon:ninetalesalola,Type:Ice and Fairy,HP:100%,Attack:110,Defense:164,Special attack:174,Special defense:204,Speed:218(faster than tsareena). Moves:[moonblast,Fairy,1x damage],[freezedry,Ice,2x damage],[blizzard,Ice,2x damage], Ice, Dark, Bug-type attack is ineffective (0.5x damage) to ninetalesalola.\nChoose the best action and your output MUST be a JSON like: {\"move\":\"\"} or {\"switch\":\"\"}\n"
}
],
"model": "gpt-4-0125-preview",
"max_tokens": 100,
"response_format": {
"type": "json_object"
},
"stop": [],
"stream": False,
"temperature": 0.8
}

<日本語訳>
システムプロンプト】
あなたはポケモンバトルで勝利することを目指すポケモンバトラーです。 行動するか、別のポケモンに切り替えるかを選択できます。 戦闘のヒントは次のとおりです: ソードダンス、カームマインド、ドラゴンダンス、ナスティプロットなどのステータスを高める動きを戦略的に使用してください。 ブーストはポケモンが交代するとリセットされます。 スティッキーウェブ、スパイク、トキシックスパイク、ステルスロックなどのトラップを戦略的に設定します。 対戦相手がブースト中、またはすでに攻撃/特殊攻撃/スピードをブーストしている場合は、ポケモンを犠牲にしてでもできるだけ早く敵をノックアウトします。 切り替えることを選択した場合、このターンに移動することを放棄し、相手のポケモンが確実に最初に移動します。 したがって、相手のポケモンからのダメージに耐えるために、スイッチインポケモンの素早さ、タイプ耐性、防御力に注意を払う必要があります。 また、交代ポケモンの素早さが相手ポケモンより遅い場合、相手ポケモンは2回連続で動きます。

【ユーザプロンプト】
ターン履歴:
戦闘開始: クウィルフィッシュを送り出しました。 相手はツァリーナを送り出した。 クウィルフィッシュの能力: 威圧。 相手のツァリーナの攻撃力を1レベルダウンさせた。
ターン 1: 現在の戦闘状態:
相手の残りポケモンは6匹。
相手ポケモン:ツァリーナ、タイプ:草、HP:100%、ダイマックス:偽、攻撃:168(-1段階アップ)、防御:213、特攻:132、特防:213、すばやさ:169、特性:クイーンリー マジェスティ(ポケモンの威厳が相手にプレッシャーをかけ、ポケモンやその仲間に対して優先技を使えなくする。) こおり、むしタイプの攻撃はツァリーナに非常に効果的(ダメージ2倍)。 ツァリーナには水、じめんタイプの攻撃が効かない(ダメージ0.5倍)。
ツァリーナの可能なすべての攻撃:[パワーウィップ、グラス、パワー:120]、[ラピッドスピン、ノーマル、パワー:50]、[ノックオフ、ダーク、パワー:65]、[ターン、バグ、パワー:70]、[ ハイジャンプキック、格闘、パワー:130]、[トリプルアクセル、アイス、パワー:20]、
現在のポケモン:クウィルフィッシュ、タイプ:みずとどく、HP:100%、攻撃:210、防御:193、特攻:142、特防:142、すばやさ:193(ツァリーナより速い)、特性:威圧( ポケモンが戦闘に参加すると、相手のポケモンを威嚇してすくめさせ、攻撃ステータスを低下させます。 HPが1残る)クウィルフィッシュには氷・格闘・むし系の攻撃は無効(ダメージ0.5倍)。
クウィルフィッシュには 4 つの動きがあります。
技:毒スパイク、タイプ:どく、ステータスわざ、威力:0、命中率:100%、効果:毒スパイクをばらまき、入れ替わった相手ポケモンを毒状態にする。
技:トゲ、タイプ:じめん、ステータスわざ、威力:0、命中率:100%、効果:トゲをばらまいて、入れ替わった相手ポケモンにダメージを与える。
技:挑発、タイプ:闇、ステータス技、威力:0、命中率:100%、効果:数ターンの間、対象はダメージ技のみ使用可能。
技:滝、種族:水、威力:79、命中:100%、効果:20%の確率で対象をひるませる(ツァリーナには無効(ダメージ0.5倍))
あなたは5匹のポケモンを持っています:
ポケモン:ライペリア、タイプ:じめん、いわ、HP:100%、攻撃:273、防御:257、特攻:136、特防:136、すばやさ:111(ツァリーナより遅い)。 技:[ストーンエッジ、岩、ダメージ1倍]、[メガホーン、むし、ダメージ2倍]、[地震、地面、ダメージ0.5倍]、草タイプの攻撃がリペリアに非常に効果的(ダメージ4倍)。 氷、格闘タイプの攻撃がリスペリアに超有効(ダメージ2倍)。 ライペリアにはノーマルタイプの攻撃は無効(ダメージ0.5倍)。
ポケモン:グリーン、タイプ:ノーマル、HP:100%、攻撃:208、防御:208、特攻:141、特防:174、素早さ:82(ツァリーナより遅い)。 技:[ボディスラム、通常、ダメージ1倍]、[払い戻し、闇、ダメージ1倍]、[地震、地面、ダメージ0.5倍]、貪欲に格闘タイプの攻撃が超有効(ダメージ2倍)。
ポケモン:ヘラクロス、タイプ:むし・かくとう、HP:100%、攻撃:246、防御:166、特攻:110、特防:198、すばやさ:182(ツァリーナより速い)。 技:[メガホーン、むし、ダメージ2倍]、[なぎ倒し、あく、ダメージ1倍]、[接近戦、格闘、ダメージ1倍]、[正面、通常、ダメージ1倍]、格闘、草、闇、むし系の攻撃は無効 ヘラクロスに(ダメージ0.5倍)。
ポケモン:インテレオン、タイプ:みず、HP:100%、攻撃:182、防御:150、特攻:246、特防:150、すばやさ:238(ツァリーナより速い)。 技:[ダークパルス、闇、ダメージ1倍]、[ターン、バグ、ダメージ2倍]、[アイスビーム、氷、ダメージ2倍]、[ハイドロポンプ、水、ダメージ0.5倍]、くさタイプの攻撃が効果抜群(ダメージ2倍) )インテレオンに。 インテレオンにはこおりタイプの攻撃は無効(ダメージ0.5倍)。
ポケモン:ナインタレサローラ、タイプ:こおりとフェアリー、HP:100%、攻撃:110、防御:164、特攻:174、特防:204、素早さ:218(ツァリーナより速い)。 技:[ムーンブラスト、フェアリー、ダメージ1倍]、[フリーズドライ、氷、ダメージ2倍]、[ブリザード、氷、ダメージ2倍]、氷、あく、むし系の攻撃は効かない(ダメージ0.5倍)。
最適なアクションを選択し、出力は次のような JSON でなければなりません: {\"move\":\"\"} または {\"switch\":\"\"}

解説

システムプロンプトは、行動選択時と交代時でそれぞれ固定のプロントが使用される。

PokeLLMon/poke_env/player/gpt_player.py at bf3fa25c6c0e40715a87d56483f323056905b3e7 · git-disl/PokeLLMon · GitHub

ユーザプロンプトは、以前のターンの履歴情報、現在の状態、出力の指示文で構成されている。

ポケモンには、タイプ、HPなどのパラメータ、ワザ、特性の説明、どのタイプの攻撃が効くかが「くさタイプの攻撃が効果抜群(ダメージ2倍)」というように記述される。
早さについては、「ツァリーナより速い」といった情報を与えている。

指示文には、技名か、交代するポケモン名を出力するように指示している。

感想

外部情報として、どの技が効くかや、素早さだけでなくどちらが先に行動できるか情報を与えているため、LLMはそれらのテキストを解釈して、最適な行動を出力していると思われる。
答えをルールベースで教えているのに近いという印象を受けた。
同等以上の行動をとれるルールベースのエージェントを記述するのはおそらく可能である。

高度な状況判断は行っていないため、平均的な人間レベルといっても、上級者にはまったくかなわないと思われる。
平均的なプレイヤーもその状況で効果的な技を選んでいるだけだと思うので、正確なデータベースを参照して効果的な技の答えを与えている分、平均くらいのプレイができても驚きはない。

情報をテキストのみで与えて行動を選択できているという自然言語の解釈能力にしては、LLMの性能が発揮されている。
ただし、答えに近い情報を計算した上で渡しているため、わざわざテキストで与える意味は薄いと思った。

まとめ

LLMでポケモンバトルをプレイするエージェントに関する論文を読んだ。
将棋AIで、AIがなぜその手を指したのかはAIが教えてくれないという課題があり、LLMによって解決できないかと考えていて、そのヒントにならないかと思って読んでみた。
しかし、使われている手法は、データベースを使ってあらかじめ効果的な技を計算してテキストとして与えているため、将棋AIへの応用は難しいと思った。
将棋AIの場合、現在の状況を言語化することや、効果的な指し手を言語化して与えるのがそもそも難しい。

LLMを使用したより洗練されたアプローチや新しいアイディアが生まれることを期待したい。

参考情報

環境構築

ubuntu:22.04のDockerイメージを使用して構築した。

nodejsインストール
apt update &&
apt install nodejs npm curl -y &&
npm install -g n &&
n stable &&
apt purge nodejs npm -y
ソースclone
git clone https://github.com/git-disl/PokeLLMon.git
cd PokeLLMon
Pokemon Showdownインストールと起動

git clone https://github.com/smogon/pokemon-showdown.git
cd pokemon-showdown
npm install
cp config/config-example.js config/config.js
node pokemon-showdown start --no-security &
cd ..
|

Pythonインストール
apt install python3-pip python-is-python3 -y
ライブラリインストール
pip install openai numpy orjson requests websockets gymnasium
ソース修正

Llma2のエージェントは、peftのインストールが必要で関連ライブラリとしてPyTorchがインストールされるため、インストールを省くため、poke_env/player/__init__.pyの

from poke_env.player.llama_player import LLAMAPlayer

の行をコメントアウトした。

vs_bot.pyは、Pokemon Showdownのアカウントと、OpenAIのAPIキーをハードコードするようになっているので、環境変数から取得するように書き換えた。

                           api_key=os.getenv("OPENAI_API_KEY"),
                           ...
                           account_configuration=AccountConfiguration(os.getenv("USERNAME"), os.getenv("PASSWORD")),
実行
python vs_bot.py

ブラウザで、http://localhost:8000/にアクセスすると観戦できる。

麻雀AIを深層強化学習で作る その11(学習処理)

前回、自己対局で特徴量を作成する処理を実装した。

今回は、自己対局で生成したデータを読み込んで学習する処理を実装した。

データ読み込み

自己対局プログラムでzlibで圧縮したデータをPythonのzlibで解凍し、Numpyのndarrayとして読み込む。
C++の構造体をnumpyのdtypeとして定義しておくことで、「np.frombuffer(data, StepData)」とするだけで構造化されたデータを読み込むことができる。

Numpyの構造体定義
PublicFeatures = np.dtype((np.float32, (N_CHANNELS_PUBLIC + 4, 9, 4)))
PrivateFeatures = np.dtype((np.float32, (N_CHANNELS_PRIVATE, 9, 4)))
Policy = np.dtype((np.float32, N_ACTIONS))
Hupai = np.dtype((np.float32, 54))
HulePlayer = np.dtype((np.float32, 5))
TajiaTingpai = np.dtype((np.float32, (3, 34)))
Fenpei = np.dtype((np.float32, 4))
StepData = np.dtype(
    [
        ("public_features", PublicFeatures),
        ("private_features", PrivateFeatures),
        ("action", np.dtype((np.int64, (1, )))),
        ("value", np.float32),
        ("logits", Policy),
        ("advantage", np.float32),
        ("hupai", Hupai),
        ("hule_player", HulePlayer),
        ("tajia_tingpai", TajiaTingpai),
        ("fenpei", Fenpei),
    ]
)

データローダ

Numpyで読み込んだデータをPyTorchのデータセットでラップして、PyTorchのデータローダで読み込めるようにする。

データセット定義
class RolloutDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.rollout_data = np.empty(0, StepData)

    def load(self, path):
        logging.info(f"loading {path}")
        with open(path, "rb") as f:
            data = zlib.decompress(f.read())
        tmp = np.frombuffer(data, StepData)
        self.rollout_data = np.concatenate((self.rollout_data, tmp))

    def calc_log_prob(self, batch_size, device):
        self.log_prob = np.empty(len(self.rollout_data), np.float32)
        with torch.inference_mode():
            for i in range(0, len(self.rollout_data), batch_size):
                batch = self.rollout_data[i:i+batch_size]
                actions = torch.from_numpy(batch["action"]).to(device)
                logits = torch.from_numpy(batch["logits"]).to(device)
                log_prob = PolicyValueNetWithAux.log_prob(actions, logits)
                self.log_prob[i:i+batch_size] = log_prob.to("cpu").detach().numpy()

    def __len__(self):
        return len(self.rollout_data)

    def __getitem__(self, idx):
        data = self.rollout_data[idx]
        log_prob = self.log_prob[idx]
        
        return data["public_features"], data["private_features"], data["action"], log_prob, data["advantage"], data["advantage"] + data["value"], data["hupai"]
log probability

PPOでは、挙動方策の選択したアクションの確率と、現在の方策のアクションの確率の比の対数を算出する。
実装上は、それぞれの確率の対数の差を計算する。

自己対局では、ロジットを記録しているため、挙動方策の確率の対数(log probability)をあらかじめ計算しておく。
上記、データセットのcalc_log_probメソッドに実装している。

学習処理

PPOアルゴリズムで学習する処理を実装する。

for epoch in range(args.n_epochs):
    logging.info(f"training epoch {epoch}")
    for public_features, private_features, actions, old_log_prob, advantages, returns, hupai in rollout_buffer:
        values, log_prob, entropy, p_aux1, p_aux2, p_aux3, v_aux = model.evaluate_actions_with_aux(public_features, private_features, actions)

        values = values.flatten()

        # Normalize advantage
        if normalize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # ratio between old and new policy, should be one at the first iteration
        ratio = torch.exp(log_prob - old_log_prob)

        # clipped surrogate loss
        policy_loss_1 = advantages * ratio
        policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
        policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

        # Value loss using the TD(gae_lambda) target
        value_loss = F.mse_loss(returns, values)

        # Entropy loss favor exploration
        entropy_loss = -torch.mean(entropy)

        # 補助タスク1 役
        p_aux1_loss = bce_with_logits_loss(p_aux1, hupai)

        loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss + p_aux1_loss

        # Optimization step
        optimizer.zero_grad()
        loss.backward()
        # Clip grad norm
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
補助タスク

モデルに定義した補助タスクの内、今回は、和了時の役の予測のみ実装した。

学習結果

自己対局で生成したデータで学習を実行し、自己対局→学習のサイクルを4サイクル回して、訓練損失を確認した。

1サイクルで、約45万局面を生成し、各サイクルで10エポック学習した。
自己対局では、ランダムに選択した役から3向聴前の局面を初期局面とする。

訓練損失 合計


方策損失


価値損失


補助タスク(役の予測)の損失

各サイクルでグラフの色が分かれている。

訓練損失の合計は、1サイクル目が一番低く、2サイクル目で上昇してそこから徐々に下がっている。
1サイクル目は方策がランダムに近く、合法手を学習するだけで、方策はランダムに近いため、エントロピー損失が大きく下がったためと考えられる。
2サイクル目からは少しずつ意味のある方策学習され始めている。

方策損失は、4サイクル目が高い値になっているが、補助タスクの学習が先に進んだためと思われる。
価値損失は、4サイクル目が一番低い値となっており、局面の勝率を学習できていそうである。

補助タスクの損失は、4サイクル目が一番低い値となっており、局面から役の予測ができていそうである。
方策と補助タスクは、方策ネットワークのヘッドにあるため、損失はサイクルごとに一方が下がるともう一方が高くなっており、交互に学習が進んでいるようである。

まとめ

麻雀AIの学習処理を実装した。
実際に自己対局でデータを生成して学習するサイクルを4サイクルだけ実行して、うまく学習できているか確認した。
サイクルごとに訓練損失が下がることが確認できた。

次は、学習量を増やして、打牌や予測勝率、役の予測結果をサンプリングして確認してみたい。

将棋AIの棋風学習にLoRAを適用 追試

前回、将棋AIのモデルにLoRAを適用して、人間プレイヤーの傾向を学習できることを確かめた。

今回、LoRAが通常の追加学習と比較して効率的か比較してみた。
また、序盤、中盤、終盤で傾向に違いがあるか検証してみた。

追加学習と比較

事前学習済みモデルに人間プレイヤーの棋譜を追加学習した場合と比較した。

条件
  • 学習率、エポックは同一
  • 方策のみ学習
結果


※凡例:finetuningが追加学習、pretrainedは事前学習済みモデル

追加学習により、事前学習済みモデルより10%ほど一致率が高くなっている。
LoRAと追加学習を比較すると、LoRAの方が2%ほど一致率が高い。

LoRAの方が効率的に人間プレイヤーの傾向を学習できている。

序盤、中盤、終盤

序盤、中盤、終盤で傾向に違いがあるか確認した。

条件
  • 40手未満を序盤、40手以上80手未満を中盤、80手以降を終盤とする
  • 序盤、中盤、終盤ごとに指し手の一致率を計測
結果


※x軸:40が序盤、80が中盤、256が終盤

LoRAで学習したモデルの一致率は、序盤と比べて、中盤、終盤の一致率が高い。
中盤と終盤はわずかに終盤の方が高い。

序盤は、人間プレイヤーの戦型が多様であるため、一致率が下がっていると考えられる。
序盤よりも中盤以降の一致率が高いことから、序盤の定跡を学習しているだけではなく人間プレイヤーの傾向を学習できていると言えそうである。

LoRA、事前学習済みモデル、通常の追加学習の比較

事前学習済みモデルと通常の追加学習したモデルでも確認してみた。

事前学習済みモデルは、序盤と中盤以降の差が大きい。
序盤の一致率が低いのは、AIの指し手は戦型が偏っているのに比べて人間の戦型が多様なためと考えられる。

まとめ

LoRAを適用することで、通常の追加学習と比較して効率的に人間プレイヤーの傾向を学習できることが確認できた。

また、序盤、中盤、終盤では、序盤と比べて中盤以降の一致率が高いことがわかった。
序盤の定跡を学習しているだけではなく、人間プレイヤーの傾向を学習できていることが確認できた。

将棋AIの棋風学習にLoRAを適用

大規模言語モデルやStable Diffusionなどの拡散モデルのファインチューニングには、LoRAという手法が使われることが多い。
全てのパラメータを微調整するよりも効率的に学習できる。

LoRAを将棋AIのモデルに適用して、棋風の学習に使えるか試してみた。

将棋AIのモデル

将棋AIのモデルには、dlshogiの学習済みモデルを使用する。
dlshogiのモデルは、ResNetで構成されている。

LoRA

標準のLoRAは、全結合層に対して適用されるため、ResNetで使用している畳み込み層には適用できない。
LoRAを、畳み込み層にも適用できるようにしたものにLoConがある。

LoConは、カーネルサイズ3x3・フィルタ数Nの畳み込み層を、カーネルサイズ3x3・フィルタ数Rの畳み込み層とカーネルサイズ1x1・フィルタ数Nの畳み込み層に分解して表現する。
RはRankを表し、元のフィルタ数Nよりも小さい値とする。
推論時には、元のパラメータにマージすることができる。

棋風学習

棋風は、特定の傾向のあるプレイヤーの棋譜から学習する。
dlshogiのモデルは価値(勝率)も学習するが、指し手のみを学習する。

実験

条件
  • 将棋倶楽部24万局集」の人間プレイヤーの棋譜を学習する(以前の記事参照)。
  • レーティング帯別に訓練データと評価データを9:1に分割する。
  • レーティング帯別に一致率で評価する。
棋譜

レーティング別の件数は以下の通り。

1800から2000あたりの件数が多い。
1000から1300あたりや1500から1700あたりの件数が少ない。
そのため、レーティング帯を500間隔に分ける(2000以上は件数が少ないため一つにする)。

結果

4エポック学習し、評価データで一致率を測定した結果は以下の通り。
比較のため、LoRAを適用前の事前学習済みモデルでの結果も載せる。

事前学習済みモデルと比較して、一致率が10%以上高くなっていることがわかる。

レーティング帯別では、低いレーティング帯ほど一致率の差が大きくなっている。

人間プレイヤーの傾向を学習できていると言えそうである。

低いレーティング帯のみ学習

低いレーティング帯の訓練データのみを学習した場合、そのレーティング帯の一致率が上がるか検証した。

レーティング1500未満のデータを学習した結果を、全レーティング帯を学習したものと比較した結果は以下の通り。

差が分かりにくいので、差分をグラフにした。

x軸が1000(500~1000)と1500(1000~1500)の一致率が少しだけ高くなっているのが分かる。
逆に、2000以上では低くなっている。
500が低い理由はわからない。

差はわずかだが、低いレーティング帯の傾向を学習できていると言えそうである。

絶対的な一致率は、高いレーティング帯の方が高い。
これは、事前学習済みモデルの傾向が反映されてためと考えるが、低いレーティング帯のプレイヤーの棋譜のみで、高いレーティング帯の一致率も事前学習済みモデルから上昇するのは、レーティング帯によらず人間のプレイヤーに共通する傾向があるのかもしれない。

まとめ

将棋AIのモデルにLoRAを適用して棋風を学習できるか試してみた。
結果、事前学習済みモデルと比較して、訓練に使用した棋譜の傾向を学習できることが確認できた。

モデルを直接ファインチューニングした場合と比べてどうかも気になるため、別途追試をしてみたい。

LLMを活用した深層学習モデルの改善

前回の記事で、LLMを使用して数学的発見を行うFunSearchの論文を紹介した。

FunSearchは、LLMを使用してプログラムの変更を行い、進化的アルゴリズムでスコアの高いプログラムを選別することで、最適な解を出力するプログラムを生成する。

この仕組みは、深層学習のモデル構造の改良にも使えないかと思い試してみた。

進化的アルゴリズムを行うには評価に時間がかかるため、今回はLLMにより、モデル構造を変更する部分を試した。

対象

dlshogiで使用しているResNetブロックを初期バージョンとして、改善したバージョンをLLMに出力させる。

class ResNetBlock(nn.Module):
    def __init__(self, channels, activation):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.act = activation

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)

        out = self.conv2(out)
        out = self.bn2(out)

        return self.act(out + x)

LLM

論文と同じCodey for Code Generationと、ChatGPT4を使用する。

プロンプト

import torch
import torch.nn as nn

class ResNetBlock_v0(nn.Module):
  """ResNetBlock that does not change input and output size"""
  def __init__(self, channels, activation):
    super(ResNetBlock_v0, self).__init__()
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(channels)
    self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(channels)
    self.act = activation

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.act(out)

    out = self.conv2(out)
    out = self.bn2(out)

    return self.act(out + x)

class ResNetBlock_v1(nn.Module):
  """A version of `ResNetBlock_v0` based on deep learning knowledge and using appropriate modules in PyTorch to change model structure and parameters to improve accuracy."""
  def __init__(self, channels, activation):

始めResNetBlock_v1のコメント部分を論文と同じように「Improved version of `ResNetBlock_v0`」としたが、出力されるコードがほとんど変わらなかったので、「深層学習の知識を元にPyTorchの適切なモジュールを使用してモデル構造とパラメータを変更して精度を向上」という指示を加えた。

結果

Codey for Code Generation

温度パラメータを1にして何回か出力したが、書き方は変わっているがモデル構造としてはほとんど変わらないものしか出力されなかった。

class ResNetBlock_v1(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock_v1, self).__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        # If input and output channels are different, then we need a 1x1 convolution to match the dimensions.
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) if stride != 1 or in_channels != out_channels else nn.Identity()

    def forward(self, x):
        return F.relu(self.residual_function(x) + self.shortcut(x))
ChatGPT4

ChatGPT4では、DropoutやDilated Convolution、SqueezeNetを使うコードが出力された。

class ResNetBlock_v1(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, activation=nn.ReLU()):
        super(ResNetBlock_v1, self).__init__()
        self.expansion = 4
        mid_channels = out_channels // self.expansion

        # Bottleneck architecture
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)

        # Dilated Convolution can be added here with dilation > 1 if needed
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)

        self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        self.se_block = SqueezeExcitationBlock(out_channels)  # Squeeze-and-Excitation block
        self.relu = activation
        self.downsample = downsample
        self.stride = stride
        self.dropout = nn.Dropout(0.5)  # Dropout

    def forward(self, x):
        residual = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        out = self.se_block(out)  # Squeeze-and-Excitation block

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.dropout(out)  # Dropout
        out = self.relu(out)

        return out

class SqueezeExcitationBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SqueezeExcitationBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

他のバージョンも出力するように指示すると、Grouped Convolutionを使用したバージョンが出力された。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ResNetBlock_v2(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, groups=1, dilation=1, use_se=False):
        super(ResNetBlock_v2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, groups=groups, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=dilation, groups=groups, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.se = SEBlock(out_channels) if use_se else None

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.se is not None:
            out = self.se(out)
        out += self.shortcut(x)
        out = self.relu(out)

        return out

考察

Codey for Code Generationでは、記述方法が変わるだけでモデル構造が変わるコードは出力されなかった。
Codey for Code Generationを使用して、モデル構造を改善するのは難しそうである。

ChatGPT4の方は、深層学習の一般知識を使用して、Dilated Convolution、Grouped Convolution、SqueezeNetといったResNetに効果のある手法を適用したコードが出力された。
出力されたコードを何も考えずに評価して、改善されたものをさらに改善していくことで、良いモデル構造が見つけられそうである。

今回は、ResNetブロックだけに適用したが、ボディ部分の構造にも適用できそうである。

まとめ

LLMでResNetのモデル構造を改善できるか試してみた。
結果、Codey for Code Generationでは良い結果が得られなかったが、ChatGPT4では深層学習の知識を活かしたモデル構造のバリエーションを出力できることがわかった。

実際に将棋AIのモデルで精度が改善されるか時間があれば検証してみたい。