TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

ONNXファイルのパース処理

dlshogiは現在TensorRTを使用しているが、TensorRTは再配布できないなどライセンスが厳しいのと、環境構築が大変なため、OpenCLに対応させたいと思っている。

OpenCLでも、NVIDIAGPUのTensorCoreをPTXインラインアセンブリという方法で使用することができるため、上手に実装すればパフォーマンスは同等にできる見込みである。
KataGoのv1.5.0で、OpenCLでのTensorCoreサポートが実装されているため、ソースを流用すればなんとか実装できると思っている。
ただし、活性化関数にswishを使っていたりといった違いがあるため、そのままでは流用できず、KataGoのOpenCLのソースを理解した上で取り込む必要があり、なかなかハードルが高い。

また、dlshogiはモデルファイルのフォーマットにONNXを採用しているため、モデルファイルをパースして、グラフ定義とパラメータを取り出す処理も実装が必要である。

ということで初めの一歩として、ONNXファイルのパース処理を試してみた。

ONNXファイルのフォーマット

ONNXファイルのフォーマットは、公式サイトで公開されている。
onnx/IR.md at main · onnx/onnx · GitHub

フォーマットは、Protocol Buffersで定義されており、onnx.proto3から、各言語向けのパース用コードが生成できる。

Protocol Buffersのインストール

WindowsC++向けProtocol Buffersのインストールを行う。

インストール方法は、公式のページに記載がある。
protobuf/src at master · protocolbuffers/protobuf · GitHub

C++ Installation - Windows」の記載内容にしたがって、vcpkgを使用してインストールする。
vcpkgのインストールは、vcpkgのページに従って行う。
GitHub - microsoft/vcpkg: C++ Library Manager for Windows, Linux, and MacOS

スタティックビルドを行いため、vcpkgからインストールする際に
x64-windows-static
を指定する。
インストールコマンドは以下の通りになる。

>vcpkg install protobuf protobuf:x64-windows-static

.proto3からC++に変換

.proto3からC++のコードへの変換は、protocというツールで行う。
protocは以下の場所にインストールされている。

C:\vcpkg\packages\protobuf_x64-windows\tools\protobuf\protoc.exe

onnxのソースをgit cloneして、以下のコマンドを実行するとC++のコードが生成される。

>git clone https://github.com/onnx/onnx.git
>cd onnx
>C:\vcpkg\packages\protobuf_x64-windows\tools\protobuf\protoc.exe --cpp_out=. onnx.prot3

生成されるソースコード

  • onnx.proto3.pb.cc
  • onnx.proto3.pb.h

コンパイル

生成されたソースをビルドするには、Protocol Buffersのインクルードディレクトリとライブラリディレクトリの設定が必要である。
Vissual Studioのプロジェクトの設定に、
インクルードディレクトリ:「C:\vcpkg\installed\x64-windows-static\include」
ライブラリディレクトリ:「C:\vcpkg\installed\x64-windows-static\lib」
を追加する。
また、リンクするライブラリに、「libprotobuf.lib」を追加する。

ビルド構成がDEBUGの場合は、
ライブラリディレクトリを「C:\\vcpkg\\installed\\x64-windows-static\\debug\\lib」に変更して、リンクするライブラリを「libprotobufd.lib」にする。

また、スタティックリンクするようにランタイムライブラリをスタティック版に設定する。

コンパイルエラーが起きる問題

生成されたソースコードをVisual C++コンパイルすると以下の箇所でコンパイルエラーが発生する。

  ::PROTOBUF_NAMESPACE_ID::internal::memswap<
      PROTOBUF_FIELD_OFFSET(AttributeProto, type_)
      + sizeof(AttributeProto::type_)
      - PROTOBUF_FIELD_OFFSET(AttributeProto, t_)>(
          reinterpret_cast<char*>(&t_),
          reinterpret_cast<char*>(&other->t_));

※同様の処理が何か所かある。

memswapのテンプレート引数が定数でないために発生している。
PROTOBUF_FIELD_OFFSETというマクロ定義の中で、reinterpret_castを使用しており、定数扱いにならないことが原因であった。

