TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】軽量価値ネットワーク その8(連結と活性化関数)

前回、軽量価値ネットワークの畳み込みの層をSIMDで実装した。

今回は、9x1と1x9の畳み込みの層の出力を連結して活性化関数ReLUを適用する処理をSIMDで実装する。

連結と活性化関数

軽量価値ネットワーク

h2 = self.act(torch.cat((h2_1.view(-1, self.dims[1] * 9), h2_2.view(-1, self.dims[1] * 9)), 1))

の部分の処理を実装する。

連結処理は、畳み込みの層の出力をはじめから一つの配列にポインタをずらして書き込むことで実現できるため、処理は省ける。
活性化関数をSIMDでベクトル化する。

実装

実装したコードは、以下の通り。

連結は畳み込みの層の処理に含まれるため、畳み込みの層の処理も再掲する。

#include <immintrin.h>
#include <cstddef>

// 水平方向の合計計算(AVX2)
inline float horizontal_sum_avx2(__m256 vec) {
    __m256 t1 = _mm256_hadd_ps(vec, vec);
    __m256 t2 = _mm256_hadd_ps(t1, t1);
    __m128 t3 = _mm256_extractf128_ps(t2, 1);
    __m128 t4 = _mm_add_ss(_mm256_castps256_ps128(t2), t3);
    return _mm_cvtss_f32(t4);
}

// 9x1の畳み込み処理
template <int InChannels, int OutChannels>
void conv9x1_avx2(const float* __restrict input,
    const float* __restrict weights,
    float* __restrict output) {
    constexpr int KERNEL_SIZE = 9;
    constexpr int inner_size = InChannels * KERNEL_SIZE;
    constexpr int padded_inner_size = ((inner_size + 7) / 8) * 8;

    for (int w = 0; w < 9; ++w) {
        alignas(32) float patch[inner_size];
        for (int ic = 0; ic < InChannels; ++ic) {
            const float* in_channel = input + ic * 81;
            int base = ic * KERNEL_SIZE;
            // 手動アンロールして9要素を展開
            patch[base + 0] = in_channel[0 * 9 + w];
            patch[base + 1] = in_channel[1 * 9 + w];
            patch[base + 2] = in_channel[2 * 9 + w];
            patch[base + 3] = in_channel[3 * 9 + w];
            patch[base + 4] = in_channel[4 * 9 + w];
            patch[base + 5] = in_channel[5 * 9 + w];
            patch[base + 6] = in_channel[6 * 9 + w];
            patch[base + 7] = in_channel[7 * 9 + w];
            patch[base + 8] = in_channel[8 * 9 + w];
        }

        int oc = 0;
        constexpr int unroll_factor = 4;
        int main_loop_count = (OutChannels / unroll_factor) * unroll_factor;

        for (; oc < main_loop_count; oc += unroll_factor) {
            __m256 sum0 = _mm256_setzero_ps();
            __m256 sum1 = _mm256_setzero_ps();
            __m256 sum2 = _mm256_setzero_ps();
            __m256 sum3 = _mm256_setzero_ps();
            const float* w_ptr0 = weights + (oc + 0) * padded_inner_size;
            const float* w_ptr1 = weights + (oc + 1) * padded_inner_size;
            const float* w_ptr2 = weights + (oc + 2) * padded_inner_size;
            const float* w_ptr3 = weights + (oc + 3) * padded_inner_size;
            int i = 0;
            for (; i <= inner_size - 8; i += 8) {
                __m256 patch_vec = _mm256_load_ps(patch + i);
                __m256 w0 = _mm256_load_ps(w_ptr0 + i);
                __m256 w1 = _mm256_load_ps(w_ptr1 + i);
                __m256 w2 = _mm256_load_ps(w_ptr2 + i);
                __m256 w3 = _mm256_load_ps(w_ptr3 + i);
                sum0 = _mm256_fmadd_ps(patch_vec, w0, sum0);
                sum1 = _mm256_fmadd_ps(patch_vec, w1, sum1);
                sum2 = _mm256_fmadd_ps(patch_vec, w2, sum2);
                sum3 = _mm256_fmadd_ps(patch_vec, w3, sum3);
            }
            float dot0 = horizontal_sum_avx2(sum0);
            float dot1 = horizontal_sum_avx2(sum1);
            float dot2 = horizontal_sum_avx2(sum2);
            float dot3 = horizontal_sum_avx2(sum3);
            for (; i < inner_size; ++i) {
                dot0 += patch[i] * w_ptr0[i];
                dot1 += patch[i] * w_ptr1[i];
                dot2 += patch[i] * w_ptr2[i];
                dot3 += patch[i] * w_ptr3[i];
            }
            // 結果を書き込み
            output[(oc + 0) * 9 + w] = dot0;
            output[(oc + 1) * 9 + w] = dot1;
            output[(oc + 2) * 9 + w] = dot2;
            output[(oc + 3) * 9 + w] = dot3;
        }

        for (; oc < OutChannels; ++oc) {
            __m256 sum = _mm256_setzero_ps();
            const float* w_ptr = weights + oc * padded_inner_size;
            int i = 0;
            for (; i <= inner_size - 8; i += 8) {
                __m256 patch_vec = _mm256_load_ps(patch + i);
                __m256 w_vec = _mm256_load_ps(w_ptr + i);
                sum = _mm256_fmadd_ps(patch_vec, w_vec, sum);
            }
            float dot = horizontal_sum_avx2(sum);
            for (; i < inner_size; ++i) {
                dot += patch[i] * w_ptr[i];
            }
            output[oc * 9 + w] = dot;
        }
    }
}

