TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク その9(埋め込み層)

前回は、軽量価値ネットワークの連結と活性化関数の推論をSIMDで実装した。

今回は、埋め込み層の推論処理を実装する。

埋め込み層

軽量価値ネットワーク

    def __init__(self, dims=(16, 4, 32), activation=nn.ReLU()):
        self.l1_1 = nn.Embedding(NUM_EMBEDDINGS1 + 1, dims[0], padding_idx=NUM_EMBEDDINGS1)
        self.l1_2 = nn.EmbeddingBag(NUM_EMBEDDINGS2 + 1, dims[0], mode='sum', padding_idx=NUM_EMBEDDINGS2)
        # ...

    def forward(self, x1, x2):
        h1_1 = self.l1_1(x1).permute(0, 2, 1).view(-1, self.dims[0], 9, 9)
        h1_2 = self.l1_2(x2).view(-1, self.dims[0], 1, 1)
        h1 = h1_1 + h1_2
        # ...

の部分の処理を実装する。

以前に軽量価値ネットワークを学習したときは、permuteを忘れていたため、テンソルのメモリレイアウトが不適切だった。

修正して再学習したところ、価値の正解率が59.7%となり、permuteなしの時(61.3%)より下がってしまった。

permuteがないと盤面の空間構造を畳み込み層に正しく伝えることはできないが、逆に精度が上がるのは謎である。

Claude 3.7 Sonnet Extended Thinkingに原因を聞いてみたところ、

実際の盤面データとモデルの相互作用により、理論的には正しいと思われる変換よりも、実際のデータ構造に適した処理の方が良い結果をもたらすことは珍しくありません。これは「データサイエンスの実践は理論だけでは説明できない」側面を表しています。

とのことである。

一旦、正しい方で実装する。

実装

実装したコードは、以下の通り。

埋め込み層の推論は、計算を行わず、重みをインデックスで辞書引きするだけである。
持ち駒などのスカラ特徴量は、EmbeddingBagの結果を盤の各座標の埋め込みにブロードキャストする。

PyTorchでは、入力は固定長ベクトルで、パディングを行っているが、推論時は、入力ベクトルを作らずに直接重みを辞書引きすることで効率化する。

#include <immintrin.h>
#include <cstring>
#include <algorithm>
#include <cassert>