Protocol Buffersのバグのようだが、生成されたソースコードの方を修正することにした。
上記の箇所を、テンプレート引数が定数になるように以下のように修正した。

  switch (PROTOBUF_FIELD_OFFSET(AttributeProto, type_)
      + sizeof(AttributeProto::type_)
      - PROTOBUF_FIELD_OFFSET(AttributeProto, t_)) {
  case 0:
      ::PROTOBUF_NAMESPACE_ID::internal::memswap<0>(
              reinterpret_cast<char*>(&t_),
              reinterpret_cast<char*>(&other->t_));
      break;
  case 2:
      ::PROTOBUF_NAMESPACE_ID::internal::memswap<2>(
          reinterpret_cast<char*>(&t_),
          reinterpret_cast<char*>(&other->t_));
      break;
  case 8:
      ::PROTOBUF_NAMESPACE_ID::internal::memswap<8>(
          reinterpret_cast<char*>(&t_),
          reinterpret_cast<char*>(&other->t_));
      break;
  case 16:
      ::PROTOBUF_NAMESPACE_ID::internal::memswap<16>(
          reinterpret_cast<char*>(&t_),
          reinterpret_cast<char*>(&other->t_));
      break;
  }

※同様の修正を複数個所に行う。

パースのサンプルプログラム

ONNXファイルをパースして、内容を表示するサンプルプログラムを作成した。

#include <iostream>
#include <fstream>
#include "onnx.proto3.pb.h"

int main(int argc, char* argv[]) {
	GOOGLE_PROTOBUF_VERIFY_VERSION;

	if (argc != 2) {
		std::cerr << "Usage:  " << argv[0] << " onnx" << std::endl;
		return -1;
	}

	onnx::ModelProto model;

	std::fstream in(argv[1], std::ios::in | std::ios::binary);
	model.ParseFromIstream(&in);

	std::cout
		<< model.ir_version() << "\n"
		<< model.producer_name() << std::endl;

	const auto& graph = model.graph();

	std::cout << "---- inputs ----" << std::endl;
	for (int i = 0; i < graph.input_size(); i++) {
		const auto& input = graph.input(i);
		std::cout << input.name() << "\t";
		const auto& tensor_type = input.type().tensor_type();
		const auto& shape = tensor_type.shape();
		std::cout << tensor_type.elem_type() << "[";
		for (int n = 0; n < shape.dim_size(); n++) {
			if (n != 0) std::cout << ",";
			std::cout << shape.dim(n).dim_value();
		}
		std::cout << "]" << std::endl;
	}

	std::cout << "---- outputs ----" << std::endl;
	for (int i = 0; i < graph.output_size(); i++) {
		const auto& output = graph.output(i);
		std::cout << output.name() << "\t";
		const auto& tensor_type = output.type().tensor_type();
		const auto& shape = tensor_type.shape();
		std::cout << tensor_type.elem_type() << "[";
		for (int n = 0; n < shape.dim_size(); n++) {
			if (n != 0) std::cout << ",";
			std::cout << shape.dim(n).dim_value();
		}
		std::cout << "]" << std::endl;
	}

	std::cout << "---- nodes ----" << std::endl;
	for (int i = 0; i < graph.node_size(); i++) {
		const auto& node = graph.node(i);
		std::cout << node.name() << "\tinputs[";
		for (int n = 0; n < node.input_size(); n++) {
			if (n != 0) std::cout << ",";
			std::cout << node.input(n);
		}
		std::cout << "]\toutputs[";
		for (int n = 0; n < node.output_size(); n++) {
			if (n != 0) std::cout << ",";
			std::cout << node.output(n);
		}
		std::cout << "]\n";
	}

	std::cout << "---- initializers ----" << std::endl;
	for (int i = 0; i < graph.initializer_size(); i++) {
		const auto& initializer = graph.initializer(i);
		std::cout << initializer.name() << "\t";
		std::cout << initializer.data_type() << ":" << initializer.dims_size() << "[";
		for (int n = 0; n < initializer.dims_size(); n++) {
			if (n != 0) std::cout << ",";
			std::cout << initializer.dims(n);
		}
		std::cout << "]" << std::endl;
	}

	return 0;
}

