TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク その5(LibTorchで推論)

前回、Stockfishの探索部を移植して、駒得評価関数で動くことを確認した。
1手詰めや、入力宣言勝ち、秒読みなど将棋特有の処理も実装した。

今回は、評価関数を以前に検討した軽量価値ネットワークにすることを試す。

軽量価値ネットワーク

以前に検討した通り、入力層を埋め込みベクトルにした、軽量の畳み込みニューラルネットワークを使用する。
計算量はNNUEとだいたい同じになるようにしている。

class LiteValueNetwork(nn.Module):
    def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()):
        super(LiteValueNetwork, self).__init__()
        self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1)
        self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2)
        self.l2_1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(9, 1), bias=False)
        self.l2_2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(1, 9), bias=False)
        self.bn2_1 = nn.BatchNorm2d(dims[1])
        self.bn2_2 = nn.BatchNorm2d(dims[1])
        self.l3 = nn.Linear(dims[1] * 9 * 2, dims[2])
        self.l4 = nn.Linear(dims[2], 1)
        self.act = activation
        self.dims = dims

    def forward(self, x1, x2):
        h1_1 = self.l1_1(x1).view(-1, self.dims[0], 9, 9)
        h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1)
        h1 = h1_1 + h1_2
        h2_1 = self.bn2_1(self.l2_1(h1))
        h2_2 = self.bn2_2(self.l2_2(h1))
        h2 = self.act(torch.cat((h2_1.view(-1, self.dims[1] * 9), h2_2.view(-1, self.dims[1] * 9)), 1))
        h3 = self.act(self.l3(h2))
        h4 = self.l4(h3)

        return h4

学習

まずは動作確認のため、floodgateの棋譜と、dlshogiとNNUE系の対局の棋譜から作成した36,819,741局面を8エポックだけ学習した。
floodgateの棋譜に対する正解率は、61.3%となった。

モデルエクスポート

学習したPyTorchのモデルをtorch.jit.traceで、TorchScripに変換して、torch.jit.freezeで推論用に変換して保存する。
torch.jit.freezeで、BatchNorm2dは直前の畳み込み層とレイヤー融合される。

    traced_model = torch.jit.trace(model, (x1, x2))
    traced_model = torch.jit.freeze(traced_model)
    traced_model.save(args.torchscript)

推論処理

学習したモデルをそのまま使用できるため、まずはLibTorchのCPU版で推論処理を実装した。

struct TorchInitializer {
    TorchInitializer() {
        at::set_num_threads(1);
        at::set_num_interop_threads(1);
    }
};
TorchInitializer torch_initializer;
c10::InferenceMode guard;
auto model = torch::jit::load(R"(model_lite-008.pt)");

    // Evaluate is the evaluator for the outer world. It returns a static evaluation
// of the position from the point of view of the side to move.
Value Eval::evaluate(const Position& pos) {

    Value v = 0;

    features1_lite_t features1;
    features2_lite_t features2;

    // set all padding index
    std::fill_n((int64_t*)features1, sizeof(features1_lite_t) / sizeof(int64_t), PIECETYPE_NUM * 2);
    std::fill_n((int64_t*)features2, sizeof(features2_lite_t) / sizeof(int64_t), MAX_FEATURES2_NUM);

    make_input_features_lite(pos, features1, features2);

    std::vector<torch::jit::IValue> x = {
        torch::from_blob(features1, { 1, (size_t)SquareNum }, torch::dtype(torch::kInt64)),
        torch::from_blob(features2, { 1, MAX_FEATURES2_NUM }, torch::dtype(torch::kInt64))
    };

    const auto y = model.forward(x);
    const float value = *y.toTensor().data_ptr<float>();

    v = int(value * 600);

    return v;
}

初期局面での探索

初期局面で、指し手が2四歩か、7四歩になれば正しく推論できていることが確認できる。

推論結果は、以下のようになった。

