TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiのPyTorch Lightning対応

dlshogiの学習は、PyTorchを使用して、モデルの訓練処理を独自に実装していた。

マルチGPUによる分散学習に対応させようと考えているが、独自に実装するより、PyTorch lightningに対応させた方が実装が楽になるため、dlshogiをPyTorch Lightningに対応させたいと考えている。

まずは、訓練の基本部分の実装を行った。

PyTorch Lightning CLI

ボイラープレートをできるだけ削除するため、PyTorch Lightning CLIを使用して実装する。

PyTorch Lightning CLIを使用すると、コマンド引数のパース処理など含めて自動で行ってくれる。

起動部分の処理は以下のように記述するだけでよい。

def main():
    LightningCLI(Model, DataModule)


if __name__ == "__main__":
    main()

ハイパーパラメータ

ハイパーパラメータは、configファイルに記述し、ハイパーパラメータを簡単に変更できるようにし、実験条件をファイルとして残せるようにする。
オプティマイザーやLRスケジューラに何を使うかも指定できる。

config.yaml
# lightning.pytorch==2.2.0.post0
seed_everything: 0
trainer:
  max_epochs: 4
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: -1
        monitor: val/loss
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
model:
  network: resnet10_relu
  val_lambda: 0.333
data:
  train_files:
    - F:\hcpe3\floodgate_2022_0214_r3800_nomate.hcpe3
  val_files:
    - F:\hcpe3\floodgate.hcpe
  batch_size: 1024
  val_batch_size: 1024
  use_average: false
  use_evalfix: false
  temperature: 1.0
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1e-4
    betas:
    - 0.9
    - 0.999
    eps: 1e-08
    weight_decay: 1e-2
lr_scheduler:
  class_path: torch.optim.lr_scheduler.StepLR
  init_args:
    step_size: 1
    gamma: 0.5

データローダ

dlshogiは、データローダを独自に実装しているので、それをそのまま使用して、PyTorch Lightningで使用できるようにラッパーを記述する。
PyTorchのデータローダは、標準では、サンプル単位にデータセットからサンプルデータを取得する。
一方、dlshogiのデータローダは、Cythonを使用してミニバッチ単位でデータを作成している。

dlshogiのデータローダをサンプル単位の取得に変更すると、呼び出しのオーバーヘッドが大きくなるため、ミニバッチ単位はそのままにして、PyTorchのデータローダをミニバッチ単位の処理に対応させる。

それには、データローダのコンストラクタのcollate_fn引数にカスタムcollate関数を渡せばよい。
なお、collate_fn引数を指定しなかった場合は、デフォルトのcollate関数が使用され、データセットからサンプリングしたサンプルのリストをtorch.Tensorに変換する処理を行っている。

データセットに、__getitems__を実装することで、サンプラーがサンプリングしたインデックスのリストが渡され、ミニバッチデータを返却できる。
カスタムcollate関数では、__getitems__で返却したミニバッチデータが渡されるので、それをそのまま返せばよい。
__getitems__で、ミニバッチデータを作成する時点で、torch.Tensorにしておけば、カスタムcollate関数では何も処理をしなくてよい。

class Hcpe3Dataset(Dataset):
    def __getitems__(self, indexes):
        batch_size = len(indexes)
        indexes = np.array(indexes, dtype=np.uint32)

        features1 = torch.empty(
            (batch_size, FEATURES1_NUM, 9, 9), dtype=torch.float32, pin_memory=True
        )
        features2 = torch.empty(
            (batch_size, FEATURES2_NUM, 9, 9), dtype=torch.float32, pin_memory=True
        )
        probability = torch.empty(
            (batch_size, 9 * 9 * MAX_MOVE_LABEL_NUM),
            dtype=torch.float32,
            pin_memory=True,
        )
        result = torch.empty((batch_size, 1), dtype=torch.float32, pin_memory=True)
        value = torch.empty((batch_size, 1), dtype=torch.float32, pin_memory=True)

        cppshogi.hcpe3_decode_with_value(
            indexes,
            features1.numpy(),
            features2.numpy(),
            probability.numpy(),
            result.numpy(),
            value.numpy(),
        )

        return features1, features2, probability, result, value
def collate(data):
    return data

データモジュール

PyTorch Lightning CLIでは、データセットとデータローダをデータモジュールというクラスで管理する。
サブコマンドに応じた使用するデータセットの使い分けを記述する。

save_hyperparametersを使うことで、データモジュールのコンストラクタの引数をconfigファイルに記述できるようになる。

class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_files,
        val_files,
        batch_size=1024,
        val_batch_size=1024,
        use_average=False,
        use_evalfix=False,
        temperature=1.0,
        patch=None,
        cache=None,
    ):
        super().__init__()
        self.save_hyperparameters()

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.train_dataset = Hcpe3Dataset(
                self.hparams.train_files,
                self.hparams.use_average,
                self.hparams.use_evalfix,
                self.hparams.temperature,
                self.hparams.patch,
                self.hparams.cache,
            )
            self.val_dataset = HcpeDataset(self.hparams.val_files)

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage == "predict":
            self.test_dataset = HcpeDataset(self.hparams.val_files)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            collate_fn=collate,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.val_batch_size,
            collate_fn=collate,
        )

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.val_batch_size)

    def predict_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.val_batch_size)

