TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dropless MoE(Mixture of Experts)を試す その4(エンジンビルド)

前回、実装したプラグインを使用して、カスタムノードを含むONNXをTensorRT の serialized engineに変換する。

dlshogiでは、C++でエンジンビルドを実装しているが、今回はPythonのtensorrtライブラリを使用して実装する。 生成したserialized engineを後でC++の推論プログラムからロードする。

エンジンビルド

TensorRT の serialized engine とは、TensorRT が ONNX などのモデルから生成したバイナリで、以下を含む。

  • TensorRT が最適化した network graph
  • 選択済みの CUDA kernel / tactic
  • layer fusion の結果
  • dynamic shape 用 optimization profile
  • weight
  • plugin layer の情報

実行環境に最適化するため、TensorRT version、GPU architecture、plugin ABI、CUDA 環境に依存する。

変換スクリプト

build_engine.py

"""Build a TensorRT engine from ONNX.

For plugin-mode ONNX, load libcustom_moe_plugin.so before parsing.
"""
from __future__ import annotations

import argparse
import ctypes
from pathlib import Path

import tensorrt as trt


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--onnx", type=Path, required=True)
    p.add_argument("--engine", type=Path, required=True)
    p.add_argument("--plugin", type=Path, default=None, help="Path to libcustom_moe_plugin.so")
    p.add_argument("--min-batch", type=int, default=1)
    p.add_argument("--opt-batch", type=int, default=8)
    p.add_argument("--max-batch", type=int, default=32)
    p.add_argument("--channels", type=int, default=3)
    p.add_argument("--height", type=int, default=32)
    p.add_argument("--width", type=int, default=32)
    p.add_argument("--fp16", action="store_true")
    p.add_argument("--workspace-gb", type=float, default=4.0)
    p.add_argument("--version-compatible", action="store_true")
    args = p.parse_args()

    logger = trt.Logger(trt.Logger.INFO)
    trt.init_libnvinfer_plugins(logger, "")
    if args.plugin is not None:
        ctypes.CDLL(args.plugin.as_posix(), mode=ctypes.RTLD_GLOBAL)
        print(f"[INFO] loaded plugin: {args.plugin}")

    explicit_batch = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    builder = trt.Builder(logger)
    network = builder.create_network(explicit_batch)
    parser = trt.OnnxParser(network, logger)

    data = args.onnx.read_bytes()
    if not parser.parse(data):
        print("[ERROR] ONNX parse failed")
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        raise SystemExit(1)

    config = builder.create_builder_config()
    if hasattr(config, "set_memory_pool_limit"):
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, int(args.workspace_gb * (1 << 30)))
    else:
        config.max_workspace_size = int(args.workspace_gb * (1 << 30))

    if args.fp16 and builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("[INFO] enabled FP16")
    elif args.fp16:
        print("[WARN] requested FP16, but platform_has_fast_fp16 is false")

    if args.version_compatible and hasattr(trt.BuilderFlag, "VERSION_COMPATIBLE"):
        config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
        if hasattr(parser, "get_used_vc_plugin_libraries") and hasattr(config, "set_plugins_to_serialize"):
            libs = parser.get_used_vc_plugin_libraries()
            if libs:
                config.set_plugins_to_serialize(libs)
                print(f"[INFO] serializing plugin libraries into engine: {libs}")

    inp = network.get_input(0)
    profile = builder.create_optimization_profile()
    min_shape = (args.min_batch, args.channels, args.height, args.width)
    opt_shape = (args.opt_batch, args.channels, args.height, args.width)
    max_shape = (args.max_batch, args.channels, args.height, args.width)
    profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)
    print(f"[INFO] optimization profile for {inp.name}: min={min_shape}, opt={opt_shape}, max={max_shape}")

    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("TensorRT engine build failed")
    args.engine.parent.mkdir(parents=True, exist_ok=True)
    args.engine.write_bytes(bytes(serialized))
    print(f"[OK] wrote {args.engine}")


if __name__ == "__main__":
    main()

解説

build_engine.py は、export_moe_onnx.py で出力した ONNX を TensorRT の serialized engine.engine ファイル)に変換するスクリプトである。特に plugin モードの ONNX には CustomMoE カスタムノードが含まれるため、ONNX のパース前に libcustom_moe_plugin.so をロードする役割も担っている。

1. 役割

このスクリプトの処理フローは、概念的に以下の通りである。

ONNX
  ↓
TensorRT OnnxParser
  ↓
TensorRT network
  ↓
Builder + BuilderConfig + OptimizationProfile
  ↓
Serialized TensorRT engine
  ↓
.engine file

plugin モードの場合は、さらに前処理として以下の流れが加わる。

libcustom_moe_plugin.so をロード
  ↓
CustomMoEPluginCreator が TensorRT plugin registry に登録
  ↓
ONNX parser が trt.plugins::CustomMoE を plugin layer として解決

2. CLI 引数

冒頭で argparse によって定義されている引数は以下の通りである。