position startpos
go
info string Available processors: 0-63
info string Using 1 thread
info depth 1 seldepth 2 multipv 1 score cp 40 nodes 30 nps 545 hashfull 0 time 55 pv 7g7f
info depth 2 seldepth 2 multipv 1 score cp 110 nodes 64 nps 1163 hashfull 0 time 55 pv 2g2f
info depth 3 seldepth 2 multipv 1 score cp 110 nodes 101 nps 1803 hashfull 0 time 56 pv 2g2f
info depth 4 seldepth 2 multipv 1 score cp 110 nodes 132 nps 2357 hashfull 0 time 56 pv 2g2f
info depth 5 seldepth 3 multipv 1 score cp 87 nodes 337 nps 5349 hashfull 0 time 63 pv 2g2f 3c3d
info depth 6 seldepth 6 multipv 1 score cp 41 nodes 906 nps 11048 hashfull 0 time 82 pv 2g2f 3c3d 2f2e
info depth 7 seldepth 6 multipv 1 score cp 60 nodes 1069 nps 11494 hashfull 0 time 93 pv 2g2f 3c3d 2f2e
info depth 8 seldepth 8 multipv 1 score cp 85 nodes 2101 nps 14489 hashfull 0 time 145 pv 2g2f 3c3d 7g7f 2b8h+ 7i8h
info depth 9 seldepth 7 multipv 1 score cp 162 nodes 2490 nps 16168 hashfull 1 time 154 pv 2g2f 7a7b 7g7f
info depth 10 seldepth 10 multipv 1 score cp 35 nodes 5809 nps 22171 hashfull 2 time 262 pv 2g2f 3c3d 2f2e 8c8d 7g7f 2b8h+ 2h8h 8d8e
info depth 11 seldepth 10 multipv 1 score cp 31 nodes 7735 nps 24021 hashfull 2 time 322 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e
info depth 12 seldepth 12 multipv 1 score cp 41 nodes 9523 nps 23283 hashfull 2 time 409 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 7g7f 2b8h+ 7i8h
info depth 13 seldepth 13 multipv 1 score cp 40 nodes 11926 nps 23522 hashfull 2 time 507 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 7a7b 2d3d 2b4d
info depth 14 seldepth 15 multipv 1 score cp 45 nodes 15422 nps 23762 hashfull 3 time 649 pv 2g2f 8c8d 2f2e 8d8e 2e2d 2c2d 2h2d 8e8f 7g7f 3c3d 8h2b+ 3a2b 8g8f
info depth 15 seldepth 18 multipv 1 score cp 36 nodes 31538 nps 19760 hashfull 8 time 1596 pv 2g2f 3c3d 2f2e 8c8d 7g7f 2b8h+ 7i8h 8d8e 3i3h 8e8f 8g8f
info depth 16 seldepth 15 multipv 1 score cp 40 nodes 35783 nps 20148 hashfull 9 time 1776 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e 3i3h 8e8f 8g8f 8b8f
info depth 17 seldepth 19 multipv 1 score cp 30 nodes 50257 nps 20175 hashfull 11 time 2491 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e 7g7f 8e8f 8h2b+ 8b2b P*2c 2b3b 8g8f
info depth 18 seldepth 20 multipv 1 score cp 27 nodes 114177 nps 23339 hashfull 22 time 4892 pv 2g2f 3c3d 2f2e 8c8d 7g7f 8d8e 8h2b+ 3a2b 3i3h 8e8f 2e2d 8f8g+ 2d2c+
info depth 19 seldepth 21 multipv 1 score cp 15 nodes 144242 nps 21954 hashfull 30 time 6570 pv 2g2f 8c8d 2f2e 8d8e 7g7f 8e8f 8g8f 3c3d 2e2d 2c2d 8h2b+ 3a2b 8f8e 8b8e 2h2d
info depth 20 seldepth 22 multipv 1 score cp 4 nodes 249419 nps 20140 hashfull 71 time 12384 pv 2g2f 8c8d 2f2e 8d8e 2e2d 2c2d 2h2d 8e8f P*2c 8f8g+ 2c2b+ 8b2b 2d2b+ 3a2b

2四歩が最善手となり、正しく推論できていそうである。

NPS

初期局面で、2万程度しかNPSが出ない。
Stockfish系の探索エンジンとしては遅すぎる。

LibTorchは、推論が別スレッドになるなど、オーバーヘッドが大きいため、軽量のモデルをバッチサイズ1で、推論する用途には向いていない。

推論処理は、NNUEと同じようにSIMDを使用して、作りこむ必要がありそうだ。

今回は、動作することが確かめられたので良しとする。

Lesserkaiと対局

Lesserkaiと対局して、勝てることを確認した。
評価値も大きく間違ってはいなさそうである。

まとめ

前回移植したStockfishの探索部に、軽量価値ネットワークを組み込んだ。
推論処理をLibTorchで実装し、正しい評価値が出力されることが確認できた。
しかし、LibTorchでは推論がStockfish系の探索エンジンとしては遅すぎることがわかった。

今後は、SIMDで推論処理を作りこんで、NNUE並みの高速化を目指す。