前回までで、層ごとの推論処理が実装できたので、今回は、PyTorchで学習したモデルをレイヤー融合して、推論用にパラメータを保存する処理を実装する。
レイヤー融合
畳み込み層の直後のBatchNorm2dは、畳み込み層のパラメータに融合できる。
PyTorchでは、fuse_modulesを使用して簡単に実装できる。
fuse_modules — PyTorch 2.6 documentation
fuse_modules(model, [['l2_1', 'bn2_1'], ['l2_2', 'bn2_2']], inplace=True)
パラメータの保存
パラメータのファイルを扱いやすくするため、パラメータを1ファイルにまとめて出力する。
C++で、読み込みしやすいようにサイズなども合わせて出力する。
検証用にパラメータ名、テンソルの形状も合わせて出力する。
state_dict = model.state_dict()
with open(args.bin, "wb") as f:
# 全パラメータ数をunsigned intとして書き出し
num_params = len(state_dict)
f.write(struct.pack("I", num_params))
for key, tensor in state_dict.items():
print(key)
# CPU に移動、float32 に変換
tensor = tensor.detach().cpu().float()
# パラメータ名を UTF-8 でエンコードし、バイト長とともに書き出す
key_bytes = key.encode('utf-8')
f.write(struct.pack("I", len(key_bytes)))
f.write(key_bytes)
# テンソルの形状情報を書き出す
shape = tensor.shape
f.write(struct.pack("I", len(shape))) # 次元数
for dim in shape:
f.write(struct.pack("I", dim))
# 要素数を書き出す
numel = tensor.numel()
f.write(struct.pack("I", numel))
# テンソルデータを書き出す(float32 の連続バイト列)
f.write(tensor.numpy().tobytes())
C++での読み込み処理
アロケータ
AVX2で推論するためパラメータを32バイト整列する必要がある。
Stockfishのmemory.hとmemory.cppで定義されているstd_aligned_allocとstd_aligned_freeを使用して32バイト整列して読み込む。
実装
以下の通り読み込み処理を実装した。
// パラメータ情報を保持する構造体 struct Parameter { std::string name; std::vector<uint32_t> shape; // 各次元のサイズ std::unique_ptr<float[], decltype(&std_aligned_free)> data; // コンストラクタ Parameter(const std::string& name, const std::vector<uint32_t>& shape, float* data) : name(name), shape(shape), data(data, std_aligned_free) { } }; // パラメータ std::vector<Parameter> parameters; bool read_parameters(const std::string& filename) { std::ifstream fin(filename, std::ios::binary); if (!fin) { std::cerr << "Error opening file: " << filename << std::endl; return false; } // 全パラメータ数 (unsigned int) を読み込む uint32_t num_params = 0; fin.read(reinterpret_cast<char*>(&num_params), sizeof(num_params)); std::cout << "Number of parameters: " << num_params << std::endl; parameters.clear(); parameters.reserve(num_params); // 各パラメータについて読み込みを行う for (uint32_t i = 0; i < num_params; ++i) { // (a) パラメータ名のバイト長を読み込む uint32_t key_length = 0; fin.read(reinterpret_cast<char*>(&key_length), sizeof(key_length)); if (!fin) { std::cerr << "Error reading data" << std::endl; return false; } // (b) パラメータ名(UTF-8)の文字列を読み込む std::string name(key_length, ' '); fin.read(&name[0], key_length); if (!fin) { std::cerr << "Error reading data" << std::endl; return false; } // (c) テンソルの次元数を読み込む uint32_t ndim = 0; fin.read(reinterpret_cast<char*>(&ndim), sizeof(ndim)); std::vector<uint32_t> shape(ndim); for (uint32_t d = 0; d < ndim; ++d) { fin.read(reinterpret_cast<char*>(&shape[d]), sizeof(uint32_t)); if (!fin) { std::cerr << "Error reading data" << std::endl; return false; } } // (d) テンソルの要素数 (numel) を読み込む uint32_t numel = 0; fin.read(reinterpret_cast<char*>(&numel), sizeof(numel)); if (!fin) { std::cerr << "Error reading data" << std::endl; return false; } #ifndef NDEBUG { uint32_t prod = 1; for (uint32_t d : shape) { prod *= d; } assert(prod == numel && "Shape product does not match numel"); } #endif // (e) テンソルデータ(float32 の連続バイト列)を読み込む float* data = reinterpret_cast<float*>(std_aligned_alloc(32, numel * sizeof(float))); if (!data) { std::cerr << "Error: Memory allocation failed for parameter " << name << std::endl; return false; } fin.read(reinterpret_cast<char*>(data), numel * sizeof(float)); if (!fin) { std::cerr << "Error reading data" << std::endl; return false; } // 32バイト整列を確認する assert assert(reinterpret_cast<std::uintptr_t>(data) % 32 == 0 && "Parameter data is not 32-byte aligned"); // 読み込んだ内容を Parameter 構造体に格納して登録 parameters.emplace_back(name, shape, data); // ログ出力:パラメータ名、shape、要素数 std::cout << "Loaded parameter: " << name << " | shape: ("; for (size_t j = 0; j < shape.size(); ++j) { std::cout << shape[j] << (j + 1 < shape.size() ? ", " : ""); } std::cout << ") | numel: " << numel << std::endl; } fin.close(); return true; }
テスト
gtestで、テストコードを実装した。
TEST(StockfishTest, read_parameters) { using namespace Stockfish; Eval::read_parameters(R"(dlshogi\model-008.bin)"); }
実行結果
Number of parameters: 10 Loaded parameter: l1_1.weight | shape: (29, 16) | numel: 464 Loaded parameter: l1_2.weight | shape: (58, 16) | numel: 928 Loaded parameter: l2_1.weight | shape: (4, 16, 9, 1) | numel: 576 Loaded parameter: l2_1.bias | shape: (4) | numel: 4 Loaded parameter: l2_2.weight | shape: (4, 16, 1, 9) | numel: 576 Loaded parameter: l2_2.bias | shape: (4) | numel: 4 Loaded parameter: l3.weight | shape: (32, 72) | numel: 2304 Loaded parameter: l3.bias | shape: (32) | numel: 32 Loaded parameter: l4.weight | shape: (1, 32) | numel: 32 Loaded parameter: l4.bias | shape: (1) | numel: 1
正しく読み込めている。
まとめ
PyTorchで学習したモデルをレイヤー融合してパラメータを保存し、C++で32バイト整列して読み込む処理を実装した。
次は、モデル全体を推論する処理を実装したい。