You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tuning_utils.cc 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/tuning_utils.h"
  17. #include "../debug/ge_util.h"
  18. #include "../debug/ge_op_types.h"
  19. #include "framework/common/scope_guard.h"
  20. namespace ge {
  21. namespace {
  22. const std::string peer_node_name_attr = "_peerNodeName";
  23. const std::string parent_node_name_attr = "_parentNodeName";
  24. const std::string alias_name_attr = "_aliasName";
  25. const std::string parent_node_attr = "parentNode";
  26. const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex";
  27. const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
  28. const std::string non_tuning_subgraph_prefix = "/subgraph_";
  29. const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
  30. const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
  31. }
  32. NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
  33. NodetoNodeNameMap TuningUtils::data_node_2_netoutput_ ;
  34. NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
  35. NodeVec TuningUtils::netoutput_nodes_;
  36. NodeVec TuningUtils::merged_graph_nodes_;
  37. SubgraphCreateOutNode TuningUtils::create_output_;
  38. std::mutex TuningUtils::mutex_;
  39. std::string TuningUtils::PrintCheckLog() {
  40. std::stringstream ss;
  41. ss << "d2n:{";
  42. for (const auto &pair : data_2_netoutput_) {
  43. ss << "data:" << pair.first << "-" << "netoutput:" << pair.second;
  44. ss << " | ";
  45. }
  46. ss << "}";
  47. ss << "netoutputs:{";
  48. for (const auto &node : netoutput_nodes_) {
  49. ss << "netoutput:" << node->GetName();
  50. ss << " | ";
  51. }
  52. ss << "}";
  53. return ss.str();
  54. }
  55. std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) {
  56. if (anchor == nullptr) {
  57. GELOGE(GRAPH_FAILED, "Anchor is nullptr");
  58. return "Null";
  59. }
  60. auto node = anchor->GetOwnerNode();
  61. return node == nullptr ? "Null" : node->GetName();
  62. }
  63. // part 1
  64. graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
  65. std::vector<ComputeGraphPtr> non_tuning_subgraphs,
  66. bool exe_flag, const std::string &path, const std::string &user_path) {
  67. int64_t i = 0;
  68. int64_t j = 0;
  69. std::lock_guard<std::mutex> lock(mutex_);
  70. for (auto &subgraph : tuning_subgraphs) {
  71. create_output_.emplace(subgraph, nullptr);
  72. auto help_info = HelpInfo{i, exe_flag, true, path, user_path};
  73. if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
  74. GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i);
  75. return GRAPH_FAILED;
  76. }
  77. i++;
  78. }
  79. for (auto &subgraph : non_tuning_subgraphs) {
  80. create_output_.emplace(subgraph, nullptr);
  81. auto help_info = HelpInfo{j, true, false, path, user_path};
  82. if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
  83. GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j);
  84. return GRAPH_FAILED;
  85. }
  86. j++;
  87. }
  88. create_output_.clear();
  89. return SUCCESS;
  90. }
  91. // +---------------+
  92. // | pld pld |
  93. // | \ / |
  94. // | relu relu |
  95. // | \ / |
  96. // | add |
  97. // | | |
  98. // | end |
  99. // +---------------+
  100. // |
  101. // |
  102. // V
  103. // +---------------+
  104. // | data data |
  105. // | \ / |
  106. // | relu relu |
  107. // | \ / |
  108. // | add |
  109. // | | |
  110. // | netoutput |
  111. // +---------------+
  112. graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph,
  113. const HelpInfo& help_info) {
  114. GE_CHECK_NOTNULL(exe_graph);
  115. graphStatus ret = exe_graph->TopologicalSortingGraph(true);
  116. if (ret != SUCCESS) {
  117. GraphUtils::DumpGEGraphToOnnx(*exe_graph, "black_box");
  118. GELOGE(ret, "Graph[%s] topological sort failed, saved to file black_box ret:%d.",
  119. exe_graph->GetName().c_str(), ret);
  120. return ret;
  121. }
  122. // clear graph id
  123. GELOGI("TUU:clear [%s] session_graph_id %s", exe_graph->GetName().c_str(),
  124. (AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "") ? "success" : "not success"));
  125. // if not make exe, just dump and return
  126. if (!help_info.exe_flag) {
  127. DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
  128. GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index);
  129. return SUCCESS;
  130. }
  131. // modify sub graph
  132. for (NodePtr &node : exe_graph->GetDirectNode()) {
  133. // 1.handle pld
  134. if (node->GetType() == PLACEHOLDER) {
  135. if (HandlePld(node) != SUCCESS) {
  136. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
  137. exe_graph->GetName().c_str());
  138. return FAILED;
  139. }
  140. }
  141. // 2.handle end
  142. if (node->GetType() == END) {
  143. if (HandleEnd(node) != SUCCESS) {
  144. GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
  145. exe_graph->GetName().c_str());
  146. return FAILED;
  147. }
  148. }
  149. }
  150. ret = exe_graph->TopologicalSortingGraph(true);
  151. if (ret != SUCCESS) {
  152. GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret);
  153. return ret;
  154. }
  155. // dump subgraphs which modified by us
  156. if (help_info.user_path.empty()) {
  157. DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
  158. } else {
  159. GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path);
  160. }
  161. return SUCCESS;
  162. }
  163. void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index,
  164. bool is_tuning_graph, std::string path) {
  165. if (!path.empty()) {
  166. if (is_tuning_graph) {
  167. GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
  168. } else {
  169. GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
  170. }
  171. } else {
  172. path = "./";
  173. if (is_tuning_graph) {
  174. GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
  175. } else {
  176. GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
  177. }
  178. }
  179. }
  180. graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) {
  181. auto graph = node->GetOwnerComputeGraph();
  182. GE_CHECK_NOTNULL(graph);
  183. auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA);
  184. GE_CHECK_NOTNULL(data_op_desc);
  185. auto pld_op_desc = node->GetOpDesc();
  186. GE_CHECK_NOTNULL(pld_op_desc);
  187. auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data
  188. // data inputdesc & outputdesc set as same
  189. if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) {
  190. GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
  191. return FAILED;
  192. }
  193. if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) {
  194. GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
  195. return FAILED;
  196. }
  197. data_node = graph->AddNode(data_op_desc);
  198. GE_CHECK_NOTNULL(data_node);
  199. if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
  200. GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
  201. return FAILED;
  202. }
  203. return SUCCESS;
  204. }
  205. graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) {
  206. auto op_desc = data_node->GetOpDesc();
  207. GE_CHECK_NOTNULL(op_desc);
  208. auto pld_desc = pld->GetOpDesc();
  209. GE_CHECK_NOTNULL(pld_desc);
  210. // inherit
  211. // a. set `end's input node type` as attr
  212. std::string parent_op_type;
  213. if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) {
  214. GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str());
  215. return FAILED;
  216. }
  217. (void) AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type);
  218. // b. set `end's input node name` as attr
  219. std::string parent_op_name;
  220. if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) {
  221. GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str());
  222. return FAILED;
  223. }
  224. (void) AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name);
  225. // c. set `end's input node's out anchor index` as attr
  226. int parent_node_anchor_index;
  227. if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) {
  228. GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str());
  229. return FAILED;
  230. }
  231. (void) AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index);
  232. GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success",
  233. pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
  234. // d. set `end node name` as attr
  235. std::string peer_end_name;
  236. if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) {
  237. GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str());
  238. return FAILED;
  239. }
  240. (void) AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name);
  241. GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success",
  242. pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
  243. return SUCCESS;
  244. }
  245. graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) {
  246. auto type_pld = node->GetType();
  247. auto type_data = data_node->GetType();
  248. if (type_pld != PLACEHOLDER || type_data != DATA) {
  249. GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s",
  250. node->GetName().c_str(), type_pld.c_str(), type_data.c_str());
  251. return FAILED;
  252. }
  253. auto graph = node->GetOwnerComputeGraph();
  254. GE_CHECK_NOTNULL(graph);
  255. std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
  256. for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  257. output_map[i] = static_cast<int>(i);
  258. }
  259. auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map);
  260. if (ret != GRAPH_SUCCESS) {
  261. GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u",
  262. node->GetName().c_str(), data_node->GetName().c_str(), ret);
  263. return FAILED;
  264. }
  265. NodeUtils::UnlinkAll(*node);
  266. ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
  267. if (ret != GRAPH_SUCCESS) {
  268. GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
  269. return FAILED;
  270. }
  271. GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)",
  272. node->GetName().c_str(), node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
  273. return ret;
  274. }
  275. graphStatus TuningUtils::HandlePld(NodePtr &node) {
  276. GE_CHECK_NOTNULL(node);
  277. auto graph = node->GetOwnerComputeGraph();
  278. GE_CHECK_NOTNULL(graph);
  279. if (HandleContinuousInputNodeNextData(node) != GRAPH_SUCCESS) {
  280. GELOGE(GRAPH_FAILED, "TUU:Failed to handle continuous node next to data node:%s",
  281. node->GetName().c_str());
  282. return GRAPH_FAILED;
  283. }
  284. NodePtr data_node = nullptr;
  285. // 1. create data node
  286. if (CreateDataNode(node, data_node) != SUCCESS) {
  287. GELOGE(FAILED,
  288. "TUU:Failed to handle node %s from graph %s",
  289. node->GetName().c_str(),
  290. graph->GetName().c_str());
  291. return FAILED;
  292. }
  293. // 2. add necessary info to data_node for recovery whole graph
  294. if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) {
  295. GELOGE(FAILED,
  296. "TUU:Failed to handle node %s from graph %s",
  297. node->GetName().c_str(),
  298. graph->GetName().c_str());
  299. return FAILED;
  300. }
  301. // 3. replace pld node by data node created before
  302. if (ChangePld2Data(node, data_node) != SUCCESS) {
  303. GELOGE(FAILED,
  304. "TUU:Failed to handle node %s from graph %s",
  305. node->GetName().c_str(),
  306. graph->GetName().c_str());
  307. return FAILED;
  308. }
  309. GELOGD("TUU:pld[%s] handle success", node->GetName().c_str());
  310. return SUCCESS;
  311. }
  312. graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) {
  313. GE_CHECK_NOTNULL(node);
  314. auto graph = node->GetOwnerComputeGraph();
  315. GE_CHECK_NOTNULL(graph);
  316. auto search = create_output_.find(graph);
  317. if (search == create_output_.end()) {
  318. GELOGE(FAILED,
  319. "TUU:node %s's owner sub graph %s not exist in create_output map",
  320. node->GetName().c_str(),
  321. graph->GetName().c_str());
  322. return FAILED;
  323. }
  324. if (search->second != nullptr) {
  325. out_node = search->second;
  326. GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str());
  327. return SUCCESS;
  328. }
  329. auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT);
  330. GE_CHECK_NOTNULL(out_op_desc);
  331. out_node = graph->AddNode(out_op_desc);
  332. GE_CHECK_NOTNULL(out_node);
  333. if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
  334. GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
  335. return FAILED;
  336. }
  337. create_output_[graph] = out_node;
  338. return SUCCESS;
  339. }
  340. graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) {
  341. GE_CHECK_NOTNULL(end);
  342. GE_CHECK_NOTNULL(out_node);
  343. auto op_desc = out_node->GetOpDesc();
  344. GE_CHECK_NOTNULL(op_desc);
  345. std::vector<std::string> alias_names = {};
  346. (void) AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names);
  347. alias_names.push_back(end->GetName());
  348. (void) AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names);
  349. return SUCCESS;
  350. }
  351. graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
  352. GE_CHECK_NOTNULL(end_node);
  353. GE_CHECK_NOTNULL(out_node);
  354. // get end in node is control node or normal node
  355. AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
  356. ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
  357. : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
  358. GE_CHECK_NOTNULL(end_in_anchor);
  359. auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
  360. GE_CHECK_NOTNULL(src_anchor);
  361. if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
  362. GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  363. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  364. GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(),
  365. end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str());
  366. return FAILED;
  367. }
  368. // add edge between `end in node` and `out_node`
  369. if (src_anchor->IsTypeOf<OutDataAnchor>()) {
  370. std::shared_ptr<InDataAnchor>
  371. anchor = ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size());
  372. GE_CHECK_NOTNULL(anchor);
  373. out_node->in_data_anchors_.push_back(anchor);
  374. if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
  375. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  376. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  377. GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(),
  378. end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str());
  379. return FAILED;
  380. }
  381. auto end_op_desc = end_node->GetOpDesc();
  382. GE_CHECK_NOTNULL(end_op_desc);
  383. auto out_node_op_desc = out_node->GetOpDesc();
  384. GE_CHECK_NOTNULL(out_node_op_desc);
  385. // end node always has one input
  386. if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) {
  387. GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str());
  388. return FAILED;
  389. }
  390. } else if (src_anchor->IsTypeOf<OutControlAnchor>()) {
  391. auto anchor = out_node->GetInControlAnchor();
  392. if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
  393. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
  394. GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
  395. GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(),
  396. end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str());
  397. return FAILED;
  398. }
  399. } else {
  400. GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed",
  401. end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str());
  402. return FAILED;
  403. }
  404. return SUCCESS;
  405. }
  406. graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
  407. GE_CHECK_NOTNULL(end_node);
  408. GE_CHECK_NOTNULL(out_node);
  409. auto type_end = end_node->GetType();
  410. auto type_out = out_node->GetType();
  411. if (type_end != END || type_out != NETOUTPUT) {
  412. GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s",
  413. end_node->GetName().c_str(), type_end.c_str(), type_out.c_str());
  414. return FAILED;
  415. }
  416. // link all `end nodes's in node` to this out_node
  417. if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) {
  418. GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str());
  419. return FAILED;
  420. }
  421. // remove `end node`
  422. NodeUtils::UnlinkAll(*end_node);
  423. auto graph = end_node->GetOwnerComputeGraph();
  424. GE_CHECK_NOTNULL(graph);
  425. if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) {
  426. GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str());
  427. return FAILED;
  428. }
  429. return SUCCESS;
  430. }
  431. graphStatus TuningUtils::HandleEnd(NodePtr &node) {
  432. GE_CHECK_NOTNULL(node);
  433. auto graph = node->GetOwnerComputeGraph();
  434. GE_CHECK_NOTNULL(graph);
  435. NodePtr out_node = nullptr;
  436. // 1. create net_output node , add only one NetOutput node to one subgraph
  437. if (CreateNetOutput(node, out_node) != SUCCESS) {
  438. GELOGE(FAILED,
  439. "TUU:Failed to handle node %s from graph %s",
  440. node->GetName().c_str(),
  441. graph->GetName().c_str());
  442. return FAILED;
  443. }
  444. // 2. add necessary info to out_node for recovery whole graph
  445. if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) {
  446. GELOGE(FAILED,
  447. "TUU:Failed to handle node %s from graph %s",
  448. node->GetName().c_str(),
  449. graph->GetName().c_str());
  450. return FAILED;
  451. }
  452. // 3. replace all end nodes by one output node created before
  453. if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) {
  454. GELOGE(FAILED,
  455. "TUU:Failed to handle node %s from graph %s",
  456. node->GetName().c_str(),
  457. graph->GetName().c_str());
  458. return FAILED;
  459. }
  460. GELOGD("TUU:end[%s] handle success", node->GetName().c_str());
  461. return SUCCESS;
  462. }
  463. // part 2
  464. graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
  465. std::function<void()> callback = [&]() {
  466. data_2_netoutput_.clear();
  467. data_node_2_netoutput_.clear();
  468. data_node_2_netoutput_node_.clear();
  469. netoutput_nodes_.clear();
  470. merged_graph_nodes_.clear();
  471. };
  472. GE_MAKE_GUARD(release, callback);
  473. // 1. get all subgraph object
  474. std::vector<ComputeGraphPtr> graphs;
  475. // options format like {index:"subgraph_path"}
  476. for (const auto &pair : options) {
  477. ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first));
  478. if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) {
  479. GELOGE(FAILED, "TUU:load graph from file failed");
  480. }
  481. graphs.push_back(compute_graph);
  482. }
  483. // 2. merge graph
  484. ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune");
  485. GE_CHECK_NOTNULL(merged_graph);
  486. if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) {
  487. GELOGE(FAILED, "TUU:MergeGraph failed");
  488. return FAILED;
  489. }
  490. // 3. set parent graph
  491. for (const auto &node : merged_graph->GetDirectNode()) {
  492. GE_CHECK_NOTNULL(node);
  493. if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) {
  494. GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str());
  495. return FAILED;
  496. }
  497. }
  498. graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph);
  499. return SUCCESS;
  500. }
  501. // +----------------------------------+
  502. // | const const |
  503. // | \ / |
  504. // | netoutput(end,end) |
  505. // +----------------------------------+
  506. // +
  507. // +----------------------------------+
  508. // | data(pld) data(pld) |
  509. // | \ / |
  510. // | relu relu |
  511. // | \ / |
  512. // | \ / |
  513. // | add |
  514. // | | |
  515. // | netoutput(end) |
  516. // +----------------------------------+
  517. // +
  518. // +----------------------------------+
  519. // | data(pld) |
  520. // | / |
  521. // | netoutput |
  522. // +----------------------------------+
  523. // |
  524. // |
  525. // V
  526. // +----------------------------------+
  527. // | const const |
  528. // | \ / |
  529. // | relu relu |
  530. // | \ / |
  531. // | \ / |
  532. // | add |
  533. // | | |
  534. // | netoutput |
  535. // +----------------------------------+
  536. graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs,
  537. ComputeGraphPtr &output_merged_compute_graph) {
  538. GE_CHECK_NOTNULL(output_merged_compute_graph);
  539. // 1. handle all subgraphs
  540. for (auto &subgraph : subgraphs) {
  541. Status ret_status = MergeSubGraph(subgraph);
  542. if (ret_status != SUCCESS) {
  543. GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str());
  544. return ret_status;
  545. }
  546. }
  547. for (const auto &node: merged_graph_nodes_) {
  548. (void) output_merged_compute_graph->AddNode(node);
  549. GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str());
  550. vector<string> recover_attr_name;
  551. (void) ge::AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_NEED_RECOVER_ATTR, recover_attr_name);
  552. if (!recover_attr_name.empty()) {
  553. for (const auto &attr_name : recover_attr_name) {
  554. if (!ge::AttrUtils::SetBool(node->GetOpDesc(), attr_name, true)) {
  555. GELOGE(GRAPH_FAILED, "Recover attr %s for node:%s failed.", attr_name.c_str(), node->GetName().c_str());
  556. return GRAPH_FAILED;
  557. }
  558. }
  559. }
  560. }
  561. // 2. remove data and output node added by us
  562. if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) {
  563. GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str());
  564. return FAILED;
  565. }
  566. graphStatus ret = output_merged_compute_graph->TopologicalSorting();
  567. if (ret != SUCCESS) {
  568. GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret);
  569. return ret;
  570. }
  571. GELOGD("TUU:Print-%s", PrintCheckLog().c_str());
  572. GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str());
  573. return SUCCESS;
  574. }
  575. graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) {
  576. for (auto &node : subgraph->GetDirectNode()) {
  577. if (kPartitionOpTypes.count(node->GetType()) > 0) {
  578. GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type");
  579. return FAILED;
  580. }
  581. // handle data converted from pld node
  582. if (node->GetType() == DATA) {
  583. auto op_desc = node->GetOpDesc();
  584. GE_CHECK_NOTNULL(op_desc);
  585. std::string peer_out_name;
  586. bool has_valid_str =
  587. (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty());
  588. if (has_valid_str) {
  589. std::lock_guard<std::mutex> lock(mutex_);
  590. data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name);
  591. data_node_2_netoutput_.emplace(node, peer_out_name);
  592. continue;
  593. }
  594. }
  595. // handle netoutput converted from end node
  596. if (node->GetType() == NETOUTPUT) {
  597. auto op_desc = node->GetOpDesc();
  598. GE_CHECK_NOTNULL(op_desc);
  599. std::vector<string> out_alias_name;
  600. bool has_valid_str =
  601. (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty());
  602. if (has_valid_str) {
  603. std::lock_guard<std::mutex> lock(mutex_);
  604. netoutput_nodes_.emplace_back(node);
  605. }
  606. }
  607. {
  608. std::lock_guard<std::mutex> lock(mutex_);
  609. merged_graph_nodes_.emplace_back(node);
  610. }
  611. GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str());
  612. }
  613. GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str());
  614. return SUCCESS;
  615. }
  616. graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) {
  617. GE_CHECK_NOTNULL(graph);
  618. // 1. traverse
  619. for (auto &pair: data_node_2_netoutput_) {
  620. auto data_node = pair.first;
  621. GE_CHECK_NOTNULL(data_node);
  622. auto netoutput_name = pair.second;
  623. auto netoutput_node = graph->FindNode(netoutput_name);
  624. GE_CHECK_NOTNULL(netoutput_node);
  625. data_node_2_netoutput_node_.emplace(data_node, netoutput_node);
  626. // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor`
  627. AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
  628. ? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor())
  629. : Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0));
  630. AnchorPtr net_output_in_anchor = nullptr;
  631. AnchorPtr src_out_anchor = nullptr;
  632. if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) {
  633. GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed",
  634. netoutput_node->GetName().c_str(), data_node->GetName().c_str());
  635. return FAILED;
  636. }
  637. // 3. relink
  638. // unlink netoutput_node with it's input in stage 4
  639. GE_CHECK_NOTNULL(data_out_anchor);
  640. for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) {
  641. if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  642. GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
  643. GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(),
  644. GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
  645. data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
  646. return FAILED;
  647. }
  648. if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  649. GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
  650. GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
  651. GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
  652. data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
  653. return FAILED;
  654. }
  655. }
  656. }
  657. // 4. remove out nodes added by us
  658. for (auto &node: netoutput_nodes_) {
  659. NodeUtils::UnlinkAll(*node);
  660. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  661. GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
  662. return FAILED;
  663. }
  664. GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str());
  665. }
  666. return SUCCESS;
  667. }
  668. graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node,
  669. NodePtr &out_node,
  670. AnchorPtr &dest_in_anchor,
  671. AnchorPtr &src_out_anchor) {
  672. // 1. get `data parent node name`, i.e. `netoutput input node name`
  673. std::string netoutput_input_name;
  674. auto op_desc = data_node->GetOpDesc();
  675. GE_CHECK_NOTNULL(op_desc);
  676. if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) {
  677. GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str());
  678. return FAILED;
  679. }
  680. // 2. find index
  681. int parent_node_anchor_index;
  682. if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) {
  683. GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str());
  684. return FAILED;
  685. }
  686. // 3.find in data or ctrl anchor by 1&2 step
  687. for (auto &in_anchor: out_node->GetAllInAnchors()) {
  688. GE_CHECK_NOTNULL(in_anchor);
  689. for (auto &src_anchor :in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl
  690. GE_CHECK_NOTNULL(src_anchor);
  691. auto src_node = src_anchor->GetOwnerNode();
  692. GE_CHECK_NOTNULL(src_node);
  693. std::string src_node_name = src_node->GetName();
  694. if (src_node_name.find(netoutput_input_name) != src_node_name.npos &&
  695. src_anchor->GetIdx() == parent_node_anchor_index) {
  696. dest_in_anchor = in_anchor;
  697. src_out_anchor = src_anchor;
  698. GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",
  699. out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(),
  700. parent_node_anchor_index, data_node->GetName().c_str());
  701. break;
  702. }
  703. }
  704. }
  705. GE_CHECK_NOTNULL(dest_in_anchor);
  706. GE_CHECK_NOTNULL(src_out_anchor);
  707. return SUCCESS;
  708. }
  709. graphStatus TuningUtils::HandleContinuousInputNodeNextData(NodePtr &node) {
  710. GE_CHECK_NOTNULL(node);
  711. for (const auto &out_anchor : node->GetAllOutAnchors()) {
  712. for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) {
  713. auto next_node = peer_in_anchor->GetOwnerNode();
  714. vector<string> remove_attr_names;
  715. bool is_no_padding_continuous_input = false;
  716. bool is_continuous_input = false;
  717. bool is_no_task = false;
  718. (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_continuous_input);
  719. (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(),
  720. ATTR_NAME_NOPADDING_CONTINUOUS_INPUT,
  721. is_no_padding_continuous_input);
  722. (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(), ATTR_NAME_NOTASK, is_no_task);
  723. if (is_continuous_input) {
  724. if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, false)) {
  725. GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_CONTINUOUS_INPUT for node:%s failed.",
  726. next_node->GetName().c_str());
  727. return GRAPH_FAILED;
  728. }
  729. remove_attr_names.emplace_back(ATTR_NAME_CONTINUOUS_INPUT);
  730. }
  731. if (is_no_padding_continuous_input) {
  732. if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, false)) {
  733. GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_NOPADDING_CONTINUOUS_INPUT for node:%s failed.",
  734. next_node->GetName().c_str());
  735. return GRAPH_FAILED;
  736. }
  737. remove_attr_names.emplace_back(ATTR_NAME_NOPADDING_CONTINUOUS_INPUT);
  738. }
  739. if ((is_continuous_input || is_no_padding_continuous_input) && is_no_task) {
  740. if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_NOTASK, false)) {
  741. GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_NOTASK for node:%s failed.",
  742. next_node->GetName().c_str());
  743. return GRAPH_FAILED;
  744. }
  745. remove_attr_names.emplace_back(ATTR_NAME_NOTASK);
  746. }
  747. if (!remove_attr_names.empty()) {
  748. if (!ge::AttrUtils::SetListStr(next_node->GetOpDesc(),
  749. ATTR_NAME_NEED_RECOVER_ATTR,
  750. remove_attr_names)) {
  751. GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_NEED_RECOVER_ATTR for node:%s failed.",
  752. next_node->GetName().c_str());
  753. return GRAPH_FAILED;
  754. }
  755. }
  756. }
  757. }
  758. return GRAPH_SUCCESS;
  759. }
  760. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示