// 1x9の畳み込み処理
template <int InChannels, int OutChannels>
void conv1x9_avx2(const float* __restrict input,
    const float* __restrict weights,
    float* __restrict output) {
    constexpr int KERNEL_SIZE = 9;
    constexpr int inner_size = InChannels * KERNEL_SIZE;
    constexpr int padded_inner_size = ((inner_size + 7) / 8) * 8;

    for (int h = 0; h < 9; ++h) {
        alignas(32) float patch[inner_size];
        for (int ic = 0; ic < InChannels; ++ic) {
            const float* in_channel = input + ic * 81;
            int base = ic * KERNEL_SIZE;
            patch[base + 0] = in_channel[h * 9 + 0];
            patch[base + 1] = in_channel[h * 9 + 1];
            patch[base + 2] = in_channel[h * 9 + 2];
            patch[base + 3] = in_channel[h * 9 + 3];
            patch[base + 4] = in_channel[h * 9 + 4];
            patch[base + 5] = in_channel[h * 9 + 5];
            patch[base + 6] = in_channel[h * 9 + 6];
            patch[base + 7] = in_channel[h * 9 + 7];
            patch[base + 8] = in_channel[h * 9 + 8];
        }

        int oc = 0;
        constexpr int unroll_factor = 4;
        int main_loop_count = (OutChannels / unroll_factor) * unroll_factor;

        for (; oc < main_loop_count; oc += unroll_factor) {
            __m256 sum0 = _mm256_setzero_ps();
            __m256 sum1 = _mm256_setzero_ps();
            __m256 sum2 = _mm256_setzero_ps();
            __m256 sum3 = _mm256_setzero_ps();
            const float* w_ptr0 = weights + (oc + 0) * padded_inner_size;
            const float* w_ptr1 = weights + (oc + 1) * padded_inner_size;
            const float* w_ptr2 = weights + (oc + 2) * padded_inner_size;
            const float* w_ptr3 = weights + (oc + 3) * padded_inner_size;
            int i = 0;
            for (; i <= inner_size - 8; i += 8) {
                __m256 patch_vec = _mm256_load_ps(patch + i);
                __m256 w0 = _mm256_load_ps(w_ptr0 + i);
                __m256 w1 = _mm256_load_ps(w_ptr1 + i);
                __m256 w2 = _mm256_load_ps(w_ptr2 + i);
                __m256 w3 = _mm256_load_ps(w_ptr3 + i);
                sum0 = _mm256_fmadd_ps(patch_vec, w0, sum0);
                sum1 = _mm256_fmadd_ps(patch_vec, w1, sum1);
                sum2 = _mm256_fmadd_ps(patch_vec, w2, sum2);
                sum3 = _mm256_fmadd_ps(patch_vec, w3, sum3);
            }
            float dot0 = horizontal_sum_avx2(sum0);
            float dot1 = horizontal_sum_avx2(sum1);
            float dot2 = horizontal_sum_avx2(sum2);
            float dot3 = horizontal_sum_avx2(sum3);
            for (; i < inner_size; ++i) {
                dot0 += patch[i] * w_ptr0[i];
                dot1 += patch[i] * w_ptr1[i];
                dot2 += patch[i] * w_ptr2[i];
                dot3 += patch[i] * w_ptr3[i];
            }
            output[(oc + 0) * 9 + h] = dot0;
            output[(oc + 1) * 9 + h] = dot1;
            output[(oc + 2) * 9 + h] = dot2;
            output[(oc + 3) * 9 + h] = dot3;
        }

        for (; oc < OutChannels; ++oc) {
            __m256 sum = _mm256_setzero_ps();
            const float* w_ptr = weights + oc * padded_inner_size;
            int i = 0;
            for (; i <= inner_size - 8; i += 8) {
                __m256 patch_vec = _mm256_load_ps(patch + i);
                __m256 w_vec = _mm256_load_ps(w_ptr + i);
                sum = _mm256_fmadd_ps(patch_vec, w_vec, sum);
            }
            float dot = horizontal_sum_avx2(sum);
            for (; i < inner_size; ++i) {
                dot += patch[i] * w_ptr[i];
            }
            output[oc * 9 + h] = dot;
        }
    }
}

