TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dropless MoE(Mixture of Experts)を試す その3(プラグイン)

前回、ONNXにエクスポートしたカスタムノードに対応する処理をTensorRTのプラグインで実装する。

TensorRTのプラグイン

公式ドキュメントに記載されている通り、プラグインクラスとプラグイン クリエーターを実装する。 プラグインがREGISTER_TENSORRT_PLUGINで登録されていると、ONNXのパーサが自動でカスタムノードをプラグインに置換する。

今回は、互換性を優先して、TensorRT 8ベースで実装する。 TensorRT 8ベースだと、nvinfer1::IPluginV2DynamicExtを継承したプラグインクラスと、nvinfer1::IPluginCreatorを継承したプラグイン クリエーターを実装する。

ここでは、プラグインクラスのenqueue()から呼び出すCUDAカーネルの実装を中心に解説する。 詳細な実装は、GitHubのソースを参照して欲しい。

CUDAカーネルは何パターンか実装しているが、CUTLASS grouped GEMMを使うenqueueTypedCutlassHalfに絞って解説する。

cudaError_t enqueueTypedCutlassHalf(
    const CustomMoeConfig& cfg,
    int64_t M,
    const void* x_void,
    const float* router_w,
    const void* w1_void,
    const void* b1_void,
    const void* w2_void,
    const void* b2_void,
    void* y_void,
    void* workspace,
    cudaStream_t stream) {
    const __half* x = static_cast<const __half*>(x_void);
    const __half* w1 = static_cast<const __half*>(w1_void);
    const __half* b1 = static_cast<const __half*>(b1_void);
    const __half* w2 = static_cast<const __half*>(w2_void);
    const __half* b2 = static_cast<const __half*>(b2_void);
    __half* y = static_cast<__half*>(y_void);

    WorkspaceParts ws = splitWorkspace(workspace, cfg, M, nvinfer1::DataType::kHALF);
    const int64_t N = M * cfg.top_k;
    const int threads = 256;

    cudaMemsetAsync(ws.counts, 0, sizeof(int) * cfg.num_experts, stream);
    cudaMemsetAsync(y, 0, sizeof(__half) * M * cfg.in_features, stream);

    routeTopKKernel<__half><<<static_cast<unsigned>((M + threads - 1) / threads), threads, 0, stream>>>(
        x, router_w, ws.topk_experts, ws.topk_gates, M, cfg.in_features, cfg.num_experts, cfg.top_k);
    countExpertsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>(
        ws.topk_experts, ws.counts, N);
    prefixAndResetKernel<<<1, 1, 0, stream>>>(ws.counts, ws.offsets, ws.write_ptr, cfg.num_experts);
    packAssignmentsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>(
        ws.topk_experts, ws.topk_gates, ws.write_ptr, ws.assignment_dst, ws.token_sorted,
        ws.expert_sorted, ws.gate_sorted, M, cfg.top_k);
    copyPackedXKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>(
        x, ws.assignment_dst, static_cast<__half*>(ws.x_sorted), M, cfg.in_features, cfg.top_k);

    cudaError_t err = runCutlassGroupedGemm(
        ws,
        cfg.num_experts,
        cfg.in_features,
        cfg.hidden_features,
        static_cast<const __half*>(ws.x_sorted),
        w1,
        static_cast<__half*>(ws.h_sorted),
        static_cast<__half*>(ws.h_sorted),
        stream);
    if (err != cudaSuccess) return err;
    addBiasGeluKernel<__half><<<static_cast<unsigned>((N * cfg.hidden_features + threads - 1) / threads), threads, 0, stream>>>(
        static_cast<__half*>(ws.h_sorted), b1, ws.expert_sorted, N, cfg.hidden_features);

    err = runCutlassGroupedGemm(
        ws,
        cfg.num_experts,
        cfg.hidden_features,
        cfg.in_features,
        static_cast<const __half*>(ws.h_sorted),
        w2,
        static_cast<__half*>(ws.y_sorted),
        static_cast<__half*>(ws.y_sorted),
        stream);
    if (err != cudaSuccess) return err;
    addBiasGateCombineKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>(
        static_cast<__half*>(ws.y_sorted), b2, ws.token_sorted, ws.expert_sorted, ws.gate_sorted,
        y, N, cfg.in_features);
    return cudaGetLastError();
}

