TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

将棋AIの進捗 その54(補助ターゲット)

深層強化学習では、メインの学習タスクに加えて、補助タスクを同時学習することで、パフォーマンスを改善するということが行われている。
[1611.05397] Reinforcement Learning with Unsupervised Auxiliary Tasks

囲碁AIの例

囲碁AIのKataGoでは、

  • 占領した領域
  • スコア

が補助ターゲットとして使用されている。

チェスAIの例

チェスAIのLeela Chess Zeroでは、

  • 残り何手で終局か(Move Left)

が補助ターゲットとして使用されている。

通常、補助ターゲットは、学習時のみ使用して、推論時には利用しない。
ただし、Leela Chess ZeroのMove Leftは、ゲームをより短い手数で終了するために探索でも活用されている。

将棋での補助ターゲット

将棋は詰みによる終局は、直前まで駒損していても詰ませれば勝ちなので、終局時の駒得のような補助ターゲットは適切ではない。
入玉の有無はありかもしれない。

Leela Chess Zeroのような後何手で終局かは、将棋の詰み手順は、ディープラーニングが苦手としているので、強化学習時に最短手数で詰ませられないと正しく学習できないかもしれない。

データ作成が簡単で、効果がありそうな補助ターゲットとして、思いつくものとしては、

がある。

ここでは、この2つを補助ターゲットに追加して、効果があるかを検証してみる。

教師データ

強化学習で検証するには時間がかかるので、floodgateの棋譜を使用して教師ありで検証する。

2018年~2020年の棋譜から、レーティングが3000以上、80手以上の棋譜から、序盤は偏りがあるため30手目以降の局面を使用する。
最大手数に達した対局は除く。

訓練:テスト=9:1の割合で、棋譜単位で分割する。

棋譜 局面数
訓練 133114 13971526
テスト 14818 1555972

実装方法

教師データフォーマット

dlshogiで使用しているhcpeのgameResult(8bit)の3bit目と4bit目で、千日手入玉宣言勝ちを表現する。

ニューラルネットワーク

価値ヘッドの出力は1つで勝率を出力しているが、そこに千日手入玉宣言勝ちに対応する出力ユニットを追加する。

        # value network
        self.l22_v = nn.Conv2d(in_channels=k, out_channels=MAX_MOVE_LABEL_NUM, kernel_size=1, bias=False)
        self.l23_v = nn.Linear(9*9*MAX_MOVE_LABEL_NUM, fcl)
        self.l24_v = nn.Linear(fcl, 1)
        # sennichite, nyugyoku
        self.l24_aux = nn.Linear(fcl, 2)
損失

千日手入玉宣言勝ちの出力の損失関数に、シグモイド交差エントロピーを使用する。
千日手入玉宣言勝ちの損失平均を補助ターゲットの損失とする。

千日手入玉宣言勝ちで終局する対局は、全体からすると割合が小さいため、正例に重み付けを行う。
PyTorchでは、BCEWithLogitsLossのpos_weightで重みを与えることができる。

学習データによって割合は異なるため、訓練データ中の千日手入玉宣言勝ちの割合を調べて、動的に重みを与える。

pos_weight = torch.tensor([
    len(train_data) / (train_data['result'] // 4 & 1).sum(), # sennichite
    len(train_data) / (train_data['result'] // 8 & 1).sum(), # nyugyoku
    ], dtype=torch.float32, device=device)
bce_with_logits_loss_aux = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

通常の損失と補助ターゲットの損失で加重平均をとり、全体の損失とする。
補助ターゲットの重みは0.1とした。

学習結果

学習結果は以下のようになった。
※末尾に(aux)の付いた系列(オレンジ)が補助ターゲットありの結果
※テストデータは100ステップごとに640局面をサンプリング
f:id:TadaoYamaoka:20210121202103p:plainf:id:TadaoYamaoka:20210121202106p:plainf:id:TadaoYamaoka:20210121202110p:plainf:id:TadaoYamaoka:20210121202113p:plainf:id:TadaoYamaoka:20210121202116p:plainf:id:TadaoYamaoka:20210121202119p:plainf:id:TadaoYamaoka:20210121202123p:plainf:id:TadaoYamaoka:20210121202126p:plain

訓練データすべてを学習後(SWAあり)、テストデータ全部で評価した結果は以下の通り。

テスト方策損失 テスト価値損失 テスト補助ターゲット損失
通常 0.93448454 0.39452724 -
補助ターゲットあり 0.94912096 0.39225218 1.08832618
テスト方策正解率 テスト価値正解率
通常 0.41683142 0.79983546
補助ターゲットあり 0.43945136 0.80143781
考察

補助ターゲットを追加することで、同一データを用いて学習した場合でも、方策と価値の両方で、精度が上がっている。

テストの補助ターゲット損失は、グラフを見ると低下していないため、千日手入玉宣言勝ちの予測の汎化性能は高くない。
ただし、出現頻度が少ないため、さらにデータを増やすことでテストでも精度が上がる可能性がある。

まとめ

千日手入玉宣言勝ちを補助ターゲットに加えることで、精度が上がることが確かめられた。
ONNXに出力する際は補助ターゲットを削除できるため、推論の速度には影響を与えることなく精度が向上できる。

今後、最大手数で引き分けになったかや、終局時に入玉しているかなどを補助ターゲットに加えるなども検証してみたい。