前回までで、Gumbel AlphaZeroの自己対局で訓練データを生成できるようになったので、今回は学習処理を実装する。
訓練データ
自己対局で、以下のようなNumpyの構造体の配列をバイナリファイルとして出力する。
ランダムアクセスが可能なように固定長にしている。
dtypeTrainingData = np.dtype(
[
("hcp", HuffmanCodedPos),
("policy", np.dtype((np.float32, MOVE_LABELS_NUM))),
("result", np.uint8),
]
)
データローダ
データローダは、PyTorchのDatasetとDataLoaderを使用して実装する。
複数ワーカに対応して、Pythonでの特徴量作成処理がボトルネックにならないようにする。
複数のワーカがそれぞれデータをメモリに読み込むと大量にメモリを消費するため、np.memmapを使用してメモリマップト配列として読み込む。
固定長なので、配列のインデックスを指定してランダムアクセスができる。
自己対局で生成したデータは、新しいものから引数で与えられたnum_filesファイル分だけ使用する。
class TrainingDataset(Dataset): def __init__(self, data_dir, num_files=None, worker_id=None, num_workers=None): """ Args: data_dir: Directory containing .data files num_files: Number of newest files to use (None for all) worker_id: Worker ID for multi-worker setup num_workers: Total number of workers """ self.data_dir = data_dir self.num_files = num_files self.worker_id = worker_id or 0 self.num_workers = num_workers or 1 # Find and sort data files by modification time (newest first) pattern = os.path.join(data_dir, "*.data") files = glob.glob(pattern) files.sort(key=os.path.getmtime, reverse=True) # Select specified number of files if num_files is not None: files = files[:num_files] self.files = files self.file_offsets = [] self.total_samples = 0 # Calculate file offsets and total samples for file_path in self.files: file_size = os.path.getsize(file_path) num_samples = file_size // dtypeTrainingData.itemsize self.file_offsets.append((self.total_samples, num_samples, file_path)) self.total_samples += num_samples # For multi-worker setup, divide data among workers if self.num_workers > 1: samples_per_worker = self.total_samples // self.num_workers self.start_idx = self.worker_id * samples_per_worker if self.worker_id == self.num_workers - 1: # Last worker takes remaining samples self.end_idx = self.total_samples else: self.end_idx = (self.worker_id + 1) * samples_per_worker self.worker_total_samples = self.end_idx - self.start_idx else: self.start_idx = 0 self.end_idx = self.total_samples self.worker_total_samples = self.total_samples # Create memory maps for files self._mmaps = {} self._board = None def _get_mmap(self, file_path): """Get or create memory map for a file""" if file_path not in self._mmaps: self._mmaps[file_path] = np.memmap( file_path, dtype=dtypeTrainingData, mode="r" ) return self._mmaps[file_path] def _find_file_and_offset(self, global_idx): """Find which file contains the sample at global_idx""" for file_start, file_samples, file_path in self.file_offsets: if global_idx < file_start + file_samples: return file_path, global_idx - file_start raise IndexError(f"Index {global_idx} out of range") def _get_board(self): """Get or create a Board instance for HCP conversion""" if self._board is None: self._board = Board() return self._board def __len__(self): return self.worker_total_samples def __getitem__(self, idx): if idx >= self.worker_total_samples: raise IndexError(f"Index {idx} out of range for worker {self.worker_id}") # Convert worker-local index to global index global_idx = self.start_idx + idx # Find file and local offset file_path, local_offset = self._find_file_and_offset(global_idx) # Get memory map and read sample mmap = self._get_mmap(file_path) sample = mmap[local_offset] # Convert HCP to input features board = self._get_board() board.set_hcp(np.asarray(sample["hcp"])) features = np.empty((FEATURES_NUM, 9, 9), dtype=np.float32) make_input_features(board, features) match sample["result"]: case 1: # BLACK_WIN result = 1 if board.turn == BLACK else 0 case 2: # WHITE_WIN result = 1 if board.turn == WHITE else 0 case _: # DRAW result = 0.5 # Convert to tensors features_tensor = torch.from_numpy(features) policy_tensor = torch.from_numpy(sample["policy"].copy()) result_tensor = torch.tensor(result, dtype=torch.float32) return features_tensor, policy_tensor, result_tensor def __del__(self): # Clean up memory maps for mmap in self._mmaps.values(): if hasattr(mmap, "_mmap"): mmap._mmap.close() def worker_init_fn(worker_id): """ Worker initialization function for DataLoader. Re-initializes the dataset for each worker process. """ worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset # The dataset instance in this worker. worker_dataset = TrainingDataset( data_dir=dataset.data_dir, num_files=dataset.num_files, worker_id=worker_id, num_workers=dataset.num_workers, ) # Copy the attributes from the new dataset to the one in the worker dataset.files = worker_dataset.files dataset.file_offsets = worker_dataset.file_offsets dataset.total_samples = worker_dataset.total_samples dataset.worker_id = worker_dataset.worker_id dataset.start_idx = worker_dataset.start_idx dataset.end_idx = worker_dataset.end_idx dataset.worker_total_samples = worker_dataset.worker_total_samples dataset._mmaps = {} # Each worker needs its own memory maps dataset._board = None # Reset the board for each worker def create_dataloader( data_dir, batch_size=32, num_files=None, num_workers=0, shuffle=True, pin_memory=True, drop_last=True, **kwargs, ): """ Create a DataLoader for training data Args: data_dir: Directory containing .data files batch_size: Batch size num_files: Number of newest files to use (None for all) num_workers: Number of worker processes shuffle: Whether to shuffle data pin_memory: Whether to pin memory for GPU transfer **kwargs: Additional arguments for DataLoader Returns: DataLoader instance """ if num_workers > 0: # Create a dummy dataset for the main process dataset = TrainingDataset(data_dir, num_files) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, worker_init_fn=worker_init_fn, **kwargs, ) else: dataset = TrainingDataset(data_dir, num_files) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory, drop_last=drop_last, **kwargs, ) return dataloader
テスト用のデータは、dlshogiと同様にhcpe形式のファイルを使用する。
hcpe形式のデータローダも同様に定義する。
訓練処理
PyTorchの標準的な方法で学習する。
ポリシーの損失関数は、CrossEntropyLossを使用してGumble AlphaZeroの改善されたポリシーの分布を学習する。
価値の損失は、BCEWithLogitsLossを使用して、勝ち/負け/引き分けをそれぞれ1, 0, 0.5として学習する。
ニューラルネットワークの定義は、python-dlshogi2を流用する。
オプションでAMP(Automatic Mixed Precision)に対応する。
tqdmで、進捗状況を表示する。
def train_epoch(model, dataloader, optimizer, scaler, device, amp_enabled=True): """Train for one epoch""" model.train() policy_criterion = nn.CrossEntropyLoss() value_criterion = nn.BCEWithLogitsLoss() total_loss = 0.0 policy_loss_sum = 0.0 value_loss_sum = 0.0 num_batches = 0 with tqdm(dataloader, desc="Training", unit="batch") as pbar: for features, policy_targets, value_targets in pbar: features = features.to(device) policy_targets = policy_targets.to(device) value_targets = value_targets.to(device) optimizer.zero_grad() with autocast(enabled=amp_enabled): policy_output, value_output = model(features) # Policy loss - targets are probability distributions policy_loss = policy_criterion(policy_output, policy_targets) # Value loss - targets are game results value_loss = value_criterion(value_output.squeeze(-1), value_targets) total_batch_loss = policy_loss + value_loss scaler.scale(total_batch_loss).backward() scaler.step(optimizer) scaler.update() total_loss += total_batch_loss.item() policy_loss_sum += policy_loss.item() value_loss_sum += value_loss.item() num_batches += 1 # Update progress bar pbar.set_postfix( { "Loss": f"{total_batch_loss.item():.4f}", "Policy": f"{policy_loss.item():.4f}", "Value": f"{value_loss.item():.4f}", } ) return { "total_loss": total_loss / num_batches, "policy_loss": policy_loss_sum / num_batches, "value_loss": value_loss_sum / num_batches, } def evaluate(model, dataloader, device, amp_enabled=True): """Evaluate the model""" model.eval() policy_criterion = nn.CrossEntropyLoss() value_criterion = nn.BCEWithLogitsLoss() total_loss = 0.0 policy_loss_sum = 0.0 value_loss_sum = 0.0 policy_correct = 0 value_correct = 0 total_samples = 0 with torch.no_grad(): with tqdm(dataloader, desc="Evaluating", unit="batch") as pbar: for features, policy_targets, value_targets in pbar: features = features.to(device) policy_targets = policy_targets.to(device) value_targets = value_targets.to(device) with autocast(enabled=amp_enabled): policy_output, value_output = model(features) # For test data, policy_targets are labels (long), not distributions policy_loss = policy_criterion(policy_output, policy_targets) # Calculate accuracy for classification policy_pred = torch.argmax(policy_output, dim=1) policy_correct += (policy_pred == policy_targets).sum().item() value_loss = value_criterion( value_output.squeeze(-1), value_targets ) # Value accuracy (threshold at 0.5) value_pred = torch.sigmoid(value_output.squeeze(-1)) value_binary_pred = (value_pred > 0.5).float() value_binary_target = (value_targets > 0.5).float() value_correct += ( (value_binary_pred == value_binary_target).sum().item() ) total_batch_loss = policy_loss + value_loss total_loss += total_batch_loss.item() policy_loss_sum += policy_loss.item() value_loss_sum += value_loss.item() total_samples += features.size(0) # Update progress bar pbar.set_postfix( { "Loss": f"{total_batch_loss.item():.4f}", "Policy Acc": f"{policy_correct/total_samples:.4f}", "Value Acc": f"{value_correct/total_samples:.4f}", } ) num_batches = len(dataloader) return { "total_loss": total_loss / num_batches, "policy_loss": policy_loss_sum / num_batches, "value_loss": value_loss_sum / num_batches, "policy_accuracy": policy_correct / total_samples, "value_accuracy": value_correct / total_samples, } def save_checkpoint(model, optimizer, scaler, epoch, loss, filepath): """Save training checkpoint""" checkpoint = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scaler_state_dict": scaler.state_dict(), "loss": loss, } torch.save(checkpoint, filepath) print(f"Checkpoint saved to {filepath}") def load_checkpoint(filepath, model, optimizer, scaler, device): """Load training checkpoint""" checkpoint = torch.load(filepath, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scaler.load_state_dict(checkpoint["scaler_state_dict"]) epoch = checkpoint["epoch"] loss = checkpoint["loss"] print(f"Checkpoint loaded from {filepath}, epoch {epoch}, loss {loss:.4f}") return epoch, loss def train(args): """Main training function""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Create model model = PolicyValueNetwork( blocks=args.blocks, channels=args.channels, fcl=args.fcl ).to(device) # Create optimizer optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) # Create gradient scaler for mixed precision scaler = GradScaler(enabled=args.amp) # Create data loaders train_dataloader = create_dataloader( args.train_dir, batch_size=args.batch_size, num_files=args.num_files, num_workers=args.num_workers, shuffle=True, ) test_dataloader = None if args.test_file: test_dataloader = create_test_dataloader( args.test_file, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, ) # Resume from checkpoint if specified start_epoch = 0 if args.resume: start_epoch, _ = load_checkpoint(args.resume, model, optimizer, scaler, device) start_epoch += 1 # Create checkpoint directory checkpoint_dir = Path(args.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) # Training loop for epoch in range(start_epoch, args.epochs): print(f"Epoch {epoch + 1}/{args.epochs}") # Training start_time = time.time() train_metrics = train_epoch( model, train_dataloader, optimizer, scaler, device, args.amp ) train_time = time.time() - start_time print( f"Train - Loss: {train_metrics['total_loss']:.4f}, " f"Policy Loss: {train_metrics['policy_loss']:.4f}, " f"Value Loss: {train_metrics['value_loss']:.4f}, " f"Time: {train_time:.2f}s" ) # Evaluation if test_dataloader and (epoch + 1) % args.eval_interval == 0: start_time = time.time() eval_metrics = evaluate(model, test_dataloader, device, args.amp) eval_time = time.time() - start_time print( f"Eval - Loss: {eval_metrics['total_loss']:.4f}, " f"Policy Loss: {eval_metrics['policy_loss']:.4f}, " f"Value Loss: {eval_metrics['value_loss']:.4f}, " f"Policy Acc: {eval_metrics['policy_accuracy']:.4f}, " f"Value Acc: {eval_metrics['value_accuracy']:.4f}, " f"Time: {eval_time:.2f}s" ) # Save checkpoint if (epoch + 1) % args.save_interval == 0: checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch + 1}.pth" save_checkpoint( model, optimizer, scaler, epoch, train_metrics["total_loss"], checkpoint_path, ) # Save final model final_model_path = checkpoint_dir / "final_model.pth" torch.save(model.state_dict(), final_model_path) print(f"Final model saved to {final_model_path}")
動作確認
python-dlshogi2の学習済みモデルを使用して自己対局で10万局面生成して、訓練できるか確認した。
自己対局
バッチサイズ64で、シミュレーション数32で、10万局面生成に1時間14分かかった。
100023pos [1:14:49, 22.28pos/s]
訓練結果
バッチサイズ64で、4エポック学習した結果は以下の通り。
評価データには2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面(重複なし)を使用した。
Epoch 1/4 Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.94batch/s, Loss=3.8103, Policy=3.2766, Value=0.5337] Train - Loss: 4.3428, Policy Loss: 3.6886, Value Loss: 0.6542, Time: 40.12s Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:09<00:00, 103.32batch/s, Loss=4.6139, Policy Acc=0.2163, Value Acc=0.5180] Eval - Loss: 4.3369, Policy Loss: 3.6071, Value Loss: 0.7298, Policy Acc: 0.2163, Value Acc: 0.5180, Time: 129.61s Epoch 2/4 Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.37batch/s, Loss=2.8277, Policy=2.4360, Value=0.3917] Train - Loss: 3.2876, Policy Loss: 2.7699, Value Loss: 0.5177, Time: 40.72s Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:10<00:00, 102.96batch/s, Loss=5.1747, Policy Acc=0.2416, Value Acc=0.5236] Eval - Loss: 4.1800, Policy Loss: 3.1499, Value Loss: 1.0302, Policy Acc: 0.2416, Value Acc: 0.5236, Time: 130.05s Epoch 3/4 Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.87batch/s, Loss=2.5337, Policy=2.1820, Value=0.3517] Train - Loss: 2.8177, Policy Loss: 2.4598, Value Loss: 0.3579, Time: 40.19s Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:09<00:00, 103.18batch/s, Loss=5.2742, Policy Acc=0.2576, Value Acc=0.5145] Eval - Loss: 4.4099, Policy Loss: 3.0620, Value Loss: 1.3479, Policy Acc: 0.2576, Value Acc: 0.5145, Time: 129.78s Epoch 4/4 Training: 100%|████████████████████████████████████████████████████| 1562/1562 [00:40<00:00, 38.42batch/s, Loss=2.3076, Policy=2.0778, Value=0.2298] Train - Loss: 2.5607, Policy Loss: 2.2704, Value Loss: 0.2903, Time: 40.66s Evaluating: 100%|███████████████████████████████████████| 13390/13390 [02:10<00:00, 102.85batch/s, Loss=6.4862, Policy Acc=0.2696, Value Acc=0.5170] Eval - Loss: 4.6274, Policy Loss: 2.9429, Value Loss: 1.6845, Policy Acc: 0.2696, Value Acc: 0.5170, Time: 130.19s Final model saved to checkpoints\final_model.pth
ポリシーの正解率は26.96%、価値の正解率は51.70%となった。
ポリシーは、10万局面の結果としては妥当だと思う。
価値はあまり学習できていないのは、Gumble AlphaZeroはランダム性が強いためかもしれない。
バグの可能性もあるので棋譜を増やして確認してみる必要がある。
ワーカー数
バッチサイズ64ではワーカーを増やしても効果なかった。
バッチサイズ1024にするとワーカーを2に増やすと、速くなった。
4にしても変わらなかったので、バッチサイズに応じてワーカー数は調整する必要がある。
ワーカー1:9.65batch/s
ワーカー2:17.71batch/s
ワーカー4:17.91batch/s