TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】ResNet+Transformerモデルのアーキテクチャ改良

昨年の世界コンピュータ将棋選手権、電竜戦では、モデルサイズを大きくすることで強くすることを試みたが、精度は上がるものの探索速度が落ちることで強くできなかった

そこで、現在一番強い40ブロック512フィルタのモデルサイズで、モデルアーキテクチャを変えることで強くすることを試みた。

アーキテクチャ変更

電竜戦の後、様々なモデルアーキテクチャの変更を実験した。

Gated Attention

以前に、記事にしたGated Attentionはわずかに効果があるため取り入れた。
Attention層で、入力に応じた要素ごとの重みを乗じることで、重要でない駒の関係にSoftmaxの確率が割り当てられることを防ぐことができる。

SwiGLU FFN

FFNで、ゲートを学習可能にすることで、重要な情報のみを通すようにする。
LLMで標準的な手法となっている。

中間のブロックにTransformerを配置

角の効きを早期に捉えるために、10ブロック間隔で、ResNetをTransformerに置き換えるようにした。

相対位置バイアス

Transformerで角の効きを捉えるには、遮蔽している駒も合わせて捉える必要がある。
Transformerは、離れた位置の駒の関係を捉えることができるが、間にある駒は捉えることができない。
相対的な位置関係を相対位置バイアスで学習することで、FFNで位置関係を考慮することができ、遮蔽を捉えることができる。

絶対位置バイアス

入玉を正確に判断するには、自駒が相手の陣にいることを正確に識別する必要がある。
ResNetはプーリングを行っていないため、駒の座標に応じた表現はできているが、より明確にモデルに座標の情報を与えることにした。

ラージカーネル

ResNetは、通常効率を重視して3x3カーネルが用いられるが、5x5や7x7といったラージカーネルは精度改善に高い効果がある。
特に、入力側で、ラージカーネルを用いることで、早期にグローバルな特徴を捉えることができる。
dlshogiの入力層は、チャンネルサイズが中間層より小さいため、入力層のみをラージカーネルに置き換えることで、効率を大きく落とすことなく、精度を向上させることができた。

SEBlock

かなり前に実験して精度が上がることを確認していたが、探索速度が落ちて強くならなかったので採用していなかったが、5ブロック間隔で挿入することで十分な効果があることが確認できた。


他にも、group convや、Nested Bottleneckや、Pre Activationなど試したが、あまり効果がないか逆効果だったので採用しなかった。

訓練結果

訓練データ

アーキテクチャを改良したモデルで、40ブロック512フィルタのモデルを学習した。

訓練データは、60ブロック768フィルタのモデルから蒸留したデータと自己対局のデータ、NNUE系との対局データを使用した。
重複を平均化して、42.4億局面になった。

比較対象

以下と比較を行う。

  • pre59_40x512 : 改良前のResNet+Transformerモデル、40ブロック512フィルタ
  • pre59_50x640 : 改良前のResNet+Transformerモデル、50ブロック640フィルタ
  • pre66_60x768 : 改良前のResNet+Transformerモデル、60ブロック768フィルタ

今回学習したモデルは、pre68_40x512である。

なお、以前のモデルとは訓練データが異なるため、純粋なモデルアーキテクチャの比較にはなっていない。

精度


step val/policy_loss
pre59_40x512 9071084 1.253049373626709
pre59_50x640 9071084 1.2316398620605469
pre66_60x768 9160159 1.2185360193252563
pre68_40x512 9333303 1.2487926483154297


step val/value_loss
pre59_40x512 9071084 0.7068616151809692
pre59_50x640 9071084 0.7094608545303345
pre66_60x768 9160159 0.7174099683761597
pre68_40x512 9333303 0.7074635028839111


step val/result_loss
pre59_40x512 9071084 0.4324478209018707
pre59_50x640 9071084 0.42958199977874756
pre66_60x768 9160159 0.42668843269348145
pre68_40x512 9333303 0.43083781003952026


step val/policy_accuracy
pre59_40x512 9071084 0.5731422901153564
pre59_50x640 9071084 0.5795454382896423
pre66_60x768 9160159 0.5846914649009705
pre68_40x512 9333303 0.576049268245697


step val/value_accuracy
pre59_40x512 9071084 0.7828457355499268
pre59_50x640 9071084 0.7846994400024414
pre66_60x768 9160159 0.7863633632659912
pre68_40x512 9333303 0.7839466333389282


step val/policy_entropy
pre59_40x512 9071084 1.0936312675476074
pre59_50x640 9071084 1.076419711112976
pre66_60x768 9160159 1.0464400053024292
pre68_40x512 9333303 1.0613174438476562


step val/value_entropy
pre59_40x512 9071084 0.4637340009212494
pre59_50x640 9071084 0.4600807726383209
pre66_60x768 9160159 0.4531950354576111
pre68_40x512 9333303 0.462144672870636


floodgateの棋譜に対する評価精度は、改良前の40ブロックモデルと50ブロックモデルの中間くらいになっている。

パラメータ数

ONNXにした後、以下のスクリプトで算出したパラメータ数は以下の通り。

def count_parameters(model):
    total_params = 0
    for initializer in model.graph.initializer:
        dims = initializer.dims
        param_count = np.prod(dims)
        total_params += param_count
    return total_params
モデル パラメータ数
pre59_40x512 213891495
pre59_50x640 423872679
pre66_60x768 739767207
pre68_40x512 192373303

パラメータ数は、改良前の40ブロックモデルと比較して、89.9%になっている。
ResNetをTransformerに置き換えるブロックが増えているため、パラメータ数が減っている。
計算量は、Transformerの方が大きいため、探索速度はパラメータ数だけでは測れない。

探索速度

40ブロックモデルの、改良前後のNPSは以下の通り。
floodgateから抽出した100局面で、5回測定した平均。

改良前 改良後
平均 11659 10658
中央値 11702 10658
最大値 12821 16298
最小値 10779 9831

NPSは、平均で91.4%に低下している。

強さ

dlshogiの定跡の36手目から80手目から抽出した中終盤互角局面集で、持ち時間400秒、1手2秒加算で連続対局した結果は以下の通り。

# PLAYER          :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
1 hayabusa-8th    :    52.9   23.2   285.5     482    59      99  277   17  188     4
2 pre68_40x512    :     1.2   22.2   246.0     487    51     100  231   30  226     6
3 pre60_40x512    :   -54.2   22.6   200.5     495    41     ---  187   27  281     5

White advantage = 147.58 +/- 13.45
Draw rate (equal opponents) = 5.63 % +/- 0.91

hayabusa-8thは、基準のためリーグに加えた氷彗の8スレッド。
dlshogiは、H100 PCI、1GPU。

改良前のモデルに比べて、R+55.4になっている。

探索速度は落ちているが、精度が向上したことで、強くなった。

改良前のモデルは、ファインチューニングしたモデルだが、改良後はファインチューニング前のため、まだ強くできる余地がある。

まとめ

dlshogiのモデルアーキテクチャ改良を行った。
Gated AttentionやSwiGLU、Transformer配置、位置バイアス、ラージカーネルなどを導入したことで精度が改善した。
探索速度は低下したが、精度向上により最終的にレーティングが+55.4向上し、強くすることができた。

さらにファインチューニングすることで強くできるか確認したい。