TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiのPyTorch Lightning対応 その5(Warm-upに対応したスケジューラ)

大規模なモデルの学習に効果があるとされる学習率スケジューリングの手法にWarm-upがある。
しかし、Pytorchの標準のスケジューラには、Warm-upに対応したスケジューラが提供されていない。

PyTorch Lightning Boltsには、Warm-upに対応したCosineAnnealingLRがある。
Linear Warmup Cosine Annealing — Lightning-Bolts 0.7.0 documentation

まだレビュー中のステータスで、機能的にもリスタートや減衰には対応していない。

深層学習界隈でよく使われるtimmには、Warm-upの他にリスタート回数や減衰率の調整が可能なCosineAnnealingLR学習率スケジューラがある。
SGDR - Stochastic Gradient Descent with Warm Restarts | timmdocs

そこで、先日作成したdlshogiのPyTorch Lightning CLIの訓練スクリプトで、timmのCosineAnnealingLRを使おうと試そうとしたが、PyTorchのLRSchedulerを継承していないため、標準的な方法では使うことができなかった。

PyTorch Lightning CLIをカスタマイズすることで使えるようにはできるようだが、PyTorch Lightning CLIを使うメリットが薄れてくるので、できればあまりカスタマイズしたくない。
How to utilize timm's scheduler? · Issue #5555 · Lightning-AI/pytorch-lightning · GitHub

そこで、LRSchedulerを継承して、timmのCosineAnnealingLRと同じ動作をするスケジュールを自作することにした。

LRSchedulerを継承したtimmのCosineAnnealingLRと同等のスケジューラ

""" This code is based on the Cosine Learning Rate Scheduler implementation found at:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/scheduler/cosine_lr.py
"""

import math

from torch.optim.lr_scheduler import LRScheduler


class CosineLRScheduler(LRScheduler):
    def __init__(
        self,
        optimizer,
        t_initial,
        lr_min=0.0,
        cycle_mul=1.0,
        cycle_decay=1.0,
        cycle_limit=1,
        warmup_t=0,
        warmup_lr_init=0,
        warmup_prefix=False,
        k_decay=1.0,
        last_epoch=-1,
    ):
        self.t_initial = t_initial
        self.lr_min = lr_min
        self.cycle_mul = cycle_mul
        self.cycle_decay = cycle_decay
        self.cycle_limit = cycle_limit
        self.warmup_t = warmup_t
        self.warmup_lr_init = warmup_lr_init
        self.warmup_prefix = warmup_prefix
        self.k_decay = k_decay

        if last_epoch == -1:
            base_lrs = [group["lr"] for group in optimizer.param_groups]
        else:
            base_lrs = [group["initial_lr"] for group in optimizer.param_groups]
        if self.warmup_t:
            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in base_lrs]
        else:
            self.warmup_steps = [1 for _ in base_lrs]

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        t = self.last_epoch
        if t < self.warmup_t:
            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
        else:
            if self.warmup_prefix:
                t = t - self.warmup_t

            if self.cycle_mul != 1:
                i = math.floor(
                    math.log(
                        1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul
                    )
                )
                t_i = self.cycle_mul**i * self.t_initial
                t_curr = (
                    t - (1 - self.cycle_mul**i) / (1 - self.cycle_mul) * self.t_initial
                )
            else:
                i = t // self.t_initial
                t_i = self.t_initial
                t_curr = t - (self.t_initial * i)

            gamma = self.cycle_decay**i
            lr_max_values = [v * gamma for v in self.base_lrs]
            k = self.k_decay

            if i < self.cycle_limit:
                lrs = [
                    self.lr_min
                    + 0.5
                    * (lr_max - self.lr_min)
                    * (1 + math.cos(math.pi * t_curr**k / t_i**k))
                    for lr_max in lr_max_values
                ]
            else:
                lrs = [self.lr_min for _ in self.base_lrs]

        return lrs

timmの方では、最大の学習率をoptimizerから取得して、メンバ変数に保持しているが、PyTorchのスケジュールは、optimizerのinitial_lrにコピーを持たせているので、その作法に合わせている。

また、timmのスケジューラは、step()にステップ数を渡すようになっているが、PyTorchのスケジューラは、step()は引数なしでlast_epochから取得するようになっているため、そちらに合わせた。

更新間隔

PyTorch Lightningの標準のスケジューラの更新間隔は、epoch単位になっている。
将棋AIでは1epochが大きいステップ数になるため、Warm-upやCosineAnnealingLRを使う場合は、ステップ単位で更新したい。

PyTorch Lightningは、更新間隔を変更するオプションがないため、LightningCLIを継承してクラスメソッドのconfigure_optimizersをオーバーライドする必要がある。
Change the scheduler interval in CLI · Lightning-AI pytorch-lightning · Discussion #13975 · GitHub

class LightningCLI(cli.LightningCLI):
    @staticmethod
    def configure_optimizers(lightning_module, optimizer, lr_scheduler=None):
        if lr_scheduler is None:
            return optimizer
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                **(
                    {"monitor": lr_scheduler.monitor}
                    if isinstance(lr_scheduler, ReduceLROnPlateau)
                    else {}
                ),
            },
        }

動作確認

config.yamlに以下のように定義を行い動作確認を行った。
チェックポイントをリジュームして継続できるかも確認した。

lr_scheduler:
  class_path: dlshogi.lr_scheduler.CosineLRScheduler
  init_args:
    t_initial: 200
    lr_min: 1e-8
    cycle_mul: 2
    cycle_limit: 8
    cycle_decay: 0.5
    warmup_t: 100
    warmup_lr_init: 1e-7
    warmup_prefix: true

Tensorboardのグラフが50ステップ間隔なため、カクカクでわかりにくいが、warmup_tやcycle_mul、cycle_decayのパラメータが機能していることが確認できる。

まとめ

dlshogiにWarm-upに対応したスケジューラを追加した。
Warm-upにどれくらい効果があるかは別途検証したい。