dlshogiのモデルファイルを入力にして実行すると、以下のように表示される。

6
pytorch
---- inputs ----
input1  1[0,62,9,9]
input2  1[0,57,9,9]
---- outputs ----
output_policy   1[0,2187]
output_value    1[0,1]
---- nodes ----
Conv_0  inputs[input1,l1_1_1.weight]    outputs[142]
Conv_1  inputs[input1,l1_1_2.weight]    outputs[143]
Conv_2  inputs[input2,l1_2.weight]      outputs[144]
Add_3   inputs[142,143] outputs[145]
Add_4   inputs[145,144] outputs[146]
BatchNormalization_5    inputs[146,norm1.weight,norm1.bias,norm1.running_mean,norm1.running_var]        outputs[147]
Sigmoid_6       inputs[147]     outputs[148]
Mul_7   inputs[147,148] outputs[149]
Conv_8  inputs[149,256,257]     outputs[255]
Sigmoid_9       inputs[255]     outputs[152]
Mul_10  inputs[255,152] outputs[153]
Conv_11 inputs[153,259,260]     outputs[258]
Add_12  inputs[258,149] outputs[156]
Sigmoid_13      inputs[156]     outputs[157]
Mul_14  inputs[156,157] outputs[158]
Conv_15 inputs[158,262,263]     outputs[261]
Sigmoid_16      inputs[261]     outputs[161]
Mul_17  inputs[261,161] outputs[162]
Conv_18 inputs[162,265,266]     outputs[264]
Add_19  inputs[264,158] outputs[165]
Sigmoid_20      inputs[165]     outputs[166]
Mul_21  inputs[165,166] outputs[167]
Conv_22 inputs[167,268,269]     outputs[267]
Sigmoid_23      inputs[267]     outputs[170]
Mul_24  inputs[267,170] outputs[171]
Conv_25 inputs[171,271,272]     outputs[270]
Add_26  inputs[270,167] outputs[174]
Sigmoid_27      inputs[174]     outputs[175]
Mul_28  inputs[174,175] outputs[176]
Conv_29 inputs[176,274,275]     outputs[273]
Sigmoid_30      inputs[273]     outputs[179]
Mul_31  inputs[273,179] outputs[180]
Conv_32 inputs[180,277,278]     outputs[276]
Add_33  inputs[276,176] outputs[183]
Sigmoid_34      inputs[183]     outputs[184]
Mul_35  inputs[183,184] outputs[185]
Conv_36 inputs[185,280,281]     outputs[279]
Sigmoid_37      inputs[279]     outputs[188]
Mul_38  inputs[279,188] outputs[189]
Conv_39 inputs[189,283,284]     outputs[282]
Add_40  inputs[282,185] outputs[192]
Sigmoid_41      inputs[192]     outputs[193]
Mul_42  inputs[192,193] outputs[194]
Conv_43 inputs[194,286,287]     outputs[285]
Sigmoid_44      inputs[285]     outputs[197]
Mul_45  inputs[285,197] outputs[198]
Conv_46 inputs[198,289,290]     outputs[288]
Add_47  inputs[288,194] outputs[201]
Sigmoid_48      inputs[201]     outputs[202]
Mul_49  inputs[201,202] outputs[203]
Conv_50 inputs[203,292,293]     outputs[291]
Sigmoid_51      inputs[291]     outputs[206]
Mul_52  inputs[291,206] outputs[207]
Conv_53 inputs[207,295,296]     outputs[294]
Add_54  inputs[294,203] outputs[210]
Sigmoid_55      inputs[210]     outputs[211]
Mul_56  inputs[210,211] outputs[212]
Conv_57 inputs[212,298,299]     outputs[297]
Sigmoid_58      inputs[297]     outputs[215]
Mul_59  inputs[297,215] outputs[216]
Conv_60 inputs[216,301,302]     outputs[300]
Add_61  inputs[300,212] outputs[219]
Sigmoid_62      inputs[219]     outputs[220]
Mul_63  inputs[219,220] outputs[221]
Conv_64 inputs[221,304,305]     outputs[303]
Sigmoid_65      inputs[303]     outputs[224]
Mul_66  inputs[303,224] outputs[225]
Conv_67 inputs[225,307,308]     outputs[306]
Add_68  inputs[306,221] outputs[228]
Sigmoid_69      inputs[228]     outputs[229]
Mul_70  inputs[228,229] outputs[230]
Conv_71 inputs[230,310,311]     outputs[309]
Sigmoid_72      inputs[309]     outputs[233]
Mul_73  inputs[309,233] outputs[234]
Conv_74 inputs[234,313,314]     outputs[312]
Add_75  inputs[312,230] outputs[237]
Sigmoid_76      inputs[237]     outputs[238]
Mul_77  inputs[237,238] outputs[239]
Conv_78 inputs[239,l22.weight]  outputs[240]
Constant_79     inputs[]        outputs[241]
Reshape_80      inputs[240,241] outputs[242]
Add_81  inputs[242,l22_2.bias]  outputs[output_policy]
Conv_82 inputs[239,316,317]     outputs[315]
Sigmoid_83      inputs[315]     outputs[246]
Mul_84  inputs[315,246] outputs[247]
Constant_85     inputs[]        outputs[248]
Reshape_86      inputs[247,248] outputs[249]
Gemm_87 inputs[249,l23_v.weight,l23_v.bias]     outputs[250]
Sigmoid_88      inputs[250]     outputs[251]
Mul_89  inputs[250,251] outputs[252]
Gemm_90 inputs[252,l24_v.weight,l24_v.bias]     outputs[253]
Sigmoid_91      inputs[253]     outputs[output_value]
---- initializers ----
256     1:4[192,192,3,3]
257     1:1[192]
259     1:4[192,192,3,3]
260     1:1[192]
262     1:4[192,192,3,3]
263     1:1[192]
265     1:4[192,192,3,3]
266     1:1[192]
268     1:4[192,192,3,3]
269     1:1[192]
271     1:4[192,192,3,3]
272     1:1[192]
274     1:4[192,192,3,3]
275     1:1[192]
277     1:4[192,192,3,3]
278     1:1[192]
280     1:4[192,192,3,3]
281     1:1[192]
283     1:4[192,192,3,3]
284     1:1[192]
286     1:4[192,192,3,3]
287     1:1[192]
289     1:4[192,192,3,3]
290     1:1[192]
292     1:4[192,192,3,3]
293     1:1[192]
295     1:4[192,192,3,3]
296     1:1[192]
298     1:4[192,192,3,3]
299     1:1[192]
301     1:4[192,192,3,3]
302     1:1[192]
304     1:4[192,192,3,3]
305     1:1[192]
307     1:4[192,192,3,3]
308     1:1[192]
310     1:4[192,192,3,3]
311     1:1[192]
313     1:4[192,192,3,3]
314     1:1[192]
316     1:4[27,192,1,1]
317     1:1[27]
l1_1_1.weight   1:4[192,62,3,3]
l1_1_2.weight   1:4[192,62,1,1]
l1_2.weight     1:4[192,57,1,1]
l22.weight      1:4[27,192,1,1]
l22_2.bias      1:1[2187]
l23_v.bias      1:1[256]
l23_v.weight    1:2[256,2187]
l24_v.bias      1:1[1]
l24_v.weight    1:2[1,256]
norm1.bias      1:1[192]
norm1.running_mean      1:1[192]
norm1.running_var       1:1[192]
norm1.weight    1:1[192]

まとめ

OpenCL対応の準備として、ONNXをパースする処理を試した。
パースすることができるようになったので、ONNXファイルからパラメータを取り出して、OpenCLの計算に使用することができる。

ニューラルネットワークの各層の処理は、ONNXファイルのnode一覧の上から順に計算していけばよい。
nodeのOPコードに対応するOpenCLカーネルを呼び出していけばよいが、KataGoのOpenCLカーネルは畳み込みとBNとReLUをまとめていたりするので、汎用的に対応するには何パターンもカーネルを実装しておく必要がある。
現実的には現在のdlshogiで使用するものだけ対応するようになりそうである。

次は、KataGoのOpenCLのソースを流用して簡単なテストをしてみるつもりである。