template <Color turn, int EmbeddingDim>
void embedding_layers(
    const Position& position,
    const float* __restrict embedding_table1,
    const float* __restrict embedding_table2,
    float* __restrict output   // 出力サイズ:EmbeddingDim * SquareNum
) {
    static_assert(EmbeddingDim % 8 == 0, "EmbeddingDim must be a multiple of 8");
    // 32バイト整列チェック
    assert(reinterpret_cast<uintptr_t>(embedding_table1) % 32 == 0);
    assert(reinterpret_cast<uintptr_t>(embedding_table2) % 32 == 0);
    assert(reinterpret_cast<uintptr_t>(output) % 32 == 0);

    // 出力バッファを0で初期化(全チャネル・全マス)
    const __m256 zero = _mm256_setzero_ps();
    for (int i = 0; i < EmbeddingDim * SquareNum; i += 8) {
        _mm256_store_ps(output + i, zero);
    }

    // 持ち駒・王手フラグ用の一時バッファ (アライメント済み)
    alignas(32) float hand_embed[EmbeddingDim];
    for (int i = 0; i < EmbeddingDim; i += 8) {
        _mm256_store_ps(hand_embed + i, zero);
    }

    // --- 盤面上の駒の埋め込み (l1_1の処理) ---
    Bitboard occupied_bb = position.occupiedBB();

    // FOREACH_BB マクロ内で各盤面上の駒について処理
    FOREACH_BB(occupied_bb, Square sq, {
        const Piece pc = position.piece(sq);
        const PieceType pt = pieceToPieceType(pc);
        Color c = pieceToColor(pc);

        // 後手の場合、色を反転し、盤面を180度回転
        if (turn == White) {
            c = oppositeColor(c);
            sq = SQ99 - sq;
        }

        // 駒配置に対応するインデックス (例: idx = PIECETYPE_NUM * (int)c + pt - 1)
        const int idx = PIECETYPE_NUM * (int)c + pt - 1;

        // embedding_table1 の該当行から埋め込みベクトルを取得し、出力バッファに「転置して」格納する。
        // 具体的には、各チャネル j について出力のインデックスは [j * SquareNum + sq] となる。
        const float* embed_vec = &embedding_table1[idx * EmbeddingDim];
        float* out_ptr = output + sq;
        if constexpr (EmbeddingDim == 16) {
            // EmbeddingDim==16 の場合、アンローリングして各チャネルを個別に格納
            out_ptr[0 * SquareNum] = embed_vec[0];
            out_ptr[1 * SquareNum] = embed_vec[1];
            out_ptr[2 * SquareNum] = embed_vec[2];
            out_ptr[3 * SquareNum] = embed_vec[3];
            out_ptr[4 * SquareNum] = embed_vec[4];
            out_ptr[5 * SquareNum] = embed_vec[5];
            out_ptr[6 * SquareNum] = embed_vec[6];
            out_ptr[7 * SquareNum] = embed_vec[7];
            out_ptr[8 * SquareNum] = embed_vec[8];
            out_ptr[9 * SquareNum] = embed_vec[9];
            out_ptr[10 * SquareNum] = embed_vec[10];
            out_ptr[11 * SquareNum] = embed_vec[11];
            out_ptr[12 * SquareNum] = embed_vec[12];
            out_ptr[13 * SquareNum] = embed_vec[13];
            out_ptr[14 * SquareNum] = embed_vec[14];
            out_ptr[15 * SquareNum] = embed_vec[15];
        }
        else {
            for (int j = 0; j < EmbeddingDim; j++) {
                out_ptr[j * SquareNum] = embed_vec[j];
            }
        }
    });

    // --- 持ち駒の埋め込み (l1_2の処理) ---
    // 各色ごとに、保持している駒の埋め込みベクトルを hand_embed に加算する。
    for (Color c = Black; c < ColorNum; ++c) {
        // 後手の場合、色反転
        const Color c2 = (turn == Black) ? c : oppositeColor(c);

        // 持ち駒情報
        const Hand hand = position.hand(c);
        int p = 0;
        for (HandPiece hp = HPawn; hp < HandPieceNum; ++hp) {
            const u32 num = std::min(hand.numOf(hp), MAX_PIECES_IN_HAND[hp]);
            const int base_idx = MAX_PIECES_IN_HAND_SUM * static_cast<int>(c2) + p;

            for (u32 i = 0; i < num; ++i) {
                const int idx = base_idx + i;
                const float* embed_vec = &embedding_table2[idx * EmbeddingDim];

                if constexpr (EmbeddingDim == 16) {
                    // AVX2で8要素ずつ加算
                    __m256 curr1 = _mm256_load_ps(hand_embed);
                    __m256 curr2 = _mm256_load_ps(hand_embed + 8);
                    __m256 add1 = _mm256_load_ps(embed_vec);
                    __m256 add2 = _mm256_load_ps(embed_vec + 8);

                    curr1 = _mm256_add_ps(curr1, add1);
                    curr2 = _mm256_add_ps(curr2, add2);

                    _mm256_store_ps(hand_embed, curr1);
                    _mm256_store_ps(hand_embed + 8, curr2);
                }
                else {
                    for (int j = 0; j < EmbeddingDim; j += 8) {
                        __m256 curr = _mm256_load_ps(hand_embed + j);
                        __m256 add = _mm256_load_ps(embed_vec + j);
                        _mm256_store_ps(hand_embed + j, _mm256_add_ps(curr, add));
                    }
                }
            }
            p += MAX_PIECES_IN_HAND[hp];
        }
    }

    // 王手フラグの埋め込みを加算
    if (position.inCheck()) {
        const int idx = MAX_FEATURES2_HAND_NUM;
        const float* check_embed_vec = &embedding_table2[idx * EmbeddingDim];

        if constexpr (EmbeddingDim == 16) {
            __m256 curr1 = _mm256_load_ps(hand_embed);
            __m256 curr2 = _mm256_load_ps(hand_embed + 8);
            __m256 add1 = _mm256_load_ps(check_embed_vec);
            __m256 add2 = _mm256_load_ps(check_embed_vec + 8);
            _mm256_store_ps(hand_embed, _mm256_add_ps(curr1, add1));
            _mm256_store_ps(hand_embed + 8, _mm256_add_ps(curr2, add2));
        }
        else {
            for (int j = 0; j < EmbeddingDim; j += 8) {
                __m256 curr = _mm256_load_ps(hand_embed + j);
                __m256 add = _mm256_load_ps(check_embed_vec + j);
                _mm256_store_ps(hand_embed + j, _mm256_add_ps(curr, add));
            }
        }
    }

    // --- 持ち駒・王手フラグの埋め込みを盤面全体にブロードキャストして加算 ---
    // 出力バッファはチャネル優先のレイアウトになっているため、各チャネルごとにhand_embedの値を各マスに加算する。
    for (int c = 0; c < EmbeddingDim; c++) {
        float* out_channel = &output[c * SquareNum];
        __m256 hand_vec = _mm256_set1_ps(hand_embed[c]);

        // 32要素ずつ処理(4x8=32)
        for (int s = 0; s < 64; s += 32) {
            __m256 data1 = _mm256_loadu_ps(out_channel + s);
            __m256 data2 = _mm256_loadu_ps(out_channel + s + 8);
            __m256 data3 = _mm256_loadu_ps(out_channel + s + 16);
            __m256 data4 = _mm256_loadu_ps(out_channel + s + 24);

            data1 = _mm256_add_ps(data1, hand_vec);
            data2 = _mm256_add_ps(data2, hand_vec);
            data3 = _mm256_add_ps(data3, hand_vec);
            data4 = _mm256_add_ps(data4, hand_vec);

            _mm256_storeu_ps(out_channel + s, data1);
            _mm256_storeu_ps(out_channel + s + 8, data2);
            _mm256_storeu_ps(out_channel + s + 16, data3);
            _mm256_storeu_ps(out_channel + s + 24, data4);
        }

        // 残り17要素(64+16=80、残り1要素)
        __m256 data5 = _mm256_loadu_ps(out_channel + 64);
        __m256 data6 = _mm256_loadu_ps(out_channel + 72);
        data5 = _mm256_add_ps(data5, hand_vec);
        data6 = _mm256_add_ps(data6, hand_vec);
        _mm256_storeu_ps(out_channel + 64, data5);
        _mm256_storeu_ps(out_channel + 72, data6);

        // 最後の1要素
        out_channel[80] += hand_embed[c];
    }
}