// ReLU処理
template <int Output>
inline void relu_avx2(float* __restrict output) {
    __m256 zero = _mm256_setzero_ps();
    int i = 0;
    // 32要素ずつ処理
    for (; i <= Output - 32; i += 32) {
        __m256 vec1 = _mm256_load_ps(output + i);
        __m256 vec2 = _mm256_load_ps(output + i + 8);
        __m256 vec3 = _mm256_load_ps(output + i + 16);
        __m256 vec4 = _mm256_load_ps(output + i + 24);

        vec1 = _mm256_max_ps(vec1, zero);
        vec2 = _mm256_max_ps(vec2, zero);
        vec3 = _mm256_max_ps(vec3, zero);
        vec4 = _mm256_max_ps(vec4, zero);

        _mm256_store_ps(output + i, vec1);
        _mm256_store_ps(output + i + 8, vec2);
        _mm256_store_ps(output + i + 16, vec3);
        _mm256_store_ps(output + i + 24, vec4);
    }
    // 16要素ずつ処理
    for (; i <= Output - 16; i += 16) {
        __m256 vec1 = _mm256_load_ps(output + i);
        __m256 vec2 = _mm256_load_ps(output + i + 8);
        vec1 = _mm256_max_ps(vec1, zero);
        vec2 = _mm256_max_ps(vec2, zero);
        _mm256_store_ps(output + i, vec1);
        _mm256_store_ps(output + i + 8, vec2);
    }
    // 8要素ずつ処理
    for (; i <= Output - 8; i += 8) {
        __m256 vec = _mm256_load_ps(output + i);
        vec = _mm256_max_ps(vec, zero);
        _mm256_store_ps(output + i, vec);
    }
    // 余りのスカラー処理
    for (; i < Output; ++i) {
        output[i] = output[i] > 0.f ? output[i] : 0.f;
    }
}

// 畳み込み層+cat+ReLU
template <int InChannels, int OutChannels>
void conv_cat_relu(
    const float* __restrict input,
    const float* __restrict weights_conv9x1,
    const float* __restrict weights_conv1x9,
    float* __restrict output) {

    constexpr int output_offset = OutChannels * 9; // conv1x9 の出力開始オフセット

    // 1. 9x1の畳み込み処理(ReLU適用は後でまとめて実施)
    conv9x1_avx2<InChannels, OutChannels>(input, weights_conv9x1, output);

    // 2. 1x9の畳み込み処理(出力の後半部分に格納)
    conv1x9_avx2<InChannels, OutChannels>(input, weights_conv1x9, output + output_offset);

    // 3. 連結後の出力全体に対してテンプレート関数のReLUを適用
    constexpr int total_output = OutChannels * 9 * 2;
    relu_avx2<total_output>(output);
}


以下の点が高速化のために工夫されている。

  • 出力バッファの共有利用

9x1の畳み込み結果は出力バッファの前半に、1x9の結果はその直後の連続領域に格納しており、別々のバッファを用いずに連続したメモリ領域に配置している。

  • 余分なコピー処理の排除

それぞれの畳み込み結果を直接所定の位置に書き込むことで、後から結果を結合するための追加のコピー処理が不要になり、オーバーヘッドを削減している。

  • 連続領域への一括ReLU適用

連結された出力全体に対してReLUを一度に適用することで、各部分ごとに別途処理する必要がなく、メモリアクセスが効率化されている。

  • ReLU処理のベクトル化

ReLU関数では32要素、16要素、8要素とブロック単位でベクトル化処理を行い、スカラー処理の割合を最小化している。

結果の検証

以下の通り、ランダム初期値で推論するベンチマークコードを実装した。
結果は、バイナリで出力して、Pythonのnumpyで検証できるようにする。

#include <iostream>
#include <chrono>
#include <random>
#include <algorithm>
#include <fstream>

