以前に、Mask-Attentionについて記事にしたが、同様の手法にAttention Branch Network(ABN)がある。
ABNは、Attention Branchの損失も計算して訓練する点がMask-Attentionと異なる。
ここでは、ABNをdlshogiのネットワークに適用して、AIが注視している座標を可視化してみる。
ネットワーク構成
policyブランチとvalueブランチそれぞれに、Attentionブランチを追加する。
policyブランチとvalueブランチは、元のネットワークから変更していない。
policy attention mapとvalue attention mapは、2Dの画像になっており、AIが注視している箇所を表す。
損失
policy attentionブランチの損失はソフトマックス交差エントロピー、value attentionブランチの損失はシグモイド交差エントロピーとして、元のネットワークの損失に加算したものを全体の損失とする。
訓練
dlshogiの強化学習で生成した60,679,983局面を使用して訓練した。
可視化
訓練したモデルを使用して、いくつかの局面で可視化してみた。
初期局面
policy attention map
精度
ABNは可視化とパフォーマンス向上を両立できる点に利点がある。
SENetでは、チャネル方向にAttentionの導入してパフォーマンスを向上している。
ABNでは画素方向にAttentionを適用している。
元のネットワークとABNを追加したネットワークで、精度を比較した。
ソース
Attention Branch Network · TadaoYamaoka/DeepLearningShogi@d62f7c7 · GitHub
まとめ
将棋のニューラルネットワークにABNを適用することで、注視している箇所の可視化ができるか試してみた。
policy attention mapについては、値の意味が良くわからなかったが、value attention mapについては注視している箇所を示しているかもしれない。
AIが注視している箇所を可視化することで、初心者がどのあたりに注目して形勢を判断すればよいかのヒントに使えないかと思って試してみたが、これが役に立つかは将棋がわからないとよくわからないのであった・・・。