TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク その12(探索に組み込み)

前回、軽量価値ネットワークのモデル全体の推論をC++SIMDを使用して実装した。
今回は、推論を探索に組み込む。

ネットワーク構成見直し

推論を組み込む前にネットワーク構成で気になる箇所があったので見直した。

埋め込み層を実装した際に、permuteを忘れていて、permuteを追加したところ逆に精度が下がるという現象があった。
permuteがないと、9x1と1x9の畳み込みに将棋盤の正しい空間構造が伝わらないはずだが、精度が上がるのが謎であった。
仮説として、9x1と1x9の畳み込みは、縦と横だけの特徴量を抽出するので、斜め方向の特徴が捉えられていないことが原因と考えた。

そこで、2x2カーネルのDepth wiseの畳み込みを追加して、局所的な斜め方向の特徴も捉えるようにする。
パディングを行わないことで、出力のサイズを8x8にできるので、SIMDで計算する上でも都合がよい。

3x3も試したが、計算量が増えて、出力サイズ7x7になるため、後段の層のパラメータが減るため精度は上がらなかった。

Depth wiseにすることで、計算量の増加を抑えている。

修正後のネットワーク構成は以下のようになる。

class LiteValueNetwork(nn.Module):
    def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()):
        super(LiteValueNetwork, self).__init__()
        self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1)
        self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2)
        self.l2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[0], kernel_size=2, groups=dims[0], bias=False)
        self.bn2 = nn.BatchNorm2d(dims[0])
        self.l3_1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(8, 1), bias=False)
        self.l3_2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(1, 8), bias=False)
        self.bn3_1 = nn.BatchNorm2d(dims[1])
        self.bn3_2 = nn.BatchNorm2d(dims[1])
        self.l4 = nn.Linear(dims[1] * 8 * 2, dims[2])
        self.l5 = nn.Linear(dims[2], 1)
        self.act = activation
        self.dims = dims

    def forward(self, x1, x2):
        h1_1 = self.l1_1(x1).permute(0, 2, 1).view(-1, self.dims[0], 9, 9)
        h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1)
        h1 = h1_1 + h1_2
        h2 = self.act(self.bn2(self.l2(h1)))
        h3_1 = self.bn3_1(self.l3_1(h2))
        h3_2 = self.bn3_2(self.l3_2(h2))
        h3 = self.act(torch.cat((h3_1.reshape(-1, self.dims[1] * 8), h3_2.reshape(-1, self.dims[1] * 8)), 1))
        h4 = self.act(self.l4(h3))
        h5 = self.l5(h4)

        return h5

精度

修正後、同一条件で学習したところ、精度が向上した。
価値の正解率が修正前59.7%から、修正後60.8%になった。

推論速度

C++側の実装も修正して、推論速度を測定した。

Performance: 10000 iterations in 14.7859 ms
Average time per iteration: 0.00147859 ms

修正前より、速くなっている。

探索に組み込み

推論処理を探索に組み込んだ。

Eval::evaluateの実装は済んでいるので、USIオプションでパラメータを読み込むだけである。
実装を簡略化するため、グルーバル変数に読み込むようにしている。

NPS

初期局面で、1スレッドで探索した場合のNPSは、約70万になった。
水匠5を1スレッドで探索した際のNPSは、約124万である。
水匠5の56%くらいの速度である。

水匠5のネットワーク構成は、HalfKP-256×2-32-32で、今回実装した軽量価値ネットワークは、埋め込み(9x9x16=1296)を除くと、1024-64-32なので、パラメータ数は比較的多い。
パラメータ数に対しては、妥当な速度と言えそうである。

最近のNNUEもパラメータ数を増やして精度を上げる方向なので、一旦このままで進めることにする。

以前にLibTorchで実装した時は、2万程度だったので35倍程度速くなっている。

対局

Lesserkai相手には余裕で勝てることを確認した。

GPSFishにはまだ勝てなかった。

36,819,741局面を8エポック学習しただけなので、学習が足りていないためと考える。

まとめ

C++SIMDを使用して実装した軽量価値ネットワークをStockfishの探索に組み込んだ。
以前にLibTorchで実装した時に比べて、各段に速くなった。
Lesserkaiには余裕で勝つが、まだ学習が足りていないため、GPSFishには勝てない強さである。
次は、数十億局面学習させて強くなるか検証したい。