You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_serialize.cc 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model_serialize.h"
  17. #include <google/protobuf/text_format.h>
  18. #include <queue>
  19. #include <iostream>
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "graph/detail/model_serialize_imp.h"
  25. #include "proto/ge_ir.pb.h"
  26. #include "utils/graph_utils.h"
  27. #include "debug/ge_op_types.h"
  28. using std::map;
  29. using std::string;
  30. namespace ge {
  31. bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) {
  32. auto sep = node_index.rfind(":");
  33. if (sep == string::npos) {
  34. GELOGW("separator is not found in node_index.");
  35. return false;
  36. }
  37. node_name = node_index.substr(0, sep);
  38. auto index_str = node_index.substr(sep + 1);
  39. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  40. return true;
  41. }
  42. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor,
  43. proto::TensorDef *tensor_proto) {
  44. GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null.");
  45. GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null.");
  46. if (tensor->tensor_def_.GetProtoMsg() != nullptr) {
  47. *tensor_proto = *tensor->tensor_def_.GetProtoMsg();
  48. return true;
  49. }
  50. return false;
  51. }
  52. bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) {
  53. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null.");
  54. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  55. op_def_proto->clear_input();
  56. // Inputs
  57. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  58. if (in_data_anchor != nullptr) {
  59. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  60. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  61. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  62. std::to_string(peer_out_anchor->GetIdx()));
  63. } else {
  64. op_def_proto->add_input("");
  65. }
  66. }
  67. }
  68. // Control edge
  69. auto control_anchor = node->GetInControlAnchor();
  70. if (control_anchor != nullptr) {
  71. auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors();
  72. for (const auto &peer_out_anchor : peer_out_anchors) {
  73. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  74. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1");
  75. }
  76. }
  77. }
  78. return true;
  79. }
  80. bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
  81. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
  82. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  83. if (op_desc->op_def_.GetProtoMsg() != nullptr) {
  84. *op_def_proto = *op_desc->op_def_.GetProtoMsg();
  85. // Delete unnecessary attr
  86. if (is_dump) {
  87. auto attr = op_def_proto->mutable_attr();
  88. attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF);
  89. attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF);
  90. attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF);
  91. GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP),
  92. attr->erase(ATTR_NAME_WEIGHTS));
  93. }
  94. op_def_proto->clear_input_desc();
  95. op_def_proto->clear_output_desc();
  96. // Input descs
  97. if (op_desc->GetAllInputsSize() > 0) {
  98. auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
  99. for (uint32_t i = 0; i < size; i++) {
  100. auto tensor_desc = op_desc->GetInputDescPtrDfault(i);
  101. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  102. *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  103. }
  104. }
  105. }
  106. // Output descs
  107. if (op_desc->GetOutputsSize() > 0) {
  108. auto size = static_cast<uint32_t>(op_desc->GetOutputsSize());
  109. for (uint32_t i = 0; i < size; i++) {
  110. auto tensor_desc = op_desc->GetOutputDescPtr(i);
  111. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  112. *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  113. }
  114. }
  115. }
  116. op_def_proto->set_id(op_desc->GetId());
  117. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  118. op_def_proto->add_subgraph_name(name);
  119. }
  120. OpDescToAttrDef(op_desc, op_def_proto);
  121. }
  122. return true;
  123. }
  124. void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
  125. proto::AttrDef key_in;
  126. proto::AttrDef value_in;
  127. auto op_desc_attr = op_def_proto->mutable_attr();
  128. if (!op_desc->input_name_idx_.empty()) {
  129. for (auto &item : op_desc->input_name_idx_) {
  130. key_in.mutable_list()->add_s(item.first);
  131. value_in.mutable_list()->add_i(item.second);
  132. }
  133. op_desc_attr->insert({"_input_name_key", key_in});
  134. op_desc_attr->insert({"_input_name_value", value_in});
  135. }
  136. proto::AttrDef key_out;
  137. proto::AttrDef value_out;
  138. if (!op_desc->output_name_idx_.empty()) {
  139. for (auto &item : op_desc->output_name_idx_) {
  140. key_out.mutable_list()->add_s(item.first);
  141. value_out.mutable_list()->add_i(item.second);
  142. }
  143. op_desc_attr->insert({"_output_name_key", key_out});
  144. op_desc_attr->insert({"_output_name_value", value_out});
  145. }
  146. proto::AttrDef opt_input;
  147. if (!op_desc->optional_input_names_.empty()) {
  148. for (auto &item : op_desc->optional_input_names_) {
  149. opt_input.mutable_list()->add_s(item);
  150. }
  151. op_desc_attr->insert({"_opt_input", opt_input});
  152. }
  153. }
  154. bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
  155. if (node == nullptr || op_def_proto == nullptr) {
  156. GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
  157. return false;
  158. }
  159. if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) {
  160. GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
  161. return false;
  162. }
  163. if (SerializeEdge(node, op_def_proto)) {
  164. return true;
  165. } else {
  166. return false;
  167. }
  168. }
  169. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
  170. proto::GraphDef *graph_proto,
  171. bool is_dump) {
  172. if (graph == nullptr || graph_proto == nullptr) {
  173. GELOGE(GRAPH_FAILED, "Input para Invalid");
  174. return false;
  175. }
  176. graph_proto->set_name(graph->GetName());
  177. // Inputs
  178. for (const auto &input : graph->GetInputNodes()) {
  179. if (input != nullptr) {
  180. graph_proto->add_input(input->GetName() + ":0");
  181. }
  182. }
  183. // Outputs
  184. for (const auto &output : graph->GetGraphOutNodesInfo()) {
  185. if (output.first != nullptr) {
  186. graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second));
  187. GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second);
  188. }
  189. }
  190. if (graph->attrs_.GetProtoMsg() != nullptr) {
  191. *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
  192. }
  193. for (const auto &node : graph->GetDirectNode()) {
  194. if (!SerializeNode(node, graph_proto->add_op(), is_dump)) {
  195. if (node->GetOpDesc() != nullptr) {
  196. GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
  197. }
  198. return false;
  199. }
  200. }
  201. return true;
  202. }
  203. bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) {
  204. if (model_proto == nullptr) {
  205. GELOGE(GRAPH_FAILED, "model_proto para Invalid");
  206. return false;
  207. }
  208. model_proto->set_name(model.GetName());
  209. model_proto->set_custom_version(model.GetPlatformVersion());
  210. model_proto->set_version(model.GetVersion());
  211. if (model.attrs_.GetProtoMsg()) {
  212. *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg();
  213. }
  214. auto &graph = model.graph_;
  215. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  216. if (compute_graph == nullptr) {
  217. GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
  218. return false;
  219. }
  220. if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) {
  221. GELOGE(GRAPH_FAILED, "SerializeGraph fail");
  222. return false;
  223. }
  224. for (auto subgraph : compute_graph->GetAllSubgraphs()) {
  225. if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) {
  226. GELOGE(GRAPH_FAILED, "Serialize subgraph failed");
  227. return false;
  228. }
  229. }
  230. return true;
  231. }
  232. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor(
  233. GeTensorPtr &tensor, proto::TensorDef &tensor_proto) {
  234. tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto));
  235. if (tensor == nullptr) {
  236. GELOGE(GRAPH_FAILED, "tensor is nullptr");
  237. return false;
  238. } else {
  239. return true;
  240. }
  241. }
  242. void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out,
  243. std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out,
  244. std::vector<string> &opt_input) {
  245. if (!key_in.empty()) {
  246. if (key_in.size() != value_in.size()) {
  247. GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
  248. value_in.size());
  249. } else {
  250. for (uint32_t i = 0; i < key_in.size(); ++i) {
  251. op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i)));
  252. }
  253. }
  254. }
  255. if (!key_out.empty()) {
  256. if (key_out.size() != value_out.size()) {
  257. GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
  258. value_out.size());
  259. } else {
  260. for (uint32_t i = 0; i < key_out.size(); ++i) {
  261. op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i)));
  262. }
  263. }
  264. }
  265. if (!opt_input.empty()) {
  266. for (const auto &i : opt_input) {
  267. op_desc->optional_input_names_.insert(i);
  268. }
  269. }
  270. }
  271. bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
  272. std::vector<string> opt_input;
  273. std::vector<string> key_in;
  274. std::vector<uint32_t> value_in;
  275. if (op_def_proto.attr().count("_opt_input") > 0) {
  276. auto &name_list = op_def_proto.attr().at("_opt_input").list();
  277. for (const auto &item_s : name_list.s()) {
  278. opt_input.push_back(item_s);
  279. }
  280. auto op_desc_attr = op_def_proto.mutable_attr();
  281. op_desc_attr->erase("_opt_input");
  282. }
  283. if (op_def_proto.attr().count("_input_name_key") > 0) {
  284. auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list();
  285. for (const auto &item_s : output_name_key_list.s()) {
  286. key_in.push_back(item_s);
  287. }
  288. auto op_desc_attr = op_def_proto.mutable_attr();
  289. op_desc_attr->erase("_input_name_key");
  290. }
  291. if (op_def_proto.attr().count("_input_name_value") > 0) {
  292. auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list();
  293. for (const auto &item_i : input_name_value_list.i()) {
  294. value_in.push_back(static_cast<uint32_t>(item_i));
  295. }
  296. auto op_desc_attr = op_def_proto.mutable_attr();
  297. op_desc_attr->erase("_input_name_value");
  298. }
  299. std::vector<string> key_out;
  300. std::vector<uint32_t> value_out;
  301. if (op_def_proto.attr().count("_output_name_key") > 0) {
  302. auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list();
  303. for (const auto &item_s : output_name_key_list.s()) {
  304. key_out.push_back(item_s);
  305. }
  306. auto op_desc_attr = op_def_proto.mutable_attr();
  307. op_desc_attr->erase("_output_name_key");
  308. }
  309. if (op_def_proto.attr().count("_output_name_value") > 0) {
  310. auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list();
  311. for (const auto &item_i : output_name_value_list.i()) {
  312. value_out.push_back(static_cast<uint32_t>(item_i));
  313. }
  314. auto op_desc_attr = op_def_proto.mutable_attr();
  315. op_desc_attr->erase("_output_name_value");
  316. }
  317. op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto));
  318. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr.");
  319. // Input tensor
  320. for (auto &input_desc : *op_def_proto.mutable_input_desc()) {
  321. std::shared_ptr<GeTensorDesc> temp_value =
  322. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc));
  323. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  324. op_desc->inputs_desc_.push_back(temp_value);
  325. }
  326. // Output tensor
  327. for (auto &output_desc : *op_def_proto.mutable_output_desc()) {
  328. std::shared_ptr<GeTensorDesc> temp_value =
  329. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc));
  330. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  331. op_desc->outputs_desc_.push_back(temp_value);
  332. }
  333. op_desc->SetId(op_def_proto.id());
  334. uint32_t graph_index = 0;
  335. for (const std::string &name : op_def_proto.subgraph_name()) {
  336. op_desc->AddSubgraphName(name);
  337. op_desc->SetSubgraphInstanceName(graph_index++, name);
  338. }
  339. // insert name index by key and value
  340. AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input);
  341. return true;
  342. }
  343. bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) {
  344. GE_RT_FALSE_CHECK_NOTNULL(graph);
  345. OpDescPtr op_desc = nullptr;
  346. if (!UnserializeOpDesc(op_desc, op_def_proto)) {
  347. GELOGW("UnserializeOpDesc error.");
  348. }
  349. NodePtr node = graph->AddNode(op_desc, op_desc->GetId());
  350. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr.");
  351. // Inputs
  352. int dst_index = 0;
  353. for (const auto &input : op_def_proto.input()) {
  354. string node_name;
  355. int32_t index = 0;
  356. if (ParseNodeIndex(input, node_name, index)) {
  357. node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()});
  358. }
  359. if (index >= 0) {
  360. dst_index++;
  361. }
  362. }
  363. node_map_[op_def_proto.name()] = node;
  364. return true;
  365. }
  366. bool ModelSerializeImp::HandleNodeNameRef() {
  367. // Edges
  368. for (auto &item : node_input_node_names_) {
  369. auto src_node_it = node_map_.find(item.src_node_name);
  370. if (src_node_it == node_map_.end()) {
  371. GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str());
  372. return false;
  373. }
  374. GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue);
  375. if (item.src_out_index >= 0) {
  376. auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index);
  377. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  378. if (src_anchor == nullptr || dst_anchor == nullptr) {
  379. GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  380. item.dst_node_name.c_str(), item.dst_in_index);
  381. return false;
  382. }
  383. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
  384. } else {
  385. // Control edge
  386. auto src_anchor = src_node_it->second->GetOutControlAnchor();
  387. auto dst_anchor = item.dst_node->GetInControlAnchor();
  388. if (src_anchor != nullptr && dst_anchor != nullptr) {
  389. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
  390. }
  391. }
  392. }
  393. // Graph input
  394. for (auto &item : graph_input_node_names_) {
  395. auto node_it = node_map_.find(item.node_name);
  396. if (node_it == node_map_.end()) {
  397. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  398. return false;
  399. }
  400. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  401. auto ret = item.graph->AddInputNode(node_it->second);
  402. if (ret == nullptr) {
  403. return false;
  404. }
  405. }
  406. // Graph output
  407. for (auto &item : graph_output_node_names_) {
  408. auto node_it = node_map_.find(item.node_name);
  409. if (node_it == node_map_.end()) {
  410. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  411. return false;
  412. }
  413. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  414. auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index);
  415. GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index);
  416. if (ret == nullptr) {
  417. GELOGE(GRAPH_FAILED, "AddOutputNode failed.");
  418. return false;
  419. }
  420. }
  421. node_input_node_names_.clear();
  422. graph_input_node_names_.clear();
  423. graph_output_node_names_.clear();
  424. node_map_.clear();
  425. return true;
  426. }
  427. bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) {
  428. std::queue<ComputeGraphPtr> all_graphs;
  429. all_graphs.emplace(compute_graph);
  430. while (!all_graphs.empty()) {
  431. ComputeGraphPtr graph = all_graphs.front();
  432. all_graphs.pop();
  433. for (const NodePtr &node : graph->GetDirectNode()) {
  434. const OpDescPtr op_desc = node->GetOpDesc();
  435. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  436. auto it = subgraphs.find(name);
  437. if (it == subgraphs.end()) {
  438. GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(),
  439. subgraphs.size());
  440. return false;
  441. }
  442. ComputeGraphPtr &subgraph = it->second;
  443. subgraph->SetParentGraph(graph);
  444. subgraph->SetParentNode(node);
  445. compute_graph->AddSubgraph(subgraph->GetName(), subgraph);
  446. all_graphs.emplace(subgraph);
  447. }
  448. }
  449. }
  450. return true;
  451. }
  452. bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) {
  453. model.name_ = model_proto.name();
  454. model.version_ = model_proto.version();
  455. model.platform_version_ = model_proto.custom_version();
  456. model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr());
  457. auto &graphs_proto = *model_proto.mutable_graph();
  458. if (!graphs_proto.empty()) {
  459. auto &graph_proto = graphs_proto[0];
  460. ComputeGraphPtr compute_graph_ptr;
  461. if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) {
  462. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
  463. }
  464. // 0 is main graph, following is subgraph.
  465. map<string, ComputeGraphPtr> subgraphs;
  466. for (int idx = 1; idx < graphs_proto.size(); ++idx) {
  467. ComputeGraphPtr subgraph;
  468. ModelSerializeImp impl;
  469. if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) {
  470. GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed");
  471. return false;
  472. }
  473. if (!impl.HandleNodeNameRef()) {
  474. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  475. return false;
  476. }
  477. subgraphs[subgraph->GetName()] = subgraph;
  478. }
  479. if (!RebuildOwnership(compute_graph_ptr, subgraphs)) {
  480. GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed");
  481. return false;
  482. }
  483. }
  484. if (!HandleNodeNameRef()) {
  485. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  486. return false;
  487. }
  488. return true;
  489. }
  490. bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) {
  491. graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name());
  492. if (graph == nullptr) {
  493. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  494. return false;
  495. }
  496. // Inputs
  497. for (auto input : graph_proto.input()) {
  498. string node_name;
  499. int32_t index;
  500. if (ParseNodeIndex(input, node_name, index)) {
  501. graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  502. }
  503. }
  504. // Outputs
  505. for (auto output : graph_proto.output()) {
  506. string node_name;
  507. int32_t index;
  508. if (ParseNodeIndex(output, node_name, index)) {
  509. graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  510. }
  511. }
  512. graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr());
  513. for (auto &op_def_proto : *graph_proto.mutable_op()) {
  514. if (!UnserializeNode(graph, op_def_proto)) {
  515. GELOGE(GRAPH_FAILED, "UnserializeNode fail");
  516. return false;
  517. }
  518. }
  519. return true;
  520. }
  521. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph,
  522. proto::GraphDef &graph_proto) {
  523. if (!UnserializeGraphWithoutEdge(graph, graph_proto)) {
  524. GELOGW("UnserializeGraphWithoutEdge fail");
  525. }
  526. if (!HandleNodeNameRef()) {
  527. GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail");
  528. return false;
  529. }
  530. return true;
  531. }
  532. bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) {
  533. GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null.");
  534. GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null.");
  535. google::protobuf::io::CodedInputStream coded_stream(data, len);
  536. // 2048M -1
  537. coded_stream.SetTotalBytesLimit(INT32_MAX, -1);
  538. if (!proto->ParseFromCodedStream(&coded_stream)) {
  539. GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len);
  540. return false;
  541. }
  542. return true;
  543. }
  544. Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) {
  545. proto::ModelDef model_def;
  546. ModelSerializeImp imp;
  547. if (!imp.SerializeModel(model, &model_def, is_dump)) {
  548. return Buffer();
  549. }
  550. #if !defined(__ANDROID__) && !defined(ANDROID)
  551. Buffer buffer(model_def.ByteSizeLong());
  552. #else
  553. Buffer buffer(model_def.ByteSize());
  554. #endif
  555. GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed");
  556. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  557. auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  558. if (ret != true) {
  559. GELOGW("serialize to array fail.");
  560. }
  561. return buffer;
  562. }
  563. size_t ModelSerialize::GetSerializeModelSize(const Model &model) {
  564. proto::ModelDef model_def;
  565. ModelSerializeImp imp;
  566. if (!imp.SerializeModel(model, &model_def)) {
  567. return 0;
  568. }
  569. #if !defined(__ANDROID__) && !defined(ANDROID)
  570. return model_def.ByteSizeLong();
  571. #else
  572. return model_def.ByteSize();
  573. #endif
  574. }
  575. Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) {
  576. if (data == nullptr) {
  577. GELOGE(GRAPH_FAILED, "data is nullptr");
  578. return Model();
  579. }
  580. std::shared_ptr<proto::ModelDef> model_proto_ptr;
  581. model_proto_ptr = ComGraphMakeShared<proto::ModelDef>();
  582. if (model_proto_ptr == nullptr) {
  583. GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
  584. return Model();
  585. }
  586. auto &model_proto = *model_proto_ptr;
  587. if (!ReadProtoFromBinaryFile(data, len, &model_proto)) {
  588. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  589. return Model();
  590. }
  591. Model model;
  592. ModelSerializeImp imp;
  593. imp.SetProtobufOwner(model_proto_ptr);
  594. if (!imp.UnserializeModel(model, model_proto)) {
  595. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  596. return Model();
  597. }
  598. return model;
  599. }
  600. Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) {
  601. std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def);
  602. GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed");
  603. ModelSerializeImp imp;
  604. imp.SetProtobufOwner(model_def_ptr);
  605. Model model;
  606. if (!imp.UnserializeModel(model, *model_def_ptr)) {
  607. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  608. return Model();
  609. }
  610. return model;
  611. }
  612. Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) {
  613. proto::GraphDef graph_def;
  614. ModelSerializeImp imp;
  615. if (!imp.SerializeGraph(graph, &graph_def)) {
  616. return Buffer();
  617. }
  618. #if !defined(__ANDROID__) && !defined(ANDROID)
  619. Buffer buffer(graph_def.ByteSizeLong());
  620. #else
  621. Buffer buffer(graph_def.ByteSize());
  622. #endif
  623. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  624. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  625. auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  626. if (ret != true) {
  627. GE_LOGE("serialize to array fail.");
  628. }
  629. return buffer;
  630. }
  631. ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) {
  632. if (data == nullptr) {
  633. GELOGE(GRAPH_FAILED, "data is nullptr");
  634. return nullptr;
  635. }
  636. std::shared_ptr<proto::GraphDef> graph_proto_ptr;
  637. graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>();
  638. if (graph_proto_ptr == nullptr) {
  639. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  640. return nullptr;
  641. }
  642. proto::GraphDef &graph_proto = *graph_proto_ptr;
  643. if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) {
  644. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  645. return nullptr;
  646. }
  647. ComputeGraphPtr graph;
  648. ModelSerializeImp imp;
  649. imp.SetProtobufOwner(graph_proto_ptr);
  650. if (!imp.UnserializeGraph(graph, graph_proto)) {
  651. return nullptr;
  652. }
  653. return graph;
  654. }
  655. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) {
  656. proto::OpDef op_def;
  657. ModelSerializeImp imp;
  658. if (!imp.SerializeOpDesc(op_desc, &op_def)) {
  659. return Buffer();
  660. }
  661. #if !defined(__ANDROID__) && !defined(ANDROID)
  662. Buffer buffer(op_def.ByteSizeLong());
  663. #else
  664. Buffer buffer(op_def.ByteSize());
  665. #endif
  666. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  667. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  668. auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  669. if (ret != true) {
  670. GE_LOGE("serialize to array fail.");
  671. }
  672. return buffer;
  673. }
  674. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data,
  675. size_t len) {
  676. if (data == nullptr) {
  677. GELOGE(GRAPH_FAILED, "data is nullptr");
  678. return nullptr;
  679. }
  680. std::shared_ptr<proto::OpDef> op_def_ptr;
  681. op_def_ptr = ComGraphMakeShared<proto::OpDef>();
  682. if (op_def_ptr == nullptr) {
  683. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  684. return nullptr;
  685. }
  686. proto::OpDef &op_def = *op_def_ptr;
  687. if (!ReadProtoFromBinaryFile(data, len, &op_def)) {
  688. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  689. return nullptr;
  690. }
  691. OpDescPtr op_desc;
  692. ModelSerializeImp imp;
  693. imp.SetProtobufOwner(op_def_ptr);
  694. if (!imp.UnserializeOpDesc(op_desc, op_def)) {
  695. GELOGW("UnserializeOpDesc error.");
  696. }
  697. return op_desc;
  698. }
  699. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示