/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/utils/ge_ir_utils.h" #include #include "framework/common/debug/ge_log.h" namespace { const char *const kControlAnchorIndex = ":-1"; const char *const kNodeTypeForSubgraph = "subgraph"; const char *const kPrefixForInputDesc = "input_desc_attr_"; const char *const kPrefixForOutputDesc = "output_desc_attr_"; const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; const int8_t kMaxRecursionDepth = 10; const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; const int64_t kInputPrefixLength = 5; const int64_t kOutputPrefixLength = 6; using AttrDefPair = ::google::protobuf::MapPair; } // namespace namespace ge { // Part 1: from IR convert to ONNX Protobuf static const std::map kGeDataTypeToOnnxMap = { {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, }; onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { auto it = kGeDataTypeToOnnxMap.find(data_type); if (it != kGeDataTypeToOnnxMap.end()) { return it->second; } else { GELOGW("EncodeDataType: datatype not support %u", data_type); return onnx::TensorProto_DataType_UNDEFINED; } } void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, onnx::NodeProto *node_proto) { if (node_proto == nullptr) { GELOGE(FAILED, "Node proto is nullptr."); return; } auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } auto attr_name = string_attr_value.first; attr->set_name(attr_name); auto attr_value = string_attr_value.second; auto value_type = attr_value.GetValueType(); switch (value_type) { case GeAttrValue::VT_FLOAT: { GeAttrValue::FLOAT data_f = 0; (void)attr_value.GetValue(data_f); attr->set_f(data_f); attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); break; } case GeAttrValue::VT_LIST_FLOAT: { GeAttrValue::LIST_FLOAT data_fs = {}; (void)attr_value.GetValue(data_fs); attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); for (auto &v : data_fs) { attr->add_floats(v); } break; } case GeAttrValue::VT_INT: { GeAttrValue::INT data_i = 0; (void)attr_value.GetValue(data_i); attr->set_type(onnx::AttributeProto_AttributeType_INT); attr->set_i(data_i); break; } case GeAttrValue::VT_LIST_INT: { GeAttrValue::LIST_INT data_is = {}; (void)attr_value.GetValue(data_is); attr->set_type(onnx::AttributeProto_AttributeType_INTS); for (auto &v : data_is) { attr->add_ints(v); } break; } case GeAttrValue::VT_STRING: { GeAttrValue::STR data_s; (void)attr_value.GetValue(data_s); attr->set_type(onnx::AttributeProto_AttributeType_STRING); attr->set_s(data_s); break; } case GeAttrValue::VT_LIST_STRING: { GeAttrValue::LIST_STR data_ss = {}; (void)attr_value.GetValue(data_ss); attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); for (auto &v : data_ss) { attr->add_strings(v); } break; } default: GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type); break; } } void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, void *data) { if (node_proto == nullptr) { GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); return; } auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } attr->set_name(name); switch (type) { case onnx::AttributeProto_AttributeType_FLOAT: attr->set_f((*(static_cast(data)))); attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); break; case onnx::AttributeProto_AttributeType_FLOATS: attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); for (auto &v : (*(static_cast *>(data)))) { attr->add_floats(v); } break; case onnx::AttributeProto_AttributeType_INT: attr->set_type(onnx::AttributeProto_AttributeType_INT); attr->set_i((*(static_cast(data)))); break; case onnx::AttributeProto_AttributeType_INTS: attr->set_type(onnx::AttributeProto_AttributeType_INTS); for (auto &v : *(static_cast *>(data))) { attr->add_ints(v); } break; case onnx::AttributeProto_AttributeType_STRING: attr->set_type(onnx::AttributeProto_AttributeType_STRING); attr->set_s((*(static_cast(data)))); break; case onnx::AttributeProto_AttributeType_STRINGS: attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); for (auto &v : *(static_cast *>(data))) { attr->add_strings(v); } break; default: GELOGW("AttributeProto AttributeType: %u is not supported for now", type); break; } } void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { if (node_proto == nullptr) { GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); return; } if (!data.empty()) { auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } attr->set_name(name); for (auto &v : data) { attr->add_ints(v); } attr->set_type(type); } } void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, ::google::protobuf::RepeatedField data) { if (node_proto == nullptr) { GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); return; } if (!data.empty()) { auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } attr->set_name(name); for (auto &v : data) { attr->add_ints(static_cast(v)); } attr->set_type(type); } } void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, ::google::protobuf::RepeatedField data) { if (node_proto == nullptr) { GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); return; } if (!data.empty()) { auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } attr->set_name(name); for (auto &v : data) { attr->add_floats(v); } attr->set_type(type); } } void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, ::google::protobuf::RepeatedPtrField<::std::string> data) { if (node_proto == nullptr) { GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); return; } if (!data.empty()) { auto attr = node_proto->add_attribute(); if (attr == nullptr) { GELOGE(GRAPH_FAILED, "attr is nullptr."); return; } attr->set_name(name); for (auto &v : data) { attr->add_strings(v); } attr->set_type(type); } } void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) { if (node_proto == nullptr || op_desc == nullptr) { GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr"); return; } // Input describes auto size_in = op_desc->GetAllInputsSize(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in); if (size_in > 0) { for (uint32_t i = 0; i < size_in; i++) { auto input_desc = op_desc->GetInputDescPtrDfault(i); if (input_desc != nullptr) { auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), &data_type); auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); auto dims = input_desc->GetShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), &dims); auto dims_origin = input_desc->GetOriginShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_origin_shape:" + std::to_string(i), &dims_origin); auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i), &layout); auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_origin_layout:" + std::to_string(i), &layout_origin); auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg(); if (tensor_descriptor != nullptr) { auto size = tensor_descriptor->size(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i), &size); auto weight_size = tensor_descriptor->weight_size(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_weight_size:" + std::to_string(i), &weight_size); auto reuse_input = tensor_descriptor->reuse_input(); auto reuse_input_int = static_cast(reuse_input); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int); auto output_tensor = tensor_descriptor->output_tensor(); auto output_tensor_int = static_cast(output_tensor); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int); auto device_type = tensor_descriptor->device_type(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_device_type:" + std::to_string(i), &device_type); auto input_tensor = tensor_descriptor->input_tensor(); auto input_tensor_int = static_cast(input_tensor); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int); auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); auto data_offset = tensor_descriptor->data_offset(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_data_offset:" + std::to_string(i), &data_offset); auto cmps_size = tensor_descriptor->cmps_size(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i), &cmps_size); auto cmps_tab = tensor_descriptor->cmps_tab(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab); auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); const auto &tensor_desc_map = tensor_descriptor->attr(); std::string suffix = ":" + std::to_string(i); AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); } else { GELOGW("Tensor descriptor is nullptr"); continue; } } else { GELOGW("Input desc is nullptr"); continue; } } } // Output describes auto size_out = op_desc->GetOutputsSize(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out); if (size_out > 0) { for (uint32_t i = 0; i < size_out; i++) { auto output_desc = op_desc->GetOutputDescPtr(i); if (output_desc != nullptr) { auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), &data_type); auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); auto dims = output_desc->GetShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), &dims); auto dims_origin = output_desc->GetOriginShape().GetDims(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_origin_shape:" + std::to_string(i), &dims_origin); auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i), &layout); auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_origin_layout:" + std::to_string(i), &layout_origin); auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg(); if (tensor_descriptor != nullptr) { auto size = tensor_descriptor->size(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i), &size); auto weight_size = tensor_descriptor->weight_size(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_weight_size:" + std::to_string(i), &weight_size); auto device_type = tensor_descriptor->device_type(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_device_type:" + std::to_string(i), &device_type); auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); const auto &tensor_desc_map = tensor_descriptor->attr(); std::string suffix = ":" + std::to_string(i); AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); } else { GELOGW("Tensor descriptor is nullptr"); continue; } } else { GELOGW("Output desc is nullptr"); continue; } } } } void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, const std::string &prefix, const std::string &suffix) { for (const auto &item : attr_map) { auto attr_name = item.first; auto attr_def = item.second; auto attr_type = attr_def.value_case(); if (attr_type == ge::proto::AttrDef::kT) { const auto &tensor_def = attr_def.t(); const auto &tensor_desc = tensor_def.desc(); auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, &data_type); auto dims = tensor_desc.shape().dim(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, dims); auto layout = tensor_desc.layout(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, &layout); auto device_type = tensor_desc.device_type(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_device_type" + suffix, &device_type); if (kDumpLevel == DUMP_ALL) { auto data = tensor_def.data(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, &data); } } if (attr_type == ge::proto::AttrDef::kS) { if (kDumpLevel == DUMP_ALL) { auto str_value = attr_def.s(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); } } if (attr_type == ge::proto::AttrDef::kI) { auto int_value = attr_def.i(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); } if (attr_type == ge::proto::AttrDef::kF) { auto float_value = attr_def.f(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); } if (attr_type == ge::proto::AttrDef::kB) { auto int_value = static_cast(attr_def.b()); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); } if (attr_type == ge::proto::AttrDef::kList) { const auto &list_value = attr_def.list(); auto list_value_type = list_value.val_type(); if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { if (kDumpLevel == DUMP_ALL) { const auto &strings = list_value.s(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); } } if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { const auto &floats = list_value.f(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); } if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { const auto &ints = list_value.i(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); } if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { const auto &bools = list_value.b(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); } } } } void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "node is nullptr"); return; } // 1.Attributes added from node's methods auto send_list = node->send_event_id_list_; if (!send_list.empty()) { AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); } auto recv_list = node->recv_event_id_list_; if (!recv_list.empty()) { AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); } auto op_desc = node->op_; if (op_desc != nullptr) { // for input_name_idx_ in opdesc auto input_name_2_indexs = op_desc->GetAllInputName(); ::google::protobuf::RepeatedPtrField<::std::string> input_names; ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; for (const auto &input_name_2_index : input_name_2_indexs) { std::string input_name = input_name_2_index.first; input_names.Add(std::move(input_name)); input_indexes.Add(input_name_2_index.second); } AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); // 2.Attributes added from node's op_(message OpDef) // Input and out describes AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); // Others auto op_def = op_desc->op_def_.GetProtoMsg(); if (op_def != nullptr) { auto id = op_def->id(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); auto stream_id = op_def->stream_id(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); const auto &input_name = op_def->input_name(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name); const auto &src_name = op_def->src_name(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name); const auto &src_index = op_def->src_index(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index); const auto &dst_name = op_def->dst_name(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name); const auto &dst_index = op_def->dst_index(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index); const auto &input_i = op_def->input_i(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i); const auto &output_i = op_def->output_i(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i); const auto &workspace = op_def->workspace(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace); const auto &workspace_bytes = op_def->workspace_bytes(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); const auto &is_input_const = op_def->is_input_const(); AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); const auto &op_def_attr_map = op_def->attr(); AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); } else { GELOGE(FAILED, "Opdef is nullptr"); return; } } else { GELOGE(FAILED, "Opdesc is nullptr"); return; } } bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) { if ((node == nullptr) || (node_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid"); return false; } // 2.Encode map attrs_ to AttributeProto for (auto &node_attr : node->attrs_) { AddAttrProtoFromAttribute(node_attr, node_proto); } // 3.Encode ge::Node members to AttributeProto AddAttrProtoFromNodeMembers(node, node_proto); return true; } void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) { if ((node == nullptr) || (node_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); return; } const auto &node_name = node->GetName(); for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); } } auto out_control_anchor = node->GetOutControlAnchor(); if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { node_proto->add_output(node_name + kControlAnchorIndex); } } bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) { if ((node == nullptr) || (node_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid"); return false; } node_proto->clear_input(); // 1. Add input by in data edge for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + std::to_string(peer_out_anchor->GetIdx())); } else { // Add "" input node_proto->add_input(""); } } // 2. Add input by in control edge auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); for (const auto &peer_out_anchor : peer_out_anchors) { if (peer_out_anchor->GetOwnerNode()) { node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); } } } else { GELOGE(FAILED, "Incontrol anchor is nullptr"); return false; } // 3. Add output for Netron visual support EncodeNodeLinkForNetronVisual(node, node_proto); return true; } bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) { if ((node == nullptr) || (node_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid"); return false; } // 1. Encode name and type node_proto->set_name(node->GetName()); /// Netron believes that some operators, such as the activation operator of softplus, only have one input, /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix /// is added to correctly display the link relation at the expense of some color features node_proto->set_op_type("ge:" + node->GetType()); if (kDumpLevel != DUMP_WITH_OUT_DESC) { // 2.for attr if (!EncodeNodeDesc(node, node_proto)) { GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str()); return false; } } // 3.for link info return EncodeNodeLink(node, node_proto); } void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) { if ((node == nullptr) || (tensor_type == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); return; } const auto &op_desc = node->GetOpDesc(); if (op_desc != nullptr) { uint32_t size_out = static_cast(op_desc->GetOutputsSize()); if (size_out > 0) { for (uint32_t i = 0; i < size_out; i++) { const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); if (ge_tensor != nullptr) { auto ge_data_type = ge_tensor->GetDataType(); auto onnx_data_type = EncodeDataType(ge_data_type); tensor_type->set_elem_type(onnx_data_type); onnx::TensorShapeProto *shape = tensor_type->mutable_shape(); if (shape != nullptr) { for (auto d : ge_tensor->GetShape().GetDims()) { auto dim = shape->add_dim(); dim->set_dim_value(d); } } else { GELOGW("Shape is nullptr"); continue; } } else { GELOGW("Ge tensor is nullptr"); continue; } } } } else { GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str()); return; } } void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) { if ((node == nullptr) || (value_info_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid"); return; } value_info_proto->set_name(node->GetName()); onnx::TypeProto *t = value_info_proto->mutable_type(); onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type(); EncodeTypeProtoTensorType(node, tensor_type); } bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) { if ((graph == nullptr) || (graph_proto == nullptr)) { GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid"); return false; } graph_proto->set_name(graph->GetName()); // 1. Add graph inputs for (const auto &input : graph->GetInputNodes()) { auto value_info_proto = graph_proto->add_input(); EncodeValueInfo(input, value_info_proto); } // 2. Add graph outputs for (const auto &output : graph->GetOutputNodes()) { auto value_info_proto = graph_proto->add_output(); EncodeValueInfo(output, value_info_proto); } // 3. Add nodes for (const auto &node : graph->GetDirectNode()) { if (!EncodeNode(node, graph_proto->add_node())) { GELOGW("EncodeNode failed"); continue; } } return true; } bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { model_proto.set_model_version(model.GetVersion()); model_proto.set_ir_version(onnx::IR_VERSION); model_proto.set_producer_name(model.GetName()); auto &graph = model.graph_; auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph == nullptr) { GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr"); return false; } auto graph_proto = model_proto.mutable_graph(); if (graph_proto == nullptr) { GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str()); return false; } if (!EncodeGraph(compute_graph, graph_proto)) { GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str()); return false; } // For subgraphs: a subgraph is represented by a node for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { if (sub_compute_graph != nullptr) { auto node_proto = graph_proto->add_node(); if (node_proto == nullptr) { GELOGW("Node proto is nullptr"); continue; } node_proto->set_name(sub_compute_graph->GetName()); node_proto->set_op_type(kNodeTypeForSubgraph); auto attr = node_proto->add_attribute(); attr->set_name("graph"); attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto sub_graph_proto = attr->mutable_g(); if (sub_graph_proto == nullptr) { GELOGW("Sub graph proto is nullptr"); continue; } if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); continue; } } else { GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); continue; } } return true; } // Part 2: from ONNX Protobuf convert to IR static std::map onnxDataTypeToGeMap = { {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, }; ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { auto it = onnxDataTypeToGeMap.find(data_type); if (it != onnxDataTypeToGeMap.end()) { return it->second; } else { GELOGW("DecodeDataType: datatype not support %u", data_type); return ge::DT_UNDEFINED; } } bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) { auto sep = node_name_index.rfind(':'); if (sep == std::string::npos) { return false; } node_name = node_name_index.substr(0, sep); auto index_str = node_name_index.substr(sep + 1); index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); return true; } bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) { if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr"); return false; } // Data edge if (item.src_out_index >= 0) { auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index); auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, item.dst_node_name.c_str(), item.dst_in_index); return false; } if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed"); return false; } // Control edge } else { auto src_anchor = node_ptr->GetOutControlAnchor(); auto dst_anchor = item.dst_node->GetInControlAnchor(); if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, item.dst_node_name.c_str(), item.dst_in_index); return false; } if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed"); return false; } } return true; } bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, const std::map &node_map) { for (const auto &node_proto : node_proto_vector) { const auto &node_name = node_proto.name(); auto dst_node = node_map.find(node_name); if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str()); return false; } int32_t dst_index = 0; for (const auto &input : node_proto.input()) { std::string input_node_name; int32_t index = 0; if (ParseNameIndex(input, input_node_name, index)) { auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; auto src_node = node_map.find(input_node_name); if (src_node == node_map.end()) { GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str()); return false; } auto node_ptr = src_node->second; if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str()); return false; } if (!DecodeNodeLinkImp(item, node_ptr)) { GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str()); return false; } } if (index >= 0) { dst_index++; } } } return true; } void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings) { if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) { GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); return; } for (int i = 0; i < attr_proto.strings_size(); i++) { strings.push_back(attr_proto.strings(i)); } } void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) { if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) { GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); return; } value = attr_proto.s(); } void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints) { if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) { GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); return; } for (int i = 0; i < attr_proto.ints_size(); i++) { ints.push_back(attr_proto.ints(i)); } } void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) { if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) { GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); return; } value = attr_proto.i(); } void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_input_desc, int32_t index, OpDescPtr &op_desc) { if (op_desc->MutableInputDesc(static_cast(index)) == nullptr) { GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast(index)) is nullptr", op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); return; } if (attr_name_for_input_desc == "input_desc_dtype") { auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); op_desc->MutableInputDesc(static_cast(index))->SetDataType(data_type); } else if (attr_name_for_input_desc == "input_desc_shape") { std::vector ints; DecodeAttribute(attr_proto, ints); GeShape ge_shape(ints); op_desc->MutableInputDesc(static_cast(index))->SetShape(ge_shape); } else if (attr_name_for_input_desc == "input_desc_layout") { auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); op_desc->MutableInputDesc(static_cast(index))->SetFormat(data_format); } else if (attr_name_for_input_desc == "input_desc_origin_shape") { std::vector ints; DecodeAttribute(attr_proto, ints); GeShape ge_shape(ints); op_desc->MutableInputDesc(static_cast(index))->SetOriginShape(ge_shape); } else if (attr_name_for_input_desc == "input_desc_origin_layout") { auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); op_desc->MutableInputDesc(static_cast(index))->SetOriginFormat(data_format); } else if (attr_name_for_input_desc == "input_desc_size") { int64_t input_size = 0; auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); DecodeAttribute(attr_proto, input_size); tensor_descriptor->set_size(input_size); } else if (attr_name_for_input_desc == "input_desc_data_offset") { auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); int64_t offset = 0; DecodeAttribute(attr_proto, offset); tensor_descriptor->set_data_offset(offset); } else { return; } } void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_output_desc, int32_t index, OpDescPtr &op_desc) { if (op_desc->MutableOutputDesc(static_cast(index)) == nullptr) { GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast(index)) is nullptr", op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); return; } if (attr_name_for_output_desc == "output_desc_dtype") { auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); op_desc->MutableOutputDesc(static_cast(index))->SetDataType(data_type); } else if (attr_name_for_output_desc == "output_desc_shape") { std::vector ints; DecodeAttribute(attr_proto, ints); GeShape ge_shape(ints); op_desc->MutableOutputDesc(static_cast(index))->SetShape(ge_shape); } else if (attr_name_for_output_desc == "output_desc_layout") { auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); op_desc->MutableOutputDesc(static_cast(index))->SetFormat(data_format); } else if (attr_name_for_output_desc == "output_desc_origin_shape") { std::vector ints; DecodeAttribute(attr_proto, ints); GeShape ge_shape(ints); op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(ge_shape); } else if (attr_name_for_output_desc == "output_desc_origin_layout") { auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); op_desc->MutableOutputDesc(static_cast(index))->SetOriginFormat(data_format); } else if (attr_name_for_output_desc == "output_desc_size") { int64_t output_size = 0; auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); DecodeAttribute(attr_proto, output_size); tensor_descriptor->set_size(output_size); } else if (attr_name_for_output_desc == "output_desc_data_offset") { auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); int64_t offset = 0; DecodeAttribute(attr_proto, offset); tensor_descriptor->set_data_offset(offset); } else { return; } } void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_input_output_desc, int32_t index, OpDescPtr &op_desc) { if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "op_desc is nullptr"); return; } if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); } else { return; } } void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { auto attr_map = op_def.mutable_attr(); const auto &attr_name = attr_proto.name(); ge::proto::AttrDef op_attr; int64_t value = 0; DecodeAttribute(attr_proto, value); op_attr.set_i(value); attr_map->insert(AttrDefPair(attr_name, op_attr)); } void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); return; } const auto &attr_name = attr_proto.name(); std::string attr_name_for_input_output_desc; int32_t index = 0; if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) { if (attr_name == "id") { op_desc->SetId(attr_proto.i()); } else if (attr_name == "stream_id") { op_desc->SetStreamId(attr_proto.i()); } else if (attr_name == "src_name") { std::vector strings; DecodeAttribute(attr_proto, strings); op_desc->SetSrcName(strings); } else if (attr_name == "dst_name") { std::vector strings; DecodeAttribute(attr_proto, strings); op_desc->SetDstName(strings); } else if (attr_name == "src_index") { std::vector ints; DecodeAttribute(attr_proto, ints); op_desc->SetSrcIndex(ints); } else if (attr_name == "dst_index") { std::vector ints; DecodeAttribute(attr_proto, ints); op_desc->SetDstIndex(ints); } else if (attr_name == "fusion_scope") { DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); } else if (attr_name == "input_i") { std::vector ints; DecodeAttribute(attr_proto, ints); op_desc->SetInputOffset(ints); } else if (attr_name == "output_i") { std::vector ints; DecodeAttribute(attr_proto, ints); op_desc->SetOutputOffset(ints); } else { return; } // Update input and output desc } else { DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); } } bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) { if (op_desc == nullptr || node_proto == nullptr) { GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr"); return false; } // 1. Decode node_proto name and type op_desc->SetName(node_proto->name()); const auto &node_type_with_ge_prefix = node_proto->op_type(); auto sep = node_type_with_ge_prefix.find(':'); if (sep == std::string::npos) { return false; } auto node_type = node_type_with_ge_prefix.substr(sep + 1); op_desc->SetType(node_type); // 2. Add empty input and output desc for (const auto &attr : node_proto->attribute()) { if (attr.name() == "input_desc_nums") { auto size_in = attr.i(); for (int64_t i = 0; i < size_in; i++) { GeTensorDesc ge_tensor_desc; GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); } } if (attr.name() == "output_desc_nums") { auto size_out = attr.i(); for (int64_t i = 0; i < size_out; i++) { GeTensorDesc ge_tensor_desc; GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); } } } // 3.Decode node_proto attributes for (int i = 0; i < node_proto->attribute_size(); i++) { DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); } return true; } bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { if (recursion_depth > kMaxRecursionDepth) { GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort"); return false; } graph = ComGraphMakeShared(graph_proto.name()); GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); /// 1. Decode all nodes first, node should include input /// and output nodes and nodes which represent sub graphs std::map node_map; std::vector node_proto_vector; for (const auto &node_proto : graph_proto.node()) { // a. nodes represent sub graphs if (node_proto.op_type() == kNodeTypeForSubgraph) { ComputeGraphPtr compute_graph; // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH const auto &node_attr = node_proto.attribute(0); if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) { (void)graph->AddSubGraph(compute_graph); } else { GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), node_attr.type()); return false; } // b. direct nodes in graph } else { node_proto_vector.push_back(node_proto); OpDescPtr op_desc = ComGraphMakeShared(); // b.1 For node desc if (!DecodeNodeDesc(&node_proto, op_desc)) { GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str()); return false; } auto node = graph->AddNode(op_desc); node_map.insert(std::make_pair(node_proto.name(), node)); } } /// We get all nodes in graph here /// b.2 For node link if (!DecodeNodeLink(node_proto_vector, node_map)) { GELOGE(GRAPH_FAILED, "Decode node link failed"); return false; } // 2. Add inputs nodes for graph for (const auto &input : graph_proto.input()) { const auto &input_node_name = input.name(); auto input_node_item = node_map.find(input_node_name); if (input_node_item == node_map.end()) { GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str()); return false; } auto ret = graph->AddInputNode(input_node_item->second); GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed"); } // 3. Add outputs nodes for graph for (const auto &output : graph_proto.output()) { const auto &output_node_name = output.name(); auto output_node_item = node_map.find(output_node_name); if (output_node_item == node_map.end()) { GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str()); return false; } auto ret = graph->AddOutputNode(output_node_item->second); if (ret == nullptr) { GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str()); continue; } } return true; } bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) { model.name_ = model_proto.producer_name(); model.version_ = static_cast(model_proto.model_version()); auto &graph_proto = model_proto.graph(); ComputeGraphPtr compute_graph; // 0 means recursion depth, father call if (!DecodeGraph(0, graph_proto, compute_graph)) { GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed"); return false; } model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph); return true; } } // namespace ge