You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_ir_utils.h 7.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
  17. #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
  18. #include <google/protobuf/map.h>
  19. #include <google/protobuf/repeated_field.h>
  20. #include <google/protobuf/stubs/port.h>
  21. #include <graph/anchor.h>
  22. #include <graph/debug/ge_log.h>
  23. #include <graph/debug/ge_util.h>
  24. #include <graph/detail/attributes_holder.h>
  25. #include <graph/ge_tensor.h>
  26. #include <graph/graph.h>
  27. #include <graph/model.h>
  28. #include <graph/node.h>
  29. #include <graph/utils/graph_utils.h>
  30. #include <graph/utils/type_utils.h>
  31. #include <map>
  32. #include <memory>
  33. #include <sstream>
  34. #include <string>
  35. #include <utility>
  36. #include <vector>
  37. #include "proto/ge_ir.pb.h"
  38. #include "proto/onnx.pb.h"
  39. namespace ge {
  40. const int kOffsetToString = 2;
  41. ///
  42. /// @ingroup ge_ir_utils
  43. /// @brief RepeatedField->String
  44. /// @param [in] const rpd_field RepeatedField
  45. /// @return String
  46. ///
  47. template <typename T>
  48. const std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) {
  49. std::stringstream ss;
  50. ss << "[";
  51. for (const T &x : rpd_field) {
  52. ss << x;
  53. ss << ", ";
  54. }
  55. std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
  56. str_ret += "]";
  57. return str_ret;
  58. }
  59. ///
  60. /// @ingroup ge_ir_utils
  61. /// @brief RepeatedPtrField->String
  62. /// @param [in] const rpd_field RepeatedPtrField
  63. /// @return String
  64. ///
  65. template <typename T>
  66. const std::string ToString(const google::protobuf::RepeatedPtrField<T> &rpd_ptr_field) {
  67. std::stringstream ss;
  68. ss << "[";
  69. for (const T &x : rpd_ptr_field) {
  70. ss << x;
  71. ss << ", ";
  72. }
  73. std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
  74. str_ret += "]";
  75. return str_ret;
  76. }
  77. ///
  78. /// @ingroup ge_ir_utils
  79. /// @brief check, if not equal, log with tag
  80. /// @param [in] const left_value, right_value reference, log_info_tag
  81. /// @return bool
  82. ///
  83. template <typename T>
  84. bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) {
  85. if (l_value == r_value) {
  86. return true;
  87. } else {
  88. GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str());
  89. return false;
  90. }
  91. }
  92. class OnnxUtils {
  93. public:
  94. enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END };
  95. static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto);
  96. static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model);
  97. private:
  98. // Part 1: from IR convert to ONNX Protobuf
  99. static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
  100. const std::string &name, void *data);
  101. static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
  102. const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data);
  103. static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
  104. const std::string &name, ::google::protobuf::RepeatedField<bool> data);
  105. static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
  106. const std::string &name, ::google::protobuf::RepeatedField<float> data);
  107. static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
  108. const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data);
  109. static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto);
  110. static void AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
  111. onnx::NodeProto *node_proto);
  112. static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc);
  113. static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map,
  114. onnx::NodeProto *node_proto, const std::string &prefix = "",
  115. const std::string &suffix = "");
  116. static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto);
  117. static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type);
  118. static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto);
  119. static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto);
  120. static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto);
  121. static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto);
  122. static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type);
  123. static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v);
  124. static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto);
  125. /// Part 2: from ONNX Protobuf convert to IR
  126. /// Describes node's link relationships
  127. struct NodeLinkInfo {
  128. std::string src_node_name;
  129. int32_t src_out_index;
  130. NodePtr dst_node;
  131. int32_t dst_in_index;
  132. std::string dst_node_name;
  133. };
  134. // Parse node name and index
  135. static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index);
  136. static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type);
  137. static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings);
  138. static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints);
  139. static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value);
  140. static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value);
  141. static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
  142. const std::string &attr_name_for_output_desc, int32_t index,
  143. OpDescPtr &op_desc);
  144. static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
  145. const std::string &attr_name_for_input_desc, int32_t index,
  146. OpDescPtr &op_desc);
  147. static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
  148. const std::string &attr_name_for_input_output_desc, int32_t index,
  149. OpDescPtr &op_desc);
  150. static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def);
  151. static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc);
  152. static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr);
  153. static bool DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
  154. const std::map<std::string, NodePtr> &node_map);
  155. static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node);
  156. static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph);
  157. };
  158. } // namespace ge
  159. #endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示