TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク その10(レイヤー融合とパラメータ保存)

前回までで、層ごとの推論処理が実装できたので、今回は、PyTorchで学習したモデルをレイヤー融合して、推論用にパラメータを保存する処理を実装する。

レイヤー融合

畳み込み層の直後のBatchNorm2dは、畳み込み層のパラメータに融合できる。

PyTorchでは、fuse_modulesを使用して簡単に実装できる。

fuse_modules — PyTorch 2.6 documentation

    fuse_modules(model, [['l2_1', 'bn2_1'], ['l2_2', 'bn2_2']], inplace=True)

パラメータの保存

パラメータのファイルを扱いやすくするため、パラメータを1ファイルにまとめて出力する。
C++で、読み込みしやすいようにサイズなども合わせて出力する。
検証用にパラメータ名、テンソルの形状も合わせて出力する。

    state_dict = model.state_dict()
    
    with open(args.bin, "wb") as f:
        # 全パラメータ数をunsigned intとして書き出し
        num_params = len(state_dict)
        f.write(struct.pack("I", num_params))
        
        for key, tensor in state_dict.items():
            print(key)

            # CPU に移動、float32 に変換
            tensor = tensor.detach().cpu().float()
            
            # パラメータ名を UTF-8 でエンコードし、バイト長とともに書き出す
            key_bytes = key.encode('utf-8')
            f.write(struct.pack("I", len(key_bytes)))
            f.write(key_bytes)
            
            # テンソルの形状情報を書き出す
            shape = tensor.shape
            f.write(struct.pack("I", len(shape)))  # 次元数
            for dim in shape:
                f.write(struct.pack("I", dim))
                
            # 要素数を書き出す
            numel = tensor.numel()
            f.write(struct.pack("I", numel))
            
            # テンソルデータを書き出す(float32 の連続バイト列)
            f.write(tensor.numpy().tobytes())

C++での読み込み処理

アロケータ

AVX2で推論するためパラメータを32バイト整列する必要がある。
Stockfishのmemory.hとmemory.cppで定義されているstd_aligned_allocとstd_aligned_freeを使用して32バイト整列して読み込む。

実装

以下の通り読み込み処理を実装した。

// パラメータ情報を保持する構造体
struct Parameter {
    std::string name;
    std::vector<uint32_t> shape;  // 各次元のサイズ
    std::unique_ptr<float[], decltype(&std_aligned_free)> data;

    // コンストラクタ
    Parameter(const std::string& name, const std::vector<uint32_t>& shape, float* data)
        : name(name), shape(shape), data(data, std_aligned_free) { }
};

// パラメータ
std::vector<Parameter> parameters;

bool read_parameters(const std::string& filename) {
    std::ifstream fin(filename, std::ios::binary);
    if (!fin) {
        std::cerr << "Error opening file: " << filename << std::endl;
        return false;
    }

    // 全パラメータ数 (unsigned int) を読み込む
    uint32_t num_params = 0;
    fin.read(reinterpret_cast<char*>(&num_params), sizeof(num_params));
    std::cout << "Number of parameters: " << num_params << std::endl;

    parameters.clear();
    parameters.reserve(num_params);

    // 各パラメータについて読み込みを行う
    for (uint32_t i = 0; i < num_params; ++i) {
        // (a) パラメータ名のバイト長を読み込む
        uint32_t key_length = 0;
        fin.read(reinterpret_cast<char*>(&key_length), sizeof(key_length));
        if (!fin) {
            std::cerr << "Error reading data" << std::endl;
            return false;
        }

        // (b) パラメータ名(UTF-8)の文字列を読み込む
        std::string name(key_length, ' ');
        fin.read(&name[0], key_length);
        if (!fin) {
            std::cerr << "Error reading data" << std::endl;
            return false;
        }

        // (c) テンソルの次元数を読み込む
        uint32_t ndim = 0;
        fin.read(reinterpret_cast<char*>(&ndim), sizeof(ndim));
        std::vector<uint32_t> shape(ndim);
        for (uint32_t d = 0; d < ndim; ++d) {
            fin.read(reinterpret_cast<char*>(&shape[d]), sizeof(uint32_t));
            if (!fin) {
                std::cerr << "Error reading data" << std::endl;
                return false;
            }
        }

        // (d) テンソルの要素数 (numel) を読み込む
        uint32_t numel = 0;
        fin.read(reinterpret_cast<char*>(&numel), sizeof(numel));
        if (!fin) {
            std::cerr << "Error reading data" << std::endl;
            return false;
        }

#ifndef NDEBUG
        {
            uint32_t prod = 1;
            for (uint32_t d : shape) {
                prod *= d;
            }
            assert(prod == numel && "Shape product does not match numel");
        }
#endif

        // (e) テンソルデータ(float32 の連続バイト列)を読み込む
        float* data = reinterpret_cast<float*>(std_aligned_alloc(32, numel * sizeof(float)));
        if (!data) {
            std::cerr << "Error: Memory allocation failed for parameter " << name << std::endl;
            return false;
        }
        fin.read(reinterpret_cast<char*>(data), numel * sizeof(float));
        if (!fin) {
            std::cerr << "Error reading data" << std::endl;
            return false;
        }

        // 32バイト整列を確認する assert
        assert(reinterpret_cast<std::uintptr_t>(data) % 32 == 0 && "Parameter data is not 32-byte aligned");

        // 読み込んだ内容を Parameter 構造体に格納して登録
        parameters.emplace_back(name, shape, data);

        // ログ出力:パラメータ名、shape、要素数
        std::cout << "Loaded parameter: " << name << " | shape: (";
        for (size_t j = 0; j < shape.size(); ++j) {
            std::cout << shape[j] << (j + 1 < shape.size() ? ", " : "");
        }
        std::cout << ") | numel: " << numel << std::endl;
    }

    fin.close();

    return true;
}
テスト

gtestで、テストコードを実装した。

TEST(StockfishTest, read_parameters) {
    using namespace Stockfish;

    Eval::read_parameters(R"(dlshogi\model-008.bin)");
}
実行結果
Number of parameters: 10
Loaded parameter: l1_1.weight | shape: (29, 16) | numel: 464
Loaded parameter: l1_2.weight | shape: (58, 16) | numel: 928
Loaded parameter: l2_1.weight | shape: (4, 16, 9, 1) | numel: 576
Loaded parameter: l2_1.bias | shape: (4) | numel: 4
Loaded parameter: l2_2.weight | shape: (4, 16, 1, 9) | numel: 576
Loaded parameter: l2_2.bias | shape: (4) | numel: 4
Loaded parameter: l3.weight | shape: (32, 72) | numel: 2304
Loaded parameter: l3.bias | shape: (32) | numel: 32
Loaded parameter: l4.weight | shape: (1, 32) | numel: 32
Loaded parameter: l4.bias | shape: (1) | numel: 1

正しく読み込めている。

まとめ

PyTorchで学習したモデルをレイヤー融合してパラメータを保存し、C++で32バイト整列して読み込む処理を実装した。
次は、モデル全体を推論する処理を実装したい。