最近のLLMでは、AttentionにQKNormが使われている。
特に、RMSNormを使う実装が主流になっている。
世界コンピュータ将棋の会場でnshogiの開発者と話した際に、SwiGLUとQKNormが効果があったということだった。
SwiGLUは、dlshogiでも採用して効果が高いことを確認していたが、QKNormは試していなかったので試してみた。
QKNorm
Attentionの Query と Key を内積前に正規化して、attention logits のスケール暴走や softmax saturation を防ぎ、学習を安定化する手法である。
元の論文では、正規化にL2 normalizationが使用されている。
また、内積の結果に通常のの代わりに learnable scalar
を掛けている。
RMSNorm
入力ベクトルを平均ではなく RMS(root mean square)だけで正規化することで、スケール変化に対する不変性を保ちながら LayerNorm を軽量化した正規化手法である。
QKNormに、RMSNormを使うのが主流であり、LLaMAやGemmaでもRMSNormが使用されている。
Gemma4の実装では、ValueにもRMSNormが使用されている。
比較パターン
- QKNormにRMSNormを使用
- QKNormにRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
- QKNormに加えて、ValueにもRMSNormを使用
- QKNormに加えて、ValueにもRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
- QKNormにLayerNormを使用
- QKNormにBatchNorm2Dを使用
訓練条件
- WCSC36のdlshogiのResNet+Transformerモデルの20ブロック256フィルタ
- 訓練データ約3.9億局面
- バッチサイズ4096
- Momentum SGD
- 学習率0.04からエポックごとに半減
- 8エポック
評価データは、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面を使用。
シードを変えて、2回測定して平均をとる。
実験結果
| 方策損失 | 価値損失 | 方策正解率 | 価値正解率 | |
|---|---|---|---|---|
| WCSC36版 | 1.41530 | 0.46225 | 0.52857 | 0.76188 |
| QK Norm(RMS Norm) | 1.40991 | 0.46068 | 0.53044 | 0.76306 |
| QK Norm(RMS Norm + weight) | 1.41369 | 0.46050 | 0.52923 | 0.76342 |
| QK Norm(RMS Norm)+V RMSNorm | 1.41012 | 0.46052 | 0.52989 | 0.76334 |
| QK Norm(RMS Norm + weight)+V RMSNorm | 1.41564 | 0.46104 | 0.52966 | 0.76279 |
| QK Norm(Layer Norm) | 1.41171 | 0.46335 | 0.53010 | 0.76078 |
| QK Norm(Batch Norm) | 1.41429 | 0.46268 | 0.52927 | 0.76157 |


考察
QKNormにRMSNormを使用したパターンが最も方策損失が低く、方策正解率が高い。
価値損失は、QKNormにlearnable weightを使用した方が低いが僅差である。
learnable weight
方策損失はlearnable weightはない方が良い。
価値損失はlearnable weightがある方が、わずかに良くなっているが誤差程度である。
ValueにRMSNormを使うパターンでは、learnable weightがない方が方策、価値ともによい。
learnable weightはない方が良いと言える。
ValueのRMSNorm
ValueのRMSNormを適用した場合、方策、価値ともに条件により少し良くなったり、悪くなったりしており、大きな改善は見られない。
効果はなさそうである。
LayerNormとBatchNorm2D
方策、価値ともに、RMSNormの方が明らかに良い。
まとめ
dlshogiのResNet+Transformerモデルに、QKNormが効果があるか試してみた。
結果、QKNormにRMSNormを使用した場合、方策、価値ともに精度が改善することが確かめられた。