TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

WhisperのモデルをONNXにする

WhisperのモデルをONNXに変換する方法について記述する。

Whisperのモデル

WhisperのモデルはPyTorchを使ってPythonで実装されている。
そのため、実行にはPyTorchをインストールしたPython環境が必要になる。
環境構築なしでスタンドアロンで利用できると用途が広がる。
また、アプリへの組み込みも行いやすくなる。

ONNXモデル

ONNXは、ニューラルネットワークの標準ファイルフォーマットである。
モデルをONNXにすると、ONNX Runtimeなどの推論用のライブラリを使って推論できる。
推論用のライブラリは、組み込みで使うことを意図しているので、スタンドアロンのアプリに組み込むことができる。

ONNXへの変換

WhisperのモデルからONNXへの変換は、pytorch.onnxを使って行う。
ただし、Whisperは、デコーダのループ処理で、前の演算結果を再利用する処理で、kv_cacheという辞書型のオブジェクトを使用しているため単純には変換できない。
kv_cacheをTensor型として入力するように変更が必要になる。

実装方法は、以下の記事を参考にした。
音声認識AIのWhisperをUnreal Engineでリアルタイムに動かすためにやったこと
この記事に書かれているコードは、torch.catの部分がうまく動かなかったため、kv_cacheの領域をあらかじめ確保する方法で実装した。

エンコーダの変更

エンコーダのモデルは、そのままpytorch.onnxで変換できるが、デコーダ側の効率化のために、エンコーダで変換した特徴マップをそのまま出力するのではなく、特徴マップを、デコーダ側のクロスアテンションで使用するkeyとvalueにあらかじめ変換して返すようにする。

        n_layer_cross_k_list = []
        n_layer_cross_v_list = []
        for block in self.textDecoder.blocks:
            n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
            n_layer_cross_v_list.append(block.cross_attn.value(audio_features))

        return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
デコーダの入力変更

デコーダの入力を以下のように変更する。

    def forward(self, tokens: Tensor,
                n_layer_self_k_cache: Tensor,
                n_layer_self_v_cache: Tensor,
                n_layer_cross_k: Tensor,
                n_layer_cross_v: Tensor,
                offset: Tensor,
                ):

n_layer_self_k_cacheと、n_layer_self_v_cacheは、それぞれセルフアテンションで使うkeyとvalueで、ループのたびに末尾にkeyとvalueを追加する。
呼び出し側で、n_text_ctxの分だけ領域をあらかじめ確保して入力する。

ResidualAttentionBlockに渡す前には、スライスして渡す。

            self_k_cache = n_layer_self_k_cache[i,:,:offset[0] + tokens.shape[-1],:]
            self_v_cache = n_layer_self_v_cache[i,:,:offset[0] + tokens.shape[-1],:]
セルフアテンションの変更

セルフアテンションでは、k_cacheとv_cacheの末尾に、keyとvalueを追加する。

        k_cache[:,-k.shape[1]:,:] = k
        v_cache[:,-v.shape[1]:,:] = v
クロスアテンションの変更

クロスアテンションのkeyとvalueは、エンコーダで事前に計算したkeyとvalueを使う。
それぞれ、引数n_layer_cross_kとn_layer_cross_vでエンコーダの出力から受け取る。

ONNXへ変換

以上の変更を行い。
エンコーダ、デコーダを別々にpytorch.onnxで変換する。

torch.onnx.export(
    encoder,
    mel,
    "encoder.onnx",
    verbose=True,
    input_names=['mel'],
    output_names=['n_layer_cross_k', 'n_layer_cross_v'])
torch.onnx.export(
    decoder,
    (tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, offset),
    "decoder.onnx",
    verbose=True,
    input_names=['tokens', 'in_n_layer_self_k_cache', 'in_n_layer_self_v_cache', 'n_layer_cross_k', 'n_layer_cross_v', 'offset'],
    output_names=['logits', 'out_n_layer_self_k_cache', 'out_n_layer_self_v_cache'],
    dynamic_axes={
                'tokens' : { 0: 'n_audio', 1 : 'n_tokens' },
                'in_n_layer_self_k_cache' : { 1: 'n_audio' },
                'in_n_layer_self_v_cache' : { 1: 'n_audio' },
                })

テスト

変換したONNXで推論できるかテストする。
Whisperで音声からテキストにする処理は、モデルの推論のみだけではなく、音声をメルスペクトログラムにする処理や、モデルの出力を埋め込みから単語に変換する処理なども実装が必要である。

モデルの推論のみをテストするため、Whisperのdecoding.pyを改造して、ONNXによる推論に書き換え、改造前の結果と一致するかを確認する。

モデルの推論処理は、_get_audio_featuresと、_detect_language、_main_loopの3か所にある。

_get_audio_featuresの変更

