TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Gumbel dlshogiを作る その4(データローダと訓練の実装)

前回までで、Gumbel AlphaZeroの自己対局で訓練データを生成できるようになったので、今回は学習処理を実装する。

訓練データ

自己対局で、以下のようなNumpyの構造体の配列をバイナリファイルとして出力する。
ランダムアクセスが可能なように固定長にしている。

dtypeTrainingData = np.dtype(
    [
        ("hcp", HuffmanCodedPos),
        ("policy", np.dtype((np.float32, MOVE_LABELS_NUM))),
        ("result", np.uint8),
    ]
)

データローダ

データローダは、PyTorchのDatasetとDataLoaderを使用して実装する。
複数ワーカに対応して、Pythonでの特徴量作成処理がボトルネックにならないようにする。

複数のワーカがそれぞれデータをメモリに読み込むと大量にメモリを消費するため、np.memmapを使用してメモリマップト配列として読み込む。
固定長なので、配列のインデックスを指定してランダムアクセスができる。

自己対局で生成したデータは、新しいものから引数で与えられたnum_filesファイル分だけ使用する。

class TrainingDataset(Dataset):
    def __init__(self, data_dir, num_files=None, worker_id=None, num_workers=None):
        """
        Args:
            data_dir: Directory containing .data files
            num_files: Number of newest files to use (None for all)
            worker_id: Worker ID for multi-worker setup
            num_workers: Total number of workers
        """
        self.data_dir = data_dir
        self.num_files = num_files
        self.worker_id = worker_id or 0
        self.num_workers = num_workers or 1

        # Find and sort data files by modification time (newest first)
        pattern = os.path.join(data_dir, "*.data")
        files = glob.glob(pattern)
        files.sort(key=os.path.getmtime, reverse=True)

        # Select specified number of files
        if num_files is not None:
            files = files[:num_files]

        self.files = files
        self.file_offsets = []
        self.total_samples = 0

        # Calculate file offsets and total samples
        for file_path in self.files:
            file_size = os.path.getsize(file_path)
            num_samples = file_size // dtypeTrainingData.itemsize
            self.file_offsets.append((self.total_samples, num_samples, file_path))
            self.total_samples += num_samples

        # For multi-worker setup, divide data among workers
        if self.num_workers > 1:
            samples_per_worker = self.total_samples // self.num_workers
            self.start_idx = self.worker_id * samples_per_worker
            if self.worker_id == self.num_workers - 1:
                # Last worker takes remaining samples
                self.end_idx = self.total_samples
            else:
                self.end_idx = (self.worker_id + 1) * samples_per_worker
            self.worker_total_samples = self.end_idx - self.start_idx
        else:
            self.start_idx = 0
            self.end_idx = self.total_samples
            self.worker_total_samples = self.total_samples

        # Create memory maps for files
        self._mmaps = {}

        self._board = None

    def _get_mmap(self, file_path):
        """Get or create memory map for a file"""
        if file_path not in self._mmaps:
            self._mmaps[file_path] = np.memmap(
                file_path, dtype=dtypeTrainingData, mode="r"
            )
        return self._mmaps[file_path]

    def _find_file_and_offset(self, global_idx):
        """Find which file contains the sample at global_idx"""
        for file_start, file_samples, file_path in self.file_offsets:
            if global_idx < file_start + file_samples:
                return file_path, global_idx - file_start
        raise IndexError(f"Index {global_idx} out of range")

    def _get_board(self):
        """Get or create a Board instance for HCP conversion"""
        if self._board is None:
            self._board = Board()
        return self._board

    def __len__(self):
        return self.worker_total_samples

    def __getitem__(self, idx):
        if idx >= self.worker_total_samples:
            raise IndexError(f"Index {idx} out of range for worker {self.worker_id}")

        # Convert worker-local index to global index
        global_idx = self.start_idx + idx

        # Find file and local offset
        file_path, local_offset = self._find_file_and_offset(global_idx)

        # Get memory map and read sample
        mmap = self._get_mmap(file_path)
        sample = mmap[local_offset]

        # Convert HCP to input features
        board = self._get_board()
        board.set_hcp(np.asarray(sample["hcp"]))

        features = np.empty((FEATURES_NUM, 9, 9), dtype=np.float32)
        make_input_features(board, features)

        match sample["result"]:
            case 1:  # BLACK_WIN
                result = 1 if board.turn == BLACK else 0
            case 2:  # WHITE_WIN
                result = 1 if board.turn == WHITE else 0
            case _:  # DRAW
                result = 0.5

        # Convert to tensors
        features_tensor = torch.from_numpy(features)
        policy_tensor = torch.from_numpy(sample["policy"].copy())
        result_tensor = torch.tensor(result, dtype=torch.float32)

        return features_tensor, policy_tensor, result_tensor

    def __del__(self):
        # Clean up memory maps
        for mmap in self._mmaps.values():
            if hasattr(mmap, "_mmap"):
                mmap._mmap.close()


