You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flow_ctrl_pass.cc 35 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/flow_ctrl_pass.h"
  17. #include <memory>
  18. #include <string>
  19. #include <vector>
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/common/omg_util.h"
  23. #include "common/ge/ge_util.h"
  24. #include "graph/manager/graph_var_manager.h"
  25. #include "graph/passes/pass_utils.h"
  26. namespace ge {
  27. // when namespace change to ge, please delete the using code.
  28. Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
  29. GE_CHECK_NOTNULL(compute_graph);
  30. if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) {
  31. GELOGI("No need FlowCtrl for graph %u.", compute_graph->GetGraphID());
  32. return NOT_CHANGED;
  33. }
  34. GELOGI("FlowCtrl pass begin.graph is [%s].", compute_graph->GetName().c_str());
  35. bool graph_change = false;
  36. // 1. Add FP/BP flow ctrl (big cycle)
  37. for (auto &node : compute_graph->GetDirectNode()) {
  38. if (node == nullptr) {
  39. continue;
  40. }
  41. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  42. uint32_t true_stream_id = 0;
  43. bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id);
  44. // FP/BP cycle flag is true_stream_id == 0
  45. if (is_found && (true_stream_id == TRUE_STREAM_ID)) {
  46. // Add big cycle
  47. Status ret = AddFpBpIteratorCtrl(compute_graph, node);
  48. if (ret != SUCCESS) {
  49. GELOGE(ret, "AddFpBpIteratorCtrl fail, node: %s.", node->GetName().c_str());
  50. return ret;
  51. }
  52. graph_change = true;
  53. // only one big cycle, so break.
  54. break;
  55. }
  56. }
  57. // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle)
  58. // NOTE: Small cycle share the variables with big cycle.
  59. for (auto &node : compute_graph->GetDirectNode()) {
  60. if (node == nullptr) {
  61. continue;
  62. }
  63. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  64. bool need_cycle_flag = false;
  65. bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag);
  66. // small cycle flag is need_stream_cycle_event == true
  67. if (is_found && need_cycle_flag) {
  68. Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node);
  69. if (ret != SUCCESS) {
  70. GELOGE(ret, "AddSpecialNodeIteratorCtrl fail, node: %s.", node->GetName().c_str());
  71. return ret;
  72. }
  73. graph_change = true;
  74. }
  75. }
  76. // add edge operation below depends on memcpy node in itertor loop set single stream,or may cause block
  77. for (auto &active_node : active_nodes_in_iter_loop_) {
  78. auto ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(),
  79. assign_add_node_in_fpbp_loop_->GetInControlAnchor());
  80. if (ret != GRAPH_SUCCESS) {
  81. GELOGW("add control edge between iter_loop_node:%s and fpbp_loop_node:%s fail, may cause block",
  82. active_node->GetName().c_str(), assign_add_node_in_fpbp_loop_->GetName().c_str());
  83. }
  84. }
  85. GELOGI("FlowCtrl pass end, graph is %s.", graph_change ? "changed" : "not changed");
  86. return graph_change ? SUCCESS : NOT_CHANGED;
  87. }
  88. bool FlowCtrlPass::CheckMultiDataSet(ComputeGraphPtr &compute_graph) {
  89. int data_set_num = 0;
  90. for (auto &node : compute_graph->GetDirectNode()) {
  91. if (node == nullptr) {
  92. continue;
  93. }
  94. string type;
  95. bool is_found = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  96. if (is_found && type == "IteratorV2") {
  97. data_set_num++;
  98. }
  99. }
  100. GELOGI("The ComputeGraph contain %d dataSet.", data_set_num);
  101. return (data_set_num > 1) ? true : false;
  102. }
  103. NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name,
  104. const std::vector<GeTensorDesc> &input_list,
  105. const std::vector<GeTensorDesc> &output_list) {
  106. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, node_type);
  107. if (op_desc == nullptr) {
  108. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  109. GELOGE(FAILED, "Make OpDesc failed, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  110. return nullptr;
  111. }
  112. for (auto &input_desc : input_list) {
  113. graphStatus graph_status = op_desc->AddInputDesc(input_desc);
  114. if (graph_status != GRAPH_SUCCESS) {
  115. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  116. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  117. GELOGE(FAILED, "Add node:%s intput desc failed, error=%u.", node_name.c_str(), graph_status);
  118. return nullptr;
  119. }
  120. }
  121. for (auto &output_desc : output_list) {
  122. graphStatus graph_status = op_desc->AddOutputDesc(output_desc);
  123. if (graph_status != GRAPH_SUCCESS) {
  124. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  125. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  126. GELOGE(FAILED, "Add node:%s output desc failed, error=%u.", node_name.c_str(), graph_status);
  127. return nullptr;
  128. }
  129. }
  130. GE_IF_BOOL_EXEC(compute_graph == nullptr,
  131. REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, check invalid");
  132. DOMI_LOGE("compute_graph is nullptr");
  133. return nullptr);
  134. NodePtr node = compute_graph->AddNode(op_desc);
  135. if (node == nullptr) {
  136. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  137. op_desc->GetName().c_str(), op_desc->GetType().c_str(), compute_graph->GetName().c_str());
  138. GELOGE(FAILED, "add node failed, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  139. return nullptr;
  140. }
  141. GELOGI("Insert op success, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  142. return node;
  143. }
  144. NodePtr FlowCtrlPass::InsertStreamSwitchOp(ComputeGraphPtr &compute_graph, const string &switch_name,
  145. const NodePtr &loop_cond, const NodePtr &iter_per_loop) {
  146. GE_IF_BOOL_EXEC(loop_cond == nullptr || loop_cond->GetOpDesc() == nullptr,
  147. REPORT_INNER_ERROR("E19999", "Param loop_cond or its op_desc is nullptr, "
  148. "check invalid");
  149. GELOGE(FAILED, "loop_cond is null");
  150. return nullptr);
  151. GE_IF_BOOL_EXEC(iter_per_loop == nullptr || iter_per_loop->GetOpDesc() == nullptr,
  152. REPORT_INNER_ERROR("E19999", "Param iter_per_loop or its op_desc is nullptr, "
  153. "check invalid");
  154. GELOGE(FAILED, "iter_per_loop is nullptr");
  155. return nullptr);
  156. std::vector<GeTensorDesc> input_desc_list = {loop_cond->GetOpDesc()->GetOutputDesc(0),
  157. iter_per_loop->GetOpDesc()->GetOutputDesc(0)};
  158. std::vector<GeTensorDesc> output_desc_list;
  159. NodePtr stream_switch = InsertOp(compute_graph, STREAMSWITCH, switch_name, input_desc_list, output_desc_list);
  160. if (stream_switch == nullptr) {
  161. GELOGE(FAILED, "InsertStreamSwitchOp failed, name:%s.", switch_name.c_str());
  162. return nullptr;
  163. }
  164. // set input 0
  165. graphStatus add_ret = GraphUtils::AddEdge(loop_cond->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0));
  166. if (add_ret != GRAPH_SUCCESS) {
  167. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  168. loop_cond->GetName().c_str(), loop_cond->GetType().c_str(),
  169. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  170. GELOGE(FAILED, "Add loop_cond_node to switch_node:%s edge failed, ret = %u.", switch_name.c_str(), add_ret);
  171. return nullptr;
  172. }
  173. // set input 1
  174. add_ret = GraphUtils::AddEdge(iter_per_loop->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1));
  175. if (add_ret != GRAPH_SUCCESS) {
  176. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:1) failed",
  177. iter_per_loop->GetName().c_str(), iter_per_loop->GetType().c_str(),
  178. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  179. GELOGE(FAILED, "Add iter_per_loop_node to switch_node:%s edge failed, ret = %u.", switch_name.c_str(), add_ret);
  180. return nullptr;
  181. }
  182. // stream switch op need switch cond by attr.
  183. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(stream_switch->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND,
  184. static_cast<int64_t>(RT_LESS)),
  185. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  186. ATTR_NAME_STREAM_SWITCH_COND.c_str(),
  187. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  188. DOMI_LOGE("set ATTR_NAME_STREAM_SWITCH_COND failed"); return nullptr);
  189. return stream_switch;
  190. }
  191. NodePtr FlowCtrlPass::AddVariableNode(ComputeGraphPtr &compute_graph, const string &name) {
  192. GE_IF_BOOL_EXEC(compute_graph == nullptr,
  193. REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, check invalid");
  194. DOMI_LOGE("compute_graph is nullptr");
  195. return nullptr);
  196. NodePtr exist_node = compute_graph->FindNode(name);
  197. if (exist_node != nullptr) {
  198. GELOGD("Node %s already exist, no need add.", name.c_str());
  199. return exist_node;
  200. }
  201. // fetch and set tensor desc
  202. GeTensorDesc tensor_desc;
  203. if (ge::VarManager::Instance(compute_graph->GetSessionID()) == nullptr) {
  204. REPORT_INNER_ERROR("E19999", "Get VarManager by session_id:%lu failed",
  205. compute_graph->GetSessionID());
  206. return nullptr;
  207. }
  208. Status ret = ge::VarManager::Instance(compute_graph->GetSessionID())->GetCurVarDesc(name, tensor_desc);
  209. if (ret != SUCCESS) {
  210. REPORT_INNER_ERROR("E19999", "Get var tensor from VarManager by name:%s failed, session_id:%lu",
  211. name.c_str(), compute_graph->GetSessionID());
  212. GELOGE(FAILED, "Get var desc fail, name:%s", name.c_str());
  213. return nullptr;
  214. }
  215. std::vector<GeTensorDesc> input_desc_list;
  216. std::vector<GeTensorDesc> output_desc_list = {tensor_desc};
  217. // insert node
  218. return InsertOp(compute_graph, VARIABLE, name, input_desc_list, output_desc_list);
  219. }
  220. Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) {
  221. NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT);
  222. if (output_node == nullptr) {
  223. GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID());
  224. return SUCCESS;
  225. }
  226. // Global step just add to main graph's netoutput node.And the main graph must be known shape
  227. if ((compute_graph->GetParentGraph() != nullptr) ||
  228. ((compute_graph->GetParentGraph() == nullptr) && (GraphUtils::IsUnknownShapeGraph(compute_graph)))) {
  229. GELOGD("Subgraph %s no need global step variable.", compute_graph->GetName().c_str());
  230. return SUCCESS;
  231. }
  232. NodePtr exist_node = compute_graph->FindNode(NODE_NAME_GLOBAL_STEP);
  233. if (exist_node != nullptr) {
  234. GELOGD("Node %s already exist, no need add.", NODE_NAME_GLOBAL_STEP.c_str());
  235. return SUCCESS;
  236. }
  237. // set global step tensor desc
  238. GeTensorDesc tensor_desc(GeShape({1}), FORMAT_ND, DT_UINT64);
  239. std::vector<GeTensorDesc> input_desc_list = {};
  240. std::vector<GeTensorDesc> output_desc_list = {tensor_desc};
  241. NodePtr global_step = InsertOp(compute_graph, VARIABLE, NODE_NAME_GLOBAL_STEP,
  242. input_desc_list, output_desc_list);
  243. if (global_step == nullptr) {
  244. GELOGE(FAILED, "Add global_step node failed, global_step is null.");
  245. return FAILED;
  246. }
  247. // add ctrl edges
  248. graphStatus add_ret = GraphUtils::AddEdge(global_step->GetOutControlAnchor(), output_node->GetInControlAnchor());
  249. if (add_ret != GRAPH_SUCCESS) {
  250. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  251. global_step->GetName().c_str(), global_step->GetType().c_str(),
  252. output_node->GetName().c_str(), output_node->GetType().c_str());
  253. GELOGE(FAILED, "Add global_step to netoutput edge failed, add_ret=%u.", add_ret);
  254. return FAILED;
  255. }
  256. GELOGD("Add global_step to netoutput edge in graph %u success", compute_graph->GetGraphID());
  257. return SUCCESS;
  258. }
  259. NodePtr FlowCtrlPass::InsertAssignOp(ge::ComputeGraphPtr &compute_graph, const string &node_type,
  260. const string &node_name, const NodePtr &ref_node, const NodePtr &value_node) {
  261. GE_IF_BOOL_EXEC(ref_node == nullptr || value_node == nullptr ||
  262. ref_node->GetOpDesc() == nullptr || value_node->GetOpDesc() == nullptr,
  263. REPORT_INNER_ERROR("E19999", "Param ref_node or value_node or their op_desc has nullptr, "
  264. "check invalid");
  265. GELOGE(FAILED, "ref node or value node is null");
  266. return nullptr);
  267. GeTensorDesc ref_tensor_desc = ref_node->GetOpDesc()->GetOutputDesc(0);
  268. GeTensorDesc val_tensor_desc = value_node->GetOpDesc()->GetOutputDesc(0);
  269. std::vector<GeTensorDesc> input_desc_list = {ref_tensor_desc, val_tensor_desc};
  270. std::vector<GeTensorDesc> output_desc_list = {ref_tensor_desc};
  271. NodePtr assign_node = InsertOp(compute_graph, node_type, node_name, input_desc_list, output_desc_list);
  272. if (assign_node == nullptr) {
  273. GELOGE(FAILED, "Insert node %s(%s) failed.", node_name.c_str(), node_type.c_str());
  274. return nullptr;
  275. }
  276. // assign node input 0 = ref_node
  277. graphStatus add_ret = GraphUtils::AddEdge(ref_node->GetOutDataAnchor(0), assign_node->GetInDataAnchor(0));
  278. if (add_ret != GRAPH_SUCCESS) {
  279. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  280. ref_node->GetName().c_str(), ref_node->GetType().c_str(),
  281. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  282. GELOGE(FAILED, "Add ref_node to %s edge failed, add_ret=%u.", node_name.c_str(), add_ret);
  283. return nullptr;
  284. }
  285. // assign input 1 = value_node
  286. add_ret = GraphUtils::AddEdge(value_node->GetOutDataAnchor(0), assign_node->GetInDataAnchor(1));
  287. if (add_ret != GRAPH_SUCCESS) {
  288. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:1) failed",
  289. value_node->GetName().c_str(), value_node->GetType().c_str(),
  290. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  291. GELOGE(FAILED, "Add value_node to %s edge failed, add_ret=%u.", node_name.c_str(), add_ret);
  292. return nullptr;
  293. }
  294. (void)ge::AttrUtils::SetBool(assign_node->GetOpDesc(), ATTR_NEED_COMPILE, true);
  295. return assign_node;
  296. }
  297. Status FlowCtrlPass::CreateIterCtrlTrueBranch(ComputeGraphPtr &compute_graph, const NodePtr &loop_cond_node,
  298. const NodePtr &loop_inc_node, NodePtr &switch_node) {
  299. /*
  300. * loopCond
  301. * |
  302. * v
  303. * switch --> AssignAdd --> active
  304. * ^
  305. * |
  306. * loopIncrement
  307. */
  308. // Insert AssignAdd node
  309. assign_add_node_in_fpbp_loop_ =
  310. InsertAssignOp(compute_graph, ASSIGNADD, NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD, loop_cond_node, loop_inc_node);
  311. if (assign_add_node_in_fpbp_loop_ == nullptr || switch_node == nullptr) {
  312. GELOGE(PARAM_INVALID, "assign add node or switch node is null");
  313. return FAILED;
  314. }
  315. string active_name = switch_node->GetName() + "_StreamActive";
  316. // add attr for stream assign model to break branch.
  317. auto status = SetStreamLabel(assign_add_node_in_fpbp_loop_, active_name);
  318. if (status != ge::SUCCESS) {
  319. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  320. active_name.c_str(), assign_add_node_in_fpbp_loop_->GetName().c_str(),
  321. assign_add_node_in_fpbp_loop_->GetType().c_str());
  322. GELOGE(status, "Set stream label failed.");
  323. return status;
  324. }
  325. // used for stream assign to find true branch
  326. status = SetActiveLabelList(switch_node, { active_name });
  327. if (status != ge::SUCCESS) {
  328. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  329. active_name.c_str(), switch_node->GetName().c_str(), switch_node->GetType().c_str());
  330. GELOGE(status, "set active_label_list failed.");
  331. return status;
  332. }
  333. // 2. Insert active node
  334. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  335. if (active_node == nullptr) {
  336. GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str());
  337. return FAILED;
  338. }
  339. status = SetStreamLabel(active_node, active_name);
  340. if (status != ge::SUCCESS) {
  341. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  342. active_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  343. GELOGE(status, "Set stream label failed.");
  344. return status;
  345. }
  346. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
  347. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  348. ATTR_NAME_IS_LOOP_ACTIVE.c_str(),
  349. active_node->GetName().c_str(), active_node->GetType().c_str());
  350. DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed");
  351. return FAILED);
  352. // add ctrl edges
  353. graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(),
  354. assign_add_node_in_fpbp_loop_->GetInControlAnchor());
  355. if (add_ret != GRAPH_SUCCESS) {
  356. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  357. switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  358. assign_add_node_in_fpbp_loop_->GetName().c_str(),
  359. assign_add_node_in_fpbp_loop_->GetType().c_str());
  360. GELOGE(FAILED, "Add switch_node to assign_add_node ctrl edge failed, add_ret=%u.", add_ret);
  361. return FAILED;
  362. }
  363. add_ret = GraphUtils::AddEdge(assign_add_node_in_fpbp_loop_->GetOutControlAnchor(),
  364. active_node->GetInControlAnchor());
  365. if (add_ret != GRAPH_SUCCESS) {
  366. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  367. assign_add_node_in_fpbp_loop_->GetName().c_str(),
  368. assign_add_node_in_fpbp_loop_->GetType().c_str(),
  369. active_node->GetName().c_str(), active_node->GetType().c_str());
  370. GELOGE(FAILED, "Add assign_add_node to active_node ctrl edge failed, add_ret=%u.", add_ret);
  371. return FAILED;
  372. }
  373. GELOGI("CreateIterCtrlTrueBranch success. StreamActive op:%s.", active_node->GetName().c_str());
  374. return SUCCESS;
  375. }
  376. Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, const NodePtr &loop_cond_node,
  377. const NodePtr &loop_reset_node, NodePtr &switch_node) {
  378. /*
  379. * loopCond
  380. * |
  381. * v
  382. * switch --> Assign --> active --> ModelExit
  383. * ^
  384. * |
  385. * loopReset
  386. */
  387. // Insert Assign node and ctrl edge
  388. NodePtr assign_node =
  389. InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node);
  390. if (assign_node == nullptr || switch_node == nullptr) {
  391. GELOGE(PARAM_INVALID, "assign_node or switch node is null.");
  392. return FAILED;
  393. }
  394. auto status = SetStreamLabel(assign_node, switch_node->GetName());
  395. if (status != ge::SUCCESS) {
  396. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  397. switch_node->GetName().c_str(), assign_node->GetName().c_str(), assign_node->GetType().c_str());
  398. GELOGE(status, "Set stream label failed.");
  399. return status;
  400. }
  401. graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor());
  402. if (add_ret != GRAPH_SUCCESS) {
  403. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  404. switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  405. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  406. GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret);
  407. return FAILED;
  408. }
  409. if (CheckMultiDataSet(compute_graph)) {
  410. GELOGI("Multi dataSae exist, model_exit node is need.");
  411. // 2. Insert active node and add ctrl edge
  412. string active_name = switch_node->GetName() + "_StreamExitActive";
  413. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  414. if (active_node == nullptr) {
  415. GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str());
  416. return FAILED;
  417. }
  418. status = SetStreamLabel(active_node, switch_node->GetName());
  419. if (status != ge::SUCCESS) {
  420. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  421. switch_node->GetName().c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  422. GELOGE(status, "Set stream label failed.");
  423. return status;
  424. }
  425. GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
  426. "set switch branch node label failed.");
  427. string model_exit_name = switch_node->GetName() + "_ModelExit";
  428. status = SetActiveLabelList(active_node, { model_exit_name });
  429. if (status != ge::SUCCESS) {
  430. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  431. model_exit_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  432. GELOGE(status, "set active_label_list failed.");
  433. return status;
  434. }
  435. add_ret = GraphUtils::AddEdge(assign_node->GetOutControlAnchor(), active_node->GetInControlAnchor());
  436. if (add_ret != GRAPH_SUCCESS) {
  437. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  438. assign_node->GetName().c_str(), assign_node->GetType().c_str(),
  439. active_node->GetName().c_str(), active_node->GetType().c_str());
  440. GELOGE(FAILED, "Add assign_node to active_node ctrl edge failed, add_ret=%u.", add_ret);
  441. return FAILED;
  442. }
  443. // 3. Insert model exit node and add ctrl edge
  444. NodePtr model_exit_node = InsertOp(compute_graph, MODELEXIT, model_exit_name, {}, {});
  445. if (model_exit_node == nullptr) {
  446. GELOGE(FAILED, "Insert model_exit node:%s for IterCtrlTrueStream failed.", model_exit_name.c_str());
  447. return FAILED;
  448. }
  449. status = SetStreamLabel(model_exit_node, model_exit_name);
  450. if (status != ge::SUCCESS) {
  451. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  452. model_exit_name.c_str(), model_exit_node->GetName().c_str(),
  453. model_exit_node->GetType().c_str());
  454. GELOGE(status, "Set stream label failed.");
  455. return status;
  456. }
  457. add_ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(), model_exit_node->GetInControlAnchor());
  458. if (add_ret != GRAPH_SUCCESS) {
  459. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  460. active_node->GetName().c_str(), assign_node->GetType().c_str(),
  461. model_exit_node->GetName().c_str(), model_exit_node->GetType().c_str());
  462. GELOGE(FAILED, "Add active_node to model_exit_node ctrl edge failed, add_ret=%u.", add_ret);
  463. return FAILED;
  464. }
  465. }
  466. GELOGI("CreateIterCtrlFalseBranch success.");
  467. return SUCCESS;
  468. }
  469. Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) {
  470. GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr."); return FAILED);
  471. string pre_node_name = pre_node->GetName();
  472. GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str());
  473. // 1. Get or add variables
  474. NodePtr loop_cond_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  475. if (loop_cond_node == nullptr) {
  476. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_COND.c_str());
  477. return FAILED;
  478. }
  479. NodePtr loop_inc_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_INCREMENT);
  480. if (loop_inc_node == nullptr) {
  481. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_INCREMENT.c_str());
  482. return FAILED;
  483. }
  484. NodePtr loop_reset_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_RESETVALUE);
  485. if (loop_reset_node == nullptr) {
  486. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_RESETVALUE.c_str());
  487. return FAILED;
  488. }
  489. NodePtr iter_per_loop_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  490. if (iter_per_loop_node == nullptr) {
  491. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_PER_ITER.c_str());
  492. return FAILED;
  493. }
  494. // 2. Add StreamSwitch
  495. string switch_name = pre_node_name + "_" + NODE_NAME_STREAM_SWITCH;
  496. NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
  497. if (switch_node == nullptr) {
  498. GELOGE(FAILED, "InsertStreamSwitchOp:%s failed.", switch_name.c_str());
  499. return FAILED;
  500. }
  501. auto status = SetStreamLabel(switch_node, switch_name);
  502. if (status != ge::SUCCESS) {
  503. REPORT_CALL_ERROR("E19999", "Set stream label:%s to op:%s(%s) failed",
  504. switch_name.c_str(), switch_node->GetName().c_str(), switch_node->GetType().c_str());
  505. GELOGE(status, "set stream label failed.");
  506. return status;
  507. }
  508. graphStatus add_ret = GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  509. if (add_ret != GRAPH_SUCCESS) {
  510. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  511. pre_node->GetName().c_str(), pre_node->GetType().c_str(),
  512. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  513. GELOGE(FAILED, "Add pre node:%s to switch_node:%s ctrl edge failed, ret = %u.", pre_node_name.c_str(),
  514. switch_name.c_str(), add_ret);
  515. return FAILED;
  516. }
  517. // 3. Create switch false branch: return results and reset the loopCond
  518. Status ret = CreateIterCtrlFalseBranch(compute_graph, loop_cond_node, loop_reset_node, switch_node);
  519. if (ret != SUCCESS) {
  520. GELOGE(ret, "CreateIterCtrlFalseBranch fail, pre node:%s.", pre_node_name.c_str());
  521. return ret;
  522. }
  523. // 4. Create switch true branch:
  524. // active train streams and increase the loopCond
  525. ret = CreateIterCtrlTrueBranch(compute_graph, loop_cond_node, loop_inc_node, switch_node);
  526. if (ret != SUCCESS) {
  527. GELOGE(ret, "CreateIterCtrlTrueBranch fail, pre node:%s.", pre_node_name.c_str());
  528. return ret;
  529. }
  530. return SUCCESS;
  531. }
  532. Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node) {
  533. /*
  534. * before add:
  535. * iterator
  536. * |
  537. * v
  538. * MemcpyAsync
  539. *
  540. * after add:
  541. * iterator ----------┐
  542. * | ┆c
  543. * v c v c
  544. * MemcpyAsync-----> switch -----> active
  545. * ^
  546. * / \
  547. * itersPerLoop loopCond
  548. */
  549. GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr,
  550. REPORT_INNER_ERROR("E19999", "Param loop_after_node or compute_graph is nullptr, "
  551. "check invalid");
  552. DOMI_LOGE("loop after node or compute graph is null.");
  553. return FAILED);
  554. InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0);
  555. if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) {
  556. REPORT_INNER_ERROR("E19999", "Param loop_after_node:%s(%s) no in data node, check invalid",
  557. loop_after_node->GetName().c_str(), loop_after_node->GetType().c_str());
  558. GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str());
  559. return FAILED;
  560. }
  561. NodePtr loop_pre_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode();
  562. // 1. Get variables
  563. NodePtr loop_cond_node = compute_graph->FindNode(NODE_NAME_FLOWCTRL_LOOP_COND);
  564. if (loop_cond_node == nullptr) {
  565. REPORT_INNER_ERROR("E19999", "Node:%s not found in graph:%s, check invalid",
  566. NODE_NAME_FLOWCTRL_LOOP_COND.c_str(), compute_graph->GetName().c_str());
  567. GELOGE(FAILED, "Find node :%s failed.", NODE_NAME_FLOWCTRL_LOOP_COND.c_str());
  568. return FAILED;
  569. }
  570. NodePtr iter_per_loop_node = compute_graph->FindNode(NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  571. if (iter_per_loop_node == nullptr) {
  572. REPORT_INNER_ERROR("E19999", "Node:%s not found in graph:%s, check invalid",
  573. NODE_NAME_FLOWCTRL_LOOP_PER_ITER.c_str(), compute_graph->GetName().c_str());
  574. GELOGE(FAILED, "Find node :%s failed.", NODE_NAME_FLOWCTRL_LOOP_PER_ITER.c_str());
  575. return FAILED;
  576. }
  577. // 2. Add StreamSwitch and edges to switch_node.
  578. GE_IF_BOOL_EXEC(loop_pre_node == nullptr,
  579. REPORT_INNER_ERROR("E19999", "Param loop_after_node:%s(%s) no in data node, "
  580. "check invalid", loop_after_node->GetName().c_str(),
  581. loop_after_node->GetType().c_str());
  582. DOMI_LOGE("loop pre node is null.");
  583. return FAILED);
  584. string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH;
  585. NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
  586. if (switch_node == nullptr) {
  587. GELOGE(FAILED, "InsertStreamSwitchOp:%s failed.", switch_name.c_str());
  588. return FAILED;
  589. }
  590. auto status = SetStreamLabel(switch_node, switch_name);
  591. if (status != ge::SUCCESS) {
  592. REPORT_CALL_ERROR("E19999", "Set stream label:%s to op:%s(%s) failed",
  593. switch_name.c_str(), switch_node->GetName().c_str(), switch_node->GetType().c_str());
  594. GELOGE(status, "set stream label failed.");
  595. return status;
  596. }
  597. graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  598. if (add_ret != GRAPH_SUCCESS) {
  599. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  600. loop_pre_node->GetName().c_str(), loop_pre_node->GetType().c_str(),
  601. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  602. GELOGE(FAILED, "Add loop_pre_node:%s to switch_node:%s ctrl edge failed, ret = %u.",
  603. loop_pre_node->GetName().c_str(), switch_name.c_str(), add_ret);
  604. return FAILED;
  605. }
  606. add_ret = GraphUtils::AddEdge(loop_after_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  607. if (add_ret != GRAPH_SUCCESS) {
  608. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  609. loop_after_node->GetName().c_str(), loop_after_node->GetType().c_str(),
  610. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  611. GELOGE(FAILED, "Add node:%s to switch_node:%s ctrl edge failed, ret = %u.", loop_after_node->GetName().c_str(),
  612. switch_name.c_str(), add_ret);
  613. return FAILED;
  614. }
  615. // 3. Create switch true branch: only active
  616. string active_name = switch_name + "_StreamActive";
  617. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  618. if (active_node == nullptr) {
  619. GELOGE(FAILED, "Insert stream active node:%s for SpecialNodeIteratorCtrl failed.", active_name.c_str());
  620. return FAILED;
  621. }
  622. status = SetStreamLabel(active_node, active_name);
  623. if (status != ge::SUCCESS) {
  624. REPORT_CALL_ERROR("E19999", "Set stream label:%s to op:%s(%s) failed",
  625. active_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  626. GELOGE(status, "set stream label failed.");
  627. return status;
  628. }
  629. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
  630. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  631. ATTR_NAME_IS_LOOP_ACTIVE.c_str(),
  632. active_node->GetName().c_str(), active_node->GetType().c_str());
  633. DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed");
  634. return FAILED);
  635. add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), active_node->GetInControlAnchor());
  636. if (add_ret != GRAPH_SUCCESS) {
  637. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  638. switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  639. active_node->GetName().c_str(), active_node->GetType().c_str());
  640. GELOGE(FAILED, "Add switch_node:%s to active_node:%s ctrl edge failed, ret = %u.", switch_name.c_str(),
  641. active_name.c_str(), add_ret);
  642. return FAILED;
  643. }
  644. // used for stream assign to find true branch
  645. status = SetActiveLabelList(switch_node, { active_name });
  646. if (status != ge::SUCCESS) {
  647. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  648. active_name.c_str(), switch_node->GetName().c_str(), switch_node->GetType().c_str());
  649. GELOGE(status, "set active_label_list failed.");
  650. return status;
  651. }
  652. // used for stream assign to find active stream
  653. status = SetActiveLabelList(active_node, { loop_pre_node->GetName() });
  654. if (status != ge::SUCCESS) {
  655. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  656. loop_pre_node->GetName().c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  657. GELOGE(status, "set active_label_list failed.");
  658. return status;
  659. }
  660. active_nodes_in_iter_loop_.push_back(active_node);
  661. return SUCCESS;
  662. }
  663. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示