TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiのPyTorch Lightning対応 その2(EMA)

前回、dlshogiのPyTorch Lightning対応の基本部分を実装した。

今回は、EMAを実装する。

EMA

EMAは、Exponential Moving Averageの略で、重みの指数移動平均をとり平準化する方法である。

dlshogiでは、SWAと呼んでいたが、SWAで使われるスケジューラは使用しておらず、重みのエポックごとの平均化ではなく指数移動平均を計算している。
その場合、EMAという別の名称を使うことが一般的になっている。

PyTorchでも、こちらのissueでSWAのモジュールをEMAに対応させている。
ドキュメントにもEMAの記載が追加されている。
torch.optim — PyTorch 2.2 documentation

PyTorch LightningのEMA対応

PyTorch Lightningには、StochasticWeightAveragingというコールバックが実装されているが、SWAのスケジューラが強制されるため、EMAに対応させることができない。

そのため、自前でEMAの実装が必要である。

こちらのissueで、EMAの実装方法について議論されていたので、参考にして実装した。

PyTorchの公式ドキュメントにあるサンプルコードを元にして、LightningModuleの適切な箇所に処理を追加する。

EMAのサンプルコード
loader, optimizer, model, loss_fn = ...
ema_model = torch.optim.swa_utils.AveragedModel(model, \
            multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
for epoch in range(300):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
          ema_model.update_parameters(model)
# Update bn statistics for the ema_model at the end
torch.optim.swa_utils.update_bn(loader, ema_model)
# Use ema_model to make predictions on test data
preds = ema_model(test_input)

ema_modelの初期化

ema_modelの初期化は、LightningModuleのコンストラクタに実装する。

    def __init__(
        self,
        network="resnet10_relu",
        use_ema=False,
        update_bn=True,
        ema_start_epoch=1,
        ema_freq=250,
        ema_decay=0.9,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = policy_value_network(network)
        if use_ema:
            self.ema_model = AveragedModel(
                self.model, multi_avg_fn=get_ema_multi_avg_fn(ema_decay)
            )

コンストラクタの引数

use_ema EMAを使用するか
update_bn 最終エポックでBatch Normalizationの統計を更新するか
ema_start_epoch EMAの開始エポック
ema_freq パラメータ更新間隔
ema_decay 指数減衰係数

EMAのパラメータ更新

EMAのパラメータ更新は、on_train_batch_end()で行う。
on_train_batch_end()がどのタイミングで呼ばれるかは、PyTorch Lightningの疑似コードを見ると理解できる。
on_train_batch_end()は、optimizer_step()の後に呼ばれるため、この位置が適切である。

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if (
            self.hparams.use_ema
            and self.current_epoch >= self.hparams.ema_start_epoch
            and self.global_step % self.hparams.ema_freq == 0
        ):
            self.ema_model.update_parameters(self.model)

update_bn

訓練中のBatch Normalizationの統計情報は、EMAモデルではなくオリジナルのパラメータに最適化されるため、EMAモデルで統計情報を計算し直す必要がある。
PyTorchにはそのためにupdate_bnという関数が用意されている。

これを訓練の最終エポックの後に実行する。
LightningModuleには、on_train_end()という訓練の終了後に呼び出されるメソッドを定義できるが、このタイミングで実行すると、チェックポイントに保存されないため、毎エポックの後に呼ばれるon_train_epoch_end()で、最終エポックかをチェックして実行する。

    def on_train_epoch_end(self):
        if (
            self.hparams.use_ema
            and self.hparams.update_bn
            and self.current_epoch == self.trainer.max_epochs - 1
            and self.current_epoch >= self.hparams.ema_start_epoch
        ):

            def data_loader():
                for x1, x2, _, _, _ in Tqdm(
                    self.trainer.datamodule.train_dataloader(),
                    desc="update_bn",
                    dynamic_ncols=True,
                    bar_format=TQDMProgressBar.BAR_FORMAT,
                ):
                    yield {"x1": x1.to(self.device), "x2": x2.to(self.device)}

            forward_ = self.ema_model.forward
            self.ema_model.forward = lambda x: forward_(**x)
            with self.trainer.precision_plugin.train_step_context():
                update_bn(data_loader(), self.ema_model)
            del self.ema_model.forward

dlshogiモデルのforward()は、引数が2つになっているが、update_bnは引数が1つの場合しか対応していないため、細工を行う必要がある。
forward()メソッドを一時的に置き換えて、引数を辞書型の一つの変数で受け取るようにし、辞書型を展開して元のforward()に渡すようにする。
update_bnでは、データローダから返されるリストかタプルの1つ目の変数がforward()に渡されるため、2つの変数を渡すためにいったん辞書型に格納する。
そのために、訓練用データローダをイテレータでラップして、辞書型に変換する処理を実装している。

update_bnには時間がかかるため、進捗状況を表示するようにしている。
PyTorch Lightningのtqdmを拡張したクラスを流用している。

EMAモデルの評価

PyTorch Lightningの仕組みでは、validationはon_train_epoch_end()の前に実行されるため、update_bnを実行した後のEMAモデルを評価するタイミングがない。

あまり適切ではないが、testサブコマンドで、use_empがTrueの場合は、EMAモデルを使用することで、EMAモデルを評価できるようにした。

    def on_test_start(self):
        if self.hparams.use_ema:
            self.tmp_model = self.model
            self.model = self.ema_model
        return super().on_test_start()

    def test_step(self, batch, batch_idx):
        self.validation_step(batch, batch_idx)

    def on_test_epoch_end(self):
        for key, val in self.validation_step_outputs.items():
            key = "test" + key[3:]
            self.log(key, torch.stack(val).mean())
            val.clear()

    def on_test_end(self):
        super().on_test_end()
        if self.hparams.use_ema:
            self.model = self.tmp_model
            del self.tmp_model

一時的にモデルをEMAモデルに置き換える処理を行っている。

実行結果

更新間隔を5ステップ、減衰係数0.999で、EMAありとなしで比較した。

訓練時間



※オレンジがEMAなし、青がEMAあり

EMAありの方が少し訓練時間が長い。
なお、このグラフは評価の時間も含んでいる。
実験で使用した訓練データより評価データの方が多いため、評価に半分くらいの時間がかかっている。

訓練損失

訓練損失は、どちらもオリジナルのモデルの値のため違いはない。

評価

評価のメトリクスは、どちらもオリジナルのモデルの値のため大きな違いはない。
ばらつきがあるのは、訓練データのシャッフルの影響である。

EMAモデルの評価

update_bnを実行した後のEMAモデルの評価結果をオリジナルモデルと比較した結果は以下の通り。

metric オリジナル EMAモデル
loss 3.988 4.063
policy_accuracy 0.2403 0.2301
value_accuracy 0.5949 0.5977
policy_entropy 2.774 3.139
value_entropy 0.5987 0.6248

EMAモデルの方がlossが少し高い。
今回実験した訓練データが少ないため、EMAの効果は確認できなかった。

まとめ

dlshogiのPyTorch Lightning対応にEMAを実装した。
以前のdlshogiと比べて、勾配クリッピング、モデルのエクスポート処理が不足しているので別途実装予定である。