Ceresで実際にどれくらいNPSがでるか確認したところ、15ブロック、192フィルタのモデルを使用して、RTX30901枚で、初期局面のNPSは、142,685であった。
AVX化後の処理
const __m256i m256i_zero{};
const __m256i m256i_one = _mm256_set1_epi32(1);
const __m256 m256_one = _mm256_set1_ps(1);
const __m256i m256i_eight = _mm256_set1_epi32(8);
__m256 m256_c = _mm256_broadcast_ss(&c);
__m256 m256_parent_q = _mm256_broadcast_ss(&parent_q);
__m256 m256_init_u = _mm256_broadcast_ss(&init_u);
__m256 m256_sqrt_sum = _mm256_broadcast_ss(&sqrt_sum);
__m256 vmaxvalue = _mm256_set1_ps(-FLT_MAX);
__m256i vnowposition = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i vmaxposition = vnowposition;
for (size_t i = 0; i < child_num; i += 8) {
if (i + 8 > child_num) {
__m256i mask_rest;
switch (child_num - i) {
case 1:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
break;
case 2:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
break;
case 3:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
break;
case 4:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
break;
case 5:
mask_rest = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
break;
case 6:
mask_rest = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
break;
case 7:
mask_rest = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
break;
default:
mask_rest = _mm256_set1_epi32(0);
break;
}
__m256i m256i_move_count = _mm256_maskload_epi32(move_count + i, mask_rest);
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);
__m256 m256_win = _mm256_maskload_ps(win + i, mask_rest);
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
__m256 m256_rate = _mm256_maskload_ps(nnrate + i, mask_rest);
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);
m256_ucb_value = _mm256_and_ps(m256_ucb_value, _mm256_castsi256_ps(mask_rest));
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);
break;
}
__m256i m256i_move_count = _mm256_load_si256((__m256i*)(move_count + i));
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);
__m256 m256_win = _mm256_load_ps(win + i);
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
__m256 m256_rate = _mm256_load_ps(nnrate + i);
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);
}
const int* maxposition = (int*)&vmaxposition;
__m256 vallmax = _mm256_max_ps(vmaxvalue, _mm256_shuffle_ps(vmaxvalue, vmaxvalue, 0xb1));
vallmax = _mm256_max_ps(vallmax, _mm256_shuffle_ps(vallmax, vallmax, 0x4e));
vallmax = _mm256_max_ps(vallmax, _mm256_permute2f128_ps(vallmax, vallmax, 0x01));
__m256 vcmp = _mm256_cmp_ps(vallmax, vmaxvalue, _CMP_EQ_US);
int mask = _mm256_movemask_ps(vcmp);
max_child = maxposition[__builtin_ctz(mask)];
パッと見には何をやっているか分からない(;'∀')
8未満を処理するためにswichで分けている部分が美しくない。
最大値のインデックスを探す処理は、はじめ思いつかなかったが、Discordで教えを乞うたところ、やねうらお氏から参考情報を教えていただき、それを元に実装したコードを、@wain_CGP氏に添削していただいた(最終的にはもらったコードほぼそのまま)。