TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dropless MoE(Mixture of Experts)を試す その2(ONNXエクスポート)

前回、PyTorchで実装したMoE対応のSwinTransformerモデルをTensorRTで推論できるように、ONNXにエクスポートする。

最低限、grouped_mmのみをカスタムオペレータとすればよいが、router + dispatch + expert + combineを分離するとTensorRTのグラフとpluginの入出力が複雑になるため、MoE化したMLPをまとめてプラグイン化する。
ソースでは、DroplessMoEMlpをカスタムノードにする。

カスタムノード

以下のようなforwardが空のノードを定義して、symbolicにONNXのオペレータのメタ情報を定義する。

class CustomMoEFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: Any,
        x: Tensor,
        router_w: Tensor,
        w1: Tensor,
        b1: Tensor,
        w2: Tensor,
        b2: Tensor,
        top_k: int,
        num_experts: int,
        in_features: int,
        hidden_features: int,
    ) -> Tensor:
        # During ONNX tracing the symbolic() method below emits the actual
        # CustomMoE node. The eager value is needed only for shape propagation,
        # so avoid tracing the expensive dynamic routing reference here.
        return torch.empty_like(x)

    @staticmethod
    def symbolic(
        g: Any,
        x: Any,
        router_w: Any,
        w1: Any,
        b1: Any,
        w2: Any,
        b2: Any,
        top_k: int,
        num_experts: int,
        in_features: int,
        hidden_features: int,
    ) -> Any:
        out = g.op(
            "trt.plugins::CustomMoE",
            x,
            router_w,
            w1,
            b1,
            w2,
            b2,
            top_k_i=int(top_k),
            num_experts_i=int(num_experts),
            in_features_i=int(in_features),
            hidden_features_i=int(hidden_features),
            activation_s="gelu",
            plugin_version_s="1",
            plugin_namespace_s="",
            outputs=1,
        )
        out.setType(x.type())
        return out

元の MoE weight を custom node に接続する wrapperを用意する。
ONNXにエクスポートする際に、パラメータを書きこむために必要である。

class PluginMoEExportWrapper(nn.Module):
    def __init__(self, src: nn.Module) -> None:
        super().__init__()
        self.in_features = int(src.in_features)
        self.hidden_features = int(src.hidden_features)
        self.num_experts = int(src.num_experts)
        self.top_k = int(src.top_k)
        # Keep these as registered parameters so torch.onnx.export writes them
        # as ONNX initializers connected to the custom node.
        self.router_weight = src.router.weight
        self.w1 = src.w1
        self.b1 = src.b1
        self.w2 = src.w2
        self.b2 = src.b2

    def forward(self, x: Tensor) -> Tensor:
        return CustomMoEFunction.apply(
            x,
            self.router_weight,
            self.w1,
            self.b1,
            self.w2,
            self.b2,
            self.top_k,
            self.num_experts,
            self.in_features,
            self.hidden_features,
        )

元のモデルのノードを置換する。

def replace_moe(model: nn.Module, model_module: Any, mode: str, fp16_experts: bool) -> int:
    replaced = 0
    for name, child in list(model.named_modules()):
        if isinstance(child, model_module.DroplessMoEMlp):
            if fp16_experts:
                child.w1.data = child.w1.data.half()
                child.b1.data = child.b1.data.half()
                child.w2.data = child.w2.data.half()
                child.b2.data = child.b2.data.half()
                child.router.float()
            wrapper: nn.Module
            if mode == "plugin":
                wrapper = PluginMoEExportWrapper(child)
            elif mode == "dense":
                wrapper = DenseMoEExportWrapper(child)
            else:
                raise ValueError(f"Unsupported mode: {mode}")
            set_child(model, name, wrapper)
            replaced += 1
    return replaced

コード全体は、以下の場所にある。
grouped_mmを使わずDense MoEとして出力するdenseモードも実装している。

github.com

ONNXエクスポート結果

ONNXにエクスポートした結果を、Netronで確認すると、CustomMoEノードが出力されていることが確認できた。


なお、denseモードでエクスポートすると複雑なグラフがエキスパートの数分だけ出力される。

まとめ

dropless MoEを実装したSwinTransformerをカスタムノードを使用してONNXにエクスポートする処理を実装した。
エクスポート結果で、MoEを実装したMLPがカスタムノードとして出力されていることが確認できた。
次は、TensorRTのプラグインを実装したい。