前回、軽量価値ネットワークのモデル全体の推論をC++でSIMDを使用して実装した。
今回は、推論を探索に組み込む。
ネットワーク構成見直し
推論を組み込む前にネットワーク構成で気になる箇所があったので見直した。
埋め込み層を実装した際に、permuteを忘れていて、permuteを追加したところ逆に精度が下がるという現象があった。
permuteがないと、9x1と1x9の畳み込みに将棋盤の正しい空間構造が伝わらないはずだが、精度が上がるのが謎であった。
仮説として、9x1と1x9の畳み込みは、縦と横だけの特徴量を抽出するので、斜め方向の特徴が捉えられていないことが原因と考えた。
そこで、2x2カーネルのDepth wiseの畳み込みを追加して、局所的な斜め方向の特徴も捉えるようにする。
パディングを行わないことで、出力のサイズを8x8にできるので、SIMDで計算する上でも都合がよい。
3x3も試したが、計算量が増えて、出力サイズ7x7になるため、後段の層のパラメータが減るため精度は上がらなかった。
Depth wiseにすることで、計算量の増加を抑えている。
修正後のネットワーク構成は以下のようになる。
class LiteValueNetwork(nn.Module): def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()): super(LiteValueNetwork, self).__init__() self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1) self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2) self.l2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[0], kernel_size=2, groups=dims[0], bias=False) self.bn2 = nn.BatchNorm2d(dims[0]) self.l3_1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(8, 1), bias=False) self.l3_2 = nn.Conv2d(in_channels=dims[0], out_channels=dims[1], kernel_size=(1, 8), bias=False) self.bn3_1 = nn.BatchNorm2d(dims[1]) self.bn3_2 = nn.BatchNorm2d(dims[1]) self.l4 = nn.Linear(dims[1] * 8 * 2, dims[2]) self.l5 = nn.Linear(dims[2], 1) self.act = activation self.dims = dims def forward(self, x1, x2): h1_1 = self.l1_1(x1).permute(0, 2, 1).view(-1, self.dims[0], 9, 9) h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1) h1 = h1_1 + h1_2 h2 = self.act(self.bn2(self.l2(h1))) h3_1 = self.bn3_1(self.l3_1(h2)) h3_2 = self.bn3_2(self.l3_2(h2)) h3 = self.act(torch.cat((h3_1.reshape(-1, self.dims[1] * 8), h3_2.reshape(-1, self.dims[1] * 8)), 1)) h4 = self.act(self.l4(h3)) h5 = self.l5(h4) return h5
精度
修正後、同一条件で学習したところ、精度が向上した。
価値の正解率が修正前59.7%から、修正後60.8%になった。
推論速度
C++側の実装も修正して、推論速度を測定した。
Performance: 10000 iterations in 14.7859 ms Average time per iteration: 0.00147859 ms
修正前より、速くなっている。
探索に組み込み
推論処理を探索に組み込んだ。
Eval::evaluateの実装は済んでいるので、USIオプションでパラメータを読み込むだけである。
実装を簡略化するため、グルーバル変数に読み込むようにしている。
NPS
初期局面で、1スレッドで探索した場合のNPSは、約70万になった。
水匠5を1スレッドで探索した際のNPSは、約124万である。
水匠5の56%くらいの速度である。
水匠5のネットワーク構成は、HalfKP-256×2-32-32で、今回実装した軽量価値ネットワークは、埋め込み(9x9x16=1296)を除くと、1024-64-32なので、パラメータ数は比較的多い。
パラメータ数に対しては、妥当な速度と言えそうである。
最近のNNUEもパラメータ数を増やして精度を上げる方向なので、一旦このままで進めることにする。
以前にLibTorchで実装した時は、2万程度だったので35倍程度速くなっている。
対局
Lesserkai相手には余裕で勝てることを確認した。

GPSFishにはまだ勝てなかった。

36,819,741局面を8エポック学習しただけなので、学習が足りていないためと考える。