TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク

NNUEとは別のアーキテクチャで軽量の価値ネットワークを学習して、Stockfish系のヒューリスティック探索と組み合わせられないか検討したいと思っている。
軽量の畳み込みニューラルネットワークを使うことで、NNUEに匹敵する軽量モデルを学習することを目指したい。

検討中のアーキテクチャ

入力層

入力層は、NNUEではEmbeddingの重みを辞書引きすることで計算を省いている。
軽量モデルにおいて、Embeddingを使うことは必要と考える。

1x1の畳み込みは、Embeddingに置き換えることができるため、dlshogiの入力層の内、盤上の駒の配置と、持ち駒を1x1の畳み込みで埋め込みベクトルにする。

盤上の駒は、座標ごとにワンホットで表すことができるため、Embeddingを使用する。
持ち駒は、駒の種類と枚数がマルチホットになるため、EmbeddingBagを使用する。
持ち駒の埋め込みベクトルはブロードキャストして、盤上の駒の埋め込みベクトルに加算する。

入力層の出力は、9x9の特徴マップになる。

2層目

9x1と1x9カーネルを並列に並べた畳み込み層とする。
9x1と1x9カーネルは、スライドが必要なく計算量を削減できる(これは、RyfamateのRyfcNetにインスパイアされている)。
9x1と1x9に分解したラージカーネルにより1層のみで盤面全体のパターンを認識することを目的としている。
それぞれのカーネルの出力は、ブロードキャストして加算を行わず、そのまま連結する。

9x1と1x9カーネルにより、位置情報が失われるが、価値を求める場合は問題にならない。

3層目

全結合とする。

出力層

全結合で、価値を出力する。

計算量

入力層の埋め込み次元は、16とする。
各座標ごとに埋め込みが出力されるので、合計は、16 x 81 = 1296次元になる。

2層目のチャンネル数は、4次元とする。
カーネルごとに9次元のベクトルが出力されるので、合計は、4 * (9 + 9) = 72次元になる。

3層目は、32次元とする。

これで、NNUEと計算量は近くなる。
今後、精度と計算量でチューニングを行う予定である。

モデルの実装

PyTorchで、以下のようにモデルを定義した。

import torch
import torch.nn as nn

from dlshogi.common import PIECETYPE_NUM, MAX_PIECES_IN_HAND_SUM

NUM_EMBEDDINGS1 = PIECETYPE_NUM * 2
NUM_EMBEDDINGS2 = MAX_PIECES_IN_HAND_SUM * 2 + 1

class LiteValueNetwork(nn.Module):
    def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()):
        super(LiteValueNetwork, self).__init__()
        self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1)
        self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2)
        self.l2_1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(9, 1), bias=False)
        self.l2_2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(1, 9), bias=False)
        self.bn2_1 = nn.BatchNorm2d(dims[1])
        self.bn2_2 = nn.BatchNorm2d(dims[1])
        self.l3 = nn.Linear(dims[1] * 9 * 2, dims[2])
        self.l4 = nn.Linear(dims[2], 1)
        self.act = activation
        self.dims = dims

    def forward(self, x1, x2):
        h1_1 = self.l1_1(x1).view(-1, self.dims[0], 9, 9)
        h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1)
        h1 = h1_1 + h1_2
        h2_1 = self.bn2_1(self.l2_1(h1))
        h2_2 = self.bn2_2(self.l2_2(h1))
        h2 = self.act(torch.cat((h2_1.view(-1, self.dims[1] * 9), h2_2.view(-1, self.dims[1] * 9)), 1))
        h3 = self.act(self.l3(h2))
        h4 = self.l4(h3)

        return h4

特徴量作成

dlshogiの特徴量作成処理を流用して、特徴量をint64_tの辞書インデックスで作成するようにした。
パディングインデックスは、特徴量インデックスの最大値+1とした。

Stockfishの推論に組み込む場合は、指し手ごとに差分を更新することになる。

template <Color turn>
inline void make_input_features_lite(const Position& position, features1_lite_t features1, features2_lite_t features2) {
    Bitboard occupied_bb = position.occupiedBB();

    FOREACH_BB(occupied_bb, Square sq, {
        const Piece pc = position.piece(sq);
        const PieceType pt = pieceToPieceType(pc);
        Color c = pieceToColor(pc);

        // 後手の場合、色を反転し、盤面を180度回転
        if (turn == White) {
            c = oppositeColor(c);
            sq = SQ99 - sq;
        }

        // 駒の配置
        features1[sq] = PIECETYPE_NUM * (int)c + pt - 1;
    });

    for (Color c = Black; c < ColorNum; ++c) {
        // 後手の場合、色を反転
        const Color c2 = turn == Black ? c : oppositeColor(c);

        // 持ち駒
        const Hand hand = position.hand(c);
        int p = 0;
        for (HandPiece hp = HPawn; hp < HandPieceNum; ++hp) {
            u32 num = hand.numOf(hp);
            if (num >= MAX_PIECES_IN_HAND[hp]) {
                num = MAX_PIECES_IN_HAND[hp];
            }
            for (size_t i = 0; i < num; ++i) {
                const int64_t idx = MAX_PIECES_IN_HAND_SUM * (int)c2 + p + i;
                features2[idx] = idx;
            }
            p += MAX_PIECES_IN_HAND[hp];
        }
    }

    // is check
    if (position.inCheck()) {
        features2[MAX_FEATURES2_HAND_NUM] = MAX_FEATURES2_HAND_NUM;
    }
}

学習条件

dlshogiの訓練データのサブセットの36,819,741局面を使用して、上記のモデルが学習できるか実験した。

訓練条件:

  • バッチサイズ: 8192
  • 学習率: 0.04
  • オプティマイザ: MomentumSDG
  • エポック数: 8

Stockfishの探索に組み込む場合は、静止探索中の局面を省くなどの工夫が必要になるが、今回は特に局面を省かず実験した。

学習結果

テストデータに、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面(重複なし)を使用して、評価した結果は以下の通り。

訓練損失


評価損失


評価正解率

価値正解率は、61.77%となっており、軽量の畳み込みアーキテクチャでも学習できることが確認できた。

なお、同一の訓練データで、ResNet20ブロック256フィルタのモデルを4エポック学習した際の正解率は、70.19%である。

まとめ

NNUEと同等の計算量の軽量の畳み込みアーキテクチャで、価値ネットワークが学習できるか検証した。
結果、軽量モデルでも価値を学習できることが確認できた。

今後、Stockfish系の探索に組み込むことを検証したいが、先は長そうである。