前回、dlshogiのPyTorch Lightning対応の基本部分を実装した。
今回は、EMAを実装する。
EMA
EMAは、Exponential Moving Averageの略で、重みの指数移動平均をとり平準化する方法である。
dlshogiでは、SWAと呼んでいたが、SWAで使われるスケジューラは使用しておらず、重みのエポックごとの平均化ではなく指数移動平均を計算している。
その場合、EMAという別の名称を使うことが一般的になっている。
PyTorchでも、こちらのissueでSWAのモジュールをEMAに対応させている。
ドキュメントにもEMAの記載が追加されている。
torch.optim — PyTorch 2.2 documentation
PyTorch LightningのEMA対応
PyTorch Lightningには、StochasticWeightAveragingというコールバックが実装されているが、SWAのスケジューラが強制されるため、EMAに対応させることができない。
そのため、自前でEMAの実装が必要である。
こちらのissueで、EMAの実装方法について議論されていたので、参考にして実装した。
PyTorchの公式ドキュメントにあるサンプルコードを元にして、LightningModuleの適切な箇所に処理を追加する。
EMAのサンプルコード
loader, optimizer, model, loss_fn = ... ema_model = torch.optim.swa_utils.AveragedModel(model, \ multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999)) for epoch in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() ema_model.update_parameters(model) # Update bn statistics for the ema_model at the end torch.optim.swa_utils.update_bn(loader, ema_model) # Use ema_model to make predictions on test data preds = ema_model(test_input)
ema_modelの初期化
ema_modelの初期化は、LightningModuleのコンストラクタに実装する。
def __init__( self, network="resnet10_relu", use_ema=False, update_bn=True, ema_start_epoch=1, ema_freq=250, ema_decay=0.9, ): super().__init__() self.save_hyperparameters() self.model = policy_value_network(network) if use_ema: self.ema_model = AveragedModel( self.model, multi_avg_fn=get_ema_multi_avg_fn(ema_decay) )
コンストラクタの引数
use_ema | EMAを使用するか |
update_bn | 最終エポックでBatch Normalizationの統計を更新するか |
ema_start_epoch | EMAの開始エポック |
ema_freq | パラメータ更新間隔 |
ema_decay | 指数減衰係数 |
EMAのパラメータ更新
EMAのパラメータ更新は、on_train_batch_end()で行う。
on_train_batch_end()がどのタイミングで呼ばれるかは、PyTorch Lightningの疑似コードを見ると理解できる。
on_train_batch_end()は、optimizer_step()の後に呼ばれるため、この位置が適切である。
def on_train_batch_end(self, outputs, batch, batch_idx): if ( self.hparams.use_ema and self.current_epoch >= self.hparams.ema_start_epoch and self.global_step % self.hparams.ema_freq == 0 ): self.ema_model.update_parameters(self.model)
update_bn
訓練中のBatch Normalizationの統計情報は、EMAモデルではなくオリジナルのパラメータに最適化されるため、EMAモデルで統計情報を計算し直す必要がある。
PyTorchにはそのためにupdate_bnという関数が用意されている。
これを訓練の最終エポックの後に実行する。
LightningModuleには、on_train_end()という訓練の終了後に呼び出されるメソッドを定義できるが、このタイミングで実行すると、チェックポイントに保存されないため、毎エポックの後に呼ばれるon_train_epoch_end()で、最終エポックかをチェックして実行する。
def on_train_epoch_end(self): if ( self.hparams.use_ema and self.hparams.update_bn and self.current_epoch == self.trainer.max_epochs - 1 and self.current_epoch >= self.hparams.ema_start_epoch ): def data_loader(): for x1, x2, _, _, _ in Tqdm( self.trainer.datamodule.train_dataloader(), desc="update_bn", dynamic_ncols=True, bar_format=TQDMProgressBar.BAR_FORMAT, ): yield {"x1": x1.to(self.device), "x2": x2.to(self.device)} forward_ = self.ema_model.forward self.ema_model.forward = lambda x: forward_(**x) with self.trainer.precision_plugin.train_step_context(): update_bn(data_loader(), self.ema_model) del self.ema_model.forward
dlshogiモデルのforward()は、引数が2つになっているが、update_bnは引数が1つの場合しか対応していないため、細工を行う必要がある。
forward()メソッドを一時的に置き換えて、引数を辞書型の一つの変数で受け取るようにし、辞書型を展開して元のforward()に渡すようにする。
update_bnでは、データローダから返されるリストかタプルの1つ目の変数がforward()に渡されるため、2つの変数を渡すためにいったん辞書型に格納する。
そのために、訓練用データローダをイテレータでラップして、辞書型に変換する処理を実装している。
update_bnには時間がかかるため、進捗状況を表示するようにしている。
PyTorch Lightningのtqdmを拡張したクラスを流用している。
EMAモデルの評価
PyTorch Lightningの仕組みでは、validationはon_train_epoch_end()の前に実行されるため、update_bnを実行した後のEMAモデルを評価するタイミングがない。
あまり適切ではないが、testサブコマンドで、use_empがTrueの場合は、EMAモデルを使用することで、EMAモデルを評価できるようにした。
def on_test_start(self): if self.hparams.use_ema: self.tmp_model = self.model self.model = self.ema_model return super().on_test_start() def test_step(self, batch, batch_idx): self.validation_step(batch, batch_idx) def on_test_epoch_end(self): for key, val in self.validation_step_outputs.items(): key = "test" + key[3:] self.log(key, torch.stack(val).mean()) val.clear() def on_test_end(self): super().on_test_end() if self.hparams.use_ema: self.model = self.tmp_model del self.tmp_model
一時的にモデルをEMAモデルに置き換える処理を行っている。
実行結果
更新間隔を5ステップ、減衰係数0.999で、EMAありとなしで比較した。
訓練時間
※オレンジがEMAなし、青がEMAあり
EMAありの方が少し訓練時間が長い。
なお、このグラフは評価の時間も含んでいる。
実験で使用した訓練データより評価データの方が多いため、評価に半分くらいの時間がかかっている。
訓練損失
訓練損失は、どちらもオリジナルのモデルの値のため違いはない。
評価
評価のメトリクスは、どちらもオリジナルのモデルの値のため大きな違いはない。
ばらつきがあるのは、訓練データのシャッフルの影響である。
まとめ
dlshogiのPyTorch Lightning対応にEMAを実装した。
以前のdlshogiと比べて、勾配クリッピング、モデルのエクスポート処理が不足しているので別途実装予定である。