int main() {
    // テストパラメータ設定
    constexpr int InChannels = 16;
    constexpr int OutChannels = 4;
    constexpr int input_h = 9;
    constexpr int input_w = 9;
    constexpr int input_size = InChannels * input_h * input_w;  // 16 * 81 = 1296
    constexpr int kernel_size = 9;
    constexpr int inner_size = InChannels * kernel_size;         // 16 * 9 = 144
    constexpr int padded_inner_size = ((inner_size + 7) / 8) * 8;
    constexpr int num_iterations = 10000;
    constexpr int output_size = OutChannels * 9 * 2;  // 連結後のサイズ

    // メモリ確保(32バイトアラインメント)
    alignas(32) float input[input_size];
    alignas(32) float weights_conv9x1[OutChannels * padded_inner_size];
    alignas(32) float weights_conv1x9[OutChannels * padded_inner_size];
    alignas(32) float output[output_size];

    // 乱数による初期化
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dist(-1.0f, 1.0f);

    for (int i = 0; i < input_size; ++i) {
        input[i] = dist(gen);
    }

    for (int oc = 0; oc < OutChannels; ++oc) {
        int offset = oc * padded_inner_size;
        for (int i = 0; i < inner_size; ++i) {
            weights_conv9x1[offset + i] = dist(gen);
            weights_conv1x9[offset + i] = dist(gen);
        }
        // パディング部分を0で初期化
        for (int i = inner_size; i < padded_inner_size; ++i) {
            weights_conv9x1[offset + i] = 0.f;
            weights_conv1x9[offset + i] = 0.f;
        }
    }

    // 統合処理のベンチマーク
    auto start_fused = std::chrono::high_resolution_clock::now();
    for (int iter = 0; iter < num_iterations; ++iter) {
        conv_cat_relu<InChannels, OutChannels>(input, weights_conv9x1, weights_conv1x9, output);
    }
    auto end_fused = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> elapsed_fused = end_fused - start_fused;

    // 演算量の計算:
    // conv9x1: 4 channels * 9 positions * (2*144) flops = 10368 flops
    // conv1x9: 同様に 10368 flops
    // 合計 = 20736 flops/反復
    const double flops = 20736.0 * num_iterations;

    // 結果の表示
    std::cout << "Output values: ";
    for (int i = 0; i < std::min(5, output_size); ++i) {
        std::cout << output[i] << " ";
    }
    std::cout << "...\n";
    std::cout << "Performance: " << num_iterations << " iterations in " << elapsed_fused.count() << " ms\n";
    std::cout << "Average time per iteration: " << elapsed_fused.count() / num_iterations << " ms\n";
    std::cout << "GFLOPS: " << flops / (elapsed_fused.count() * 1e6) << "\n";

    // バイナリファイル出力(必要に応じて)
    {
        std::ofstream fout("input.bin", std::ios::binary);
        fout.write(reinterpret_cast<const char*>(input), sizeof(float) * input_size);
    }
    {
        std::ofstream fout("weights_conv9x1.bin", std::ios::binary);
        fout.write(reinterpret_cast<const char*>(weights_conv9x1), sizeof(float) * OutChannels * padded_inner_size);
    }
    {
        std::ofstream fout("weights_conv1x9.bin", std::ios::binary);
        fout.write(reinterpret_cast<const char*>(weights_conv1x9), sizeof(float) * OutChannels * padded_inner_size);
    }
    {
        std::ofstream fout("output_fused.bin", std::ios::binary);
        fout.write(reinterpret_cast<const char*>(output), sizeof(float) * output_size);
    }
    {
        std::ofstream fout("meta.txt");
        fout << "InChannels " << InChannels << "\n";
        fout << "OutChannels " << OutChannels << "\n";
        fout << "InputH " << input_h << "\n";
        fout << "InputW " << input_w << "\n";
        fout << "KernelSize " << kernel_size << "\n";
        fout << "PaddedInnerSize " << padded_inner_size << "\n";
        fout << "OutputSize " << output_size << "\n";
    }

    return 0;
}
実行結果
Output values: 0.262765 0 0 0 2.36965 ...
Performance: 10000 iterations in 10.6277 ms
Average time per iteration: 0.00106277 ms
GFLOPS: 19.5113

前回のそれぞれの畳みこみの推論時間の約2倍で連結とReLUは無視できる程度に高速に推論できている。

結果検証

Pythonのnumpyで結果を検証した。

import numpy as np

# ----- パラメータ設定(C++側の設定と一致) -----
InChannels = 16
OutChannels = 4
InputH = 9
InputW = 9
KernelSize = 9           # 9x1, 1x9 なのでカーネル長は9
InnerSize = InChannels * KernelSize  # 16 * 9 = 144
# C++では padded_inner_size = ((InnerSize+7)/8)*8 ですが、144は既に8の倍数
PaddedInnerSize = InnerSize
# conv9x1 と conv1x9 の各出力は (OutChannels, 9) の形状となり、連結後の全体出力は 2*(OutChannels*9) 要素
OutputSize = OutChannels * 9 * 2  # 4*9*2 = 72

