TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

将棋AIの実験ノート:Attention Branch Network

以前に、Mask-Attentionについて記事にしたが、同様の手法にAttention Branch Network(ABN)がある。
ABNは、Attention Branchの損失も計算して訓練する点がMask-Attentionと異なる。
f:id:TadaoYamaoka:20210129213306p:plain
ここでは、ABNをdlshogiのネットワークに適用して、AIが注視している座標を可視化してみる。

ネットワーク構成

policyブランチとvalueブランチそれぞれに、Attentionブランチを追加する。
policyブランチとvalueブランチは、元のネットワークから変更していない。
f:id:TadaoYamaoka:20210129220935p:plain
policy attention mapとvalue attention mapは、2Dの画像になっており、AIが注視している箇所を表す。

損失

policy attentionブランチの損失はソフトマックス交差エントロピーvalue attentionブランチの損失はシグモイド交差エントロピーとして、元のネットワークの損失に加算したものを全体の損失とする。

訓練

dlshogiの強化学習で生成した60,679,983局面を使用して訓練した。

可視化

訓練したモデルを使用して、いくつかの局面で可視化してみた。

初期局面

f:id:TadaoYamaoka:20210129222532p:plain:w300

policy attention map

f:id:TadaoYamaoka:20210129223155p:plain:w300

value attention map

f:id:TadaoYamaoka:20210129223225p:plain:w300

policyの予測は「2六歩」と「7六歩」が高いが、policy attention mapはその位置の値が低くなっている。
高くなるなら分かりやすいが、低くなっているので、値が何を示しているのか良くわからない。

逆に、value attention mapは、「2六歩」と「7六歩」の位置の値が高くなっているので、なんとなく注視している箇所を示していそうである。

サンプル局面

f:id:TadaoYamaoka:20210129223852p:plain:w300

policy attention map

f:id:TadaoYamaoka:20210129224157p:plain:w300

value attention map

f:id:TadaoYamaoka:20210129224224p:plain:w300

policyの予測は「3三桂成」が高いが、policy attention mapはその位置の値が低くなっている。
やはり値が何を示しているのか良くわからない。

value attention mapは、「3三」の値が高くなっている。
玉の周りも値が高いので、注視箇所を示していそうである。

精度

ABNは可視化とパフォーマンス向上を両立できる点に利点がある。
SENetでは、チャネル方向にAttentionの導入してパフォーマンスを向上している。
ABNでは画素方向にAttentionを適用している。

元のネットワークとABNを追加したネットワークで、精度を比較した。

訓練損失
policy平均損失 value平均損失
元のネットワーク(resnet+swish) 0.68139506 0.38120323
ABN 0.63697344 0.37672683
テスト損失

floodgateからサンプリングした棋譜で評価した際のテスト損失

policy損失 value損失
元のネットワーク(resnet+swish) 0.97505525 0.54548708
ABN 0.97899158 0.54168929
テスト正解率
policy正解率 value正解率
元のネットワーク(resnet+swish) 0.42881633 0.70599075
ABN 0.43730732 0.71036645

policyのテスト損失は良くなっていないが、正解率は高くなっている。
valueのテスト損失、正解率ともに、元のネットワークよりも良くなっている。

ソース

Attention Branch Network · TadaoYamaoka/DeepLearningShogi@d62f7c7 · GitHub

まとめ

将棋のニューラルネットワークにABNを適用することで、注視している箇所の可視化ができるか試してみた。
policy attention mapについては、値の意味が良くわからなかったが、value attention mapについては注視している箇所を示しているかもしれない。

AIが注視している箇所を可視化することで、初心者がどのあたりに注目して形勢を判断すればよいかのヒントに使えないかと思って試してみたが、これが役に立つかは将棋がわからないとよくわからないのであった・・・。