TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

ChainerのモデルをC++で読み込む

以前の日記でAlphaGoのSL policy networkをChainerで学習した結果をC++囲碁プログラムで使用したいと考えている。
その際、C++からPythonを実行するのでは、オーバーヘッドが大きく、実行環境のハードルが上がりポータビリティが下がってしまう。
そこで、C++ではDCNNをcuDNNを使って実装して、Chainerで学習したモデルを読み込んで使用したい。

ChainerのモデルをPythonプログラムで読み込んで任意の形式で出力してからC++で読み込んでもよいが、Chainerのモデルを直接C++で読み込めるようにすれば変換処理を実行する手間が省ける。

Chainerのserializers.save_npzで保存したモデルは、zipファイル形式になっており、zipファイルの中身は各レイヤーのパラメータの種類ごとに「レイヤー名/パラメータ名.npy」というファイルになっている。
.npyファイルは、NumPyの保存形式で、簡単なヘッダーとfloat型の配列がバイナリで保存されている。

このようにシンプル構成のため、C++から読み込むことも容易にできる。

zipファイルの読み込み

zipファイルの読み込みは、ファイルフォーマットの仕様を見て実装すればよい。こちらのページに日本語のわかりやすい解説もある。

全ての形式のzipファイルフォーマットに対応するのは大変だが、Chainerで保存される圧縮形式は固定なので、1パターンのみ実装すればよく、Local file headerとFile dataのみ使用して実装可能である。

以前の日記のWaveファイルの読み込みと同じテクニックで、Local file headerをアライメント1の構造体で定義して、ファイルから読み込んだバイナリを構造体として扱う。

Local file headerに続くFile dataをサイズの情報がLocal file headerにあるので、そのサイズ分読み込む。
以上を繰り返す。終了条件は、Local file headerのlocal_file_header_signatureで判定できる。

コード例
#pragma pack(push, 1)
struct LocalFileHeader
{
	unsigned long local_file_header_signature; // 4_bytes (0x04034b50)
	unsigned short version_needed_to_extract; // 2_bytes
	unsigned short general_purpose_bit_flag; // 2_bytes
	unsigned short compression_method; // 2_bytes
	unsigned short last_mod_file_time; // 2_bytes
	unsigned short last_mod_file_date; // 2_bytes
	unsigned long crc_32; // 4_bytes
	unsigned long compressed_size; // 4_bytes
	unsigned long uncompressed_size; // 4_bytes
	unsigned short file_name_length; // 2_bytes
	unsigned short extra_field_length; // 2_bytes
	// ここまで30bytes

	char* file_name; // (variable_size)
	char* extra_field; // (variable_size)

	LocalFileHeader() : file_name(nullptr), extra_field(nullptr) {}
	~LocalFileHeader() {
		delete[] file_name;
		delete[] extra_field;
	}

};
#pragma pack(pop)

int main()
{
	ifstream infile("sl_policy.model", ios_base::in | ios_base::binary);

	while (true)
	{
		// Local file header
		LocalFileHeader lfh;
		infile.read((char*)&lfh, 30);

		if (lfh.local_file_header_signature != 0x04034b50)
		{
			break;
		}

		lfh.file_name = new char[lfh.file_name_length + 1];

		infile.read(lfh.file_name, lfh.file_name_length);
		lfh.file_name[lfh.file_name_length] = '\0';

		infile.seekg(lfh.extra_field_length, ios_base::cur);

		// File data
		unsigned char* file_data = new unsigned char[lfh.compressed_size];
		infile.read((char*)file_data, lfh.compressed_size);
	}

	return 0;
}

圧縮データの伸縮

File dataは、圧縮されているので、zlibを使用して伸縮する。
zlibは、Visual StudioのNuGetから簡単にインストールできる。

zlibの使用方法は簡単で、inflateInit2でz_streamを初期化し、z_streamにデータのポインタを設定した後、inflateで伸縮し、inflateEndで終了処理を行う。
細かいパラメータはMinizipのコードを参考にした。

コード例
		z_stream strm = { 0 };
		inflateInit2(&strm, -MAX_WBITS);
		
		strm.next_in = file_data;
		strm.avail_in = lfh.compressed_size;
		strm.next_out = uncompressed_data;
		strm.avail_out = lfh.uncompressed_size;
		inflate(&strm, Z_NO_FLUSH);
		inflateEnd(&strm);

.npyファイルの読み込み

