TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

拡散モデルで将棋の方策を学習する

拡散モデルで、将棋の方策を学習できないか試してみた。

拡散モデル

拡散モデルは、高品質で多様なサンプルを生成できる生成モデルである。
昨年8月にStable Diffusionが公開されたことで注目を集めている。

拡散モデルは、確率微分方程式によって表される確率分布を近似するモデルで、モード崩壊を起こさず多様な分布を学習できるという特徴がある。

また、プロンプトと呼ばれるテキストにより条件付けを行い、テキストに従った画像を生成できる。

将棋の方策

将棋の方策は、座標と移動方向の組み合わせで表現でき、dlshogiで採用している表現方法では2187次元になる。
つまり、指し手は、局面によって条件づけられた2187次元の確率分布からサンプリングを行っていることになる。

拡散モデルの可能性

条件付けを行い高次元の確率分布からサンプリングを行うという仕組みは、将棋の方策においても適用できると考える。
将棋は、詰みのある局面や入玉した局面では指し手の性質が変わるという特徴があるが、拡散モデルの表現力であれば、異なるモードの状態でもうまく学習できるはずである。
現状のdlshogiでは、詰みのある局面は学習が難しいため、訓練データからは除外し、対局時に詰み探索に頼っている。

モデル構成

方策を複数チャンネルの9x9の画像として表現して、画像生成の拡散モデルと同じように学習する。

拡散過程では、方策からノイズに変換する。
逆拡散過程では、ノイズから方策に変換する。

条件付けは、局面をCNNで特徴マップに変換したものを使用する。
この部分はdlshogiの学習済みモデルのバックボーンの層を流用する。

実装

diffusersを使って、サンプルコードを参考にしながら実装した。

import argparse
import numpy as np
import types
from tqdm import tqdm

from diffusers import (
    DDPMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils import randn_tensor

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

from dlshogi.common import *
from dlshogi.data_loader import Hcpe3DataLoader
from dlshogi.data_loader import DataLoader
from dlshogi import serializers
from dlshogi.network.policy_value_network import policy_value_network


parser = argparse.ArgumentParser()
parser.add_argument("cache")
parser.add_argument("test_data")
parser.add_argument("model")
parser.add_argument("--network", default="resnet30x384_relu")
parser.add_argument("-e", "--epoch", type=int, default=1)
parser.add_argument("-g", "--gpu", type=int, default=0)
parser.add_argument("-b", "--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument('--eval_interval', type=int, default=1000)
parser.add_argument("--num_inference_steps", type=int, default=20)
args = parser.parse_args()


if args.gpu >= 0:
    device = torch.device(f"cuda:{args.gpu}")
else:
    device = torch.device("cpu")


train_len, actual_len = Hcpe3DataLoader.load_files([], cache=args.cache)
train_data = np.arange(train_len, dtype=np.uint32)
test_data = np.fromfile(args.test_data, dtype=HuffmanCodedPosAndEval)

train_dataloader = Hcpe3DataLoader(train_data, args.batch_size, device, shuffle=True)
test_dataloader = DataLoader(test_data, args.eval_batch_size, device)


model = policy_value_network(args.network)


def forward(self, x1, x2):
    u1_1_1 = self.l1_1_1(x1)
    u1_1_2 = self.l1_1_2(x1)
    u1_2 = self.l1_2(x2)
    u1 = self.act(self.norm1(u1_1_1 + u1_1_2 + u1_2))
    return self.blocks(u1)


model.forward = types.MethodType(forward, model)
model.to(device)
serializers.load_npz(args.model, model)
model.eval()


unet = UNet2DConditionModel(
    sample_size=9,
    in_channels=27,
    out_channels=27,
    encoder_hid_dim=model.policy.in_channels * 9 * 9,
)
unet.to(device)

optimizer = AdamW(
    unet.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

lr_scheduler = StepLR(optimizer=optimizer, step_size=1, gamma=0.5)

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
)

generator = torch.Generator(device=unet.device).manual_seed(0)

nl_loss = torch.nn.NLLLoss()
def accuracy(y, t):
    return (torch.max(y, 1)[1] == t).sum().detach().item() / len(t)

def eval():
    noise_scheduler = DDPMScheduler()

    x1, x2, policies, win, value = test_dataloader.sample()
    bsz = policies.shape[0]

    encoder_hidden_states = model(x1, x2).reshape(bsz, 1, -1)

    # inference (sample random noise and denoise)
    image = randn_tensor((bsz, 27, 9, 9), generator=generator, device=device)

    # set step values
    noise_scheduler.set_timesteps(args.num_inference_steps)

    unet.eval()
    for t in noise_scheduler.timesteps:
        # 1. predict noise model_output
        model_output = unet(image, t, encoder_hidden_states).sample

        # 2. compute previous image: x_t -> x_t-1
        image = noise_scheduler.step(model_output, t, image, generator=generator).prev_sample

    pred = image.reshape(bsz, -1)
    loss = nl_loss(torch.log(pred), policies)
    writer.add_scalar("loss/eval", loss.detach().item(), step)
    writer.add_scalar("accuracy/eval", accuracy(pred, policies), step)

    unet.train()

writer = SummaryWriter()


step = 0
for epoch in range(args.epoch):
    unet.train()
    for x1, x2, policies, win, value in tqdm(
        train_dataloader, total=train_len // args.batch_size, desc=f"epoch: {epoch}"
    ):
        bsz = policies.shape[0]
        policies = policies.reshape((bsz, 27, 9, 9))

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(policies)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            noise_scheduler.config.num_train_timesteps,
            (bsz,),
            device=policies.device,
        )
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_policies = noise_scheduler.add_noise(policies, noise, timesteps)

        # Get the embedding for position
        encoder_hidden_states = model(x1, x2).reshape(bsz, 1, -1)

        target = noise

        # Predict the noise residual and compute loss
        model_pred = unet(noisy_policies, timesteps, encoder_hidden_states).sample

        loss = F.mse_loss(model_pred, target, reduction="mean")

        # Backpropagate
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        writer.add_scalar("loss/train", loss.detach().item(), step)

        if step % args.eval_interval == 0:
            eval()

    lr_scheduler.step()

条件付けは、テキストではなくdlshogiのモデルのバックボーンを使用している。
duffusersは、埋め込みのテンソルがシーケンスになっている必要があるため、シーケンスのサイズ=1とした。

Stable Diffusionのように、VAEで潜在空間への圧縮は行わず、直接方策の確率を入力している。

学習結果

まだ学習を少し試しただけだが、訓練損失は下がることを確認した。

逆拡散過程で生成した方策の正解率は0で、今のところ学習できていない。

確認したところ、出力に負の値が含まれていたため、方策の確率を直接学習するよりlogits空間に変換した方がよいかもしれない。

まとめ

拡散モデルによる将棋の方策の学習を試してみた。
推論速度が遅いため、大会向けの将棋AIとしては速度面で使い物にならなそうであるが、精度で既存モデルを上回るか検証したい。
今のところの感触は学習が難しそうである。

続く。
tadaoyamaoka.hatenablog.com