p.add_argument("--onnx", type=Path, required=True)
p.add_argument("--engine", type=Path, required=True)
p.add_argument("--plugin", type=Path, default=None)
p.add_argument("--min-batch", type=int, default=1)
p.add_argument("--opt-batch", type=int, default=8)
p.add_argument("--max-batch", type=int, default=32)
p.add_argument("--channels", type=int, default=3)
p.add_argument("--height", type=int, default=32)
p.add_argument("--width", type=int, default=32)
p.add_argument("--fp16", action="store_true")
p.add_argument("--workspace-gb", type=float, default=4.0)
p.add_argument("--version-compatible", action="store_true")

主な引数の意味を以下にまとめる。

引数 意味
--onnx 入力 ONNX ファイル
--engine 出力 TensorRT engine ファイル
--plugin libcustom_moe_plugin.so のパス
--min-batch dynamic batch の最小値
--opt-batch TensorRT が最適化の基準とする batch サイズ
--max-batch dynamic batch の最大値
--channels 入力チャンネル数(CIFAR-10 の場合は通常 3)
--height 入力画像の高さ
--width 入力画像の幅
--fp16 TensorRT の FP16 ビルドフラグを有効化
--workspace-gb 戦略選択やビルドに使用するワークスペースの上限
--version-compatible TensorRT のバージョン互換エンジン機能を試行

3. logger と plugin の初期化

まず、TensorRT logger を作成する。

logger = trt.Logger(trt.Logger.INFO)

その後、標準の TensorRT プラグインを初期化する。

trt.init_libnvinfer_plugins(logger, "")

さらに --plugin が指定されている場合は、ctypes.CDLL を用いて .so ファイルをロードする。

if args.plugin is not None:
    ctypes.CDLL(args.plugin.as_posix(), mode=ctypes.RTLD_GLOBAL)
    print(f"[INFO] loaded plugin: {args.plugin}")

この工程は plugin モードにおいて極めて重要である。export_moe_onnx.py の plugin モードで作成された ONNX には trt.plugins::CustomMoE ノードが含まれている。そのため、パース前に libcustom_moe_plugin.so をロードしておかなければ、TensorRT parser は CustomMoE を解決できなくなる。

RTLD_GLOBAL を指定する理由は、ロードした共有ライブラリのシンボルやプラグインクリエイターをプロセス全体から参照可能にするためである。

4. explicit batch network の作成

次に、TensorRT の builder、network、ONNX parser を作成する。

explicit_batch = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(logger)
network = builder.create_network(explicit_batch)
parser = trt.OnnxParser(network, logger)

このコードでは explicit batch mode を採用している。これは現代の TensorRT における標準的な形式であり、入力シェイプを [N, C, H, W] のようにバッチ次元を含めて明示的に扱う。

export_moe_onnx.py ではバッチ軸を dynamic に設定して ONNX をエクスポートしているため、TensorRT 側でも explicit batch と optimization profile の組み合わせが必要となる。


5. ONNX のパース

ONNX ファイルをバイト列として読み込み、parser に渡す。

data = args.onnx.read_bytes()
if not parser.parse(data):
    print("[ERROR] ONNX parse failed")
    for i in range(parser.num_errors):
        print(parser.get_error(i))
    raise SystemExit(1)

パースに失敗した場合は、エラー内容をすべて表示して処理を終了する。plugin モードで失敗する典型的な原因は以下の通りである。

  • --plugin が指定されていない
  • libcustom_moe_plugin.so のロードに失敗している
  • ONNX のカスタムオペレーション名とプラグイン名が一致していない
  • プラグインのバージョンやネームスペースが一致していない
  • TensorRT がカスタムオペレーションの属性を解釈できない

6. BuilderConfig とワークスペース

パース完了後、エンジンのビルド設定(config)を作成する。

config = builder.create_builder_config()

ワークスペースの上限設定については、TensorRT のバージョンによる API の差異を吸収するため、2 通りの方法に対応させている。

if hasattr(config, "set_memory_pool_limit"):
    config.set_memory_pool_limit(
        trt.MemoryPoolType.WORKSPACE,
        int(args.workspace_gb * (1 << 30))
    )
else:
    config.max_workspace_size = int(args.workspace_gb * (1 << 30))

ワークスペースは、ビルド時の戦略選択や実行時の一時バッファとして利用されるメモリ領域である。値を大きくすれば必ずしも高速化するわけではないが、不足すると最適な戦略が選ばれず、ビルドの失敗や性能低下を招く。

7. FP16 フラグ

--fp16 が指定され、かつ GPU が高速な FP16 演算に対応している場合に限り、FP16 フラグを有効化する。

if args.fp16 and builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)
    print("[INFO] enabled FP16")
elif args.fp16:
    print("[WARN] requested FP16, but platform_has_fast_fp16 is false")

このフラグは、TensorRT に対して「可能なレイヤーや戦略において FP16 実行を許可する」ことを伝えるものである。