訓練と評価の処理

訓練と評価の処理は、LightningModuleに実装する。
ループは記述する必要なく、訓練と評価の各ステップで実行する処理をそれぞれtraining_stepとvalidation_stepに記述する。

    def training_step(self, batch, batch_idx):
        features1, features2, probability, result, value = batch
        y1, y2 = self.model(features1, features2)
        loss1 = cross_entropy_loss_with_soft_target(y1, probability).mean()
        loss2 = bce_with_logits_loss(y2, result)
        loss3 = bce_with_logits_loss(y2, value)
        loss = (
            loss1
            + (1 - self.hparams.val_lambda) * loss2
            + self.hparams.val_lambda * loss3
        )
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        features1, features2, move, result, value = batch
        y1, y2 = self.model(features1, features2)
        loss1 = cross_entropy_loss(y1, move).mean()
        loss2 = bce_with_logits_loss(y2, result)
        loss3 = bce_with_logits_loss(y2, value)
        loss = (
            loss1
            + (1 - self.hparams.val_lambda) * loss2
            + self.hparams.val_lambda * loss3
        )
        self.validation_step_outputs["val/loss"].append(loss)

        self.validation_step_outputs["val/policy_accuracy"].append(accuracy(y1, move))
        self.validation_step_outputs["val/value_accuracy"].append(
            binary_accuracy(y2, result)
        )

        entropy1 = (-F.softmax(y1, dim=1) * F.log_softmax(y1, dim=1)).sum(dim=1)
        self.validation_step_outputs["val/policy_entropy"].append(entropy1.mean())

        p2 = y2.sigmoid()
        # entropy2 = -(p2 * F.log(p2) + (1 - p2) * F.log(1 - p2))
        log1p_ey2 = F.softplus(y2)
        entropy2 = -(p2 * (y2 - log1p_ey2) + (1 - p2) * -log1p_ey2)
        self.validation_step_outputs["val/value_entropy"].append(entropy2.mean())

評価は、評価データ全体の平均で行うため、validation_stepでは、評価メトリックをリストに追加しておき、on_validation_epoch_endで平均を出力する。

    def on_validation_epoch_end(self):
        for key, val in self.validation_step_outputs.items():
            self.log(key, torch.stack(val).mean())
            val.clear()


以上のように記述することで、訓練と評価の処理を実装できる。
dlshogiの独自実装と比較して、コードの記述量は半分くらいになっている。

実行

訓練は、サブコマンドfitにconfigファイルのパスをオプションに指定して実行する。

python ptl.py fit -c config.yaml
コンソール出力結果

コンソール出力結果は、以下の通り。

anaconda3\lib\site-packages\lightning\fabric\utilities\seed.py:40: No seed found, seed set to 0
Seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: dlshogi\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type               | Params
---------------------------------------------
0 | model | PolicyValueNetwork | 7.3 M
---------------------------------------------
7.3 M     Trainable params
0         Non-trainable params
7.3 M     Total params
29.386    Total estimated model params size (MB)
Sanity Checking: |                                                                                                         | 0/? [00:00<?, ?it/s]
anaconda3\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
anaconda3\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████| 254/254 [00:55<00:00,  4.59it/s, v_num=0]`Trainer.fit` stopped: `max_epochs=4` reached.                                                                                                     
Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████| 254/254 [00:56<00:00,  4.47it/s, v_num=0] 

DataLoaderに、num_workersを指定するように警告が表示されているが、ワーカーはマルチプロセスで起動されてプロセス間の通信が発生するため、指定しない方がよい。
特に、Windowsではプロセスがspawnされるため、ワーカーのプロセス間の通信がpickleによるデータ送信になるため、非常に遅い。
dlshogiのデータローダは内部でOpenMPでサンプル単位に並列化しているため、十分に速い。

TensorBoard

ログは、TensorBoardに出力される。
LRスケジューラのログ、訓練損失、評価メトリクスが確認できる。

チェックポイント

チェックポイントは、lightning_logs\version_0\checkpointsに保存される。
version_0の部分は、実行ごとにインクリメントされる。

チェックポイントは、エポックごとに保存するようにconfigで指定しているが、top_kを残すなど別の保存方法にもできる。

まとめ

dlshogiをPyTorch Lightning対応させた。
SWAの処理が実装できていないため、別途実装予定である。
StochasticWeightAveragingというコールバックが用意されているため、configだけで対応できるかもしれない。

マルチGPUによる分散学習が簡単に試せるようになるので、8GPUを使用した大規模将棋モデルの学習を試してみたい。