TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

将棋AIの実験ノート:AVX対応

コンピュータチェスのCeresでは、PUCTによるノード選択の処理をAVXを使って高速化している。
これは、Ceres独自の「parallelized descent algorithm」(並列降下アルゴリズム)と合わせて使用することで、効果を発揮するもののようだ。

Ceresで実際にどれくらいNPSがでるか確認したところ、15ブロック、192フィルタのモデルを使用して、RTX30901枚で、初期局面のNPSは、142,685であった。

go nodes 1000000
Loaded network weights: 750181: 15x192 WDL MLH  from H:\src\Ceres\lc0\weights_run3_750181.pb.gz
info depth 1 seldepth 2 time 1016 nodes 1 score cp -0 tbhits 0 nps 1 pv g2g3  string M= 157
(略)
info depth 10 seldepth 22 time 7005 nodes 999471 score cp 8 tbhits 0 nps 142685 pv c2c4 e7e6 g2g3 d7d5 f1g2 d5d4 g1f3 c7c5 e2e3 b8c6 e3d4 c5d4 e1g1 g8f6 f1e1 f8d6 a2a3 a7a5 d2d3 e8g8 b1d2  string M= 157
bestmove c2c4

一方、dlshogiは、10ブロック、192フィルタのモデルで、初期局面のNPSは48,245である。
チェスと将棋の盤面サイズ、終端ノードでの詰み探索のありなどの条件の違いはあるが、Ceresはdlshogiより大きいモデルで約2.96倍のNPSが出ている。

dlshogiの探索部もまだまだ高速化の余地があることが分かる。

Ceresの並列アルゴリズムをdlshogiにも導入したいと考えているが、とりあえず、並列アルゴリズムは現在のままで、ノード選択処理のAVX化を行ってみた。

Ceresは並列アルゴリズムと組み合わせてAVX化しているが、単にノードを選択する際に、複数の子ノードのUCBを計算して、UCBが最大のノードを選ぶ処理をAVX化した。

AVX化前の処理

	for (int i = 0; i < child_num; i++) {
		const WinType win = uct_child[i].win;
		const int move_count = uct_child[i].move_count;

		if (move_count == 0) {
			// 未探索のノードの価値に、親ノードの価値を使用する
			q = parent_q;
			u = init_u;
		}
		else {
			q = (float)(win / move_count);
			u = sqrt_sum / (1 + move_count);
		}

		const float rate = uct_child[i].nnrate;
		const float ucb_value = q + c * u * rate;
		if (ucb_value > max_value) {
			max_value = ucb_value;
			max_child = i;
		}
	}

AVX化後の処理