PyTorchで実装したDroplessMoEMlpの流れをそのままCUDAで実装している。

  routeTopKKernel
  countExpertsKernel
  prefixAndResetKernel
  packAssignmentsKernel
  copyPackedXKernel
  CUTLASS Grouped GEMM FC1
  addBiasGeluKernel
  CUTLASS Grouped GEMM FC2
  addBiasGateCombineKernel

関数冒頭:void ポインタのキャスト

enqueueTypedCutlassHalf は、TensorRT plugin から void* で渡されたテンソルを __half* にキャストする。ただし、router_w だけは const float* のまま保持する。これは routing 計算を FP32 で行い、数値的安定性を確保する設計に基づいている。

Workspace の分割

splitWorkspace により、TensorRT が提供する一時領域を各内部バッファへ割り当てる。 重要な点は、x_sorted, h_sorted, y_sorted が assignment 単位(N = M * K 行)を持つことである。1 token が K 個の expert に送られるため、expert MLP の入力行数は M ではなく N となる。

初期化:counts と y のゼロクリア

counts(expert ごとの割当数)と出力 ycudaMemsetAsync で 0 初期化する。y の初期化が必要な理由は、最終的な結合処理において y[token, d] += gate * expert_output という atomic add を行うためである。

routeTopKKernel:Router と Top-k Gate の計算

1 thread が 1 token を担当し、以下の処理を行う。

  1. Logits 計算: x (FP16) と router_w (FP32) の内積を float で計算。
  2. Top-k 選択と正規化: E 個の logits から top-k を選び、その中だけで再正規化(Softmax)を行い gate 値を算出する。

count / prefixAndResetKernel:配置の決定

  • countExpertsKernel: 各 assignment がどの expert に割り当てられたかを走査し、atomicAddcounts[e] を算出する。
  • prefixAndResetKernel: counts から prefix sum を計算し、offsets(各 expert の開始位置)と write_ptr を作成する。これにより、各 expert が処理するパックドバッファ上の範囲が確定する。

pack / copyPackedXKernel:データの並べ替え

packAssignmentsKernel で各 assignment の書き込み先インデックスを確定し、copyPackedXKernel で実データを x_sorted へコピーする。これにより、x_sorted は同一の expert に送られる token が連続して並ぶレイアウト(expert-major)となる。

FC1:CUTLASS Grouped GEMM

本関数の主目的である。runCutlassGroupedGemm を呼び出し、全 expert の行列積を一括実行する。 各 expert e において、A_e = [M_e, D]、B_e = [D, H] となり、結果 D_e = [M_e, H] を得る。

setupGroupedGemmMetaKernel では、CUTLASS が要求する各 expert の問題サイズ(Problem Size)やポインタの配列を構築する。

  • Arch: Sm80 (Ampere)
  • 精度: 入出力 FP16、Accumulator FP32
  • Epilogue: LinearCombination (beta=0)

addBiasGeluKernel:FC1 の Bias と GELU

CUTLASS の Grouped GEMM は純粋な行列積のみを担当するため、Bias 加算と GELU 活性化関数は別個の kernel で実行する。この際、expert_sorted を参照して各行に対応する expert の bias を特定する。

FC2:CUTLASS Grouped GEMM

FC1 と同様に runCutlassGroupedGemm を実行する。

  • 入力 (A): h_sorted [N, H]
  • 重み (B): w2 [E, H, D]
  • 出力 (D): y_sorted [N, D]

addBiasGateCombineKernel:FC2 Bias と Scatter-Add

最後に FC2 の出力に対して bias を加算し、gate 値を乗じた上で、元の token インデックスに基づき y へ足し込む。 atomicAdd を用いることで、1 token に対する K 個の expert 出力を正しく集約する。

まとめ

MoEのカスタムノードに対応するTensorRTプラグインを実装し、CUTLASS Grouped GEMMを使うCUDAカーネルを中心に解説した。 この実装の最適化は不十分で、Biasや活性化関数は一つのカーネルに融合する余地がある。

次回は、プラグインをロードしてTensorRTエンジンをビルドする処理を実装したい。