TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

将棋AIの実験ノート:Normalizer-Free Networks

ちょうどFixup Initializationを試したタイミングで、DeepMindからBatch Normalizerを削除してSOTAを達成したという論文が発表された。

さっそく試してみようとしたが、ソースコードJaxで実装されているため、PyTorchで実装し直す必要がある。
deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub

はじめ自力で実装していたが、途中で非公式のPyTorchの実装が公開されたたため、そちらを大いに参考にさせてもらった。
pytorch-image-models/nfnet.py at master · rwightman/pytorch-image-models · GitHub

ひとまず実装できたので、現在のdlshogiのResNet(Batch Normalizerあり)と、Normalizer-Free ResNetsで比較してみた。

論文の要旨

Batch Normalizerの効果
  • 損失を滑らかにし、より大きな学習率とより大きなバッチサイズで安定した訓練を可能とする。
  • 正則化の効果もある。
Batch Normalizerの欠点
  • 驚くほど高価な計算プリミティブであり、メモリオーバーヘッドが発生し、一部のネットワークで勾配を評価するために必要な時間が大幅に増加する。
  • 訓練中と推論時のモデルの動作の間に不一致が生じ、調整が必要な隠れたハイパーパラメーターが導入される。
  • ミニバッチの訓練サンプル間の独立性を壊す。
提案手法
  • ネットワーク構成は、Normalizer-Free ResNets(NF-ResNets)をベースにする。残差ブロックの分散の増加に合わせてダウンスケールする。

[2101.08692] Characterizing signal propagation to close the performance gap in unnormalized ResNets

  • 残差ブロックの最後に、学習可能なスカラー(SkipInit)を含める。
  • 畳み込み層の重みの平均と分散により、再パラメータ化する「Scaled Weight Standardization」を導入する。
  • 従来の勾配クリッピングは、閾値の選択に敏感であるため、重みのノルムによる比を用いた「Adaptive Gradient Clipping (AGC) 」を導入する。

\displaystyle
\begin{equation}
G^{\ell}_i \rightarrow
    \begin{cases}
    
    \lambda \frac{\|W^{\ell}_i\|^\star_F}{\|G^{\ell}_i\|_F}G^{\ell}_i& \text{if $\frac{\|G^{\ell}_i\|_F}{\|W^{\ell}_i\|^\star_F} > \lambda$}, \\
    G^{\ell}_i & \text{otherwise.}
    \end{cases}
\end{equation}

実験方法

dlshogiの強化学習で生成した31,564,618局面で訓練し、floodgateからサンプリングした856,923局面で評価する。
MomentumSGD(lr=0.1)で学習する。
SWAはオフにする。
重み減衰は有効(rate=0.0001)にする。
論文の残差ブロックの畳み込み層は3層だが、dlshogiは2層なので2層に合わせる。
活性化関数は論文ではgeluだが、dlshogiに合わせてsiluを使用する。
論文では、AGCは、最終層には適用しないとあるので、policyヘッドと、valueヘッドの全結合層は除外する。

測定結果

f:id:TadaoYamaoka:20210220155549p:plainf:id:TadaoYamaoka:20210220155552p:plainf:id:TadaoYamaoka:20210220155555p:plainf:id:TadaoYamaoka:20210220155559p:plain

考察

Batch Normalizerを用いたResNetの方が、速く訓練損失が低下しており、テスト損失も低い。
NFNetsには、いくつかパラメータがあるため、いくつか変更して試してみたが、ほとんど結果は変わらなかった。

まとめ

Normalizer-Free Networksを、dlshogiの学習で試してみた。
論文ではImageNetの学習で、SOTAを達成したと報告されているが、将棋の学習では効果を確認できなかった。

画像のデータセットで実装したコードを使用してそもそも再現できるのかや、論文と残差ブロックの畳み込み層の数が異なるなど条件が異なる部分もあるため、そろえた場合にどうなるかは別途確認したい。

学習時間

学習時間は以下の通りであった。
PyTorchでAMPを使用している。

訓練時間
ResNet(Batch Normalizerあり) 1:16:34 100%
NFNets 1:09:05 90.2%