_get_audio_featuresを以下のように変更する。

        io_binding = self.encoder_session.io_binding()
        io_binding.bind_input('mel', device_type='cuda', device_id=0, element_type=np.float32, shape=mel.shape, buffer_ptr=mel.data_ptr())
        io_binding.bind_output('n_layer_cross_k', device_type='cuda')
        io_binding.bind_output('n_layer_cross_v', device_type='cuda')

        self.encoder_session.run_with_iobinding(io_binding)

        n_layer_cross_k, n_layer_cross_v = io_binding.get_outputs()

        return n_layer_cross_k, n_layer_cross_v
_detect_languageの変更
        logits = self.model.logits(x, mel)[:, 0]

を以下のように変更する。

        n_layer_self_k_cache = onnxruntime.OrtValue.ortvalue_from_shape_and_type((len(self.model.decoder.blocks), n_audio, self.model.dims.n_text_ctx, self.model.dims.n_text_state), element_type=np.float32, device_type='cuda', device_id=0)
        n_layer_self_v_cache = onnxruntime.OrtValue.ortvalue_from_shape_and_type((len(self.model.decoder.blocks), n_audio, self.model.dims.n_text_ctx, self.model.dims.n_text_state), element_type=np.float32, device_type='cuda', device_id=0)
        offset = torch.zeros(1, dtype=torch.int64)

        io_binding = self.decoder_session.io_binding()
        io_binding.bind_input('tokens', device_type='cuda', device_id=0, element_type=np.int64, shape=x.shape, buffer_ptr=x.data_ptr())
        io_binding.bind_ortvalue_input('in_n_layer_self_k_cache', n_layer_self_k_cache)
        io_binding.bind_ortvalue_input('in_n_layer_self_v_cache', n_layer_self_v_cache)
        io_binding.bind_ortvalue_input('n_layer_cross_k', n_layer_cross_k)
        io_binding.bind_ortvalue_input('n_layer_cross_v', n_layer_cross_v)
        io_binding.bind_cpu_input('offset', offset.numpy())
        io_binding.bind_output('logits')
        io_binding.bind_output('out_n_layer_self_k_cache')
        io_binding.bind_output('out_n_layer_self_v_cache')

        self.decoder_session.run_with_iobinding(io_binding)

        logits_onnx, n_layer_self_k_cache, n_layer_self_v_cache = io_binding.get_outputs()
        logits = torch.from_numpy(logits_onnx.numpy()[:, 0])
_main_loopの変更
                logits = self.inference.logits(tokens, audio_features)

を以下のように変更する。

                if tokens.shape[-1] > self.inference.initial_token_length:
                    # only need to use the last token except in the first forward pass
                    offset = np.array([tokens.shape[1] - 1], np.int64)
                    tokens_onnx = tokens[:, -1:]
                else:
                    offset = np.zeros(1, np.int64)
                    tokens_onnx = tokens

                io_binding = self.decoder_session.io_binding()
                io_binding.bind_input('tokens', device_type='cuda', device_id=0, element_type=np.int64, shape=tokens_onnx.shape, buffer_ptr=tokens_onnx.data_ptr())
                io_binding.bind_ortvalue_input('in_n_layer_self_k_cache', n_layer_self_k_cache)
                io_binding.bind_ortvalue_input('in_n_layer_self_v_cache', n_layer_self_v_cache)
                io_binding.bind_ortvalue_input('n_layer_cross_k', n_layer_cross_k)
                io_binding.bind_ortvalue_input('n_layer_cross_v', n_layer_cross_v)
                io_binding.bind_cpu_input('offset', offset)
                io_binding.bind_output('logits')
                io_binding.bind_output('out_n_layer_self_k_cache')
                io_binding.bind_output('out_n_layer_self_v_cache')

                self.decoder_session.run_with_iobinding(io_binding)

                logits_onnx, n_layer_self_k_cache, n_layer_self_v_cache = io_binding.get_outputs()
                logits = torch.from_numpy(logits_onnx.numpy()).to(tokens.device)

テストを実行した結果、変更前と同一の音声で同じ結果になることが確認できた。

まとめ

WhisperのモデルをONNXにする方法について記述した。

PyTorchとONNXでtorch.catの動作が異なるため、試行錯誤が必要だった。
また、スライスしたTensorを更新する場合も、PyTorchはinplaceで行われるが、ONNXだと新規メモリ割り当てが行われるという違いがあり、デバッグに時間がかかった。
ONNXの内部では、デバッガが使えないため、エラーメッセージから原因を推測して、トライ&エラーを繰り返す必要があった。

ONNXに変換できたので、スタンドアロンでリアルタイムに音声をテキストに変換するツールを作成してみたい。
また、ONNXモデルの最適化も試したい。

続き
tadaoyamaoka.hatenablog.com