def worker_init_fn(worker_id):
    """
    Worker initialization function for DataLoader.
    Re-initializes the dataset for each worker process.
    """
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # The dataset instance in this worker.

    worker_dataset = TrainingDataset(
        data_dir=dataset.data_dir,
        num_files=dataset.num_files,
        worker_id=worker_id,
        num_workers=dataset.num_workers,
    )

    # Copy the attributes from the new dataset to the one in the worker
    dataset.files = worker_dataset.files
    dataset.file_offsets = worker_dataset.file_offsets
    dataset.total_samples = worker_dataset.total_samples
    dataset.worker_id = worker_dataset.worker_id
    dataset.start_idx = worker_dataset.start_idx
    dataset.end_idx = worker_dataset.end_idx
    dataset.worker_total_samples = worker_dataset.worker_total_samples
    dataset._mmaps = {}  # Each worker needs its own memory maps
    dataset._board = None  # Reset the board for each worker


def create_dataloader(
    data_dir,
    batch_size=32,
    num_files=None,
    num_workers=0,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
    **kwargs,
):
    """
    Create a DataLoader for training data

    Args:
        data_dir: Directory containing .data files
        batch_size: Batch size
        num_files: Number of newest files to use (None for all)
        num_workers: Number of worker processes
        shuffle: Whether to shuffle data
        pin_memory: Whether to pin memory for GPU transfer
        **kwargs: Additional arguments for DataLoader

    Returns:
        DataLoader instance
    """

    if num_workers > 0:
        # Create a dummy dataset for the main process
        dataset = TrainingDataset(data_dir, num_files)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
            worker_init_fn=worker_init_fn,
            **kwargs,
        )
    else:
        dataset = TrainingDataset(data_dir, num_files)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            pin_memory=pin_memory,
            drop_last=drop_last,
            **kwargs,
        )

    return dataloader

テスト用のデータは、dlshogiと同様にhcpe形式のファイルを使用する。
hcpe形式のデータローダも同様に定義する。

訓練処理

PyTorchの標準的な方法で学習する。

ポリシーの損失関数は、CrossEntropyLossを使用してGumble AlphaZeroの改善されたポリシーの分布を学習する。
価値の損失は、BCEWithLogitsLossを使用して、勝ち/負け/引き分けをそれぞれ1, 0, 0.5として学習する。

ニューラルネットワークの定義は、python-dlshogi2を流用する。

オプションでAMP(Automatic Mixed Precision)に対応する。

tqdmで、進捗状況を表示する。