.npyファイルのフォーマットは、こちらのページに仕様がある。
先頭がヘッダーになっており、その後にデータの配列が格納されている。配列のshapeはヘッダーに書かれている。

.npyファイルもzipファイルと同様にアライメント1の構造体でヘッダーを定義して読み込む。
zipファイルの読み込みと合わせて、ヘッダーの内容とデータの先頭10個を表示する例を以下に示す。
サンプルのため配列はshapeを無視して1次元配列として扱っている。

コード例
#include <zlib.h>
#include <fstream>

using namespace std;

#pragma pack(push, 1)
struct LocalFileHeader
{
	unsigned long local_file_header_signature; // 4_bytes (0x04034b50)
	unsigned short version_needed_to_extract; // 2_bytes
	unsigned short general_purpose_bit_flag; // 2_bytes
	unsigned short compression_method; // 2_bytes
	unsigned short last_mod_file_time; // 2_bytes
	unsigned short last_mod_file_date; // 2_bytes
	unsigned long crc_32; // 4_bytes
	unsigned long compressed_size; // 4_bytes
	unsigned long uncompressed_size; // 4_bytes
	unsigned short file_name_length; // 2_bytes
	unsigned short extra_field_length; // 2_bytes
	// ここまで30bytes

	char* file_name; // (variable_size)
	char* extra_field; // (variable_size)

	LocalFileHeader() : file_name(nullptr), extra_field(nullptr) {}
	~LocalFileHeader() {
		delete[] file_name;
		delete[] extra_field;
	}

};

struct NPY
{
	char magic_string[6]; // 6 bytes (0x93NUMPY)
	unsigned char major_version; // 1 byte
	unsigned char minor_version; // 1 byte
	unsigned short header_len; // 2 bytes
	// ここまで10bytes
};
#pragma pack(pop)

int main()
{
	ifstream infile("sl_policy.model", ios_base::in | ios_base::binary);

	while (true)
	{
		// Local file header
		LocalFileHeader lfh;
		infile.read((char*)&lfh, 30);

		if (lfh.local_file_header_signature != 0x04034b50)
		{
			break;
		}

		lfh.file_name = new char[lfh.file_name_length + 1];

		infile.read(lfh.file_name, lfh.file_name_length);
		lfh.file_name[lfh.file_name_length] = '\0';

		infile.seekg(lfh.extra_field_length, ios_base::cur);

		// File data
		unsigned char* file_data = new unsigned char[lfh.compressed_size];
		infile.read((char*)file_data, lfh.compressed_size);

		unsigned char* uncompressed_data = new unsigned char[lfh.uncompressed_size];

		z_stream strm = { 0 };
		inflateInit2(&strm, -MAX_WBITS);
		
		strm.next_in = file_data;
		strm.avail_in = lfh.compressed_size;
		strm.next_out = uncompressed_data;
		strm.avail_out = lfh.uncompressed_size;
		inflate(&strm, Z_NO_FLUSH);
		inflateEnd(&strm);

		// NPY
		NPY* npy = (NPY*)uncompressed_data;

		char array_format_string[256];
		memcpy(array_format_string, uncompressed_data + 10, npy->header_len);
		array_format_string[npy->header_len] = '\0';
		printf("%s", array_format_string);

		float* w = (float*)(uncompressed_data + 10 + npy->header_len);
		for (int i = 0; i < 10; i++)
		{
			printf("%f ", w[i]);
		}
		printf("\n");
	}

	return 0;
}
実行結果
layer10/W.npy
{'descr': '<f4', 'fortran_order': False, 'shape': (192, 192, 3, 3), }
-0.016173 0.015109 -0.003892 0.008751 -0.021007 -0.006746 -0.009573 -0.006381 0.009846 -0.008884
layer11/b.npy
{'descr': '<f4', 'fortran_order': False, 'shape': (192,), }
0.002800 0.002900 -0.002989 0.004142 0.001429 -0.000620 0.005034 0.006323 -0.000228 0.005045
layer13_2/b.npy
{'descr': '<f4', 'fortran_order': False, 'shape': (361,), }
-0.007827 -0.000078 -0.001707 0.001981 -0.003589 -0.004128 0.002796 0.001678 0.003007 0.002428
layer6/W.npy
{'descr': '<f4', 'fortran_order': False, 'shape': (192, 192, 3, 3), }
0.017788 -0.003138 0.022310 -0.020758 -0.021130 0.013108 -0.009873 0.023039 0.008826 0.018210
(略)