TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】QKNormを試す

最近のLLMでは、AttentionにQKNormが使われている。
特に、RMSNormを使う実装が主流になっている。

世界コンピュータ将棋の会場でnshogiの開発者と話した際に、SwiGLUとQKNormが効果があったということだった。
SwiGLUは、dlshogiでも採用して効果が高いことを確認していたが、QKNormは試していなかったので試してみた。

QKNorm

Attentionの Query と Key を内積前に正規化して、attention logits のスケール暴走や softmax saturation を防ぎ、学習を安定化する手法である。
元の論文では、正規化にL2 normalizationが使用されている。

また、内積の結果に通常の1 / \sqrt{d}の代わりに learnable scalar \alpha を掛けている。

RMSNorm

入力ベクトルを平均ではなく RMS(root mean square)だけで正規化することで、スケール変化に対する不変性を保ちながら LayerNorm を軽量化した正規化手法である。
QKNormに、RMSNormを使うのが主流であり、LLaMAGemmaでもRMSNormが使用されている。

Gemma4の実装では、ValueにもRMSNormが使用されている。

比較パターン

  • QKNormにRMSNormを使用
  • QKNormにRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
  • QKNormに加えて、ValueにもRMSNormを使用
  • QKNormに加えて、ValueにもRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
  • QKNormにLayerNormを使用
  • QKNormにBatchNorm2Dを使用

訓練条件

  • WCSC36のdlshogiのResNet+Transformerモデルの20ブロック256フィルタ
  • 訓練データ約3.9億局面
  • バッチサイズ4096
  • Momentum SGD
  • 学習率0.04からエポックごとに半減
  • 8エポック

評価データは、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面を使用。
シードを変えて、2回測定して平均をとる。

実験結果

方策損失 価値損失 方策正解率 価値正解率
WCSC36版 1.41530 0.46225 0.52857 0.76188
QK Norm(RMS Norm) 1.40991 0.46068 0.53044 0.76306
QK Norm(RMS Norm + weight) 1.41369 0.46050 0.52923 0.76342
QK Norm(RMS Norm)+V RMSNorm 1.41012 0.46052 0.52989 0.76334
QK Norm(RMS Norm + weight)+V RMSNorm 1.41564 0.46104 0.52966 0.76279
QK Norm(Layer Norm) 1.41171 0.46335 0.53010 0.76078
QK Norm(Batch Norm) 1.41429 0.46268 0.52927 0.76157


考察

QKNormにRMSNormを使用したパターンが最も方策損失が低く、方策正解率が高い。
価値損失は、QKNormにlearnable weightを使用した方が低いが僅差である。

learnable weight

方策損失はlearnable weightはない方が良い。
価値損失はlearnable weightがある方が、わずかに良くなっているが誤差程度である。
ValueにRMSNormを使うパターンでは、learnable weightがない方が方策、価値ともによい。

learnable weightはない方が良いと言える。

ValueのRMSNorm

ValueのRMSNormを適用した場合、方策、価値ともに条件により少し良くなったり、悪くなったりしており、大きな改善は見られない。
効果はなさそうである。

LayerNormとBatchNorm2D

方策、価値ともに、RMSNormの方が明らかに良い。

まとめ

dlshogiのResNet+Transformerモデルに、QKNormが効果があるか試してみた。
結果、QKNormにRMSNormを使用した場合、方策、価値ともに精度が改善することが確かめられた。