template <int EmbeddingDim>
void embedding_layers(
    const Position& position,
    const float* __restrict embedding_table1,
    const float* __restrict embedding_table2,
    float* __restrict output
) {
    if (position.turn() == Black) {
        embedding_layers<Black, EmbeddingDim>(position, embedding_table1, embedding_table2, output);
    }
    else {
        embedding_layers<White, EmbeddingDim>(position, embedding_table1, embedding_table2, output);
    }
}

void embedding_layers(
    const Position& position,
    const float* __restrict embedding_table1,
    const float* __restrict embedding_table2,
    float* __restrict output
) {
    embedding_layers<EMBEDDING_DIM>(position, embedding_table1, embedding_table2, output);
}

以下の点が高速化のために工夫されている。

  • AVX2命令によるSIMD並列化

8要素ずつ同時に処理するため、浮動小数点演算をAVX2の __m256 命令で実施している。

  • 32バイト整列

埋め込みテーブルや出力バッファのアドレスが32バイト境界に整列されているかをassertでチェックし、最適なメモリアクセスを保証している。

EmbeddingDimが16の場合、ループをアンローリングして各チャネルへの代入を直接記述することで、ループオーバーヘッドを削減している。

  • 持ち駒埋め込みの集約と効率的なブロードキャスト

持ち駒の各埋め込みベクトルをAVX2で加算し、さらにその合算結果を各盤面マスに効率的にブロードキャストして加算している。

