前回、PyTorchで実装したMoE対応のSwinTransformerモデルをTensorRTで推論できるように、ONNXにエクスポートする。
最低限、grouped_mmのみをカスタムオペレータとすればよいが、router + dispatch + expert + combineを分離するとTensorRTのグラフとpluginの入出力が複雑になるため、MoE化したMLPをまとめてプラグイン化する。
ソースでは、DroplessMoEMlpをカスタムノードにする。
カスタムノード
以下のようなforwardが空のノードを定義して、symbolicにONNXのオペレータのメタ情報を定義する。
class CustomMoEFunction(torch.autograd.Function): @staticmethod def forward( ctx: Any, x: Tensor, router_w: Tensor, w1: Tensor, b1: Tensor, w2: Tensor, b2: Tensor, top_k: int, num_experts: int, in_features: int, hidden_features: int, ) -> Tensor: # During ONNX tracing the symbolic() method below emits the actual # CustomMoE node. The eager value is needed only for shape propagation, # so avoid tracing the expensive dynamic routing reference here. return torch.empty_like(x) @staticmethod def symbolic( g: Any, x: Any, router_w: Any, w1: Any, b1: Any, w2: Any, b2: Any, top_k: int, num_experts: int, in_features: int, hidden_features: int, ) -> Any: out = g.op( "trt.plugins::CustomMoE", x, router_w, w1, b1, w2, b2, top_k_i=int(top_k), num_experts_i=int(num_experts), in_features_i=int(in_features), hidden_features_i=int(hidden_features), activation_s="gelu", plugin_version_s="1", plugin_namespace_s="", outputs=1, ) out.setType(x.type()) return out
元の MoE weight を custom node に接続する wrapperを用意する。
ONNXにエクスポートする際に、パラメータを書きこむために必要である。
class PluginMoEExportWrapper(nn.Module): def __init__(self, src: nn.Module) -> None: super().__init__() self.in_features = int(src.in_features) self.hidden_features = int(src.hidden_features) self.num_experts = int(src.num_experts) self.top_k = int(src.top_k) # Keep these as registered parameters so torch.onnx.export writes them # as ONNX initializers connected to the custom node. self.router_weight = src.router.weight self.w1 = src.w1 self.b1 = src.b1 self.w2 = src.w2 self.b2 = src.b2 def forward(self, x: Tensor) -> Tensor: return CustomMoEFunction.apply( x, self.router_weight, self.w1, self.b1, self.w2, self.b2, self.top_k, self.num_experts, self.in_features, self.hidden_features, )
元のモデルのノードを置換する。
def replace_moe(model: nn.Module, model_module: Any, mode: str, fp16_experts: bool) -> int: replaced = 0 for name, child in list(model.named_modules()): if isinstance(child, model_module.DroplessMoEMlp): if fp16_experts: child.w1.data = child.w1.data.half() child.b1.data = child.b1.data.half() child.w2.data = child.w2.data.half() child.b2.data = child.b2.data.half() child.router.float() wrapper: nn.Module if mode == "plugin": wrapper = PluginMoEExportWrapper(child) elif mode == "dense": wrapper = DenseMoEExportWrapper(child) else: raise ValueError(f"Unsupported mode: {mode}") set_child(model, name, wrapper) replaced += 1 return replaced
コード全体は、以下の場所にある。
grouped_mmを使わずDense MoEとして出力するdenseモードも実装している。
ONNXエクスポート結果
ONNXにエクスポートした結果を、Netronで確認すると、CustomMoEノードが出力されていることが確認できた。

なお、denseモードでエクスポートすると複雑なグラフがエキスパートの数分だけ出力される。

まとめ
dropless MoEを実装したSwinTransformerをカスタムノードを使用してONNXにエクスポートする処理を実装した。
エクスポート結果で、MoEを実装したMLPがカスタムノードとして出力されていることが確認できた。
次は、TensorRTのプラグインを実装したい。