コンピュータチェスのCeresでは、PUCTによるノード選択の処理をAVXを使って高速化している。
これは、Ceres独自の「parallelized descent algorithm」(並列降下アルゴリズム)と合わせて使用することで、効果を発揮するもののようだ。
Ceresで実際にどれくらいNPSがでるか確認したところ、15ブロック、192フィルタのモデルを使用して、RTX30901枚で、初期局面のNPSは、142,685であった。
go nodes 1000000 Loaded network weights: 750181: 15x192 WDL MLH from H:\src\Ceres\lc0\weights_run3_750181.pb.gz info depth 1 seldepth 2 time 1016 nodes 1 score cp -0 tbhits 0 nps 1 pv g2g3 string M= 157 (略) info depth 10 seldepth 22 time 7005 nodes 999471 score cp 8 tbhits 0 nps 142685 pv c2c4 e7e6 g2g3 d7d5 f1g2 d5d4 g1f3 c7c5 e2e3 b8c6 e3d4 c5d4 e1g1 g8f6 f1e1 f8d6 a2a3 a7a5 d2d3 e8g8 b1d2 string M= 157 bestmove c2c4
一方、dlshogiは、10ブロック、192フィルタのモデルで、初期局面のNPSは48,245である。
チェスと将棋の盤面サイズ、終端ノードでの詰み探索のありなどの条件の違いはあるが、Ceresはdlshogiより大きいモデルで約2.96倍のNPSが出ている。
dlshogiの探索部もまだまだ高速化の余地があることが分かる。
Ceresの並列アルゴリズムをdlshogiにも導入したいと考えているが、とりあえず、並列アルゴリズムは現在のままで、ノード選択処理のAVX化を行ってみた。
Ceresは並列アルゴリズムと組み合わせてAVX化しているが、単にノードを選択する際に、複数の子ノードのUCBを計算して、UCBが最大のノードを選ぶ処理をAVX化した。
AVX化前の処理
for (int i = 0; i < child_num; i++) { const WinType win = uct_child[i].win; const int move_count = uct_child[i].move_count; if (move_count == 0) { // 未探索のノードの価値に、親ノードの価値を使用する q = parent_q; u = init_u; } else { q = (float)(win / move_count); u = sqrt_sum / (1 + move_count); } const float rate = uct_child[i].nnrate; const float ucb_value = q + c * u * rate; if (ucb_value > max_value) { max_value = ucb_value; max_child = i; } }
AVX化後の処理
const __m256i m256i_zero{}; const __m256i m256i_one = _mm256_set1_epi32(1); const __m256 m256_one = _mm256_set1_ps(1); const __m256i m256i_eight = _mm256_set1_epi32(8); __m256 m256_c = _mm256_broadcast_ss(&c); __m256 m256_parent_q = _mm256_broadcast_ss(&parent_q); __m256 m256_init_u = _mm256_broadcast_ss(&init_u); __m256 m256_sqrt_sum = _mm256_broadcast_ss(&sqrt_sum); __m256 vmaxvalue = _mm256_set1_ps(-FLT_MAX); __m256i vnowposition = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); __m256i vmaxposition = vnowposition; for (size_t i = 0; i < child_num; i += 8) { if (i + 8 > child_num) { // 残り8未満 __m256i mask_rest; switch (child_num - i) { case 1: mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); break; case 2: mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); break; case 3: mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); break; case 4: mask_rest = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); break; case 5: mask_rest = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); break; case 6: mask_rest = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); break; case 7: mask_rest = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); break; default: // unreachable mask_rest = _mm256_set1_epi32(0); break; } __m256i m256i_move_count = _mm256_maskload_epi32(move_count + i, mask_rest); __m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero); // q = (float)(win / move_count); __m256 m256_win = _mm256_maskload_ps(win + i, mask_rest); __m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count); __m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count); __m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask)); // u = sqrt_sum / (1 + move_count); __m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one); __m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1); __m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask)); __m256 m256_rate = _mm256_maskload_ps(nnrate + i, mask_rest); //const float ucb_value = q + c * u * rate; __m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u); m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate); m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value); // mask m256_ucb_value = _mm256_and_ps(m256_ucb_value, _mm256_castsi256_ps(mask_rest)); // find max __m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS); vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue); vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp)); vnowposition = _mm256_add_epi32(vnowposition, m256i_eight); break; } //if (move_count == 0) { __m256i m256i_move_count = _mm256_load_si256((__m256i*)(move_count + i)); __m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero); // // 未探索のノードの価値に、親ノードの価値を使用する // q = parent_q; // u = init_u; // --> 下記のelseの計算結果と合わせて_mm256_blendv_psで設定する //} //else { // q = (float)(win / move_count); __m256 m256_win = _mm256_load_ps(win + i); __m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count); __m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count); __m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask)); // u = sqrt_sum / (1 + move_count); __m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one); __m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1); __m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask)); //} //const float rate = uct_child[i].nnrate; __m256 m256_rate = _mm256_load_ps(nnrate + i); //const float ucb_value = q + c * u * rate; __m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u); m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate); m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value); // find max __m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS); vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue); vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp)); vnowposition = _mm256_add_epi32(vnowposition, m256i_eight); } // find max const int* maxposition = (int*)&vmaxposition; __m256 vallmax = _mm256_max_ps(vmaxvalue, _mm256_shuffle_ps(vmaxvalue, vmaxvalue, 0xb1)); vallmax = _mm256_max_ps(vallmax, _mm256_shuffle_ps(vallmax, vallmax, 0x4e)); vallmax = _mm256_max_ps(vallmax, _mm256_permute2f128_ps(vallmax, vallmax, 0x01)); __m256 vcmp = _mm256_cmp_ps(vallmax, vmaxvalue, _CMP_EQ_US); int mask = _mm256_movemask_ps(vcmp); max_child = maxposition[__builtin_ctz(mask)];
パッと見には何をやっているか分からない(;'∀')
8未満を処理するためにswichで分けている部分が美しくない。
最大値のインデックスを探す処理は、はじめ思いつかなかったが、Discordで教えを乞うたところ、やねうらお氏から参考情報を教えていただき、それを元に実装したコードを、@wain_CGP氏に添削していただいた(最終的にはもらったコードほぼそのまま)。
測定
上記のAVX化した処理を探索部に組み込んで、どれくら速くなるか測定を行った。
V100 8枚で、floodgateからサンプリングした100局面(1局面は詰みなので除外)で測定した。
https://github.com/TadaoYamaoka/DeepLearningShogi/blob/master/utils/benchmark.py
測定結果
AVXなし | AVXあり | 比 | |
---|---|---|---|
平均 | 229621 | 239431 | 1.04 |
中央値 | 235244 | 242866 | 1.03 |
最大 | 264526 | 392199 | 1.48 |
最小 | 148579 | 152466 | 1.03 |
平均で、4%程NPSが上昇した。
最大では、1.48倍の局面がある。
探索する局面の子ノードの数によって効果が変わってくる。
コード
探索部に組み込んだ処理はこちら。
DeepLearningShogi/UctSearch.cpp at 63d17b043c4ca0a37d15ea4e417a63cdffb29cc6 · TadaoYamaoka/DeepLearningShogi · GitHub
詰みのAND-OR木探索も含んでいるので、さらに美しくなくなっている。
まとめ
Ceresにインスパイアされてノード選択処理のAVX化を行った。
結果、平均で4%程高速化することができた。
AVX化を行ってもCeresのNPSには遠く及んでいない。
Ceresは、AVX化よりも衝突を回避した並列アルゴリズムの方がNPSへの寄与が大きいと思われる。
そのうち、Ceresの並列アルゴリズムをdlshogiにも導入したいと考えている。
2021/4/4 追記
AVX化の前後で強さを確認を測定した。
持ち時間3分、1手1秒加算
# PLAYER : RATING ERROR POINTS PLAYED (%) CFS(%) W D L D(%) 1 dlshogi_avx2 : 0.2 9.0 500.5 1000 50 52 370 261 369 26 2 dlshogi : -0.2 9.0 499.5 1000 50 --- 369 261 370 26
AVX化しても強くなっていない。