深層強化学習では、メインの学習タスクに加えて、補助タスクを同時学習することで、パフォーマンスを改善するということが行われている。
[1611.05397] Reinforcement Learning with Unsupervised Auxiliary Tasks
チェスAIの例
チェスAIのLeela Chess Zeroでは、
- 残り何手で終局か(Move Left)
が補助ターゲットとして使用されている。
通常、補助ターゲットは、学習時のみ使用して、推論時には利用しない。
ただし、Leela Chess ZeroのMove Leftは、ゲームをより短い手数で終了するために探索でも活用されている。
将棋での補助ターゲット
将棋は詰みによる終局は、直前まで駒損していても詰ませれば勝ちなので、終局時の駒得のような補助ターゲットは適切ではない。
入玉の有無はありかもしれない。
Leela Chess Zeroのような後何手で終局かは、将棋の詰み手順は、ディープラーニングが苦手としているので、強化学習時に最短手数で詰ませられないと正しく学習できないかもしれない。
データ作成が簡単で、効果がありそうな補助ターゲットとして、思いつくものとしては、
がある。
ここでは、この2つを補助ターゲットに追加して、効果があるかを検証してみる。
教師データ
強化学習で検証するには時間がかかるので、floodgateの棋譜を使用して教師ありで検証する。
2018年~2020年の棋譜から、レーティングが3000以上、80手以上の棋譜から、序盤は偏りがあるため30手目以降の局面を使用する。
最大手数に達した対局は除く。
訓練:テスト=9:1の割合で、棋譜単位で分割する。
棋譜数 | 局面数 | |
---|---|---|
訓練 | 133114 | 13971526 |
テスト | 14818 | 1555972 |
実装方法
ニューラルネットワーク
価値ヘッドの出力は1つで勝率を出力しているが、そこに千日手と入玉宣言勝ちに対応する出力ユニットを追加する。
# value network self.l22_v = nn.Conv2d(in_channels=k, out_channels=MAX_MOVE_LABEL_NUM, kernel_size=1, bias=False) self.l23_v = nn.Linear(9*9*MAX_MOVE_LABEL_NUM, fcl) self.l24_v = nn.Linear(fcl, 1) # sennichite, nyugyoku self.l24_aux = nn.Linear(fcl, 2)
損失
千日手と入玉宣言勝ちの出力の損失関数に、シグモイド交差エントロピーを使用する。
千日手と入玉宣言勝ちの損失平均を補助ターゲットの損失とする。
千日手と入玉宣言勝ちで終局する対局は、全体からすると割合が小さいため、正例に重み付けを行う。
PyTorchでは、BCEWithLogitsLossのpos_weightで重みを与えることができる。
学習データによって割合は異なるため、訓練データ中の千日手と入玉宣言勝ちの割合を調べて、動的に重みを与える。
pos_weight = torch.tensor([ len(train_data) / (train_data['result'] // 4 & 1).sum(), # sennichite len(train_data) / (train_data['result'] // 8 & 1).sum(), # nyugyoku ], dtype=torch.float32, device=device) bce_with_logits_loss_aux = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
通常の損失と補助ターゲットの損失で加重平均をとり、全体の損失とする。
補助ターゲットの重みは0.1とした。
学習結果
学習結果は以下のようになった。
※末尾に(aux)の付いた系列(オレンジ)が補助ターゲットありの結果
※テストデータは100ステップごとに640局面をサンプリング
訓練データすべてを学習後(SWAあり)、テストデータ全部で評価した結果は以下の通り。
テスト方策損失 | テスト価値損失 | テスト補助ターゲット損失 | |
---|---|---|---|
通常 | 0.93448454 | 0.39452724 | - |
補助ターゲットあり | 0.94912096 | 0.39225218 | 1.08832618 |
テスト方策正解率 | テスト価値正解率 | |
---|---|---|
通常 | 0.41683142 | 0.79983546 |
補助ターゲットあり | 0.43945136 | 0.80143781 |