You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_serialize.cc 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model_serialize.h"
  17. #include <google/protobuf/text_format.h>
  18. #include <queue>
  19. #include <iostream>
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "graph/detail/model_serialize_imp.h"
  25. #include "proto/ge_ir.pb.h"
  26. #include "utils/graph_utils.h"
  27. #include "debug/ge_op_types.h"
  28. using std::map;
  29. using std::string;
  30. namespace ge {
  31. bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) {
  32. auto sep = node_index.rfind(":");
  33. if (sep == string::npos) {
  34. GELOGW("separator is not found in node_index.");
  35. return false;
  36. }
  37. node_name = node_index.substr(0, sep);
  38. auto index_str = node_index.substr(sep + 1);
  39. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  40. return true;
  41. }
  42. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor,
  43. proto::TensorDef *tensor_proto) {
  44. GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null.");
  45. GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null.");
  46. if (tensor->tensor_def_.GetProtoMsg() != nullptr) {
  47. *tensor_proto = *tensor->tensor_def_.GetProtoMsg();
  48. return true;
  49. }
  50. return false;
  51. }
  52. bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) {
  53. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null.");
  54. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  55. op_def_proto->clear_input();
  56. // Inputs
  57. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  58. if (in_data_anchor != nullptr) {
  59. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  60. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  61. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  62. std::to_string(peer_out_anchor->GetIdx()));
  63. } else {
  64. op_def_proto->add_input("");
  65. }
  66. }
  67. }
  68. // Control edge
  69. auto control_anchor = node->GetInControlAnchor();
  70. if (control_anchor != nullptr) {
  71. auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors();
  72. for (const auto &peer_out_anchor : peer_out_anchors) {
  73. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  74. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1");
  75. }
  76. }
  77. }
  78. return true;
  79. }
  80. bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
  81. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
  82. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  83. if (op_desc->op_def_.GetProtoMsg() != nullptr) {
  84. *op_def_proto = *op_desc->op_def_.GetProtoMsg();
  85. // Delete unnecessary attr
  86. if (is_dump) {
  87. auto attr = op_def_proto->mutable_attr();
  88. attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF);
  89. attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF);
  90. attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF);
  91. GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP),
  92. attr->erase(ATTR_NAME_WEIGHTS));
  93. }
  94. op_def_proto->clear_input_desc();
  95. op_def_proto->clear_output_desc();
  96. // Input descs
  97. if (op_desc->GetAllInputsSize() > 0) {
  98. auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
  99. for (uint32_t i = 0; i < size; i++) {
  100. auto tensor_desc = op_desc->GetInputDescPtrDfault(i);
  101. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  102. *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  103. }
  104. }
  105. }
  106. // Output descs
  107. if (op_desc->GetOutputsSize() > 0) {
  108. auto size = static_cast<uint32_t>(op_desc->GetOutputsSize());
  109. for (uint32_t i = 0; i < size; i++) {
  110. auto tensor_desc = op_desc->GetOutputDescPtr(i);
  111. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  112. *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  113. }
  114. }
  115. }
  116. op_def_proto->set_id(op_desc->GetId());
  117. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  118. op_def_proto->add_subgraph_name(name);
  119. }
  120. if (!op_desc->output_name_idx_.empty()) {
  121. proto::AttrDef key;
  122. proto::AttrDef value;
  123. for (auto &item : op_desc->output_name_idx_) {
  124. key.mutable_list()->add_s(item.first);
  125. value.mutable_list()->add_i(item.second);
  126. }
  127. auto op_desc_attr = op_def_proto->mutable_attr();
  128. op_desc_attr->insert({"_output_name_key", key});
  129. op_desc_attr->insert({"_output_name_value", value});
  130. }
  131. }
  132. return true;
  133. }
  134. bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
  135. if (node == nullptr || op_def_proto == nullptr) {
  136. GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
  137. return false;
  138. }
  139. if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) {
  140. GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
  141. return false;
  142. }
  143. if (SerializeEdge(node, op_def_proto)) {
  144. return true;
  145. } else {
  146. return false;
  147. }
  148. }
  149. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
  150. proto::GraphDef *graph_proto,
  151. bool is_dump) {
  152. if (graph == nullptr || graph_proto == nullptr) {
  153. GELOGE(GRAPH_FAILED, "Input para Invalid");
  154. return false;
  155. }
  156. graph_proto->set_name(graph->GetName());
  157. // Inputs
  158. for (const auto &input : graph->GetInputNodes()) {
  159. if (input != nullptr) {
  160. graph_proto->add_input(input->GetName() + ":0");
  161. }
  162. }
  163. // Outputs
  164. for (const auto &output : graph->GetOutputNodes()) {
  165. if (output != nullptr) {
  166. graph_proto->add_output(output->GetName() + ":0");
  167. }
  168. }
  169. if (graph->attrs_.GetProtoMsg() != nullptr) {
  170. *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
  171. }
  172. for (const auto &node : graph->GetDirectNode()) {
  173. if (!SerializeNode(node, graph_proto->add_op(), is_dump)) {
  174. if (node->GetOpDesc() != nullptr) {
  175. GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
  176. }
  177. return false;
  178. }
  179. }
  180. return true;
  181. }
  182. bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) {
  183. if (model_proto == nullptr) {
  184. GELOGE(GRAPH_FAILED, "model_proto para Invalid");
  185. return false;
  186. }
  187. model_proto->set_name(model.GetName());
  188. model_proto->set_custom_version(model.GetPlatformVersion());
  189. model_proto->set_version(model.GetVersion());
  190. if (model.attrs_.GetProtoMsg()) {
  191. *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg();
  192. }
  193. auto &graph = model.graph_;
  194. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  195. if (compute_graph == nullptr) {
  196. GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
  197. return false;
  198. }
  199. if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) {
  200. GELOGE(GRAPH_FAILED, "SerializeGraph fail");
  201. return false;
  202. }
  203. for (auto subgraph : compute_graph->GetAllSubgraphs()) {
  204. if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) {
  205. GELOGE(GRAPH_FAILED, "Serialize subgraph failed");
  206. return false;
  207. }
  208. }
  209. return true;
  210. }
  211. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor(
  212. GeTensorPtr &tensor, proto::TensorDef &tensor_proto) {
  213. tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto));
  214. if (tensor == nullptr) {
  215. GELOGE(GRAPH_FAILED, "tensor is nullptr");
  216. return false;
  217. } else {
  218. return true;
  219. }
  220. }
  221. bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
  222. std::vector<string> key;
  223. std::vector<uint32_t> value;
  224. if (op_def_proto.attr().count("_output_name_key") > 0) {
  225. auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list();
  226. for (const auto &item_s : output_name_key_list.s()) {
  227. key.push_back(item_s);
  228. }
  229. auto op_desc_attr = op_def_proto.mutable_attr();
  230. op_desc_attr->erase("_output_name_key");
  231. }
  232. if (op_def_proto.attr().count("_output_name_value") > 0) {
  233. auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list();
  234. for (const auto &item_i : output_name_value_list.i()) {
  235. value.push_back(static_cast<uint32_t>(item_i));
  236. }
  237. auto op_desc_attr = op_def_proto.mutable_attr();
  238. op_desc_attr->erase("_output_name_value");
  239. }
  240. op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto));
  241. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr.");
  242. // Input tensor
  243. for (auto &input_desc : *op_def_proto.mutable_input_desc()) {
  244. std::shared_ptr<GeTensorDesc> temp_value =
  245. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc));
  246. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  247. op_desc->inputs_desc_.push_back(temp_value);
  248. }
  249. // Output tensor
  250. for (auto &output_desc : *op_def_proto.mutable_output_desc()) {
  251. std::shared_ptr<GeTensorDesc> temp_value =
  252. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc));
  253. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  254. op_desc->outputs_desc_.push_back(temp_value);
  255. }
  256. op_desc->SetId(op_def_proto.id());
  257. uint32_t graph_index = 0;
  258. for (const std::string &name : op_def_proto.subgraph_name()) {
  259. op_desc->AddSubgraphName(name);
  260. op_desc->SetSubgraphInstanceName(graph_index++, name);
  261. }
  262. if (key.size() != 0) {
  263. if (key.size() != value.size()) {
  264. GELOGE(GRAPH_FAILED, "twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size());
  265. } else {
  266. for (uint32_t i = 0; i < key.size(); ++i) {
  267. op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key.at(i), value.at(i)));
  268. }
  269. }
  270. }
  271. return true;
  272. }
  273. bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) {
  274. GE_RT_FALSE_CHECK_NOTNULL(graph);
  275. OpDescPtr op_desc = nullptr;
  276. if (!UnserializeOpDesc(op_desc, op_def_proto)) {
  277. GELOGW("UnserializeOpDesc error.");
  278. }
  279. NodePtr node = graph->AddNode(op_desc, op_desc->GetId());
  280. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr.");
  281. // Inputs
  282. int dst_index = 0;
  283. for (const auto &input : op_def_proto.input()) {
  284. string node_name;
  285. int32_t index = 0;
  286. if (ParseNodeIndex(input, node_name, index)) {
  287. node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()});
  288. }
  289. if (index >= 0) {
  290. dst_index++;
  291. }
  292. }
  293. node_map_[op_def_proto.name()] = node;
  294. return true;
  295. }
  296. bool ModelSerializeImp::HandleNodeNameRef() {
  297. // Edges
  298. for (auto &item : node_input_node_names_) {
  299. auto src_node_it = node_map_.find(item.src_node_name);
  300. if (src_node_it == node_map_.end()) {
  301. GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str());
  302. return false;
  303. }
  304. GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue);
  305. if (item.src_out_index >= 0) {
  306. auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index);
  307. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  308. if (src_anchor == nullptr || dst_anchor == nullptr) {
  309. GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  310. item.dst_node_name.c_str(), item.dst_in_index);
  311. return false;
  312. }
  313. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
  314. } else {
  315. // Control edge
  316. auto src_anchor = src_node_it->second->GetOutControlAnchor();
  317. auto dst_anchor = item.dst_node->GetInControlAnchor();
  318. if (src_anchor != nullptr && dst_anchor != nullptr) {
  319. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
  320. }
  321. }
  322. }
  323. // Graph input
  324. for (auto &item : graph_input_node_names_) {
  325. auto node_it = node_map_.find(item.node_name);
  326. if (node_it == node_map_.end()) {
  327. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  328. return false;
  329. }
  330. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  331. auto ret = item.graph->AddInputNode(node_it->second);
  332. if (ret == nullptr) {
  333. return false;
  334. }
  335. }
  336. // Graph output
  337. for (auto &item : graph_output_node_names_) {
  338. auto node_it = node_map_.find(item.node_name);
  339. if (node_it == node_map_.end()) {
  340. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  341. return false;
  342. }
  343. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  344. auto ret = item.graph->AddOutputNode(node_it->second);
  345. if (ret == nullptr) {
  346. GELOGE(GRAPH_FAILED, "AddOutputNode failed.");
  347. return false;
  348. }
  349. }
  350. node_input_node_names_.clear();
  351. graph_input_node_names_.clear();
  352. graph_output_node_names_.clear();
  353. node_map_.clear();
  354. return true;
  355. }
  356. bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) {
  357. std::queue<ComputeGraphPtr> all_graphs;
  358. all_graphs.emplace(compute_graph);
  359. while (!all_graphs.empty()) {
  360. ComputeGraphPtr graph = all_graphs.front();
  361. all_graphs.pop();
  362. for (const NodePtr &node : graph->GetDirectNode()) {
  363. const OpDescPtr op_desc = node->GetOpDesc();
  364. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  365. auto it = subgraphs.find(name);
  366. if (it == subgraphs.end()) {
  367. GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(),
  368. subgraphs.size());
  369. return false;
  370. }
  371. ComputeGraphPtr &subgraph = it->second;
  372. subgraph->SetParentGraph(graph);
  373. subgraph->SetParentNode(node);
  374. compute_graph->AddSubgraph(subgraph->GetName(), subgraph);
  375. all_graphs.emplace(subgraph);
  376. }
  377. }
  378. }
  379. return true;
  380. }
  381. bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) {
  382. model.name_ = model_proto.name();
  383. model.version_ = model_proto.version();
  384. model.platform_version_ = model_proto.custom_version();
  385. model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr());
  386. auto &graphs_proto = *model_proto.mutable_graph();
  387. if (!graphs_proto.empty()) {
  388. auto &graph_proto = graphs_proto[0];
  389. ComputeGraphPtr compute_graph_ptr;
  390. if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) {
  391. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
  392. }
  393. // 0 is main graph, following is subgraph.
  394. map<string, ComputeGraphPtr> subgraphs;
  395. for (int idx = 1; idx < graphs_proto.size(); ++idx) {
  396. ComputeGraphPtr subgraph;
  397. ModelSerializeImp impl;
  398. if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) {
  399. GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed");
  400. return false;
  401. }
  402. if (!impl.HandleNodeNameRef()) {
  403. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  404. return false;
  405. }
  406. subgraphs[subgraph->GetName()] = subgraph;
  407. }
  408. if (!RebuildOwnership(compute_graph_ptr, subgraphs)) {
  409. GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed");
  410. return false;
  411. }
  412. }
  413. if (!HandleNodeNameRef()) {
  414. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  415. return false;
  416. }
  417. return true;
  418. }
  419. bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) {
  420. graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name());
  421. if (graph == nullptr) {
  422. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  423. return false;
  424. }
  425. // Inputs
  426. for (auto input : graph_proto.input()) {
  427. string node_name;
  428. int32_t index;
  429. if (ParseNodeIndex(input, node_name, index)) {
  430. graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  431. }
  432. }
  433. // Outputs
  434. for (auto output : graph_proto.output()) {
  435. string node_name;
  436. int32_t index;
  437. if (ParseNodeIndex(output, node_name, index)) {
  438. graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  439. }
  440. }
  441. graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr());
  442. for (auto &op_def_proto : *graph_proto.mutable_op()) {
  443. if (!UnserializeNode(graph, op_def_proto)) {
  444. GELOGE(GRAPH_FAILED, "UnserializeNode fail");
  445. return false;
  446. }
  447. }
  448. return true;
  449. }
  450. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph,
  451. proto::GraphDef &graph_proto) {
  452. if (!UnserializeGraphWithoutEdge(graph, graph_proto)) {
  453. GELOGW("UnserializeGraphWithoutEdge fail");
  454. }
  455. if (!HandleNodeNameRef()) {
  456. GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail");
  457. return false;
  458. }
  459. return true;
  460. }
  461. bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) {
  462. GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null.");
  463. GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null.");
  464. google::protobuf::io::CodedInputStream coded_stream(data, len);
  465. // 2048M -1
  466. coded_stream.SetTotalBytesLimit(INT32_MAX, -1);
  467. if (!proto->ParseFromCodedStream(&coded_stream)) {
  468. GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len);
  469. return false;
  470. }
  471. return true;
  472. }
  473. Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) {
  474. proto::ModelDef model_def;
  475. ModelSerializeImp imp;
  476. if (!imp.SerializeModel(model, &model_def, is_dump)) {
  477. return Buffer();
  478. }
  479. #if !defined(__ANDROID__) && !defined(ANDROID)
  480. Buffer buffer(model_def.ByteSizeLong());
  481. #else
  482. Buffer buffer(model_def.ByteSize());
  483. #endif
  484. GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed");
  485. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  486. auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  487. if (ret != true) {
  488. GELOGW("serialize to array fail.");
  489. }
  490. return buffer;
  491. }
  492. size_t ModelSerialize::GetSerializeModelSize(const Model &model) {
  493. proto::ModelDef model_def;
  494. ModelSerializeImp imp;
  495. if (!imp.SerializeModel(model, &model_def)) {
  496. return 0;
  497. }
  498. #if !defined(__ANDROID__) && !defined(ANDROID)
  499. return model_def.ByteSizeLong();
  500. #else
  501. return model_def.ByteSize();
  502. #endif
  503. }
  504. Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) {
  505. if (data == nullptr) {
  506. GELOGE(GRAPH_FAILED, "data is nullptr");
  507. return Model();
  508. }
  509. std::shared_ptr<proto::ModelDef> model_proto_ptr;
  510. model_proto_ptr = ComGraphMakeShared<proto::ModelDef>();
  511. if (model_proto_ptr == nullptr) {
  512. GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
  513. return Model();
  514. }
  515. auto &model_proto = *model_proto_ptr;
  516. if (!ReadProtoFromBinaryFile(data, len, &model_proto)) {
  517. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  518. return Model();
  519. }
  520. Model model;
  521. ModelSerializeImp imp;
  522. imp.SetProtobufOwner(model_proto_ptr);
  523. if (!imp.UnserializeModel(model, model_proto)) {
  524. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  525. return Model();
  526. }
  527. return model;
  528. }
  529. Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) {
  530. std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def);
  531. GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed");
  532. ModelSerializeImp imp;
  533. imp.SetProtobufOwner(model_def_ptr);
  534. Model model;
  535. if (!imp.UnserializeModel(model, *model_def_ptr)) {
  536. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  537. return Model();
  538. }
  539. return model;
  540. }
  541. Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) {
  542. proto::GraphDef graph_def;
  543. ModelSerializeImp imp;
  544. if (!imp.SerializeGraph(graph, &graph_def)) {
  545. return Buffer();
  546. }
  547. #if !defined(__ANDROID__) && !defined(ANDROID)
  548. Buffer buffer(graph_def.ByteSizeLong());
  549. #else
  550. Buffer buffer(graph_def.ByteSize());
  551. #endif
  552. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  553. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  554. auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  555. if (ret != true) {
  556. GE_LOGE("serialize to array fail.");
  557. }
  558. return buffer;
  559. }
  560. ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) {
  561. if (data == nullptr) {
  562. GELOGE(GRAPH_FAILED, "data is nullptr");
  563. return nullptr;
  564. }
  565. std::shared_ptr<proto::GraphDef> graph_proto_ptr;
  566. graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>();
  567. if (graph_proto_ptr == nullptr) {
  568. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  569. return nullptr;
  570. }
  571. proto::GraphDef &graph_proto = *graph_proto_ptr;
  572. if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) {
  573. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  574. return nullptr;
  575. }
  576. ComputeGraphPtr graph;
  577. ModelSerializeImp imp;
  578. imp.SetProtobufOwner(graph_proto_ptr);
  579. if (!imp.UnserializeGraph(graph, graph_proto)) {
  580. return nullptr;
  581. }
  582. return graph;
  583. }
  584. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) {
  585. proto::OpDef op_def;
  586. ModelSerializeImp imp;
  587. if (!imp.SerializeOpDesc(op_desc, &op_def)) {
  588. return Buffer();
  589. }
  590. #if !defined(__ANDROID__) && !defined(ANDROID)
  591. Buffer buffer(op_def.ByteSizeLong());
  592. #else
  593. Buffer buffer(op_def.ByteSize());
  594. #endif
  595. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  596. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  597. auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  598. if (ret != true) {
  599. GE_LOGE("serialize to array fail.");
  600. }
  601. return buffer;
  602. }
  603. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data,
  604. size_t len) {
  605. if (data == nullptr) {
  606. GELOGE(GRAPH_FAILED, "data is nullptr");
  607. return nullptr;
  608. }
  609. std::shared_ptr<proto::OpDef> op_def_ptr;
  610. op_def_ptr = ComGraphMakeShared<proto::OpDef>();
  611. if (op_def_ptr == nullptr) {
  612. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  613. return nullptr;
  614. }
  615. proto::OpDef &op_def = *op_def_ptr;
  616. if (!ReadProtoFromBinaryFile(data, len, &op_def)) {
  617. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  618. return nullptr;
  619. }
  620. OpDescPtr op_desc;
  621. ModelSerializeImp imp;
  622. imp.SetProtobufOwner(op_def_ptr);
  623. if (!imp.UnserializeOpDesc(op_desc, op_def)) {
  624. GELOGW("UnserializeOpDesc error.");
  625. }
  626. return op_desc;
  627. }
  628. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.