TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】ラージカーネル+Transformerモデルの学習

以前に検証したラージカーネルのモデルにTransformerを組み合わせたモデルの学習を行った。

実験段階では20ブロック256フィルタのモデルを使用したが、今回は11月末に行われる電竜戦向けに40ブロック512フィルタのモデルを学習した。

モデル構造

20ブロック256フィルタのモデルで事前検証を行い、精度が高くNPSが大きく下がらないモデル構造を選定した。

ラージカーネル

以前に検証した1x9, 9x1, 1x1の3つのカーネルを並列に並べたものを採用した。
これを5ブロックおきに配置した。

Transformer

標準的なマルチヘッドセルフアテンションを少し変更し、出力のプロジェクション層の前に活性化関数、後に正規化を加えた。
ResNetの構成に近くなりResNetブロックと組み合わせた場合に精度が上がることがわかった。これはLeVitのモデル構造を参考にした。

FFNは、標準的なTransformerと同様に4倍の次元とした。

ResNetの最終ブロックをTransformerブロックに置き換えた。

位置エンコーディング

相対位置エンコーディング、相対位置バイアスなど試したが、効果がなかったため、位置エンコーディングは使用しないことにした。
以前に考察した通り、dlshogiのモデルではプーリングを使用していないため、各座標の特徴量が位置情報を保持している。
そのため、明示的に位置情報をエンコードしなくても、入力に応じた相対的な位置を考慮できる。

入力層の正規化

dlshogiは、入力層に盤面の入力に3x3と1x1のカーネル、持ち駒など数値特徴量に1x1カーネルの畳み込みを使用して、それぞれの出力を加算してから正規化を行っていた。
この構成では、正規化層をレイヤー融合ができないため、加算前に正規化を行うように変更した。
これによりわずかにNPSが改善する。

学習

訓練データとして、hcpe3形式で重複局面を平均化した26億局面を使用した。

バッチサイズ4096で、12エポック学習した。

学習率スケジューラは、Cosineスケジューラだと高い学習率から始めると発散したため、StepLRSchedulerでエポックごとに1/2にした。
Transformerの学習を安定させるため、Warmupも行った。

AMPの混合精度は、float16だとTransformerの学習で損失がNaNになりやすいため、bfloat16で学習した。

学習結果

比較のために、Transformerありとなしのモデルを学習した。
また、訓練データが少しことなるが、以前に学習した40ブロック512フィルタの通常のResNetブロックとも比較する。

結果グラフのラベルの意味は以下の通り。
pre54 : 以前のResNetモデル
pre55 : ラージカーネル+Transformer
pre56 : ラージカーネルのみ

方策損失


価値損失


方策正解率


価値正解率


最終的な精度
モデル 方策損失 価値損失 方策正解率 価値正解率
pre54 1.266429 0.434222 0.569292 0.780081
pre55 1.255327 0.433520 0.571877 0.781784
pre56 1.256707 0.433032 0.571802 0.782097

今回学習したモデル(pre55, pre56)はどちらも以前のResNet(pre54)よりも、方策、価値ともに精度が上がっている。
ラージカーネル+Transformerのモデル(pre55)は、方策正解率は約0.25%、価値正解率は約0.17%向上している。
ラージカーネルのみのモデルの方が正解率はわずかに高い。

NPS

floodgateから抽出した100局面で、4回測定した際のNPSの統計量は以下の通り。
参考として、30ブロック384フィルタのモデル(pre44)も記載する。
RTX 4090で測定した。

pre54 pre55 pre56 pre44(参考)
平均 8008 8766 8682 19221
中央値 7982 8764 8667 19270
最大 9392 10140 9946 20858
最小 7596 8293 8147 17920

同じ40ブロック512フィルタのモデルで、ラージカーネル+TransformerのモデルがNPSが最も高かった。

以前に20ブロック256フィルタで実験した際は、1x9,9x9,1x1のラージカーネルは、3x3カーネルよりも少しNPSが下がったが、512フィルタのモデルでは逆にNPSが改善した。
今回は5ブロック間隔で配置したが、より積極的に使ってもよさそうである。

また、最終ブロックをTransformerに置き換えたモデル(pre55)の方が、ラージカーネルのみのモデルよりNPSが高くなった。
事前の検証ではTransformerにより少しNPSが落ちていたが、これも512フィルタのモデルでは逆にNPSが改善した。
Transformerブロックもより積極的に使ってもよさそうである。

強さ

同一持ち時間

互角局面集を使用して、持ち時間を同一とした場合の強さは以下の通り。
持ち時間は、400秒1手ごとに2秒加算とした。
H100で測定した。

基準ソフトとして、NNUE系もリーグに加えているが互角になるように持ち時間を調整しているため、以下の結果からは除外している。

# PLAYER  :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
1 pre55   :    18.1   15.6   832.0    1632    51      50  728  208  696    13
2 pre54   :    18.0   20.3   467.5     935    50      80  414  107  414    11
3 pre56   :     3.9   25.5   264.0     553    48      76  230   68  255    12
4 pre44   :   -10.5   26.4   251.5     550    46      90  210   83  257    15

ラージカーネル+Transformerのモデルが最も勝率が高いが、有意差があるほどではない。

以前のResNetモデル(pre54)よりも明確に強くなることを期待していたが、ほとんど強くなっていない。

30ブロック384フィルタのモデルよりは有意に強くなっている。

同一ノード数

互角局面集を使用して、ノード数を固定した場合の強さは以下の通り。
ノード数は、50万,60万,70万,80万の4パターンで対局した。

# PLAYER  :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
1 pre55   :    25.4   28.5   213.0     390    55      83  181   64  145    16
2 pre54   :     1.2   36.6   108.5     220    49      54   96   25   99    11
3 pre56   :    -2.7   51.5    55.5     113    49      55   48   15   50    13
4 pre44   :   -17.0   43.9    78.5     170    46     ---   66   25   79    15

ラージカーネル+Transformerのモデルが最も勝率が高いが、対局数が十分でなく有意差があるとは言えない。

以前のResNetモデル(pre54)よりも、ラージカーネルを加えたモデル(pre56)の勝率が低く、精度の向上が強さに反映されていない。

今回、強さの測定に互角局面集を用いたが、pre55,pre56の訓練データは、dlshogiの定跡で互角の局面を初期局面とした棋譜を増やしている。
そのため、互角局面集を用いた測定で精度が強さに反映されなかったのかもしれない。

まとめ

ラージカーネルとTransformerを使用した40ブロック512フィルタのモデルを学習した。
以前のResNetモデルと比べて、精度とNPSが向上することが確認できた。
しかし、互角局面集で強さを測定したところ有意に強くなったとは言えない。

モデル構造の検証にかなり時間をかけたが、モデル構造の工夫ではあまり強くならなかった。
ただし、精度は通常のResNetよりも早く上昇するため、収束するまで学習できない場合、同一訓練時間では強くなる可能性がある。

今後は、モデルサイズとデータを増やす方向で強くできるか検討したい。
また、終盤はNPSが高い方が有利の可能性もあるため、蒸留したモデルも学習したい。