ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, MXNet)可以采用相同格式存储模型数据并交互。 ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方的支持ONNX。(参考维基百科)
如需自定义网络模型结构,尝试使用自定义的结构保存网络模型,首先需要解析ONNX的net结构。ONNX使用google定义的报文格式protocol buffer,用于RPC 系统和持续数据存储系统。
Protocol Buffers 是一种轻便高效的结构化数据存储格式,可以用于结构化数据串行化,或者说序列化。它很适合做数据存储或 RPC 数据交换格式。可用于通讯协议、数据存储等领域的语言无关、平台无关、可扩展的序列化结构数据格式。目前提供了 C++、Java、Python 三种语言的 API,本文使用了C++接口解析ONNX模型。
Protocol buffer结构通常会定义要给proto文件,通过ONNXgithub链接下载onnx.protohttps://github.com/onnx/onnx。解析模型用到的结构主要如下:
ModelProto:最高级别的结构,定义了整个网络模型结构;GraphProto: graph定义了模型的计算逻辑以及带有参数的node节点,组成一个有向图结构;NodeProto: 网络有向图的各个节点OP的结构,通常称为层,例如conv,relu层;AttributeProto:各OP的参数,通过该结构访问,例如:conv层的stride,dilation等;TensorProto: 序列化的tensor value,一般weight,bias等常量均保存为该种结构;TensorShapeProto:网络的输入shape以及constant输入tensor的维度信息均保存为该种结构;TypeProto:表示ONNX数据类型。具体解析流程是读取.onnx文件,获得一个model结构,通过model结构访问到graph结构,然后通过graph访问整个网络的所有node以及input,output,通过node结构可以访问到OP的参数。
下面给出解析demo:
void ReadProtoFromBinaryFile(const char* filename, google::protobuf::Message* proto) { int fd = open(filename, O_RDONLY); google::protobuf::io::FileInputStream* raw_input = new google::protobuf::io::FileInputStream(fd); google::protobuf::io::CodedInputStream* coded_input = new google::protobuf::io::CodedInputStream(raw_input); coded_input->SetTotalBytesLimit(INT_MAX, 536870912); bool success = proto->ParseFromCodedStream(coded_input); //bool success = proto->ParseFromZeroCopyStream(raw_input); delete coded_input; delete raw_input; close(fd); if (success != true) { exit(1); } } int main() { char* ch = "inception_v3.onnx"; onnx::ModelProto model_data; ReadProtoFromBinaryFile(ch, &model_data); //读取文件保存为modelproto结构 onnx::GraphProto graph = model_data.graph(); //访问graph结构 int num = graph.node_size(); //node节点个数 int input_size = graph.input_size(); //网络输入个数,input以及各层常量输入 for (int i = 0; i < input_size; ++i) { const std::string name = graph.input(i).name(); onnx::TypeProto type = graph.input(i).type(); onnx::TensorShapeProto shape = type.tensor_type().shape();//输入维度 for (int i = 0; i < shape.dim_size(); i++) { std::cout << shape.dim(i).dim_value(); std::cout<< std::endl; } } int output_size = graph.output_size(); //网络output个数 for (int i = 0; i < num; i++) //遍历每个node结构 { const onnx::NodeProto node = graph.node(i); std::string node_name = node.name(); std::cout <<"cur node name:"<< node_name << std::endl; const ::google::protobuf::RepeatedPtrField< ::onnx::AttributeProto> attr = node.attribute(); //每个node结构的参数信息 const std::string type = node.op_type(); int in_size = node.input_size(); int out_size = node.output_size(); } return 0; }
