前回、ONNXにエクスポートしたカスタムノードに対応する処理をTensorRTのプラグインで実装する。
TensorRTのプラグイン
公式ドキュメントに記載されている通り、プラグインクラスとプラグイン クリエーターを実装する。 プラグインがREGISTER_TENSORRT_PLUGINで登録されていると、ONNXのパーサが自動でカスタムノードをプラグインに置換する。
今回は、互換性を優先して、TensorRT 8ベースで実装する。 TensorRT 8ベースだと、nvinfer1::IPluginV2DynamicExtを継承したプラグインクラスと、nvinfer1::IPluginCreatorを継承したプラグイン クリエーターを実装する。
ここでは、プラグインクラスのenqueue()から呼び出すCUDAカーネルの実装を中心に解説する。 詳細な実装は、GitHubのソースを参照して欲しい。
CUDAカーネルは何パターンか実装しているが、CUTLASS grouped GEMMを使うenqueueTypedCutlassHalfに絞って解説する。
cudaError_t enqueueTypedCutlassHalf( const CustomMoeConfig& cfg, int64_t M, const void* x_void, const float* router_w, const void* w1_void, const void* b1_void, const void* w2_void, const void* b2_void, void* y_void, void* workspace, cudaStream_t stream) { const __half* x = static_cast<const __half*>(x_void); const __half* w1 = static_cast<const __half*>(w1_void); const __half* b1 = static_cast<const __half*>(b1_void); const __half* w2 = static_cast<const __half*>(w2_void); const __half* b2 = static_cast<const __half*>(b2_void); __half* y = static_cast<__half*>(y_void); WorkspaceParts ws = splitWorkspace(workspace, cfg, M, nvinfer1::DataType::kHALF); const int64_t N = M * cfg.top_k; const int threads = 256; cudaMemsetAsync(ws.counts, 0, sizeof(int) * cfg.num_experts, stream); cudaMemsetAsync(y, 0, sizeof(__half) * M * cfg.in_features, stream); routeTopKKernel<__half><<<static_cast<unsigned>((M + threads - 1) / threads), threads, 0, stream>>>( x, router_w, ws.topk_experts, ws.topk_gates, M, cfg.in_features, cfg.num_experts, cfg.top_k); countExpertsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>( ws.topk_experts, ws.counts, N); prefixAndResetKernel<<<1, 1, 0, stream>>>(ws.counts, ws.offsets, ws.write_ptr, cfg.num_experts); packAssignmentsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>( ws.topk_experts, ws.topk_gates, ws.write_ptr, ws.assignment_dst, ws.token_sorted, ws.expert_sorted, ws.gate_sorted, M, cfg.top_k); copyPackedXKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>( x, ws.assignment_dst, static_cast<__half*>(ws.x_sorted), M, cfg.in_features, cfg.top_k); cudaError_t err = runCutlassGroupedGemm( ws, cfg.num_experts, cfg.in_features, cfg.hidden_features, static_cast<const __half*>(ws.x_sorted), w1, static_cast<__half*>(ws.h_sorted), static_cast<__half*>(ws.h_sorted), stream); if (err != cudaSuccess) return err; addBiasGeluKernel<__half><<<static_cast<unsigned>((N * cfg.hidden_features + threads - 1) / threads), threads, 0, stream>>>( static_cast<__half*>(ws.h_sorted), b1, ws.expert_sorted, N, cfg.hidden_features); err = runCutlassGroupedGemm( ws, cfg.num_experts, cfg.hidden_features, cfg.in_features, static_cast<const __half*>(ws.h_sorted), w2, static_cast<__half*>(ws.y_sorted), static_cast<__half*>(ws.y_sorted), stream); if (err != cudaSuccess) return err; addBiasGateCombineKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>( static_cast<__half*>(ws.y_sorted), b2, ws.token_sorted, ws.expert_sorted, ws.gate_sorted, y, N, cfg.in_features); return cudaGetLastError(); }
PyTorchで実装したDroplessMoEMlpの流れをそのままCUDAで実装している。
routeTopKKernel countExpertsKernel prefixAndResetKernel packAssignmentsKernel copyPackedXKernel CUTLASS Grouped GEMM FC1 addBiasGeluKernel CUTLASS Grouped GEMM FC2 addBiasGateCombineKernel
関数冒頭:void ポインタのキャスト
enqueueTypedCutlassHalf は、TensorRT plugin から void* で渡されたテンソルを __half* にキャストする。ただし、router_w だけは const float* のまま保持する。これは routing 計算を FP32 で行い、数値的安定性を確保する設計に基づいている。
Workspace の分割
splitWorkspace により、TensorRT が提供する一時領域を各内部バッファへ割り当てる。
重要な点は、x_sorted, h_sorted, y_sorted が assignment 単位(N = M * K 行)を持つことである。1 token が K 個の expert に送られるため、expert MLP の入力行数は M ではなく N となる。
初期化:counts と y のゼロクリア
counts(expert ごとの割当数)と出力 y を cudaMemsetAsync で 0 初期化する。y の初期化が必要な理由は、最終的な結合処理において y[token, d] += gate * expert_output という atomic add を行うためである。
routeTopKKernel:Router と Top-k Gate の計算
1 thread が 1 token を担当し、以下の処理を行う。
- Logits 計算:
x(FP16) とrouter_w(FP32) の内積を float で計算。 - Top-k 選択と正規化: E 個の logits から top-k を選び、その中だけで再正規化(Softmax)を行い
gate値を算出する。
count / prefixAndResetKernel:配置の決定
- countExpertsKernel: 各 assignment がどの expert に割り当てられたかを走査し、
atomicAddでcounts[e]を算出する。 - prefixAndResetKernel:
countsから prefix sum を計算し、offsets(各 expert の開始位置)とwrite_ptrを作成する。これにより、各 expert が処理するパックドバッファ上の範囲が確定する。
pack / copyPackedXKernel:データの並べ替え
packAssignmentsKernel で各 assignment の書き込み先インデックスを確定し、copyPackedXKernel で実データを x_sorted へコピーする。これにより、x_sorted は同一の expert に送られる token が連続して並ぶレイアウト(expert-major)となる。
FC1:CUTLASS Grouped GEMM
本関数の主目的である。runCutlassGroupedGemm を呼び出し、全 expert の行列積を一括実行する。
各 expert e において、A_e = [M_e, D]、B_e = [D, H] となり、結果 D_e = [M_e, H] を得る。
setupGroupedGemmMetaKernel では、CUTLASS が要求する各 expert の問題サイズ(Problem Size)やポインタの配列を構築する。
- Arch: Sm80 (Ampere)
- 精度: 入出力 FP16、Accumulator FP32
- Epilogue:
LinearCombination(beta=0)
addBiasGeluKernel:FC1 の Bias と GELU
CUTLASS の Grouped GEMM は純粋な行列積のみを担当するため、Bias 加算と GELU 活性化関数は別個の kernel で実行する。この際、expert_sorted を参照して各行に対応する expert の bias を特定する。
FC2:CUTLASS Grouped GEMM
FC1 と同様に runCutlassGroupedGemm を実行する。
- 入力 (A):
h_sorted[N, H] - 重み (B):
w2[E, H, D] - 出力 (D):
y_sorted[N, D]
addBiasGateCombineKernel:FC2 Bias と Scatter-Add
最後に FC2 の出力に対して bias を加算し、gate 値を乗じた上で、元の token インデックスに基づき y へ足し込む。
atomicAdd を用いることで、1 token に対する K 個の expert 出力を正しく集約する。
まとめ
MoEのカスタムノードに対応するTensorRTプラグインを実装し、CUTLASS Grouped GEMMを使うCUDAカーネルを中心に解説した。 この実装の最適化は不十分で、Biasや活性化関数は一つのカーネルに融合する余地がある。
次回は、プラグインをロードしてTensorRTエンジンをビルドする処理を実装したい。