def train_epoch(model, dataloader, optimizer, scaler, device, amp_enabled=True):
    """Train for one epoch"""
    model.train()

    policy_criterion = nn.CrossEntropyLoss()
    value_criterion = nn.BCEWithLogitsLoss()

    total_loss = 0.0
    policy_loss_sum = 0.0
    value_loss_sum = 0.0
    num_batches = 0

    with tqdm(dataloader, desc="Training", unit="batch") as pbar:
        for features, policy_targets, value_targets in pbar:
            features = features.to(device)
            policy_targets = policy_targets.to(device)
            value_targets = value_targets.to(device)

            optimizer.zero_grad()

            with autocast(enabled=amp_enabled):
                policy_output, value_output = model(features)

                # Policy loss - targets are probability distributions
                policy_loss = policy_criterion(policy_output, policy_targets)

                # Value loss - targets are game results
                value_loss = value_criterion(value_output.squeeze(-1), value_targets)

                total_batch_loss = policy_loss + value_loss

            scaler.scale(total_batch_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += total_batch_loss.item()
            policy_loss_sum += policy_loss.item()
            value_loss_sum += value_loss.item()
            num_batches += 1

            # Update progress bar
            pbar.set_postfix(
                {
                    "Loss": f"{total_batch_loss.item():.4f}",
                    "Policy": f"{policy_loss.item():.4f}",
                    "Value": f"{value_loss.item():.4f}",
                }
            )

    return {
        "total_loss": total_loss / num_batches,
        "policy_loss": policy_loss_sum / num_batches,
        "value_loss": value_loss_sum / num_batches,
    }


def evaluate(model, dataloader, device, amp_enabled=True):
    """Evaluate the model"""
    model.eval()

    policy_criterion = nn.CrossEntropyLoss()
    value_criterion = nn.BCEWithLogitsLoss()

    total_loss = 0.0
    policy_loss_sum = 0.0
    value_loss_sum = 0.0
    policy_correct = 0
    value_correct = 0
    total_samples = 0

    with torch.no_grad():
        with tqdm(dataloader, desc="Evaluating", unit="batch") as pbar:
            for features, policy_targets, value_targets in pbar:
                features = features.to(device)
                policy_targets = policy_targets.to(device)
                value_targets = value_targets.to(device)

                with autocast(enabled=amp_enabled):
                    policy_output, value_output = model(features)

                    # For test data, policy_targets are labels (long), not distributions
                    policy_loss = policy_criterion(policy_output, policy_targets)
                    # Calculate accuracy for classification
                    policy_pred = torch.argmax(policy_output, dim=1)
                    policy_correct += (policy_pred == policy_targets).sum().item()

                    value_loss = value_criterion(
                        value_output.squeeze(-1), value_targets
                    )

                    # Value accuracy (threshold at 0.5)
                    value_pred = torch.sigmoid(value_output.squeeze(-1))
                    value_binary_pred = (value_pred > 0.5).float()
                    value_binary_target = (value_targets > 0.5).float()
                    value_correct += (
                        (value_binary_pred == value_binary_target).sum().item()
                    )

                    total_batch_loss = policy_loss + value_loss

                total_loss += total_batch_loss.item()
                policy_loss_sum += policy_loss.item()
                value_loss_sum += value_loss.item()
                total_samples += features.size(0)

                # Update progress bar
                pbar.set_postfix(
                    {
                        "Loss": f"{total_batch_loss.item():.4f}",
                        "Policy Acc": f"{policy_correct/total_samples:.4f}",
                        "Value Acc": f"{value_correct/total_samples:.4f}",
                    }
                )

    num_batches = len(dataloader)

    return {
        "total_loss": total_loss / num_batches,
        "policy_loss": policy_loss_sum / num_batches,
        "value_loss": value_loss_sum / num_batches,
        "policy_accuracy": policy_correct / total_samples,
        "value_accuracy": value_correct / total_samples,
    }


def save_checkpoint(model, optimizer, scaler, epoch, loss, filepath):
    """Save training checkpoint"""
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "loss": loss,
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")


def load_checkpoint(filepath, model, optimizer, scaler, device):
    """Load training checkpoint"""
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scaler.load_state_dict(checkpoint["scaler_state_dict"])
    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    print(f"Checkpoint loaded from {filepath}, epoch {epoch}, loss {loss:.4f}")
    return epoch, loss


def train(args):
    """Main training function"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create model
    model = PolicyValueNetwork(
        blocks=args.blocks, channels=args.channels, fcl=args.fcl
    ).to(device)

    # Create optimizer
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    # Create gradient scaler for mixed precision
    scaler = GradScaler(enabled=args.amp)

    # Create data loaders
    train_dataloader = create_dataloader(
        args.train_dir,
        batch_size=args.batch_size,
        num_files=args.num_files,
        num_workers=args.num_workers,
        shuffle=True,
    )

    test_dataloader = None
    if args.test_file:
        test_dataloader = create_test_dataloader(
            args.test_file,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            shuffle=False,
        )

    # Resume from checkpoint if specified
    start_epoch = 0
    if args.resume:
        start_epoch, _ = load_checkpoint(args.resume, model, optimizer, scaler, device)
        start_epoch += 1

    # Create checkpoint directory
    checkpoint_dir = Path(args.checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print(f"Epoch {epoch + 1}/{args.epochs}")

        # Training
        start_time = time.time()
        train_metrics = train_epoch(
            model, train_dataloader, optimizer, scaler, device, args.amp
        )
        train_time = time.time() - start_time

        print(
            f"Train - Loss: {train_metrics['total_loss']:.4f}, "
            f"Policy Loss: {train_metrics['policy_loss']:.4f}, "
            f"Value Loss: {train_metrics['value_loss']:.4f}, "
            f"Time: {train_time:.2f}s"
        )

        # Evaluation
        if test_dataloader and (epoch + 1) % args.eval_interval == 0:
            start_time = time.time()
            eval_metrics = evaluate(model, test_dataloader, device, args.amp)
            eval_time = time.time() - start_time

            print(
                f"Eval - Loss: {eval_metrics['total_loss']:.4f}, "
                f"Policy Loss: {eval_metrics['policy_loss']:.4f}, "
                f"Value Loss: {eval_metrics['value_loss']:.4f}, "
                f"Policy Acc: {eval_metrics['policy_accuracy']:.4f}, "
                f"Value Acc: {eval_metrics['value_accuracy']:.4f}, "
                f"Time: {eval_time:.2f}s"
            )

        # Save checkpoint
        if (epoch + 1) % args.save_interval == 0:
            checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch + 1}.pth"
            save_checkpoint(
                model,
                optimizer,
                scaler,
                epoch,
                train_metrics["total_loss"],
                checkpoint_path,
            )

    # Save final model
    final_model_path = checkpoint_dir / "final_model.pth"
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")

動作確認

python-dlshogi2の学習済みモデルを使用して自己対局で10万局面生成して、訓練できるか確認した。

自己対局

バッチサイズ64で、シミュレーション数32で、10万局面生成に1時間14分かかった。

100023pos [1:14:49, 22.28pos/s]
訓練結果

バッチサイズ64で、4エポック学習した結果は以下の通り。
評価データには2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面(重複なし)を使用した。

Epoch 1/4
Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.94batch/s, Loss=3.8103, Policy=3.2766, Value=0.5337] 
Train - Loss: 4.3428, Policy Loss: 3.6886, Value Loss: 0.6542, Time: 40.12s
Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:09<00:00, 103.32batch/s, Loss=4.6139, Policy Acc=0.2163, Value Acc=0.5180] 
Eval - Loss: 4.3369, Policy Loss: 3.6071, Value Loss: 0.7298, Policy Acc: 0.2163, Value Acc: 0.5180, Time: 129.61s
Epoch 2/4
Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.37batch/s, Loss=2.8277, Policy=2.4360, Value=0.3917] 
Train - Loss: 3.2876, Policy Loss: 2.7699, Value Loss: 0.5177, Time: 40.72s
Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:10<00:00, 102.96batch/s, Loss=5.1747, Policy Acc=0.2416, Value Acc=0.5236] 
Eval - Loss: 4.1800, Policy Loss: 3.1499, Value Loss: 1.0302, Policy Acc: 0.2416, Value Acc: 0.5236, Time: 130.05s
Epoch 3/4
Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.87batch/s, Loss=2.5337, Policy=2.1820, Value=0.3517]
Train - Loss: 2.8177, Policy Loss: 2.4598, Value Loss: 0.3579, Time: 40.19s
Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:09<00:00, 103.18batch/s, Loss=5.2742, Policy Acc=0.2576, Value Acc=0.5145]
Eval - Loss: 4.4099, Policy Loss: 3.0620, Value Loss: 1.3479, Policy Acc: 0.2576, Value Acc: 0.5145, Time: 129.78s
Epoch 4/4
Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.42batch/s, Loss=2.3076, Policy=2.0778, Value=0.2298]
Train - Loss: 2.5607, Policy Loss: 2.2704, Value Loss: 0.2903, Time: 40.66s
Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:10<00:00, 102.85batch/s, Loss=6.4862, Policy Acc=0.2696, Value Acc=0.5170]
Eval - Loss: 4.6274, Policy Loss: 2.9429, Value Loss: 1.6845, Policy Acc: 0.2696, Value Acc: 0.5170, Time: 130.19s
Final model saved to checkpoints\final_model.pth

ポリシーの正解率は26.96%、価値の正解率は51.70%となった。
ポリシーは、10万局面の結果としては妥当だと思う。
価値はあまり学習できていないのは、Gumble AlphaZeroはランダム性が強いためかもしれない。
バグの可能性もあるので棋譜を増やして確認してみる必要がある。

ワーカー数

バッチサイズ64ではワーカーを増やしても効果なかった。
バッチサイズ1024にするとワーカーを2に増やすと、速くなった。
4にしても変わらなかったので、バッチサイズに応じてワーカー数は調整する必要がある。

ワーカー1:9.65batch/s
ワーカー2:17.71batch/s
ワーカー4:17.91batch/s

まとめ

自己対局によって生成した固定長バイナリ訓練データを、PyTorchのDataLoaderとnp.memmapで効率的に読み込む仕組みを構築した。
訓練処理の動作確認を行い、ポリシーを学習できることを確認した。
価値の正解率があまり上がらなかったため、バグがないか追加で確認を行いたい。

自己対局の生成に時間がかかるため、高速化を考えたい。
わかりやすさは優先したいため、現状の処理のままマルチプロセスでGPUの推論のみ排他制御して高速化できないか検討したい。

次回は、初期値のモデルから自己対局して学習できるか確認したい。