dlshogiは現在TensorRTを使用しているが、TensorRTは再配布できないなどライセンスが厳しいのと、環境構築が大変なため、OpenCLに対応させたいと思っている。
OpenCLでも、NVIDIAのGPUのTensorCoreをPTXインラインアセンブリという方法で使用することができるため、上手に実装すればパフォーマンスは同等にできる見込みである。
KataGoのv1.5.0で、OpenCLでのTensorCoreサポートが実装されているため、ソースを流用すればなんとか実装できると思っている。
ただし、活性化関数にswishを使っていたりといった違いがあるため、そのままでは流用できず、KataGoのOpenCLのソースを理解した上で取り込む必要があり、なかなかハードルが高い。
また、dlshogiはモデルファイルのフォーマットにONNXを採用しているため、モデルファイルをパースして、グラフ定義とパラメータを取り出す処理も実装が必要である。
ということで初めの一歩として、ONNXファイルのパース処理を試してみた。
ONNXファイルのフォーマット
ONNXファイルのフォーマットは、公式サイトで公開されている。
onnx/IR.md at main · onnx/onnx · GitHub
フォーマットは、Protocol Buffersで定義されており、onnx.proto3から、各言語向けのパース用コードが生成できる。
Protocol Buffersのインストール
WindowsにC++向けProtocol Buffersのインストールを行う。
インストール方法は、公式のページに記載がある。
protobuf/src at master · protocolbuffers/protobuf · GitHub
「C++ Installation - Windows」の記載内容にしたがって、vcpkgを使用してインストールする。
vcpkgのインストールは、vcpkgのページに従って行う。
GitHub - microsoft/vcpkg: C++ Library Manager for Windows, Linux, and MacOS
スタティックビルドを行いため、vcpkgからインストールする際に
x64-windows-static
を指定する。
インストールコマンドは以下の通りになる。
>vcpkg install protobuf protobuf:x64-windows-static
.proto3からC++に変換
.proto3からC++のコードへの変換は、protocというツールで行う。
protocは以下の場所にインストールされている。
C:\vcpkg\packages\protobuf_x64-windows\tools\protobuf\protoc.exe
onnxのソースをgit cloneして、以下のコマンドを実行するとC++のコードが生成される。
>git clone https://github.com/onnx/onnx.git >cd onnx >C:\vcpkg\packages\protobuf_x64-windows\tools\protobuf\protoc.exe --cpp_out=. onnx.prot3
生成されるソースコード
- onnx.proto3.pb.cc
- onnx.proto3.pb.h
コンパイル
生成されたソースをビルドするには、Protocol Buffersのインクルードディレクトリとライブラリディレクトリの設定が必要である。
Vissual Studioのプロジェクトの設定に、
インクルードディレクトリ:「C:\vcpkg\installed\x64-windows-static\include」
ライブラリディレクトリ:「C:\vcpkg\installed\x64-windows-static\lib」
を追加する。
また、リンクするライブラリに、「libprotobuf.lib」を追加する。
ビルド構成がDEBUGの場合は、
ライブラリディレクトリを「C:\\vcpkg\\installed\\x64-windows-static\\debug\\lib」に変更して、リンクするライブラリを「libprotobufd.lib」にする。
また、スタティックリンクするようにランタイムライブラリをスタティック版に設定する。
コンパイルエラーが起きる問題
生成されたソースコードをVisual C++でコンパイルすると以下の箇所でコンパイルエラーが発生する。
::PROTOBUF_NAMESPACE_ID::internal::memswap< PROTOBUF_FIELD_OFFSET(AttributeProto, type_) + sizeof(AttributeProto::type_) - PROTOBUF_FIELD_OFFSET(AttributeProto, t_)>( reinterpret_cast<char*>(&t_), reinterpret_cast<char*>(&other->t_));
※同様の処理が何か所かある。
memswapのテンプレート引数が定数でないために発生している。
PROTOBUF_FIELD_OFFSETというマクロ定義の中で、reinterpret_castを使用しており、定数扱いにならないことが原因であった。
Protocol Buffersのバグのようだが、生成されたソースコードの方を修正することにした。
上記の箇所を、テンプレート引数が定数になるように以下のように修正した。
switch (PROTOBUF_FIELD_OFFSET(AttributeProto, type_) + sizeof(AttributeProto::type_) - PROTOBUF_FIELD_OFFSET(AttributeProto, t_)) { case 0: ::PROTOBUF_NAMESPACE_ID::internal::memswap<0>( reinterpret_cast<char*>(&t_), reinterpret_cast<char*>(&other->t_)); break; case 2: ::PROTOBUF_NAMESPACE_ID::internal::memswap<2>( reinterpret_cast<char*>(&t_), reinterpret_cast<char*>(&other->t_)); break; case 8: ::PROTOBUF_NAMESPACE_ID::internal::memswap<8>( reinterpret_cast<char*>(&t_), reinterpret_cast<char*>(&other->t_)); break; case 16: ::PROTOBUF_NAMESPACE_ID::internal::memswap<16>( reinterpret_cast<char*>(&t_), reinterpret_cast<char*>(&other->t_)); break; }
※同様の修正を複数個所に行う。
パースのサンプルプログラム
ONNXファイルをパースして、内容を表示するサンプルプログラムを作成した。
#include <iostream> #include <fstream> #include "onnx.proto3.pb.h" int main(int argc, char* argv[]) { GOOGLE_PROTOBUF_VERIFY_VERSION; if (argc != 2) { std::cerr << "Usage: " << argv[0] << " onnx" << std::endl; return -1; } onnx::ModelProto model; std::fstream in(argv[1], std::ios::in | std::ios::binary); model.ParseFromIstream(&in); std::cout << model.ir_version() << "\n" << model.producer_name() << std::endl; const auto& graph = model.graph(); std::cout << "---- inputs ----" << std::endl; for (int i = 0; i < graph.input_size(); i++) { const auto& input = graph.input(i); std::cout << input.name() << "\t"; const auto& tensor_type = input.type().tensor_type(); const auto& shape = tensor_type.shape(); std::cout << tensor_type.elem_type() << "["; for (int n = 0; n < shape.dim_size(); n++) { if (n != 0) std::cout << ","; std::cout << shape.dim(n).dim_value(); } std::cout << "]" << std::endl; } std::cout << "---- outputs ----" << std::endl; for (int i = 0; i < graph.output_size(); i++) { const auto& output = graph.output(i); std::cout << output.name() << "\t"; const auto& tensor_type = output.type().tensor_type(); const auto& shape = tensor_type.shape(); std::cout << tensor_type.elem_type() << "["; for (int n = 0; n < shape.dim_size(); n++) { if (n != 0) std::cout << ","; std::cout << shape.dim(n).dim_value(); } std::cout << "]" << std::endl; } std::cout << "---- nodes ----" << std::endl; for (int i = 0; i < graph.node_size(); i++) { const auto& node = graph.node(i); std::cout << node.name() << "\tinputs["; for (int n = 0; n < node.input_size(); n++) { if (n != 0) std::cout << ","; std::cout << node.input(n); } std::cout << "]\toutputs["; for (int n = 0; n < node.output_size(); n++) { if (n != 0) std::cout << ","; std::cout << node.output(n); } std::cout << "]\n"; } std::cout << "---- initializers ----" << std::endl; for (int i = 0; i < graph.initializer_size(); i++) { const auto& initializer = graph.initializer(i); std::cout << initializer.name() << "\t"; std::cout << initializer.data_type() << ":" << initializer.dims_size() << "["; for (int n = 0; n < initializer.dims_size(); n++) { if (n != 0) std::cout << ","; std::cout << initializer.dims(n); } std::cout << "]" << std::endl; } return 0; }
dlshogiのモデルファイルを入力にして実行すると、以下のように表示される。
6 pytorch ---- inputs ---- input1 1[0,62,9,9] input2 1[0,57,9,9] ---- outputs ---- output_policy 1[0,2187] output_value 1[0,1] ---- nodes ---- Conv_0 inputs[input1,l1_1_1.weight] outputs[142] Conv_1 inputs[input1,l1_1_2.weight] outputs[143] Conv_2 inputs[input2,l1_2.weight] outputs[144] Add_3 inputs[142,143] outputs[145] Add_4 inputs[145,144] outputs[146] BatchNormalization_5 inputs[146,norm1.weight,norm1.bias,norm1.running_mean,norm1.running_var] outputs[147] Sigmoid_6 inputs[147] outputs[148] Mul_7 inputs[147,148] outputs[149] Conv_8 inputs[149,256,257] outputs[255] Sigmoid_9 inputs[255] outputs[152] Mul_10 inputs[255,152] outputs[153] Conv_11 inputs[153,259,260] outputs[258] Add_12 inputs[258,149] outputs[156] Sigmoid_13 inputs[156] outputs[157] Mul_14 inputs[156,157] outputs[158] Conv_15 inputs[158,262,263] outputs[261] Sigmoid_16 inputs[261] outputs[161] Mul_17 inputs[261,161] outputs[162] Conv_18 inputs[162,265,266] outputs[264] Add_19 inputs[264,158] outputs[165] Sigmoid_20 inputs[165] outputs[166] Mul_21 inputs[165,166] outputs[167] Conv_22 inputs[167,268,269] outputs[267] Sigmoid_23 inputs[267] outputs[170] Mul_24 inputs[267,170] outputs[171] Conv_25 inputs[171,271,272] outputs[270] Add_26 inputs[270,167] outputs[174] Sigmoid_27 inputs[174] outputs[175] Mul_28 inputs[174,175] outputs[176] Conv_29 inputs[176,274,275] outputs[273] Sigmoid_30 inputs[273] outputs[179] Mul_31 inputs[273,179] outputs[180] Conv_32 inputs[180,277,278] outputs[276] Add_33 inputs[276,176] outputs[183] Sigmoid_34 inputs[183] outputs[184] Mul_35 inputs[183,184] outputs[185] Conv_36 inputs[185,280,281] outputs[279] Sigmoid_37 inputs[279] outputs[188] Mul_38 inputs[279,188] outputs[189] Conv_39 inputs[189,283,284] outputs[282] Add_40 inputs[282,185] outputs[192] Sigmoid_41 inputs[192] outputs[193] Mul_42 inputs[192,193] outputs[194] Conv_43 inputs[194,286,287] outputs[285] Sigmoid_44 inputs[285] outputs[197] Mul_45 inputs[285,197] outputs[198] Conv_46 inputs[198,289,290] outputs[288] Add_47 inputs[288,194] outputs[201] Sigmoid_48 inputs[201] outputs[202] Mul_49 inputs[201,202] outputs[203] Conv_50 inputs[203,292,293] outputs[291] Sigmoid_51 inputs[291] outputs[206] Mul_52 inputs[291,206] outputs[207] Conv_53 inputs[207,295,296] outputs[294] Add_54 inputs[294,203] outputs[210] Sigmoid_55 inputs[210] outputs[211] Mul_56 inputs[210,211] outputs[212] Conv_57 inputs[212,298,299] outputs[297] Sigmoid_58 inputs[297] outputs[215] Mul_59 inputs[297,215] outputs[216] Conv_60 inputs[216,301,302] outputs[300] Add_61 inputs[300,212] outputs[219] Sigmoid_62 inputs[219] outputs[220] Mul_63 inputs[219,220] outputs[221] Conv_64 inputs[221,304,305] outputs[303] Sigmoid_65 inputs[303] outputs[224] Mul_66 inputs[303,224] outputs[225] Conv_67 inputs[225,307,308] outputs[306] Add_68 inputs[306,221] outputs[228] Sigmoid_69 inputs[228] outputs[229] Mul_70 inputs[228,229] outputs[230] Conv_71 inputs[230,310,311] outputs[309] Sigmoid_72 inputs[309] outputs[233] Mul_73 inputs[309,233] outputs[234] Conv_74 inputs[234,313,314] outputs[312] Add_75 inputs[312,230] outputs[237] Sigmoid_76 inputs[237] outputs[238] Mul_77 inputs[237,238] outputs[239] Conv_78 inputs[239,l22.weight] outputs[240] Constant_79 inputs[] outputs[241] Reshape_80 inputs[240,241] outputs[242] Add_81 inputs[242,l22_2.bias] outputs[output_policy] Conv_82 inputs[239,316,317] outputs[315] Sigmoid_83 inputs[315] outputs[246] Mul_84 inputs[315,246] outputs[247] Constant_85 inputs[] outputs[248] Reshape_86 inputs[247,248] outputs[249] Gemm_87 inputs[249,l23_v.weight,l23_v.bias] outputs[250] Sigmoid_88 inputs[250] outputs[251] Mul_89 inputs[250,251] outputs[252] Gemm_90 inputs[252,l24_v.weight,l24_v.bias] outputs[253] Sigmoid_91 inputs[253] outputs[output_value] ---- initializers ---- 256 1:4[192,192,3,3] 257 1:1[192] 259 1:4[192,192,3,3] 260 1:1[192] 262 1:4[192,192,3,3] 263 1:1[192] 265 1:4[192,192,3,3] 266 1:1[192] 268 1:4[192,192,3,3] 269 1:1[192] 271 1:4[192,192,3,3] 272 1:1[192] 274 1:4[192,192,3,3] 275 1:1[192] 277 1:4[192,192,3,3] 278 1:1[192] 280 1:4[192,192,3,3] 281 1:1[192] 283 1:4[192,192,3,3] 284 1:1[192] 286 1:4[192,192,3,3] 287 1:1[192] 289 1:4[192,192,3,3] 290 1:1[192] 292 1:4[192,192,3,3] 293 1:1[192] 295 1:4[192,192,3,3] 296 1:1[192] 298 1:4[192,192,3,3] 299 1:1[192] 301 1:4[192,192,3,3] 302 1:1[192] 304 1:4[192,192,3,3] 305 1:1[192] 307 1:4[192,192,3,3] 308 1:1[192] 310 1:4[192,192,3,3] 311 1:1[192] 313 1:4[192,192,3,3] 314 1:1[192] 316 1:4[27,192,1,1] 317 1:1[27] l1_1_1.weight 1:4[192,62,3,3] l1_1_2.weight 1:4[192,62,1,1] l1_2.weight 1:4[192,57,1,1] l22.weight 1:4[27,192,1,1] l22_2.bias 1:1[2187] l23_v.bias 1:1[256] l23_v.weight 1:2[256,2187] l24_v.bias 1:1[1] l24_v.weight 1:2[1,256] norm1.bias 1:1[192] norm1.running_mean 1:1[192] norm1.running_var 1:1[192] norm1.weight 1:1[192]
まとめ
OpenCL対応の準備として、ONNXをパースする処理を試した。
パースすることができるようになったので、ONNXファイルからパラメータを取り出して、OpenCLの計算に使用することができる。
ニューラルネットワークの各層の処理は、ONNXファイルのnode一覧の上から順に計算していけばよい。
nodeのOPコードに対応するOpenCLのカーネルを呼び出していけばよいが、KataGoのOpenCLのカーネルは畳み込みとBNとReLUをまとめていたりするので、汎用的に対応するには何パターンもカーネルを実装しておく必要がある。
現実的には現在のdlshogiで使用するものだけ対応するようになりそうである。
次は、KataGoのOpenCLのソースを流用して簡単なテストをしてみるつもりである。