結果の検証

入力にPositionクラスを使うため、前回までのようにPythonでの検証は行わない。
推論時間のみ測定する。

検証用コードは以下の通り。
gtestで実装している。

TEST(StockfishTest, embedding_layers) {
    using namespace Stockfish;

    initTable();

    constexpr int NUM_ITERATIONS = 10000;

    constexpr int NUM_EMBEDDINGS1 = PieceTypeNum * 2;
    constexpr int NUM_EMBEDDINGS2 = MAX_PIECES_IN_HAND_SUM * 2 + 1;

    // 32バイトアライメントでメモリ確保
    float* embedding_table1 = static_cast<float*>(_mm_malloc(NUM_EMBEDDINGS1 * Eval::EMBEDDING_DIM * sizeof(float), 32));
    float* embedding_table2 = static_cast<float*>(_mm_malloc(NUM_EMBEDDINGS2 * Eval::EMBEDDING_DIM * sizeof(float), 32));
    float* output = static_cast<float*>(_mm_malloc(Eval::EMBEDDING_DIM * (int)SquareNum * sizeof(float), 32));

    if (!embedding_table1 || !embedding_table2 || !output) {
        std::cerr << "Error: Failed to allocate aligned memory!" << std::endl;
        return;
    }

    // 乱数生成器の初期化
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

    // 埋め込みテーブルを乱数で初期化
    for (int i = 0; i < NUM_EMBEDDINGS1 * Eval::EMBEDDING_DIM; ++i) {
        embedding_table1[i] = dist(gen);
    }
    for (int i = 0; i < NUM_EMBEDDINGS2 * Eval::EMBEDDING_DIM; ++i) {
        embedding_table2[i] = dist(gen);
    }

    Stockfish::Position pos;
    pos.set("+P4g1nl/4g1k2/p3pp1p1/2p2Lp1p/2n4P1/3p1N2P/P1Ps+rPP2/LpK3BS1/R4G2L b B2SNPgp 103");


    // ベンチマーク実行
    std::cout << "\nRunning benchmark..." << std::endl;
    auto start = std::chrono::high_resolution_clock::now();

    for (int iter = 0; iter < NUM_ITERATIONS; ++iter) {
        Eval::embedding_layers(pos, embedding_table1, embedding_table2, output);
    }

    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> elapsed = end - start;

    // 結果出力
    std::cout << "\nBenchmark results:" << std::endl;
    std::cout << "Total time: " << std::fixed << std::setprecision(2) << elapsed.count() << " ms" << std::endl;
    std::cout << "Average time per position: " << std::fixed << std::setprecision(4)
        << elapsed.count() / NUM_ITERATIONS << " ms" << std::endl;
    std::cout << "Positions per second: " << std::fixed << std::setprecision(0)
        << (NUM_ITERATIONS * 1000) / elapsed.count() << std::endl;

    // メモリ解放
    _mm_free(embedding_table1);
    _mm_free(embedding_table2);
    _mm_free(output);
}
実行結果
Total time: 3.87 ms
Average time per position: 0.0004 ms
Positions per second: 2586586

十分に高速に推論できている。

まとめ

埋め込み層の推論処理をSIMDを使用して実装した。
固定長の入力ベクトルを作らず、特徴量作成と同時に重みを辞書引きすることで効率的に実装できた。
次は、PyTorchの学習済みパラメータをレイヤー融合してC++側で扱えるように出力する処理を実装したい。