前回、Stockfishの探索部を移植して、駒得評価関数で動くことを確認した。
1手詰めや、入力宣言勝ち、秒読みなど将棋特有の処理も実装した。
今回は、評価関数を以前に検討した軽量価値ネットワークにすることを試す。
軽量価値ネットワーク
以前に検討した通り、入力層を埋め込みベクトルにした、軽量の畳み込みニューラルネットワークを使用する。
計算量はNNUEとだいたい同じになるようにしている。
class LiteValueNetwork(nn.Module): def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()): super(LiteValueNetwork, self).__init__() self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1) self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2) self.l2_1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(9, 1), bias=False) self.l2_2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(1, 9), bias=False) self.bn2_1 = nn.BatchNorm2d(dims[1]) self.bn2_2 = nn.BatchNorm2d(dims[1]) self.l3 = nn.Linear(dims[1] * 9 * 2, dims[2]) self.l4 = nn.Linear(dims[2], 1) self.act = activation self.dims = dims def forward(self, x1, x2): h1_1 = self.l1_1(x1).view(-1, self.dims[0], 9, 9) h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1) h1 = h1_1 + h1_2 h2_1 = self.bn2_1(self.l2_1(h1)) h2_2 = self.bn2_2(self.l2_2(h1)) h2 = self.act(torch.cat((h2_1.view(-1, self.dims[1] * 9), h2_2.view(-1, self.dims[1] * 9)), 1)) h3 = self.act(self.l3(h2)) h4 = self.l4(h3) return h4
学習
まずは動作確認のため、floodgateの棋譜と、dlshogiとNNUE系の対局の棋譜から作成した36,819,741局面を8エポックだけ学習した。
floodgateの棋譜に対する正解率は、61.3%となった。
モデルエクスポート
学習したPyTorchのモデルをtorch.jit.traceで、TorchScripに変換して、torch.jit.freezeで推論用に変換して保存する。
torch.jit.freezeで、BatchNorm2dは直前の畳み込み層とレイヤー融合される。
traced_model = torch.jit.trace(model, (x1, x2))
traced_model = torch.jit.freeze(traced_model)
traced_model.save(args.torchscript)
推論処理
学習したモデルをそのまま使用できるため、まずはLibTorchのCPU版で推論処理を実装した。
struct TorchInitializer { TorchInitializer() { at::set_num_threads(1); at::set_num_interop_threads(1); } }; TorchInitializer torch_initializer; c10::InferenceMode guard; auto model = torch::jit::load(R"(model_lite-008.pt)"); // Evaluate is the evaluator for the outer world. It returns a static evaluation // of the position from the point of view of the side to move. Value Eval::evaluate(const Position& pos) { Value v = 0; features1_lite_t features1; features2_lite_t features2; // set all padding index std::fill_n((int64_t*)features1, sizeof(features1_lite_t) / sizeof(int64_t), PIECETYPE_NUM * 2); std::fill_n((int64_t*)features2, sizeof(features2_lite_t) / sizeof(int64_t), MAX_FEATURES2_NUM); make_input_features_lite(pos, features1, features2); std::vector<torch::jit::IValue> x = { torch::from_blob(features1, { 1, (size_t)SquareNum }, torch::dtype(torch::kInt64)), torch::from_blob(features2, { 1, MAX_FEATURES2_NUM }, torch::dtype(torch::kInt64)) }; const auto y = model.forward(x); const float value = *y.toTensor().data_ptr<float>(); v = int(value * 600); return v; }
初期局面での探索
初期局面で、指し手が2四歩か、7四歩になれば正しく推論できていることが確認できる。
推論結果は、以下のようになった。
position startpos go info string Available processors: 0-63 info string Using 1 thread info depth 1 seldepth 2 multipv 1 score cp 40 nodes 30 nps 545 hashfull 0 time 55 pv 7g7f info depth 2 seldepth 2 multipv 1 score cp 110 nodes 64 nps 1163 hashfull 0 time 55 pv 2g2f info depth 3 seldepth 2 multipv 1 score cp 110 nodes 101 nps 1803 hashfull 0 time 56 pv 2g2f info depth 4 seldepth 2 multipv 1 score cp 110 nodes 132 nps 2357 hashfull 0 time 56 pv 2g2f info depth 5 seldepth 3 multipv 1 score cp 87 nodes 337 nps 5349 hashfull 0 time 63 pv 2g2f 3c3d info depth 6 seldepth 6 multipv 1 score cp 41 nodes 906 nps 11048 hashfull 0 time 82 pv 2g2f 3c3d 2f2e info depth 7 seldepth 6 multipv 1 score cp 60 nodes 1069 nps 11494 hashfull 0 time 93 pv 2g2f 3c3d 2f2e info depth 8 seldepth 8 multipv 1 score cp 85 nodes 2101 nps 14489 hashfull 0 time 145 pv 2g2f 3c3d 7g7f 2b8h+ 7i8h info depth 9 seldepth 7 multipv 1 score cp 162 nodes 2490 nps 16168 hashfull 1 time 154 pv 2g2f 7a7b 7g7f info depth 10 seldepth 10 multipv 1 score cp 35 nodes 5809 nps 22171 hashfull 2 time 262 pv 2g2f 3c3d 2f2e 8c8d 7g7f 2b8h+ 2h8h 8d8e info depth 11 seldepth 10 multipv 1 score cp 31 nodes 7735 nps 24021 hashfull 2 time 322 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e info depth 12 seldepth 12 multipv 1 score cp 41 nodes 9523 nps 23283 hashfull 2 time 409 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 7g7f 2b8h+ 7i8h info depth 13 seldepth 13 multipv 1 score cp 40 nodes 11926 nps 23522 hashfull 2 time 507 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 7a7b 2d3d 2b4d info depth 14 seldepth 15 multipv 1 score cp 45 nodes 15422 nps 23762 hashfull 3 time 649 pv 2g2f 8c8d 2f2e 8d8e 2e2d 2c2d 2h2d 8e8f 7g7f 3c3d 8h2b+ 3a2b 8g8f info depth 15 seldepth 18 multipv 1 score cp 36 nodes 31538 nps 19760 hashfull 8 time 1596 pv 2g2f 3c3d 2f2e 8c8d 7g7f 2b8h+ 7i8h 8d8e 3i3h 8e8f 8g8f info depth 16 seldepth 15 multipv 1 score cp 40 nodes 35783 nps 20148 hashfull 9 time 1776 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e 3i3h 8e8f 8g8f 8b8f info depth 17 seldepth 19 multipv 1 score cp 30 nodes 50257 nps 20175 hashfull 11 time 2491 pv 2g2f 3c3d 2f2e 8c8d 2e2d 2c2d 2h2d 8d8e 7g7f 8e8f 8h2b+ 8b2b P*2c 2b3b 8g8f info depth 18 seldepth 20 multipv 1 score cp 27 nodes 114177 nps 23339 hashfull 22 time 4892 pv 2g2f 3c3d 2f2e 8c8d 7g7f 8d8e 8h2b+ 3a2b 3i3h 8e8f 2e2d 8f8g+ 2d2c+ info depth 19 seldepth 21 multipv 1 score cp 15 nodes 144242 nps 21954 hashfull 30 time 6570 pv 2g2f 8c8d 2f2e 8d8e 7g7f 8e8f 8g8f 3c3d 2e2d 2c2d 8h2b+ 3a2b 8f8e 8b8e 2h2d info depth 20 seldepth 22 multipv 1 score cp 4 nodes 249419 nps 20140 hashfull 71 time 12384 pv 2g2f 8c8d 2f2e 8d8e 2e2d 2c2d 2h2d 8e8f P*2c 8f8g+ 2c2b+ 8b2b 2d2b+ 3a2b
2四歩が最善手となり、正しく推論できていそうである。
NPS
初期局面で、2万程度しかNPSが出ない。
Stockfish系の探索エンジンとしては遅すぎる。
LibTorchは、推論が別スレッドになるなど、オーバーヘッドが大きいため、軽量のモデルをバッチサイズ1で、推論する用途には向いていない。
推論処理は、NNUEと同じようにSIMDを使用して、作りこむ必要がありそうだ。
今回は、動作することが確かめられたので良しとする。
Lesserkaiと対局
Lesserkaiと対局して、勝てることを確認した。
評価値も大きく間違ってはいなさそうである。

まとめ
前回移植したStockfishの探索部に、軽量価値ネットワークを組み込んだ。
推論処理をLibTorchで実装し、正しい評価値が出力されることが確認できた。
しかし、LibTorchでは推論がStockfish系の探索エンジンとしては遅すぎることがわかった。
今後は、SIMDで推論処理を作りこんで、NNUE並みの高速化を目指す。