const __m256i m256i_zero{};
const __m256i m256i_one = _mm256_set1_epi32(1);
const __m256 m256_one = _mm256_set1_ps(1);
const __m256i m256i_eight = _mm256_set1_epi32(8);

	__m256 m256_c = _mm256_broadcast_ss(&c);
	__m256 m256_parent_q = _mm256_broadcast_ss(&parent_q);
	__m256 m256_init_u = _mm256_broadcast_ss(&init_u);
	__m256 m256_sqrt_sum = _mm256_broadcast_ss(&sqrt_sum);

	__m256 vmaxvalue = _mm256_set1_ps(-FLT_MAX);
	__m256i vnowposition = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
	__m256i vmaxposition = vnowposition;
	for (size_t i = 0; i < child_num; i += 8) {
		if (i + 8 > child_num) {
			// 残り8未満
			__m256i mask_rest;
			switch (child_num - i) {
			case 1:
				mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
				break;
			case 2:
				mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
				break;
			case 3:
				mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
				break;
			case 4:
				mask_rest = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
				break;
			case 5:
				mask_rest = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
				break;
			case 6:
				mask_rest = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
				break;
			case 7:
				mask_rest = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
				break;
			default:
				// unreachable
				mask_rest = _mm256_set1_epi32(0);
				break;
			}
			__m256i m256i_move_count = _mm256_maskload_epi32(move_count + i, mask_rest);
			__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);

			//	q = (float)(win / move_count);
			__m256 m256_win = _mm256_maskload_ps(win + i, mask_rest);
			__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
			__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
			__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));

			//	u = sqrt_sum / (1 + move_count);
			__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
			__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
			__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
			__m256 m256_rate = _mm256_maskload_ps(nnrate + i, mask_rest);

			//const float ucb_value = q + c * u * rate;
			__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
			m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
			m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);

			// mask
			m256_ucb_value = _mm256_and_ps(m256_ucb_value, _mm256_castsi256_ps(mask_rest));

			// find max
			__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
			vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
			vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
			vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);

			break;
		}

		//if (move_count == 0) {
		__m256i m256i_move_count = _mm256_load_si256((__m256i*)(move_count + i));
		__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);

		//	// 未探索のノードの価値に、親ノードの価値を使用する
		//	q = parent_q;
		//	u = init_u;
		//  --> 下記のelseの計算結果と合わせて_mm256_blendv_psで設定する
		//}
		//else {
		//	q = (float)(win / move_count);
		__m256 m256_win = _mm256_load_ps(win + i);
		__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
		__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
		__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));

		//	u = sqrt_sum / (1 + move_count);
		__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
		__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
		__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
		//}

		//const float rate = uct_child[i].nnrate;
		__m256 m256_rate = _mm256_load_ps(nnrate + i);

		//const float ucb_value = q + c * u * rate;
		__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
		m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
		m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);

		// find max
		__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
		vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
		vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
		vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);
	}
	// find max
	const int* maxposition = (int*)&vmaxposition;
	__m256 vallmax = _mm256_max_ps(vmaxvalue, _mm256_shuffle_ps(vmaxvalue, vmaxvalue, 0xb1));
	vallmax = _mm256_max_ps(vallmax, _mm256_shuffle_ps(vallmax, vallmax, 0x4e));
	vallmax = _mm256_max_ps(vallmax, _mm256_permute2f128_ps(vallmax, vallmax, 0x01));
	__m256 vcmp = _mm256_cmp_ps(vallmax, vmaxvalue, _CMP_EQ_US);
	int mask = _mm256_movemask_ps(vcmp);
	max_child = maxposition[__builtin_ctz(mask)];

パッと見には何をやっているか分からない(;'∀')
8未満を処理するためにswichで分けている部分が美しくない。

最大値のインデックスを探す処理は、はじめ思いつかなかったが、Discordで教えを乞うたところ、やねうらお氏から参考情報を教えていただき、それを元に実装したコードを、@wain_CGP氏に添削していただいた(最終的にはもらったコードほぼそのまま)。

測定

上記のAVX化した処理を探索部に組み込んで、どれくら速くなるか測定を行った。
V100 8枚で、floodgateからサンプリングした100局面(1局面は詰みなので除外)で測定した。
https://github.com/TadaoYamaoka/DeepLearningShogi/blob/master/utils/benchmark.py

測定結果
AVXなし AVXあり
平均 229621 239431 1.04
中央値 235244 242866 1.03
最大 264526 392199 1.48
最小 148579 152466 1.03

平均で、4%程NPSが上昇した。
最大では、1.48倍の局面がある。
探索する局面の子ノードの数によって効果が変わってくる。

コード

探索部に組み込んだ処理はこちら。
DeepLearningShogi/UctSearch.cpp at 63d17b043c4ca0a37d15ea4e417a63cdffb29cc6 · TadaoYamaoka/DeepLearningShogi · GitHub
詰みのAND-OR木探索も含んでいるので、さらに美しくなくなっている。

まとめ

Ceresにインスパイアされてノード選択処理のAVX化を行った。
結果、平均で4%程高速化することができた。

AVX化を行ってもCeresのNPSには遠く及んでいない。
Ceresは、AVX化よりも衝突を回避した並列アルゴリズムの方がNPSへの寄与が大きいと思われる。
そのうち、Ceresの並列アルゴリズムをdlshogiにも導入したいと考えている。

2021/4/4 追記

AVX化の前後で強さを確認を測定した。
持ち時間3分、1手1秒加算

   # PLAYER          :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
   1 dlshogi_avx2    :     0.2    9.0   500.5    1000    50      52  370  261  369    26
   2 dlshogi         :    -0.2    9.0   499.5    1000    50     ---  369  261  370    26

AVX化しても強くなっていない。