You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tuning_utils.cc 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/tuning_utils.h"
  17. #include "../debug/ge_util.h"
  18. #include "../debug/ge_op_types.h"
  19. namespace ge {
  20. const std::string peer_node_name_attr = "_peerNodeName";
  21. const std::string parent_node_name_attr = "_parentNodeName";
  22. const std::string alias_name_attr = "_aliasName";
  23. const std::string parent_node_attr = "parentNode";
  24. const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex";
  25. const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
  26. const std::string non_tuning_subgraph_prefix = "/subgraph_";
  27. const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
  28. const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
  29. NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
  30. NodetoNodeNameMap TuningUtils::data_node_2_netoutput_;
  31. NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
  32. NodeSet TuningUtils::netoutput_nodes_;
  33. NodeSet TuningUtils::merged_graph_nodes_;
  34. SubgraphCreateOutNode TuningUtils::create_output_;
  35. std::mutex TuningUtils::mutex_;
  36. std::string TuningUtils::PrintCheckLog() {
  37. std::stringstream ss;
  38. ss << "d2n:{";
  39. for (const auto &pair : data_2_netoutput_) {
  40. ss << "data:" << pair.first << "-"
  41. << "netoutput:" << pair.second;
  42. ss << " | ";
  43. }
  44. ss << "}";
  45. ss << "netoutputs:{";
  46. for (const auto &node : netoutput_nodes_) {
  47. ss << "netoutput:" << node->GetName();
  48. ss << " | ";
  49. }
  50. ss << "}";
  51. return ss.str();
  52. }
  53. std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) {
  54. if (anchor == nullptr) {
  55. GELOGE(GRAPH_FAILED, "Anchor is nullptr");
  56. return "Null";
  57. }
  58. auto node = anchor->GetOwnerNode();
  59. return node == nullptr ? "Null" : node->GetName();
  60. }
  61. // part 1
  62. graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
  63. std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag,
  64. const std::string &path, const std::string &user_path) {
  65. int64_t i = 0;
  66. int64_t j = 0;
  67. std::lock_guard<std::mutex> lock(mutex_);
  68. for (auto &subgraph : tuning_subgraphs) {
  69. create_output_.emplace(subgraph, nullptr);
  70. auto help_info = HelpInfo{i, exe_flag, true, path, user_path};
  71. if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
  72. GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i);
  73. return GRAPH_FAILED;
  74. }
  75. i++;
  76. }
  77. for (auto &subgraph : non_tuning_subgraphs) {
  78. create_output_.emplace(subgraph, nullptr);
  79. auto help_info = HelpInfo{j, true, false, path, user_path};
  80. if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
  81. GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j);
  82. return GRAPH_FAILED;
  83. }
  84. j++;
  85. }
  86. create_output_.clear();
  87. return SUCCESS;
  88. }
  89. // +---------------+
  90. // | pld pld |
  91. // | \ / |
  92. // | relu relu |
  93. // | \ / |
  94. // | add |
  95. // | | |
  96. // | end |
  97. // +---------------+
  98. // |
  99. // |
  100. // V
  101. // +---------------+
  102. // | data data |
  103. // | \ / |
  104. // | relu relu |
  105. // | \ / |
  106. // | add |
  107. // | | |
  108. // | netoutput |
  109. // +---------------+
  110. graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) {
  111. GE_CHECK_NOTNULL(exe_graph);
  112. // if not make exe, just dump and return
  113. if (!help_info.exe_flag) {
  114. DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
  115. GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index);
  116. return SUCCESS;
  117. }
  118. // modify sub graph
  119. for (NodePtr &node : exe_graph->GetDirectNode()) {
  120. // 1.handle pld
  121. if (node->GetType() == PLACEHOLDER) {
  122. if (HandlePld(node) != SUCCESS) {
  123. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
  124. exe_graph->GetName().c_str());
  125. return FAILED;
  126. }
  127. }
  128. // 2.handle end
  129. if (node->GetType() == END) {
  130. if (HandleEnd(node) != SUCCESS) {
  131. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
  132. exe_graph->GetName().c_str());
  133. return FAILED;
  134. }
  135. }
  136. }
  137. graphStatus ret = exe_graph->TopologicalSorting();
  138. if (ret != SUCCESS) {
  139. GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret);
  140. return ret;
  141. }
  142. // dump subgraphs which modified by us
  143. if (help_info.user_path.empty()) {
  144. DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
  145. } else {
  146. GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path);
  147. }
  148. return SUCCESS;
  149. }
  150. void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) {
  151. if (!path.empty()) {
  152. if (is_tuning_graph) {
  153. GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
  154. } else {
  155. GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
  156. }
  157. } else {
  158. path = "./";
  159. if (is_tuning_graph) {
  160. GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
  161. } else {
  162. GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
  163. }
  164. }
  165. }
  166. graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) {
  167. auto graph = node->GetOwnerComputeGraph();
  168. GE_CHECK_NOTNULL(graph);
  169. auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA);
  170. GE_CHECK_NOTNULL(data_op_desc);
  171. auto pld_op_desc = node->GetOpDesc();
  172. GE_CHECK_NOTNULL(pld_op_desc);
  173. auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data
  174. // data inputdesc & outputdesc set as same
  175. if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) {
  176. GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
  177. return FAILED;
  178. }
  179. if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) {
  180. GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
  181. return FAILED;
  182. }
  183. data_node = graph->AddNode(data_op_desc);
  184. GE_CHECK_NOTNULL(data_node);
  185. if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
  186. GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
  187. return FAILED;
  188. }
  189. return SUCCESS;
  190. }
  191. graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) {
  192. auto op_desc = data_node->GetOpDesc();
  193. GE_CHECK_NOTNULL(op_desc);
  194. auto pld_desc = pld->GetOpDesc();
  195. GE_CHECK_NOTNULL(pld_desc);
  196. // inherit
  197. // a. set `end's input node type` as attr
  198. std::string parent_op_type;
  199. if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) {
  200. GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str());
  201. return FAILED;
  202. }
  203. (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type);
  204. // b. set `end's input node name` as attr
  205. std::string parent_op_name;
  206. if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) {
  207. GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str());
  208. return FAILED;
  209. }
  210. (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name);
  211. // c. set `end's input node's out anchor index` as attr
  212. int parent_node_anchor_index;
  213. if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) {
  214. GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str());
  215. return FAILED;
  216. }
  217. (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index);
  218. GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
  219. data_node->GetName().c_str(), data_node->GetType().c_str());
  220. // d. set `end node name` as attr
  221. std::string peer_end_name;
  222. if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) {
  223. GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str());
  224. return FAILED;
  225. }
  226. (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name);
  227. GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
  228. data_node->GetName().c_str(), data_node->GetType().c_str());
  229. return SUCCESS;
  230. }
  231. graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) {
  232. auto type_pld = node->GetType();
  233. auto type_data = data_node->GetType();
  234. if (type_pld != PLACEHOLDER || type_data != DATA) {
  235. GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(),
  236. type_data.c_str());
  237. return FAILED;
  238. }
  239. auto graph = node->GetOwnerComputeGraph();
  240. GE_CHECK_NOTNULL(graph);
  241. std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
  242. for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  243. output_map[i] = static_cast<int>(i);
  244. }
  245. auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map);
  246. if (ret != GRAPH_SUCCESS) {
  247. GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(),
  248. data_node->GetName().c_str(), ret);
  249. return FAILED;
  250. }
  251. NodeUtils::UnlinkAll(*node);
  252. ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
  253. if (ret != GRAPH_SUCCESS) {
  254. GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
  255. return FAILED;
  256. }
  257. GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(),
  258. node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
  259. return ret;
  260. }
  261. graphStatus TuningUtils::HandlePld(NodePtr &node) {
  262. GE_CHECK_NOTNULL(node);
  263. auto graph = node->GetOwnerComputeGraph();
  264. GE_CHECK_NOTNULL(graph);
  265. NodePtr data_node = nullptr;
  266. // 1. create data node
  267. if (CreateDataNode(node, data_node) != SUCCESS) {
  268. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  269. return FAILED;
  270. }
  271. // 2. add necessary info to data_node for recovery whole graph
  272. if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) {
  273. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  274. return FAILED;
  275. }
  276. // 3. replace pld node by data node created before
  277. if (ChangePld2Data(node, data_node) != SUCCESS) {
  278. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  279. return FAILED;
  280. }
  281. GELOGD("TUU:pld[%s] handle success", node->GetName().c_str());
  282. return SUCCESS;
  283. }
  284. graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) {
  285. GE_CHECK_NOTNULL(node);
  286. auto graph = node->GetOwnerComputeGraph();
  287. GE_CHECK_NOTNULL(graph);
  288. auto search = create_output_.find(graph);
  289. if (search == create_output_.end()) {
  290. GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(),
  291. graph->GetName().c_str());
  292. return FAILED;
  293. }
  294. if (search->second != nullptr) {
  295. out_node = search->second;
  296. GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str());
  297. return SUCCESS;
  298. }
  299. auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT);
  300. GE_CHECK_NOTNULL(out_op_desc);
  301. out_node = graph->AddNode(out_op_desc);
  302. GE_CHECK_NOTNULL(out_node);
  303. if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
  304. GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
  305. return FAILED;
  306. }
  307. create_output_[graph] = out_node;
  308. return SUCCESS;
  309. }
  310. graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) {
  311. GE_CHECK_NOTNULL(end);
  312. GE_CHECK_NOTNULL(out_node);
  313. auto op_desc = out_node->GetOpDesc();
  314. GE_CHECK_NOTNULL(op_desc);
  315. std::vector<std::string> alias_names = {};
  316. (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names);
  317. alias_names.push_back(end->GetName());
  318. (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names);
  319. return SUCCESS;
  320. }
  321. graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
  322. GE_CHECK_NOTNULL(end_node);
  323. GE_CHECK_NOTNULL(out_node);
  324. // get end in node is control node or normal node
  325. AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
  326. ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
  327. : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
  328. auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
  329. if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
  330. GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  331. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  332. GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(),
  333. end_node->GetOwnerComputeGraph()->GetName().c_str());
  334. return FAILED;
  335. }
  336. // add edge between `end in node` and `out_node`
  337. if (src_anchor->IsTypeOf<OutDataAnchor>()) {
  338. std::shared_ptr<InDataAnchor> anchor =
  339. ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size());
  340. GE_CHECK_NOTNULL(anchor);
  341. out_node->in_data_anchors_.push_back(anchor);
  342. if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
  343. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  344. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  345. GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
  346. end_node->GetOwnerComputeGraph()->GetName().c_str());
  347. return FAILED;
  348. }
  349. auto end_op_desc = end_node->GetOpDesc();
  350. GE_CHECK_NOTNULL(end_op_desc);
  351. auto out_node_op_desc = out_node->GetOpDesc();
  352. GE_CHECK_NOTNULL(out_node_op_desc);
  353. // end node always has one input
  354. if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) {
  355. GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str());
  356. return FAILED;
  357. }
  358. } else if (src_anchor->IsTypeOf<OutControlAnchor>()) {
  359. auto anchor = out_node->GetInControlAnchor();
  360. if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
  361. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  362. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  363. GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
  364. end_node->GetOwnerComputeGraph()->GetName().c_str());
  365. return FAILED;
  366. }
  367. } else {
  368. GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(),
  369. end_node->GetOwnerComputeGraph()->GetName().c_str());
  370. return FAILED;
  371. }
  372. return SUCCESS;
  373. }
  374. graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
  375. GE_CHECK_NOTNULL(end_node);
  376. GE_CHECK_NOTNULL(out_node);
  377. auto type_end = end_node->GetType();
  378. auto type_out = out_node->GetType();
  379. if (type_end != END || type_out != NETOUTPUT) {
  380. GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(),
  381. type_end.c_str(), type_out.c_str());
  382. return FAILED;
  383. }
  384. // link all `end nodes's in node` to this out_node
  385. if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) {
  386. GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str());
  387. return FAILED;
  388. }
  389. // remove `end node`
  390. NodeUtils::UnlinkAll(*end_node);
  391. auto graph = end_node->GetOwnerComputeGraph();
  392. GE_CHECK_NOTNULL(graph);
  393. if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) {
  394. GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str());
  395. return FAILED;
  396. }
  397. return SUCCESS;
  398. }
  399. graphStatus TuningUtils::HandleEnd(NodePtr &node) {
  400. GE_CHECK_NOTNULL(node);
  401. auto graph = node->GetOwnerComputeGraph();
  402. GE_CHECK_NOTNULL(graph);
  403. NodePtr out_node = nullptr;
  404. // 1. create net_output node , add only one NetOutput node to one subgraph
  405. if (CreateNetOutput(node, out_node) != SUCCESS) {
  406. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  407. return FAILED;
  408. }
  409. // 2. add necessary info to out_node for recovery whole graph
  410. if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) {
  411. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  412. return FAILED;
  413. }
  414. // 3. replace all end nodes by one output node created before
  415. if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) {
  416. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
  417. return FAILED;
  418. }
  419. GELOGD("TUU:end[%s] handle success", node->GetName().c_str());
  420. return SUCCESS;
  421. }
  422. // part 2
  423. graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
  424. // 1. get all subgraph object
  425. std::vector<ComputeGraphPtr> graphs;
  426. // options format like {index:"subgraph_path"}
  427. for (const auto &pair : options) {
  428. ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first));
  429. if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) {
  430. GELOGE(FAILED, "TUU:load graph from file failed");
  431. }
  432. graphs.push_back(compute_graph);
  433. }
  434. // 2. merge graph
  435. ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune");
  436. GE_CHECK_NOTNULL(merged_graph);
  437. if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) {
  438. GELOGE(FAILED, "TUU:MergeGraph failed");
  439. return FAILED;
  440. }
  441. // 3. set parent graph
  442. for (const auto &node : merged_graph->GetDirectNode()) {
  443. GE_CHECK_NOTNULL(node);
  444. if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) {
  445. GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str());
  446. return FAILED;
  447. }
  448. }
  449. graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph);
  450. return SUCCESS;
  451. }
  452. // +----------------------------------+
  453. // | const const |
  454. // | \ / |
  455. // | netoutput(end,end) |
  456. // +----------------------------------+
  457. // +
  458. // +----------------------------------+
  459. // | data(pld) data(pld) |
  460. // | \ / |
  461. // | relu relu |
  462. // | \ / |
  463. // | \ / |
  464. // | add |
  465. // | | |
  466. // | netoutput(end) |
  467. // +----------------------------------+
  468. // +
  469. // +----------------------------------+
  470. // | data(pld) |
  471. // | / |
  472. // | netoutput |
  473. // +----------------------------------+
  474. // |
  475. // |
  476. // V
  477. // +----------------------------------+
  478. // | const const |
  479. // | \ / |
  480. // | relu relu |
  481. // | \ / |
  482. // | \ / |
  483. // | add |
  484. // | | |
  485. // | netoutput |
  486. // +----------------------------------+
  487. graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs,
  488. ComputeGraphPtr &output_merged_compute_graph) {
  489. GE_CHECK_NOTNULL(output_merged_compute_graph);
  490. // 1. handle all subgraphs
  491. for (auto &subgraph : subgraphs) {
  492. Status ret_status = MergeSubGraph(subgraph);
  493. if (ret_status != SUCCESS) {
  494. GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str());
  495. return ret_status;
  496. }
  497. }
  498. for (const auto &node : merged_graph_nodes_) {
  499. (void)output_merged_compute_graph->AddNode(node);
  500. GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str());
  501. }
  502. // 2. remove data and output node added by us
  503. if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) {
  504. GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str());
  505. return FAILED;
  506. }
  507. graphStatus ret = output_merged_compute_graph->TopologicalSorting();
  508. if (ret != SUCCESS) {
  509. GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret);
  510. return ret;
  511. }
  512. GELOGD("TUU:Print-%s", PrintCheckLog().c_str());
  513. GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str());
  514. return SUCCESS;
  515. }
  516. graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) {
  517. for (auto &node : subgraph->GetDirectNode()) {
  518. if (kPartitionOpTypes.count(node->GetType()) > 0) {
  519. GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type");
  520. return FAILED;
  521. }
  522. // handle data converted from pld node
  523. if (node->GetType() == DATA) {
  524. auto op_desc = node->GetOpDesc();
  525. GE_CHECK_NOTNULL(op_desc);
  526. std::string peer_out_name;
  527. bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty());
  528. if (has_valid_str) {
  529. std::lock_guard<std::mutex> lock(mutex_);
  530. data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name);
  531. data_node_2_netoutput_.emplace(node, peer_out_name);
  532. continue;
  533. }
  534. }
  535. // handle netoutput converted from end node
  536. if (node->GetType() == NETOUTPUT) {
  537. auto op_desc = node->GetOpDesc();
  538. GE_CHECK_NOTNULL(op_desc);
  539. std::vector<string> out_alias_name;
  540. bool has_valid_str =
  541. (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty());
  542. if (has_valid_str) {
  543. std::lock_guard<std::mutex> lock(mutex_);
  544. netoutput_nodes_.insert(node);
  545. }
  546. }
  547. {
  548. std::lock_guard<std::mutex> lock(mutex_);
  549. merged_graph_nodes_.emplace(node);
  550. }
  551. GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str());
  552. }
  553. GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str());
  554. return SUCCESS;
  555. }
  556. graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) {
  557. GE_CHECK_NOTNULL(graph);
  558. // 1. traverse
  559. for (auto &pair : data_node_2_netoutput_) {
  560. auto data_node = pair.first;
  561. GE_CHECK_NOTNULL(data_node);
  562. auto netoutput_name = pair.second;
  563. auto netoutput_node = graph->FindNode(netoutput_name);
  564. GE_CHECK_NOTNULL(netoutput_node);
  565. data_node_2_netoutput_node_.emplace(data_node, netoutput_node);
  566. // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor`
  567. AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
  568. ? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor())
  569. : Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0));
  570. AnchorPtr net_output_in_anchor = nullptr;
  571. AnchorPtr src_out_anchor = nullptr;
  572. if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) {
  573. GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed",
  574. netoutput_node->GetName().c_str(), data_node->GetName().c_str());
  575. return FAILED;
  576. }
  577. // 3. relink
  578. if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) {
  579. GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
  580. GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
  581. GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(),
  582. data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
  583. return FAILED;
  584. }
  585. GE_CHECK_NOTNULL(data_out_anchor);
  586. for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) {
  587. if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  588. GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
  589. GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(),
  590. GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
  591. data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
  592. return FAILED;
  593. }
  594. if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  595. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
  596. GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
  597. GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
  598. data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
  599. return FAILED;
  600. }
  601. }
  602. }
  603. // 4. remove out nodes added by us
  604. for (auto &node : netoutput_nodes_) {
  605. NodeUtils::UnlinkAll(*node);
  606. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  607. GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
  608. return FAILED;
  609. }
  610. GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str());
  611. }
  612. return SUCCESS;
  613. }
  614. graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor,
  615. AnchorPtr &src_out_anchor) {
  616. // 1. get `data parent node name`, i.e. `netoutput input node name`
  617. std::string netoutput_input_name;
  618. auto op_desc = data_node->GetOpDesc();
  619. GE_CHECK_NOTNULL(op_desc);
  620. if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) {
  621. GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str());
  622. return FAILED;
  623. }
  624. // 2. find index
  625. int parent_node_anchor_index;
  626. if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) {
  627. GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str());
  628. return FAILED;
  629. }
  630. // 3.find in data or ctrl anchor by 1&2 step
  631. for (auto &in_anchor : out_node->GetAllInAnchors()) {
  632. GE_CHECK_NOTNULL(in_anchor);
  633. for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl
  634. GE_CHECK_NOTNULL(src_anchor);
  635. auto src_node = src_anchor->GetOwnerNode();
  636. GE_CHECK_NOTNULL(src_node);
  637. if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) {
  638. dest_in_anchor = in_anchor;
  639. src_out_anchor = src_anchor;
  640. GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",
  641. out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(),
  642. parent_node_anchor_index, data_node->GetName().c_str());
  643. break;
  644. }
  645. }
  646. }
  647. GE_CHECK_NOTNULL(dest_in_anchor);
  648. GE_CHECK_NOTNULL(src_out_anchor);
  649. return SUCCESS;
  650. }
  651. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示