なお、MoE プラグインについては、x/outputexpert weights/biases は FP16 または FP32、router_w は FP32 である必要がある。FP16 での実行を最適化するには、ONNX エクスポート時に --fp16-experts を使用し、あらかじめ重みを FP16 に変換しておく設計が推奨される。

8. バージョン互換エンジン

--version-compatible オプションが指定された場合の処理は以下の通りである。

if args.version_compatible and hasattr(trt.BuilderFlag, "VERSION_COMPATIBLE"):
    config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
    if hasattr(parser, "get_used_vc_plugin_libraries") and hasattr(config, "set_plugins_to_serialize"):
        libs = parser.get_used_vc_plugin_libraries()
        if libs:
            config.set_plugins_to_serialize(libs)
            print(f"[INFO] serializing plugin libraries into engine: {libs}")

これはエンジンの互換性を高め、必要に応じてプラグインライブラリの情報をエンジン内にシリアライズする機能である。ただし、動作は TensorRT のバージョンやプラグインの実装に依存するため、通常の開発フェーズではこのフラグなしでビルドを行う方がトラブルシューティングは容易である。

9. Optimization Profile

ONNX 側でバッチ次元が dynamic になっているため、ビルドには optimization profile の設定が不可欠である。

inp = network.get_input(0)
profile = builder.create_optimization_profile()

min_shape = (args.min_batch, args.channels, args.height, args.width)
opt_shape = (args.opt_batch, args.channels, args.height, args.width)
max_shape = (args.max_batch, args.channels, args.height, args.width)

profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

opt_shape は TensorRT が最適化の際に最も重視する形状である。実運用でバッチサイズ 1 の推論がメインであれば --opt-batch 1 とし、複数のバッチを処理するならデフォルトの 8 程度が妥当な設定となる。

10. エンジンのビルドと保存

最後に、エンジンをビルドする。

serialized = builder.build_serialized_network(network, config)
if serialized is None:
    raise RuntimeError("TensorRT engine build failed")

ビルドが成功すれば、指定されたパスにバイト列として .engine ファイルを保存する。

args.engine.parent.mkdir(parents=True, exist_ok=True)
args.engine.write_bytes(bytes(serialized))
print(f"[OK] wrote {args.engine}")

このエンジンファイルは、C++ 側の infer.cpp などで deserializeCudaEngine() を用いてロードされる。推論時にも、デシリアライズの前にビルド時と同じプラグイン .so をロードしておく必要がある点に注意したい。

11. 実行例

plugin モードの ONNX からエンジンを作成する場合

python build_engine.py \
  --onnx model_moe_plugin.onnx \
  --engine model_moe_plugin.engine \
  --plugin ./libcustom_moe_plugin.so \
  --min-batch 1 \
  --opt-batch 8 \
  --max-batch 32 \
  --workspace-gb 4

FP16 エンジンを作成する場合

python build_engine.py \
  --onnx model_moe_plugin_fp16.onnx \
  --engine model_moe_plugin_fp16.engine \
  --plugin ./libcustom_moe_plugin.so \
  --fp16 \
  --min-batch 1 \
  --opt-batch 8 \
  --max-batch 32

### dense モードの ONNX の場合

`CustomMoE` プラグインノードを含まないため、`--plugin` の指定は不要である。

python build_engine.py \ --onnx model_moe_dense.onnx \ --engine model_moe_dense.engine \ --fp16

## 12. 制限事項と注意点

* **入力数は 1 個を想定:** コード上、0 番目の入力テンソルのみを取得しているため、複数入力モデルには拡張が必要である。
* **バッチ以外の次元は固定:** `C/H/W` は引数で固定されており、同一エンジンで解像度を動的に切り替える設計にはなっていない。
* **プロファイルは 1 つ:** 複数のバッチ範囲や解像度を 1 つのエンジンに含める場合は、複数のプロファイルを追加する実装が必要である。
* **プラグインのロード順:** `ctypes.CDLL()` は、必ず `parser.parse()` よりも前に実行される必要がある。
* **環境の整合性:** `libcustom_moe_plugin.so` は、ビルド環境と実行環境の間で TensorRT、CUDA、コンパイラの ABI が一致していなければならない。


## 13. 処理順のまとめ

1. CLI 引数の読み込み
2. TensorRT logger の作成
3. TensorRT 標準プラグインの初期化
4. 必要に応じた `libcustom_moe_plugin.so` のロード
5. explicit batch network の作成
6. ONNX parser によるネットワークの構築
7. BuilderConfig の作成
8. ワークスペース上限の設定
9. FP16 フラグの設定(任意)
10. バージョン互換フラグの設定(任意)
11. Optimization Profile の設定
12. エンジンのビルド実行
13. `.engine` ファイルとしての保存

# まとめ
カスタムノードを含むONNXからTensorRT の serialized engineを生成するエンジンビルドスクリプトについて解説した。
次は、C++でTensorRTを使用した推論処理を実装したい。