# ----- バイナリファイルの読み込み -----
input_data = np.fromfile("../input.bin", dtype=np.float32)               # サイズ: 16*9*9 = 1296
weights_conv9x1 = np.fromfile("../weights_conv9x1.bin", dtype=np.float32)  # サイズ: OutChannels * PaddedInnerSize = 4*144
weights_conv1x9 = np.fromfile("../weights_conv1x9.bin", dtype=np.float32)  # 同上
output_cpp = np.fromfile("../output_fused.bin", dtype=np.float32)          # サイズ: OutputSize = 72

# 入力データを (InChannels, InputH, InputW) に変形
input_data = input_data.reshape(InChannels, InputH, InputW)
# 重みは各出力チャネルごとに 144 要素のベクトルとして格納されているので変形
weights_conv9x1 = weights_conv9x1.reshape(OutChannels, PaddedInnerSize)
weights_conv1x9 = weights_conv1x9.reshape(OutChannels, PaddedInnerSize)

# ----- conv9x1の計算 -----
# 各出力位置は、各チャネルのw列(9要素)を連結したパッチ(サイズ144)との内積で求める
conv9x1 = np.empty((OutChannels, InputW), dtype=np.float32)
for w in range(InputW):
    # 各チャネルについて、w列目(9要素)を連結(shape: (16, 9) -> (144,))
    patch = input_data[:, :, w].reshape(-1)
    conv9x1[:, w] = np.dot(weights_conv9x1, patch)

# ----- conv1x9の計算 -----
# 各出力位置は、各チャネルのh行(9要素)を連結したパッチ(サイズ144)との内積で求める
conv1x9 = np.empty((OutChannels, InputH), dtype=np.float32)
for h in range(InputH):
    patch = input_data[:, h, :].reshape(-1)
    conv1x9[:, h] = np.dot(weights_conv1x9, patch)

# ----- 結果の連結とReLU適用 -----
# C++ではまず conv9x1 の出力が連続して格納され、その後 conv1x9 の出力が格納されています。
fused_output = np.concatenate((conv9x1.flatten(), conv1x9.flatten()))
# ReLU適用(負の値を0に)
output_np = np.maximum(fused_output, 0)

# ----- 結果の比較(validation.pyと同様の表示) -----
print("Numpy computed output (first 5 elements):", output_np[:5])
print("C++ output (first 5 elements):", output_cpp[:5])
difference = np.abs(output_np - output_cpp)
print("Absolute difference:", difference)

結果:

Numpy computed output (first 5 elements): [0.2627653 0.        0.        0.        2.3696544]
C++ output (first 5 elements): [0.26276505 0.         0.         0.         2.3696537 ]
Absolute difference: [2.3841858e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00 7.1525574e-07
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 4.7683716e-07
 9.5367432e-07 0.0000000e+00 2.3841858e-07 0.0000000e+00 0.0000000e+00
 1.4305115e-06 7.1525574e-07 5.9604645e-07 4.7683716e-07 2.9802322e-07
 0.0000000e+00 0.0000000e+00 1.1920929e-07 0.0000000e+00 0.0000000e+00
 4.7683716e-07 5.3644180e-07 7.1525574e-07 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 4.7683716e-07 0.0000000e+00 0.0000000e+00
 2.3841858e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 2.3841858e-07 4.1723251e-07 0.0000000e+00
 1.9073486e-06 0.0000000e+00 4.7683716e-07 0.0000000e+00 0.0000000e+00
 0.0000000e+00 9.5367432e-07 4.7683716e-07 2.3841858e-07 0.0000000e+00
 0.0000000e+00 0.0000000e+00 5.9604645e-08 0.0000000e+00 0.0000000e+00
 1.1920929e-07 0.0000000e+00 2.3841858e-07 0.0000000e+00 7.1525574e-07
 3.5762787e-07 4.7683716e-07 0.0000000e+00 4.7683716e-07 0.0000000e+00
 0.0000000e+00 0.0000000e+00]

わずかなずれがあるが、計算順序の違いによる浮動小数点の誤差で、無視できる程度である。

まとめ

連結と活性化関数の処理をSIMDで実装した。
出力バッファを共通化して連結を省略して、ReLUをベクトル化することで畳み込みの処理に比べて無視できる程度の速度で推論できた。
次は、埋め込みの処理を実装したい。