また、dlshogiはモデルファイルのフォーマットにONNXを採用しているため、モデルファイルをパースして、グラフ定義とパラメータを取り出す処理も実装が必要である。
パースのサンプルプログラム
ONNXファイルをパースして、内容を表示するサンプルプログラムを作成した。
#include <iostream>
#include <fstream>
#include "onnx.proto3.pb.h"
int main(int argc, char* argv[]) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
if (argc != 2) {
std::cerr << "Usage: " << argv[0] << " onnx" << std::endl;
return -1;
}
onnx::ModelProto model;
std::fstream in(argv[1], std::ios::in | std::ios::binary);
model.ParseFromIstream(&in);
std::cout
<< model.ir_version() << "\n"
<< model.producer_name() << std::endl;
const auto& graph = model.graph();
std::cout << "---- inputs ----" << std::endl;
for (int i = 0; i < graph.input_size(); i++) {
const auto& input = graph.input(i);
std::cout << input.name() << "\t";
const auto& tensor_type = input.type().tensor_type();
const auto& shape = tensor_type.shape();
std::cout << tensor_type.elem_type() << "[";
for (int n = 0; n < shape.dim_size(); n++) {
if (n != 0) std::cout << ",";
std::cout << shape.dim(n).dim_value();
}
std::cout << "]" << std::endl;
}
std::cout << "---- outputs ----" << std::endl;
for (int i = 0; i < graph.output_size(); i++) {
const auto& output = graph.output(i);
std::cout << output.name() << "\t";
const auto& tensor_type = output.type().tensor_type();
const auto& shape = tensor_type.shape();
std::cout << tensor_type.elem_type() << "[";
for (int n = 0; n < shape.dim_size(); n++) {
if (n != 0) std::cout << ",";
std::cout << shape.dim(n).dim_value();
}
std::cout << "]" << std::endl;
}
std::cout << "---- nodes ----" << std::endl;
for (int i = 0; i < graph.node_size(); i++) {
const auto& node = graph.node(i);
std::cout << node.name() << "\tinputs[";
for (int n = 0; n < node.input_size(); n++) {
if (n != 0) std::cout << ",";
std::cout << node.input(n);
}
std::cout << "]\toutputs[";
for (int n = 0; n < node.output_size(); n++) {
if (n != 0) std::cout << ",";
std::cout << node.output(n);
}
std::cout << "]\n";
}
std::cout << "---- initializers ----" << std::endl;
for (int i = 0; i < graph.initializer_size(); i++) {
const auto& initializer = graph.initializer(i);
std::cout << initializer.name() << "\t";
std::cout << initializer.data_type() << ":" << initializer.dims_size() << "[";
for (int n = 0; n < initializer.dims_size(); n++) {
if (n != 0) std::cout << ",";
std::cout << initializer.dims(n);
}
std::cout << "]" << std::endl;
}
return 0;
}
dlshogiのモデルファイルを入力にして実行すると、以下のように表示される。
6
pytorch
---- inputs ----
input1 1[0,62,9,9]
input2 1[0,57,9,9]
---- outputs ----
output_policy 1[0,2187]
output_value 1[0,1]
---- nodes ----
Conv_0 inputs[input1,l1_1_1.weight] outputs[142]
Conv_1 inputs[input1,l1_1_2.weight] outputs[143]
Conv_2 inputs[input2,l1_2.weight] outputs[144]
Add_3 inputs[142,143] outputs[145]
Add_4 inputs[145,144] outputs[146]
BatchNormalization_5 inputs[146,norm1.weight,norm1.bias,norm1.running_mean,norm1.running_var] outputs[147]
Sigmoid_6 inputs[147] outputs[148]
Mul_7 inputs[147,148] outputs[149]
Conv_8 inputs[149,256,257] outputs[255]
Sigmoid_9 inputs[255] outputs[152]
Mul_10 inputs[255,152] outputs[153]
Conv_11 inputs[153,259,260] outputs[258]
Add_12 inputs[258,149] outputs[156]
Sigmoid_13 inputs[156] outputs[157]
Mul_14 inputs[156,157] outputs[158]
Conv_15 inputs[158,262,263] outputs[261]
Sigmoid_16 inputs[261] outputs[161]
Mul_17 inputs[261,161] outputs[162]
Conv_18 inputs[162,265,266] outputs[264]
Add_19 inputs[264,158] outputs[165]
Sigmoid_20 inputs[165] outputs[166]
Mul_21 inputs[165,166] outputs[167]
Conv_22 inputs[167,268,269] outputs[267]
Sigmoid_23 inputs[267] outputs[170]
Mul_24 inputs[267,170] outputs[171]
Conv_25 inputs[171,271,272] outputs[270]
Add_26 inputs[270,167] outputs[174]
Sigmoid_27 inputs[174] outputs[175]
Mul_28 inputs[174,175] outputs[176]
Conv_29 inputs[176,274,275] outputs[273]
Sigmoid_30 inputs[273] outputs[179]
Mul_31 inputs[273,179] outputs[180]
Conv_32 inputs[180,277,278] outputs[276]
Add_33 inputs[276,176] outputs[183]
Sigmoid_34 inputs[183] outputs[184]
Mul_35 inputs[183,184] outputs[185]
Conv_36 inputs[185,280,281] outputs[279]
Sigmoid_37 inputs[279] outputs[188]
Mul_38 inputs[279,188] outputs[189]
Conv_39 inputs[189,283,284] outputs[282]
Add_40 inputs[282,185] outputs[192]
Sigmoid_41 inputs[192] outputs[193]
Mul_42 inputs[192,193] outputs[194]
Conv_43 inputs[194,286,287] outputs[285]
Sigmoid_44 inputs[285] outputs[197]
Mul_45 inputs[285,197] outputs[198]
Conv_46 inputs[198,289,290] outputs[288]
Add_47 inputs[288,194] outputs[201]
Sigmoid_48 inputs[201] outputs[202]
Mul_49 inputs[201,202] outputs[203]
Conv_50 inputs[203,292,293] outputs[291]
Sigmoid_51 inputs[291] outputs[206]
Mul_52 inputs[291,206] outputs[207]
Conv_53 inputs[207,295,296] outputs[294]
Add_54 inputs[294,203] outputs[210]
Sigmoid_55 inputs[210] outputs[211]
Mul_56 inputs[210,211] outputs[212]
Conv_57 inputs[212,298,299] outputs[297]
Sigmoid_58 inputs[297] outputs[215]
Mul_59 inputs[297,215] outputs[216]
Conv_60 inputs[216,301,302] outputs[300]
Add_61 inputs[300,212] outputs[219]
Sigmoid_62 inputs[219] outputs[220]
Mul_63 inputs[219,220] outputs[221]
Conv_64 inputs[221,304,305] outputs[303]
Sigmoid_65 inputs[303] outputs[224]
Mul_66 inputs[303,224] outputs[225]
Conv_67 inputs[225,307,308] outputs[306]
Add_68 inputs[306,221] outputs[228]
Sigmoid_69 inputs[228] outputs[229]
Mul_70 inputs[228,229] outputs[230]
Conv_71 inputs[230,310,311] outputs[309]
Sigmoid_72 inputs[309] outputs[233]
Mul_73 inputs[309,233] outputs[234]
Conv_74 inputs[234,313,314] outputs[312]
Add_75 inputs[312,230] outputs[237]
Sigmoid_76 inputs[237] outputs[238]
Mul_77 inputs[237,238] outputs[239]
Conv_78 inputs[239,l22.weight] outputs[240]
Constant_79 inputs[] outputs[241]
Reshape_80 inputs[240,241] outputs[242]
Add_81 inputs[242,l22_2.bias] outputs[output_policy]
Conv_82 inputs[239,316,317] outputs[315]
Sigmoid_83 inputs[315] outputs[246]
Mul_84 inputs[315,246] outputs[247]
Constant_85 inputs[] outputs[248]
Reshape_86 inputs[247,248] outputs[249]
Gemm_87 inputs[249,l23_v.weight,l23_v.bias] outputs[250]
Sigmoid_88 inputs[250] outputs[251]
Mul_89 inputs[250,251] outputs[252]
Gemm_90 inputs[252,l24_v.weight,l24_v.bias] outputs[253]
Sigmoid_91 inputs[253] outputs[output_value]
---- initializers ----
256 1:4[192,192,3,3]
257 1:1[192]
259 1:4[192,192,3,3]
260 1:1[192]
262 1:4[192,192,3,3]
263 1:1[192]
265 1:4[192,192,3,3]
266 1:1[192]
268 1:4[192,192,3,3]
269 1:1[192]
271 1:4[192,192,3,3]
272 1:1[192]
274 1:4[192,192,3,3]
275 1:1[192]
277 1:4[192,192,3,3]
278 1:1[192]
280 1:4[192,192,3,3]
281 1:1[192]
283 1:4[192,192,3,3]
284 1:1[192]
286 1:4[192,192,3,3]
287 1:1[192]
289 1:4[192,192,3,3]
290 1:1[192]
292 1:4[192,192,3,3]
293 1:1[192]
295 1:4[192,192,3,3]
296 1:1[192]
298 1:4[192,192,3,3]
299 1:1[192]
301 1:4[192,192,3,3]
302 1:1[192]
304 1:4[192,192,3,3]
305 1:1[192]
307 1:4[192,192,3,3]
308 1:1[192]
310 1:4[192,192,3,3]
311 1:1[192]
313 1:4[192,192,3,3]
314 1:1[192]
316 1:4[27,192,1,1]
317 1:1[27]
l1_1_1.weight 1:4[192,62,3,3]
l1_1_2.weight 1:4[192,62,1,1]
l1_2.weight 1:4[192,57,1,1]
l22.weight 1:4[27,192,1,1]
l22_2.bias 1:1[2187]
l23_v.bias 1:1[256]
l23_v.weight 1:2[256,2187]
l24_v.bias 1:1[1]
l24_v.weight 1:2[1,256]
norm1.bias 1:1[192]
norm1.running_mean 1:1[192]
norm1.running_var 1:1[192]
norm1.weight 1:1[192]