dlshogiの学習は、PyTorchを使用して、モデルの訓練処理を独自に実装していた。
マルチGPUによる分散学習に対応させようと考えているが、独自に実装するより、PyTorch lightningに対応させた方が実装が楽になるため、dlshogiをPyTorch Lightningに対応させたいと考えている。
まずは、訓練の基本部分の実装を行った。
PyTorch Lightning CLI
ボイラープレートをできるだけ削除するため、PyTorch Lightning CLIを使用して実装する。
PyTorch Lightning CLIを使用すると、コマンド引数のパース処理など含めて自動で行ってくれる。
起動部分の処理は以下のように記述するだけでよい。
def main(): LightningCLI(Model, DataModule) if __name__ == "__main__": main()
ハイパーパラメータ
ハイパーパラメータは、configファイルに記述し、ハイパーパラメータを簡単に変更できるようにし、実験条件をファイルとして残せるようにする。
オプティマイザーやLRスケジューラに何を使うかも指定できる。
config.yaml
# lightning.pytorch==2.2.0.post0 seed_everything: 0 trainer: max_epochs: 4 callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: save_top_k: -1 monitor: val/loss - class_path: lightning.pytorch.callbacks.LearningRateMonitor model: network: resnet10_relu val_lambda: 0.333 data: train_files: - F:\hcpe3\floodgate_2022_0214_r3800_nomate.hcpe3 val_files: - F:\hcpe3\floodgate.hcpe batch_size: 1024 val_batch_size: 1024 use_average: false use_evalfix: false temperature: 1.0 optimizer: class_path: torch.optim.AdamW init_args: lr: 1e-4 betas: - 0.9 - 0.999 eps: 1e-08 weight_decay: 1e-2 lr_scheduler: class_path: torch.optim.lr_scheduler.StepLR init_args: step_size: 1 gamma: 0.5
データローダ
dlshogiは、データローダを独自に実装しているので、それをそのまま使用して、PyTorch Lightningで使用できるようにラッパーを記述する。
PyTorchのデータローダは、標準では、サンプル単位にデータセットからサンプルデータを取得する。
一方、dlshogiのデータローダは、Cythonを使用してミニバッチ単位でデータを作成している。
dlshogiのデータローダをサンプル単位の取得に変更すると、呼び出しのオーバーヘッドが大きくなるため、ミニバッチ単位はそのままにして、PyTorchのデータローダをミニバッチ単位の処理に対応させる。
それには、データローダのコンストラクタのcollate_fn引数にカスタムcollate関数を渡せばよい。
なお、collate_fn引数を指定しなかった場合は、デフォルトのcollate関数が使用され、データセットからサンプリングしたサンプルのリストをtorch.Tensorに変換する処理を行っている。
データセットに、__getitems__を実装することで、サンプラーがサンプリングしたインデックスのリストが渡され、ミニバッチデータを返却できる。
カスタムcollate関数では、__getitems__で返却したミニバッチデータが渡されるので、それをそのまま返せばよい。
__getitems__で、ミニバッチデータを作成する時点で、torch.Tensorにしておけば、カスタムcollate関数では何も処理をしなくてよい。
class Hcpe3Dataset(Dataset): def __getitems__(self, indexes): batch_size = len(indexes) indexes = np.array(indexes, dtype=np.uint32) features1 = torch.empty( (batch_size, FEATURES1_NUM, 9, 9), dtype=torch.float32, pin_memory=True ) features2 = torch.empty( (batch_size, FEATURES2_NUM, 9, 9), dtype=torch.float32, pin_memory=True ) probability = torch.empty( (batch_size, 9 * 9 * MAX_MOVE_LABEL_NUM), dtype=torch.float32, pin_memory=True, ) result = torch.empty((batch_size, 1), dtype=torch.float32, pin_memory=True) value = torch.empty((batch_size, 1), dtype=torch.float32, pin_memory=True) cppshogi.hcpe3_decode_with_value( indexes, features1.numpy(), features2.numpy(), probability.numpy(), result.numpy(), value.numpy(), ) return features1, features2, probability, result, value
def collate(data): return data
データモジュール
PyTorch Lightning CLIでは、データセットとデータローダをデータモジュールというクラスで管理する。
サブコマンドに応じた使用するデータセットの使い分けを記述する。
save_hyperparametersを使うことで、データモジュールのコンストラクタの引数をconfigファイルに記述できるようになる。
class DataModule(pl.LightningDataModule): def __init__( self, train_files, val_files, batch_size=1024, val_batch_size=1024, use_average=False, use_evalfix=False, temperature=1.0, patch=None, cache=None, ): super().__init__() self.save_hyperparameters() def setup(self, stage: str): # Assign train/val datasets for use in dataloaders if stage == "fit": self.train_dataset = Hcpe3Dataset( self.hparams.train_files, self.hparams.use_average, self.hparams.use_evalfix, self.hparams.temperature, self.hparams.patch, self.hparams.cache, ) self.val_dataset = HcpeDataset(self.hparams.val_files) # Assign test dataset for use in dataloader(s) if stage == "test" or stage == "predict": self.test_dataset = HcpeDataset(self.hparams.val_files) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=collate, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.hparams.val_batch_size, collate_fn=collate, ) def test_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.hparams.val_batch_size) def predict_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.hparams.val_batch_size)
訓練と評価の処理
訓練と評価の処理は、LightningModuleに実装する。
ループは記述する必要なく、訓練と評価の各ステップで実行する処理をそれぞれtraining_stepとvalidation_stepに記述する。
def training_step(self, batch, batch_idx): features1, features2, probability, result, value = batch y1, y2 = self.model(features1, features2) loss1 = cross_entropy_loss_with_soft_target(y1, probability).mean() loss2 = bce_with_logits_loss(y2, result) loss3 = bce_with_logits_loss(y2, value) loss = ( loss1 + (1 - self.hparams.val_lambda) * loss2 + self.hparams.val_lambda * loss3 ) self.log("train/loss", loss) return loss def validation_step(self, batch, batch_idx): features1, features2, move, result, value = batch y1, y2 = self.model(features1, features2) loss1 = cross_entropy_loss(y1, move).mean() loss2 = bce_with_logits_loss(y2, result) loss3 = bce_with_logits_loss(y2, value) loss = ( loss1 + (1 - self.hparams.val_lambda) * loss2 + self.hparams.val_lambda * loss3 ) self.validation_step_outputs["val/loss"].append(loss) self.validation_step_outputs["val/policy_accuracy"].append(accuracy(y1, move)) self.validation_step_outputs["val/value_accuracy"].append( binary_accuracy(y2, result) ) entropy1 = (-F.softmax(y1, dim=1) * F.log_softmax(y1, dim=1)).sum(dim=1) self.validation_step_outputs["val/policy_entropy"].append(entropy1.mean()) p2 = y2.sigmoid() # entropy2 = -(p2 * F.log(p2) + (1 - p2) * F.log(1 - p2)) log1p_ey2 = F.softplus(y2) entropy2 = -(p2 * (y2 - log1p_ey2) + (1 - p2) * -log1p_ey2) self.validation_step_outputs["val/value_entropy"].append(entropy2.mean())
評価は、評価データ全体の平均で行うため、validation_stepでは、評価メトリックをリストに追加しておき、on_validation_epoch_endで平均を出力する。
def on_validation_epoch_end(self): for key, val in self.validation_step_outputs.items(): self.log(key, torch.stack(val).mean()) val.clear()
以上のように記述することで、訓練と評価の処理を実装できる。
dlshogiの独自実装と比較して、コードの記述量は半分くらいになっている。
実行
訓練は、サブコマンドfitにconfigファイルのパスをオプションに指定して実行する。
python ptl.py fit -c config.yaml
コンソール出力結果
コンソール出力結果は、以下の通り。
anaconda3\lib\site-packages\lightning\fabric\utilities\seed.py:40: No seed found, seed set to 0 Seed set to 0 GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision Missing logger folder: dlshogi\lightning_logs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params --------------------------------------------- 0 | model | PolicyValueNetwork | 7.3 M --------------------------------------------- 7.3 M Trainable params 0 Non-trainable params 7.3 M Total params 29.386 Total estimated model params size (MB) Sanity Checking: | | 0/? [00:00<?, ?it/s] anaconda3\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance. anaconda3\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance. Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████| 254/254 [00:55<00:00, 4.59it/s, v_num=0]`Trainer.fit` stopped: `max_epochs=4` reached. Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████| 254/254 [00:56<00:00, 4.47it/s, v_num=0]
DataLoaderに、num_workersを指定するように警告が表示されているが、ワーカーはマルチプロセスで起動されてプロセス間の通信が発生するため、指定しない方がよい。
特に、Windowsではプロセスがspawnされるため、ワーカーのプロセス間の通信がpickleによるデータ送信になるため、非常に遅い。
dlshogiのデータローダは内部でOpenMPでサンプル単位に並列化しているため、十分に速い。
TensorBoard
ログは、TensorBoardに出力される。
LRスケジューラのログ、訓練損失、評価メトリクスが確認できる。
チェックポイント
チェックポイントは、lightning_logs\version_0\checkpointsに保存される。
version_0の部分は、実行ごとにインクリメントされる。
チェックポイントは、エポックごとに保存するようにconfigで指定しているが、top_kを残すなど別の保存方法にもできる。
まとめ
dlshogiをPyTorch Lightning対応させた。
SWAの処理が実装できていないため、別途実装予定である。
StochasticWeightAveragingというコールバックが用意されているため、configだけで対応できるかもしれない。
マルチGPUによる分散学習が簡単に試せるようになるので、8GPUを使用した大規模将棋モデルの学習を試してみたい。