You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_ir_utils.cc 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/ge_ir_utils.h"
  17. #include <utility>
  18. #include "framework/common/debug/ge_log.h"
  19. #include "mmpa/mmpa_api.h"
  20. namespace {
  21. const char *const kControlAnchorIndex = ":-1";
  22. const char *const kNodeTypeForSubgraph = "subgraph";
  23. const char *const kPrefixForInputDesc = "input_desc_attr_";
  24. const char *const kPrefixForOutputDesc = "output_desc_attr_";
  25. const char *const kDumpGEGraph = "DUMP_GE_GRAPH";
  26. const int8_t kMaxRecursionDepth = 10;
  27. const int kBase = 10;
  28. char kDumpGeGraph[MMPA_MAX_PATH] = { 0x00 };
  29. const int64_t kDumpLevel =
  30. (mmGetEnv(kDumpGEGraph, kDumpGeGraph, MMPA_MAX_PATH) == EN_OK) ?
  31. std::strtol(kDumpGeGraph, nullptr, kBase) : ge::OnnxUtils::NO_DUMP;
  32. const int64_t kInputPrefixLength = 5;
  33. const int64_t kOutputPrefixLength = 6;
  34. using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>;
  35. } // namespace
  36. namespace ge {
  37. // Part 1: from IR convert to ONNX Protobuf
  38. namespace{
  39. const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = {
  40. {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64},
  41. {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32},
  42. {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8},
  43. {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16},
  44. {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16},
  45. {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL},
  46. };
  47. }
  48. struct AttrNameComp {
  49. inline bool operator()(const onnx::AttributeProto &lsh, const onnx::AttributeProto &rsh) {
  50. return lsh.name() < rsh.name();
  51. }
  52. };
  53. onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) {
  54. auto it = kGeDataTypeToOnnxMap.find(data_type);
  55. if (it != kGeDataTypeToOnnxMap.end()) {
  56. return it->second;
  57. } else {
  58. GELOGW("EncodeDataType: datatype not support %u", data_type);
  59. return onnx::TensorProto_DataType_UNDEFINED;
  60. }
  61. }
  62. void OnnxUtils::AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
  63. onnx::NodeProto *node_proto) {
  64. if (node_proto == nullptr) {
  65. GELOGE(FAILED, "Node proto is nullptr.");
  66. return;
  67. }
  68. auto attr = node_proto->add_attribute();
  69. if (attr == nullptr) {
  70. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  71. return;
  72. }
  73. auto attr_name = string_attr_value.first;
  74. attr->set_name(attr_name);
  75. auto attr_value = string_attr_value.second;
  76. auto value_type = attr_value.GetValueType();
  77. switch (value_type) {
  78. case GeAttrValue::VT_FLOAT: {
  79. GeAttrValue::FLOAT data_f = 0;
  80. (void)attr_value.GetValue(data_f);
  81. attr->set_f(data_f);
  82. attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
  83. break;
  84. }
  85. case GeAttrValue::VT_LIST_FLOAT: {
  86. GeAttrValue::LIST_FLOAT data_fs = {};
  87. (void)attr_value.GetValue(data_fs);
  88. attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
  89. for (auto &v : data_fs) {
  90. attr->add_floats(v);
  91. }
  92. break;
  93. }
  94. case GeAttrValue::VT_INT: {
  95. GeAttrValue::INT data_i = 0;
  96. (void)attr_value.GetValue(data_i);
  97. attr->set_type(onnx::AttributeProto_AttributeType_INT);
  98. attr->set_i(data_i);
  99. break;
  100. }
  101. case GeAttrValue::VT_LIST_INT: {
  102. GeAttrValue::LIST_INT data_is = {};
  103. (void)attr_value.GetValue(data_is);
  104. attr->set_type(onnx::AttributeProto_AttributeType_INTS);
  105. for (auto &v : data_is) {
  106. attr->add_ints(v);
  107. }
  108. break;
  109. }
  110. case GeAttrValue::VT_STRING: {
  111. GeAttrValue::STR data_s;
  112. (void)attr_value.GetValue(data_s);
  113. attr->set_type(onnx::AttributeProto_AttributeType_STRING);
  114. attr->set_s(data_s);
  115. break;
  116. }
  117. case GeAttrValue::VT_LIST_STRING: {
  118. GeAttrValue::LIST_STR data_ss = {};
  119. (void)attr_value.GetValue(data_ss);
  120. attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
  121. for (auto &v : data_ss) {
  122. attr->add_strings(v);
  123. }
  124. break;
  125. }
  126. default:
  127. GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type);
  128. break;
  129. }
  130. }
  131. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  132. void *data) {
  133. if (node_proto == nullptr) {
  134. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  135. return;
  136. }
  137. auto attr = node_proto->add_attribute();
  138. if (attr == nullptr) {
  139. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  140. return;
  141. }
  142. attr->set_name(name);
  143. switch (type) {
  144. case onnx::AttributeProto_AttributeType_FLOAT:
  145. attr->set_f((*(static_cast<float *>(data))));
  146. attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
  147. break;
  148. case onnx::AttributeProto_AttributeType_FLOATS:
  149. attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
  150. for (auto &v : (*(static_cast<std::vector<float> *>(data)))) {
  151. attr->add_floats(v);
  152. }
  153. break;
  154. case onnx::AttributeProto_AttributeType_INT:
  155. attr->set_type(onnx::AttributeProto_AttributeType_INT);
  156. attr->set_i((*(static_cast<int64_t *>(data))));
  157. break;
  158. case onnx::AttributeProto_AttributeType_INTS:
  159. attr->set_type(onnx::AttributeProto_AttributeType_INTS);
  160. for (auto &v : *(static_cast<std::vector<int64_t> *>(data))) {
  161. attr->add_ints(v);
  162. }
  163. break;
  164. case onnx::AttributeProto_AttributeType_STRING:
  165. attr->set_type(onnx::AttributeProto_AttributeType_STRING);
  166. attr->set_s((*(static_cast<std::string *>(data))));
  167. break;
  168. case onnx::AttributeProto_AttributeType_STRINGS:
  169. attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
  170. for (auto &v : *(static_cast<std::vector<std::string> *>(data))) {
  171. attr->add_strings(v);
  172. }
  173. break;
  174. default:
  175. GELOGW("AttributeProto AttributeType: %u is not supported for now", type);
  176. break;
  177. }
  178. }
  179. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  180. ::google::protobuf::RepeatedField<::google::protobuf::int64> data) {
  181. if (node_proto == nullptr) {
  182. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  183. return;
  184. }
  185. if (!data.empty()) {
  186. auto attr = node_proto->add_attribute();
  187. if (attr == nullptr) {
  188. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  189. return;
  190. }
  191. attr->set_name(name);
  192. for (auto &v : data) {
  193. attr->add_ints(v);
  194. }
  195. attr->set_type(type);
  196. }
  197. }
  198. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  199. ::google::protobuf::RepeatedField<bool> data) {
  200. if (node_proto == nullptr) {
  201. GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
  202. return;
  203. }
  204. if (!data.empty()) {
  205. auto attr = node_proto->add_attribute();
  206. if (attr == nullptr) {
  207. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  208. return;
  209. }
  210. attr->set_name(name);
  211. for (auto &v : data) {
  212. attr->add_ints(static_cast<int64_t>(v));
  213. }
  214. attr->set_type(type);
  215. }
  216. }
  217. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  218. ::google::protobuf::RepeatedField<float> data) {
  219. if (node_proto == nullptr) {
  220. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  221. return;
  222. }
  223. if (!data.empty()) {
  224. auto attr = node_proto->add_attribute();
  225. if (attr == nullptr) {
  226. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  227. return;
  228. }
  229. attr->set_name(name);
  230. for (auto &v : data) {
  231. attr->add_floats(v);
  232. }
  233. attr->set_type(type);
  234. }
  235. }
  236. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  237. ::google::protobuf::RepeatedPtrField<::std::string> data) {
  238. if (node_proto == nullptr) {
  239. GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
  240. return;
  241. }
  242. if (!data.empty()) {
  243. auto attr = node_proto->add_attribute();
  244. if (attr == nullptr) {
  245. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  246. return;
  247. }
  248. attr->set_name(name);
  249. for (auto &v : data) {
  250. attr->add_strings(v);
  251. }
  252. attr->set_type(type);
  253. }
  254. }
  255. void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) {
  256. if (node_proto == nullptr || op_desc == nullptr) {
  257. GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr");
  258. return;
  259. }
  260. // Input describes
  261. auto size_in = op_desc->GetAllInputsSize();
  262. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in);
  263. if (size_in > 0) {
  264. for (uint32_t i = 0; i < size_in; i++) {
  265. auto input_desc = op_desc->GetInputDescPtrDfault(i);
  266. if (input_desc != nullptr) {
  267. auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType());
  268. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  269. "input_desc_dtype:" + std::to_string(i), &data_type);
  270. auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType());
  271. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  272. "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin);
  273. auto dims = input_desc->GetShape().GetDims();
  274. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  275. "input_desc_shape:" + std::to_string(i), &dims);
  276. auto dims_origin = input_desc->GetOriginShape().GetDims();
  277. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  278. "input_desc_origin_shape:" + std::to_string(i), &dims_origin);
  279. auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat());
  280. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  281. "input_desc_layout:" + std::to_string(i), &layout);
  282. auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat());
  283. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  284. "input_desc_origin_layout:" + std::to_string(i), &layout_origin);
  285. auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg();
  286. if (tensor_descriptor != nullptr) {
  287. auto size = tensor_descriptor->size();
  288. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  289. "input_desc_size:" + std::to_string(i), &size);
  290. auto weight_size = tensor_descriptor->weight_size();
  291. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  292. "input_desc_weight_size:" + std::to_string(i), &weight_size);
  293. auto reuse_input = tensor_descriptor->reuse_input();
  294. auto reuse_input_int = static_cast<int64_t>(reuse_input);
  295. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  296. "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int);
  297. auto output_tensor = tensor_descriptor->output_tensor();
  298. auto output_tensor_int = static_cast<int64_t>(output_tensor);
  299. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  300. "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int);
  301. auto device_type = tensor_descriptor->device_type();
  302. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  303. "input_desc_device_type:" + std::to_string(i), &device_type);
  304. auto input_tensor = tensor_descriptor->input_tensor();
  305. auto input_tensor_int = static_cast<int64_t>(input_tensor);
  306. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  307. "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int);
  308. auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
  309. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  310. "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
  311. auto data_offset = tensor_descriptor->data_offset();
  312. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  313. "input_desc_data_offset:" + std::to_string(i), &data_offset);
  314. auto cmps_size = tensor_descriptor->cmps_size();
  315. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i),
  316. &cmps_size);
  317. auto cmps_tab = tensor_descriptor->cmps_tab();
  318. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  319. "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab);
  320. auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset();
  321. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  322. "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset);
  323. const auto &tensor_desc_map = tensor_descriptor->attr();
  324. std::string suffix = ":" + std::to_string(i);
  325. AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix);
  326. } else {
  327. GELOGW("Tensor descriptor is nullptr");
  328. continue;
  329. }
  330. } else {
  331. GELOGW("Input desc is nullptr");
  332. continue;
  333. }
  334. }
  335. }
  336. // Output describes
  337. auto size_out = op_desc->GetOutputsSize();
  338. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out);
  339. if (size_out > 0) {
  340. for (uint32_t i = 0; i < size_out; i++) {
  341. auto output_desc = op_desc->GetOutputDescPtr(i);
  342. if (output_desc != nullptr) {
  343. auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType());
  344. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  345. "output_desc_dtype:" + std::to_string(i), &data_type);
  346. auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType());
  347. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  348. "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type);
  349. auto dims = output_desc->GetShape().GetDims();
  350. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  351. "output_desc_shape:" + std::to_string(i), &dims);
  352. auto dims_origin = output_desc->GetOriginShape().GetDims();
  353. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  354. "output_desc_origin_shape:" + std::to_string(i), &dims_origin);
  355. auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat());
  356. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i),
  357. &layout);
  358. auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat());
  359. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  360. "output_desc_origin_layout:" + std::to_string(i), &layout_origin);
  361. auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg();
  362. if (tensor_descriptor != nullptr) {
  363. auto size = tensor_descriptor->size();
  364. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i),
  365. &size);
  366. auto weight_size = tensor_descriptor->weight_size();
  367. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  368. "output_desc_weight_size:" + std::to_string(i), &weight_size);
  369. auto device_type = tensor_descriptor->device_type();
  370. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  371. "output_desc_device_type:" + std::to_string(i), &device_type);
  372. auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
  373. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  374. "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
  375. const auto &tensor_desc_map = tensor_descriptor->attr();
  376. std::string suffix = ":" + std::to_string(i);
  377. AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix);
  378. } else {
  379. GELOGW("Tensor descriptor is nullptr");
  380. continue;
  381. }
  382. } else {
  383. GELOGW("Output desc is nullptr");
  384. continue;
  385. }
  386. }
  387. }
  388. }
  389. void OnnxUtils::AddAttrProtoForAttrsFromAttrMap(
  390. const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto,
  391. const std::string& prefix, const std::string& suffix) {
  392. for (const auto &item : attr_map) {
  393. auto attr_name = item.first;
  394. auto attr_def = item.second;
  395. auto attr_type = attr_def.value_case();
  396. if (attr_type == ge::proto::AttrDef::kT) {
  397. const auto &tensor_def = attr_def.t();
  398. const auto &tensor_desc = tensor_def.desc();
  399. auto data_type = ge::proto::DataType_Name(tensor_desc.dtype());
  400. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  401. prefix + attr_name + "_desc_dtype" + suffix, &data_type);
  402. auto dims = tensor_desc.shape().dim();
  403. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  404. prefix + attr_name + "_desc_shape" + suffix, dims);
  405. auto layout = tensor_desc.layout();
  406. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  407. prefix + attr_name + "_desc_layout" + suffix, &layout);
  408. auto device_type = tensor_desc.device_type();
  409. AddAttrProto(node_proto, ge::onnx::AttributeProto_AttributeType_STRING,
  410. prefix + attr_name + "_desc_device_type" + suffix, &device_type);
  411. if (kDumpLevel == DUMP_ALL) {
  412. auto data = tensor_def.data();
  413. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  414. prefix + attr_name + "_data" + suffix, &data);
  415. }
  416. }
  417. if (attr_type == ge::proto::AttrDef::kS) {
  418. if (kDumpLevel == DUMP_ALL) {
  419. auto str_value = attr_def.s();
  420. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value);
  421. }
  422. }
  423. if (attr_type == ge::proto::AttrDef::kI) {
  424. auto int_value = attr_def.i();
  425. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
  426. }
  427. if (attr_type == ge::proto::AttrDef::kF) {
  428. auto float_value = attr_def.f();
  429. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value);
  430. }
  431. if (attr_type == ge::proto::AttrDef::kB) {
  432. auto int_value = static_cast<int64_t>(attr_def.b());
  433. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
  434. }
  435. if (attr_type == ge::proto::AttrDef::kList) {
  436. const auto &list_value = attr_def.list();
  437. auto list_value_type = list_value.val_type();
  438. if (list_value_type ==
  439. ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) {
  440. if (kDumpLevel == DUMP_ALL) {
  441. const auto &strings = list_value.s();
  442. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings);
  443. }
  444. }
  445. if (list_value_type ==
  446. ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) {
  447. const auto &floats = list_value.f();
  448. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats);
  449. }
  450. if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) {
  451. const auto &ints = list_value.i();
  452. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints);
  453. }
  454. if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) {
  455. const auto &bools = list_value.b();
  456. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools);
  457. }
  458. }
  459. }
  460. }
  461. void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) {
  462. if (node == nullptr) {
  463. GELOGE(GRAPH_FAILED, "node is nullptr");
  464. return;
  465. }
  466. // 1.Attributes added from node's methods
  467. auto send_list = node->send_event_id_list_;
  468. if (!send_list.empty()) {
  469. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list);
  470. }
  471. auto recv_list = node->recv_event_id_list_;
  472. if (!recv_list.empty()) {
  473. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list);
  474. }
  475. auto op_desc = node->op_;
  476. if (op_desc != nullptr) {
  477. // for input_name_idx_ in opdesc
  478. auto input_name_2_indexs = op_desc->GetAllInputName();
  479. ::google::protobuf::RepeatedPtrField<::std::string> input_names;
  480. ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes;
  481. for (const auto &input_name_2_index: input_name_2_indexs) {
  482. std::string input_name = input_name_2_index.first;
  483. input_names.Add(std::move(input_name));
  484. input_indexes.Add(input_name_2_index.second);
  485. }
  486. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names);
  487. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes);
  488. // 2.Attributes added from node's op_(message OpDef)
  489. // Input and out describes
  490. AddAttrProtoForOpInAndOutDesc(node_proto, op_desc);
  491. // Others
  492. auto op_def = op_desc->op_def_.GetProtoMsg();
  493. if (op_def != nullptr) {
  494. auto id = op_def->id();
  495. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id);
  496. auto stream_id = op_def->stream_id();
  497. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id);
  498. const auto &input_name = op_def->input_name();
  499. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name);
  500. const auto &src_name = op_def->src_name();
  501. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name);
  502. const auto &src_index = op_def->src_index();
  503. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index);
  504. const auto &dst_name = op_def->dst_name();
  505. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name);
  506. const auto &dst_index = op_def->dst_index();
  507. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index);
  508. const auto &input_i = op_def->input_i();
  509. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i);
  510. const auto &output_i = op_def->output_i();
  511. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i);
  512. const auto &workspace = op_def->workspace();
  513. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace);
  514. const auto &workspace_bytes = op_def->workspace_bytes();
  515. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes);
  516. const auto &is_input_const = op_def->is_input_const();
  517. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const);
  518. const auto &op_def_attr_map = op_def->attr();
  519. AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto);
  520. } else {
  521. GELOGE(FAILED, "Opdef is nullptr");
  522. return;
  523. }
  524. } else {
  525. GELOGE(FAILED, "Opdesc is nullptr");
  526. return;
  527. }
  528. }
  529. bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) {
  530. if ((node == nullptr) || (node_proto == nullptr)) {
  531. GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid");
  532. return false;
  533. }
  534. // 2.Encode map<string, GeAttrValue> attrs_ to AttributeProto
  535. for (auto &node_attr : node->attrs_) {
  536. AddAttrProtoFromAttribute(node_attr, node_proto);
  537. }
  538. // 3.Encode ge::Node members to AttributeProto
  539. AddAttrProtoFromNodeMembers(node, node_proto);
  540. // 4. Sort node attributes by name.
  541. std::sort(node_proto->mutable_attribute()->begin(), node_proto->mutable_attribute()->end(), AttrNameComp());
  542. return true;
  543. }
  544. void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) {
  545. if ((node == nullptr) || (node_proto == nullptr)) {
  546. GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid");
  547. return;
  548. }
  549. const auto &node_name = node->GetName();
  550. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  551. if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) {
  552. node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx()));
  553. }
  554. }
  555. auto out_control_anchor = node->GetOutControlAnchor();
  556. if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) {
  557. node_proto->add_output(node_name + kControlAnchorIndex);
  558. }
  559. }
  560. bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) {
  561. if ((node == nullptr) || (node_proto == nullptr)) {
  562. GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid");
  563. return false;
  564. }
  565. node_proto->clear_input();
  566. // 1. Add input by in data edge
  567. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  568. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  569. if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
  570. node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  571. std::to_string(peer_out_anchor->GetIdx()));
  572. } else {
  573. // Add "" input
  574. node_proto->add_input("");
  575. }
  576. }
  577. // 2. Add input by in control edge
  578. auto in_control_anchor = node->GetInControlAnchor();
  579. if (in_control_anchor != nullptr) {
  580. auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors();
  581. for (const auto &peer_out_anchor : peer_out_anchors) {
  582. if (peer_out_anchor->GetOwnerNode()) {
  583. node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex);
  584. }
  585. }
  586. } else {
  587. GELOGE(FAILED, "Incontrol anchor is nullptr");
  588. return false;
  589. }
  590. // 3. Add output for Netron visual support
  591. EncodeNodeLinkForNetronVisual(node, node_proto);
  592. return true;
  593. }
  594. bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) {
  595. if ((node == nullptr) || (node_proto == nullptr)) {
  596. GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid");
  597. return false;
  598. }
  599. // 1. Encode name and type
  600. node_proto->set_name(node->GetName());
  601. /// Netron believes that some operators, such as the activation operator of softplus, only have one input,
  602. /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix
  603. /// is added to correctly display the link relation at the expense of some color features
  604. node_proto->set_op_type("ge:" + node->GetType());
  605. if (kDumpLevel != DUMP_WITH_OUT_DESC) {
  606. // 2.for attr
  607. if (!EncodeNodeDesc(node, node_proto)) {
  608. GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str());
  609. return false;
  610. }
  611. }
  612. // 3.for link info
  613. return EncodeNodeLink(node, node_proto);
  614. }
  615. void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) {
  616. if ((node == nullptr) || (tensor_type == nullptr)) {
  617. GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid");
  618. return;
  619. }
  620. const auto &op_desc = node->GetOpDesc();
  621. if (op_desc != nullptr) {
  622. uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize());
  623. if (size_out > 0) {
  624. for (uint32_t i = 0; i < size_out; i++) {
  625. const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i);
  626. if (ge_tensor != nullptr) {
  627. auto ge_data_type = ge_tensor->GetDataType();
  628. auto onnx_data_type = EncodeDataType(ge_data_type);
  629. tensor_type->set_elem_type(onnx_data_type);
  630. onnx::TensorShapeProto *shape = tensor_type->mutable_shape();
  631. if (shape != nullptr) {
  632. for (auto d : ge_tensor->GetShape().GetDims()) {
  633. auto dim = shape->add_dim();
  634. dim->set_dim_value(d);
  635. }
  636. } else {
  637. GELOGW("Shape is nullptr");
  638. continue;
  639. }
  640. } else {
  641. GELOGW("Ge tensor is nullptr");
  642. continue;
  643. }
  644. }
  645. }
  646. } else {
  647. GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str());
  648. return;
  649. }
  650. }
  651. void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) {
  652. if ((node == nullptr) || (value_info_proto == nullptr)) {
  653. GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid");
  654. return;
  655. }
  656. value_info_proto->set_name(node->GetName());
  657. onnx::TypeProto *t = value_info_proto->mutable_type();
  658. onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type();
  659. EncodeTypeProtoTensorType(node, tensor_type);
  660. }
  661. bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) {
  662. if ((graph == nullptr) || (graph_proto == nullptr)) {
  663. GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid");
  664. return false;
  665. }
  666. graph_proto->set_name(graph->GetName());
  667. // 1. Add graph inputs
  668. for (const auto &input : graph->GetInputNodes()) {
  669. auto value_info_proto = graph_proto->add_input();
  670. EncodeValueInfo(input, value_info_proto);
  671. }
  672. // 2. Add graph outputs
  673. for (const auto &output : graph->GetOutputNodes()) {
  674. auto value_info_proto = graph_proto->add_output();
  675. EncodeValueInfo(output, value_info_proto);
  676. }
  677. // 3. Add nodes
  678. for (const auto &node : graph->GetDirectNode()) {
  679. if (!EncodeNode(node, graph_proto->add_node())) {
  680. GELOGW("EncodeNode failed");
  681. continue;
  682. }
  683. }
  684. return true;
  685. }
  686. bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) {
  687. model_proto.set_model_version(model.GetVersion());
  688. model_proto.set_ir_version(onnx::IR_VERSION);
  689. model_proto.set_producer_name(model.GetName());
  690. auto &graph = model.graph_;
  691. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  692. if (compute_graph == nullptr) {
  693. GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr");
  694. return false;
  695. }
  696. auto graph_proto = model_proto.mutable_graph();
  697. if (graph_proto == nullptr) {
  698. GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str());
  699. return false;
  700. }
  701. if (!EncodeGraph(compute_graph, graph_proto)) {
  702. GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str());
  703. return false;
  704. }
  705. // For subgraphs: a subgraph is represented by a node
  706. for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) {
  707. if (sub_compute_graph != nullptr) {
  708. auto node_proto = graph_proto->add_node();
  709. if (node_proto == nullptr) {
  710. GELOGW("Node proto is nullptr");
  711. continue;
  712. }
  713. node_proto->set_name(sub_compute_graph->GetName());
  714. node_proto->set_op_type(kNodeTypeForSubgraph);
  715. auto attr = node_proto->add_attribute();
  716. attr->set_name("graph");
  717. attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
  718. auto sub_graph_proto = attr->mutable_g();
  719. if (sub_graph_proto == nullptr) {
  720. GELOGW("Sub graph proto is nullptr");
  721. continue;
  722. }
  723. if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) {
  724. GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str());
  725. continue;
  726. }
  727. } else {
  728. GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str());
  729. continue;
  730. }
  731. }
  732. return true;
  733. }
  734. // Part 2: from ONNX Protobuf convert to IR
  735. static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = {
  736. {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64},
  737. {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32},
  738. {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8},
  739. {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16},
  740. {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16},
  741. {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL},
  742. };
  743. ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) {
  744. auto it = onnxDataTypeToGeMap.find(data_type);
  745. if (it != onnxDataTypeToGeMap.end()) {
  746. return it->second;
  747. } else {
  748. GELOGW("DecodeDataType: datatype not support %u", data_type);
  749. return ge::DT_UNDEFINED;
  750. }
  751. }
  752. bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) {
  753. auto sep = node_name_index.rfind(':');
  754. if (sep == std::string::npos) {
  755. return false;
  756. }
  757. node_name = node_name_index.substr(0, sep);
  758. auto index_str = node_name_index.substr(sep + 1);
  759. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  760. return true;
  761. }
  762. bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) {
  763. if (node_ptr == nullptr) {
  764. GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr");
  765. return false;
  766. }
  767. // Data edge
  768. if (item.src_out_index >= 0) {
  769. auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index);
  770. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  771. if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
  772. GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  773. item.dst_node_name.c_str(), item.dst_in_index);
  774. return false;
  775. }
  776. if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
  777. GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed");
  778. return false;
  779. }
  780. // Control edge
  781. } else {
  782. auto src_anchor = node_ptr->GetOutControlAnchor();
  783. auto dst_anchor = item.dst_node->GetInControlAnchor();
  784. if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
  785. GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  786. item.dst_node_name.c_str(), item.dst_in_index);
  787. return false;
  788. }
  789. if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
  790. GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed");
  791. return false;
  792. }
  793. }
  794. return true;
  795. }
  796. bool OnnxUtils::DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
  797. const std::map<std::string, NodePtr> &node_map) {
  798. for (const auto &node_proto : node_proto_vector) {
  799. const auto &node_name = node_proto.name();
  800. auto dst_node = node_map.find(node_name);
  801. if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) {
  802. GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str());
  803. return false;
  804. }
  805. int32_t dst_index = 0;
  806. for (const auto &input : node_proto.input()) {
  807. std::string input_node_name;
  808. int32_t index = 0;
  809. if (ParseNameIndex(input, input_node_name, index)) {
  810. auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()};
  811. auto src_node = node_map.find(input_node_name);
  812. if (src_node == node_map.end()) {
  813. GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str());
  814. return false;
  815. }
  816. auto node_ptr = src_node->second;
  817. if (node_ptr == nullptr) {
  818. GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str());
  819. return false;
  820. }
  821. if (!DecodeNodeLinkImp(item, node_ptr)) {
  822. GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str());
  823. return false;
  824. }
  825. }
  826. if (index >= 0) {
  827. dst_index++;
  828. }
  829. }
  830. }
  831. return true;
  832. }
  833. void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector<std::string> &strings) {
  834. if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRINGS) {
  835. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  836. return;
  837. }
  838. for (int i = 0; i < attr_proto.strings_size(); i++) {
  839. strings.push_back(attr_proto.strings(i));
  840. }
  841. }
  842. void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::string &value) {
  843. if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRING) {
  844. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  845. return;
  846. }
  847. value = attr_proto.s();
  848. }
  849. void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints) {
  850. if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INTS) {
  851. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  852. return;
  853. }
  854. for (int i = 0; i < attr_proto.ints_size(); i++) {
  855. ints.push_back(attr_proto.ints(i));
  856. }
  857. }
  858. void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, int64_t &value) {
  859. if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INT) {
  860. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  861. return;
  862. }
  863. value = attr_proto.i();
  864. }
  865. void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
  866. const std::string &attr_name_for_input_desc, int32_t index,
  867. OpDescPtr &op_desc) {
  868. if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
  869. GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr",
  870. op_desc->GetName().c_str(), attr_name_for_input_desc.c_str());
  871. return;
  872. }
  873. if (attr_name_for_input_desc == "input_desc_dtype") {
  874. auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
  875. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
  876. } else if (attr_name_for_input_desc == "input_desc_shape") {
  877. std::vector<std::int64_t> ints;
  878. DecodeAttribute(attr_proto, ints);
  879. GeShape ge_shape(ints);
  880. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
  881. } else if (attr_name_for_input_desc == "input_desc_layout") {
  882. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  883. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
  884. } else if (attr_name_for_input_desc == "input_desc_origin_shape") {
  885. std::vector<std::int64_t> ints;
  886. DecodeAttribute(attr_proto, ints);
  887. GeShape ge_shape(ints);
  888. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
  889. } else if (attr_name_for_input_desc == "input_desc_origin_layout") {
  890. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  891. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
  892. } else if (attr_name_for_input_desc == "input_desc_size") {
  893. int64_t input_size = 0;
  894. auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  895. DecodeAttribute(attr_proto, input_size);
  896. tensor_descriptor->set_size(input_size);
  897. } else if (attr_name_for_input_desc == "input_desc_data_offset") {
  898. auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  899. int64_t offset = 0;
  900. DecodeAttribute(attr_proto, offset);
  901. tensor_descriptor->set_data_offset(offset);
  902. } else {
  903. return;
  904. }
  905. }
  906. void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
  907. const std::string &attr_name_for_output_desc, int32_t index,
  908. OpDescPtr &op_desc) {
  909. if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) {
  910. GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr",
  911. op_desc->GetName().c_str(), attr_name_for_output_desc.c_str());
  912. return;
  913. }
  914. if (attr_name_for_output_desc == "output_desc_dtype") {
  915. auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
  916. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
  917. } else if (attr_name_for_output_desc == "output_desc_shape") {
  918. std::vector<std::int64_t> ints;
  919. DecodeAttribute(attr_proto, ints);
  920. GeShape ge_shape(ints);
  921. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
  922. } else if (attr_name_for_output_desc == "output_desc_layout") {
  923. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  924. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
  925. } else if (attr_name_for_output_desc == "output_desc_origin_shape") {
  926. std::vector<std::int64_t> ints;
  927. DecodeAttribute(attr_proto, ints);
  928. GeShape ge_shape(ints);
  929. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
  930. } else if (attr_name_for_output_desc == "output_desc_origin_layout") {
  931. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  932. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
  933. } else if (attr_name_for_output_desc == "output_desc_size") {
  934. int64_t output_size = 0;
  935. auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  936. DecodeAttribute(attr_proto, output_size);
  937. tensor_descriptor->set_size(output_size);
  938. } else if (attr_name_for_output_desc == "output_desc_data_offset") {
  939. auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  940. int64_t offset = 0;
  941. DecodeAttribute(attr_proto, offset);
  942. tensor_descriptor->set_data_offset(offset);
  943. } else {
  944. return;
  945. }
  946. }
  947. void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
  948. const std::string &attr_name_for_input_output_desc, int32_t index,
  949. OpDescPtr &op_desc) {
  950. if (op_desc == nullptr) {
  951. GELOGE(GRAPH_FAILED, "op_desc is nullptr");
  952. return;
  953. }
  954. if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") {
  955. DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  956. } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") {
  957. DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  958. } else {
  959. return;
  960. }
  961. }
  962. void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) {
  963. auto attr_map = op_def.mutable_attr();
  964. const auto &attr_name = attr_proto.name();
  965. ge::proto::AttrDef op_attr;
  966. int64_t value = 0;
  967. DecodeAttribute(attr_proto, value);
  968. op_attr.set_i(value);
  969. attr_map->insert(AttrDefPair(attr_name, op_attr));
  970. }
  971. void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) {
  972. if (op_desc == nullptr) {
  973. GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr");
  974. return;
  975. }
  976. const auto &attr_name = attr_proto.name();
  977. std::string attr_name_for_input_output_desc;
  978. int32_t index = 0;
  979. if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) {
  980. if (attr_name == "id") {
  981. op_desc->SetId(attr_proto.i());
  982. } else if (attr_name == "stream_id") {
  983. op_desc->SetStreamId(attr_proto.i());
  984. } else if (attr_name == "src_name") {
  985. std::vector<std::string> strings;
  986. DecodeAttribute(attr_proto, strings);
  987. op_desc->SetSrcName(strings);
  988. } else if (attr_name == "dst_name") {
  989. std::vector<std::string> strings;
  990. DecodeAttribute(attr_proto, strings);
  991. op_desc->SetDstName(strings);
  992. } else if (attr_name == "src_index") {
  993. std::vector<std::int64_t> ints;
  994. DecodeAttribute(attr_proto, ints);
  995. op_desc->SetSrcIndex(ints);
  996. } else if (attr_name == "dst_index") {
  997. std::vector<std::int64_t> ints;
  998. DecodeAttribute(attr_proto, ints);
  999. op_desc->SetDstIndex(ints);
  1000. } else if (attr_name == "fusion_scope") {
  1001. DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg());
  1002. } else if (attr_name == "input_i") {
  1003. std::vector<std::int64_t> ints;
  1004. DecodeAttribute(attr_proto, ints);
  1005. op_desc->SetInputOffset(ints);
  1006. } else if (attr_name == "output_i") {
  1007. std::vector<std::int64_t> ints;
  1008. DecodeAttribute(attr_proto, ints);
  1009. op_desc->SetOutputOffset(ints);
  1010. } else {
  1011. return;
  1012. }
  1013. // Update input and output desc
  1014. } else {
  1015. DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  1016. }
  1017. }
  1018. bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) {
  1019. if (op_desc == nullptr || node_proto == nullptr) {
  1020. GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr");
  1021. return false;
  1022. }
  1023. // 1. Decode node_proto name and type
  1024. op_desc->SetName(node_proto->name());
  1025. const auto &node_type_with_ge_prefix = node_proto->op_type();
  1026. auto sep = node_type_with_ge_prefix.find(':');
  1027. if (sep == std::string::npos) {
  1028. return false;
  1029. }
  1030. auto node_type = node_type_with_ge_prefix.substr(sep + 1);
  1031. op_desc->SetType(node_type);
  1032. // 2. Add empty input and output desc
  1033. for (const auto &attr : node_proto->attribute()) {
  1034. if (attr.name() == "input_desc_nums") {
  1035. auto size_in = attr.i();
  1036. for (int64_t i = 0; i < size_in; i++) {
  1037. GeTensorDesc ge_tensor_desc;
  1038. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed.");
  1039. }
  1040. }
  1041. if (attr.name() == "output_desc_nums") {
  1042. auto size_out = attr.i();
  1043. for (int64_t i = 0; i < size_out; i++) {
  1044. GeTensorDesc ge_tensor_desc;
  1045. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed.");
  1046. }
  1047. }
  1048. }
  1049. // 3.Decode node_proto attributes
  1050. for (int i = 0; i < node_proto->attribute_size(); i++) {
  1051. DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc);
  1052. }
  1053. return true;
  1054. }
  1055. bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) {
  1056. if (recursion_depth > kMaxRecursionDepth) {
  1057. GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort");
  1058. return false;
  1059. }
  1060. graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name());
  1061. GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed");
  1062. /// 1. Decode all nodes first, node should include input
  1063. /// and output nodes and nodes which represent sub graphs
  1064. std::map<std::string, NodePtr> node_map;
  1065. std::vector<onnx::NodeProto> node_proto_vector;
  1066. for (const auto &node_proto : graph_proto.node()) {
  1067. // a. nodes represent sub graphs
  1068. if (node_proto.op_type() == kNodeTypeForSubgraph) {
  1069. ComputeGraphPtr compute_graph;
  1070. // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH
  1071. const auto &node_attr = node_proto.attribute(0);
  1072. if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) &&
  1073. DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) {
  1074. (void)graph->AddSubGraph(compute_graph);
  1075. } else {
  1076. GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(),
  1077. node_attr.type());
  1078. return false;
  1079. }
  1080. // b. direct nodes in graph
  1081. } else {
  1082. node_proto_vector.push_back(node_proto);
  1083. OpDescPtr op_desc = ComGraphMakeShared<OpDesc>();
  1084. // b.1 For node desc
  1085. if (!DecodeNodeDesc(&node_proto, op_desc)) {
  1086. GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str());
  1087. return false;
  1088. }
  1089. auto node = graph->AddNode(op_desc);
  1090. node_map.insert(std::make_pair(node_proto.name(), node));
  1091. }
  1092. }
  1093. /// We get all nodes in graph here
  1094. /// b.2 For node link
  1095. if (!DecodeNodeLink(node_proto_vector, node_map)) {
  1096. GELOGE(GRAPH_FAILED, "Decode node link failed");
  1097. return false;
  1098. }
  1099. // 2. Add inputs nodes for graph
  1100. for (const auto &input : graph_proto.input()) {
  1101. const auto &input_node_name = input.name();
  1102. auto input_node_item = node_map.find(input_node_name);
  1103. if (input_node_item == node_map.end()) {
  1104. GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str());
  1105. return false;
  1106. }
  1107. auto ret = graph->AddInputNode(input_node_item->second);
  1108. GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed");
  1109. }
  1110. // 3. Add outputs nodes for graph
  1111. for (const auto &output : graph_proto.output()) {
  1112. const auto &output_node_name = output.name();
  1113. auto output_node_item = node_map.find(output_node_name);
  1114. if (output_node_item == node_map.end()) {
  1115. GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str());
  1116. return false;
  1117. }
  1118. auto ret = graph->AddOutputNode(output_node_item->second);
  1119. if (ret == nullptr) {
  1120. GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str());
  1121. continue;
  1122. }
  1123. }
  1124. return true;
  1125. }
  1126. bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) {
  1127. model.name_ = model_proto.producer_name();
  1128. model.version_ = static_cast<uint32_t>(model_proto.model_version());
  1129. auto &graph_proto = model_proto.graph();
  1130. ComputeGraphPtr compute_graph;
  1131. // 0 means recursion depth, father call
  1132. if (!DecodeGraph(0, graph_proto, compute_graph)) {
  1133. GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed");
  1134. return false;
  1135. }
  1136. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  1137. return true;
  1138. }
  1139. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示