You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flow_ctrl_pass.cc 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/flow_ctrl_pass.h"
  17. #include <memory>
  18. #include <string>
  19. #include <vector>
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/common/omg_util.h"
  23. #include "common/ge/ge_util.h"
  24. #include "graph/manager/graph_var_manager.h"
  25. #include "graph/passes/pass_utils.h"
  26. namespace ge {
  27. // when namespace change to ge, please delete the using code.
  28. Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
  29. GE_CHECK_NOTNULL(compute_graph);
  30. if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) {
  31. GELOGI("No need FlowCtrl for graph %u.", compute_graph->GetGraphID());
  32. return NOT_CHANGED;
  33. }
  34. GELOGI("FlowCtrl pass begin.graph is [%s].", compute_graph->GetName().c_str());
  35. bool graph_change = false;
  36. // 1. Add FP/BP flow ctrl (big cycle)
  37. for (auto &node : compute_graph->GetDirectNode()) {
  38. if (node == nullptr) {
  39. continue;
  40. }
  41. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  42. uint32_t true_stream_id = 0;
  43. bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id);
  44. // FP/BP cycle flag is true_stream_id == 0
  45. if (is_found && (true_stream_id == TRUE_STREAM_ID)) {
  46. // Add big cycle
  47. Status ret = AddFpBpIteratorCtrl(compute_graph, node);
  48. if (ret != SUCCESS) {
  49. GELOGE(ret, "AddFpBpIteratorCtrl fail, node: %s.", node->GetName().c_str());
  50. return ret;
  51. }
  52. graph_change = true;
  53. // only one big cycle, so break.
  54. break;
  55. }
  56. }
  57. // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle)
  58. // NOTE: Small cycle share the variables with big cycle.
  59. for (auto &node : compute_graph->GetDirectNode()) {
  60. if (node == nullptr) {
  61. continue;
  62. }
  63. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
  64. bool need_cycle_flag = false;
  65. bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag);
  66. // small cycle flag is need_stream_cycle_event == true
  67. if (is_found && need_cycle_flag) {
  68. Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node);
  69. if (ret != SUCCESS) {
  70. GELOGE(ret, "AddSpecialNodeIteratorCtrl fail, node: %s.", node->GetName().c_str());
  71. return ret;
  72. }
  73. graph_change = true;
  74. }
  75. }
  76. // add edge operation below depends on memcpy node in itertor loop set single stream,or may cause block
  77. for (auto &active_node : active_nodes_in_iter_loop_) {
  78. auto ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(),
  79. assign_add_node_in_fpbp_loop_->GetInControlAnchor());
  80. if (ret != GRAPH_SUCCESS) {
  81. GELOGW("add control edge between iter_loop_node:%s and fpbp_loop_node:%s fail, may cause block",
  82. active_node->GetName().c_str(), assign_add_node_in_fpbp_loop_->GetName().c_str());
  83. }
  84. }
  85. GELOGI("FlowCtrl pass end, graph is %s.", graph_change ? "changed" : "not changed");
  86. return graph_change ? SUCCESS : NOT_CHANGED;
  87. }
  88. bool FlowCtrlPass::CheckMultiDataSet(ComputeGraphPtr &compute_graph) {
  89. int data_set_num = 0;
  90. for (auto &node : compute_graph->GetDirectNode()) {
  91. if (node == nullptr) {
  92. continue;
  93. }
  94. string type;
  95. bool is_found = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  96. if (is_found && type == "IteratorV2") {
  97. data_set_num++;
  98. }
  99. }
  100. GELOGI("The ComputeGraph contain %d dataSet.", data_set_num);
  101. return (data_set_num > 1) ? true : false;
  102. }
  103. NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name,
  104. const std::vector<GeTensorDesc> &input_list,
  105. const std::vector<GeTensorDesc> &output_list) {
  106. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, node_type);
  107. if (op_desc == nullptr) {
  108. GELOGE(FAILED, "Make OpDesc failed, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  109. return nullptr;
  110. }
  111. for (auto &input_desc : input_list) {
  112. graphStatus graph_status = op_desc->AddInputDesc(input_desc);
  113. if (graph_status != GRAPH_SUCCESS) {
  114. GELOGE(FAILED, "Add node:%s intput desc failed, error=%u.", node_name.c_str(), graph_status);
  115. return nullptr;
  116. }
  117. }
  118. for (auto &output_desc : output_list) {
  119. graphStatus graph_status = op_desc->AddOutputDesc(output_desc);
  120. if (graph_status != GRAPH_SUCCESS) {
  121. GELOGE(FAILED, "Add node:%s output desc failed, error=%u.", node_name.c_str(), graph_status);
  122. return nullptr;
  123. }
  124. }
  125. GE_IF_BOOL_EXEC(compute_graph == nullptr, DOMI_LOGE("compute_graph is nullptr"); return nullptr);
  126. NodePtr node = compute_graph->AddNode(op_desc);
  127. if (node == nullptr) {
  128. GELOGE(FAILED, "add node failed, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  129. return nullptr;
  130. }
  131. GELOGI("Insert op success, name:%s, type:%s.", node_name.c_str(), node_type.c_str());
  132. return node;
  133. }
  134. NodePtr FlowCtrlPass::InsertStreamSwitchOp(ComputeGraphPtr &compute_graph, const string &switch_name,
  135. const NodePtr &loop_cond, const NodePtr &iter_per_loop) {
  136. GE_IF_BOOL_EXEC(loop_cond == nullptr || loop_cond->GetOpDesc() == nullptr,
  137. GELOGE(FAILED, "loop_cond is null"); return nullptr);
  138. GE_IF_BOOL_EXEC(iter_per_loop == nullptr || iter_per_loop->GetOpDesc() == nullptr,
  139. GELOGE(FAILED, "iter_per_loop is nullptr"); return nullptr);
  140. std::vector<GeTensorDesc> input_desc_list = {loop_cond->GetOpDesc()->GetOutputDesc(0),
  141. iter_per_loop->GetOpDesc()->GetOutputDesc(0)};
  142. std::vector<GeTensorDesc> output_desc_list;
  143. NodePtr stream_switch = InsertOp(compute_graph, STREAMSWITCH, switch_name, input_desc_list, output_desc_list);
  144. if (stream_switch == nullptr) {
  145. GELOGE(FAILED, "InsertStreamSwitchOp failed, name:%s.", switch_name.c_str());
  146. return nullptr;
  147. }
  148. // set input 0
  149. graphStatus add_ret = GraphUtils::AddEdge(loop_cond->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0));
  150. if (add_ret != GRAPH_SUCCESS) {
  151. GELOGE(FAILED, "Add loop_cond_node to switch_node:%s edge failed, ret = %u.", switch_name.c_str(), add_ret);
  152. return nullptr;
  153. }
  154. // set input 1
  155. add_ret = GraphUtils::AddEdge(iter_per_loop->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1));
  156. if (add_ret != GRAPH_SUCCESS) {
  157. GELOGE(FAILED, "Add iter_per_loop_node to switch_node:%s edge failed, ret = %u.", switch_name.c_str(), add_ret);
  158. return nullptr;
  159. }
  160. // stream switch op need switch cond by attr.
  161. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(stream_switch->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND,
  162. static_cast<int64_t>(RT_LESS)),
  163. DOMI_LOGE("set ATTR_NAME_STREAM_SWITCH_COND failed"); return nullptr);
  164. return stream_switch;
  165. }
  166. NodePtr FlowCtrlPass::AddVariableNode(ComputeGraphPtr &compute_graph, const string &name) {
  167. GE_IF_BOOL_EXEC(compute_graph == nullptr, DOMI_LOGE("compute_graph is nullptr"); return nullptr);
  168. NodePtr exist_node = compute_graph->FindNode(name);
  169. if (exist_node != nullptr) {
  170. GELOGD("Node %s already exist, no need add.", name.c_str());
  171. return exist_node;
  172. }
  173. // fetch and set tensor desc
  174. GeTensorDesc tensor_desc;
  175. if (ge::VarManager::Instance(compute_graph->GetSessionID()) == nullptr) {
  176. return nullptr;
  177. }
  178. Status ret = ge::VarManager::Instance(compute_graph->GetSessionID())->GetCurVarDesc(name, tensor_desc);
  179. if (ret != SUCCESS) {
  180. GELOGE(FAILED, "Get var desc fail, name:%s", name.c_str());
  181. return nullptr;
  182. }
  183. std::vector<GeTensorDesc> input_desc_list;
  184. std::vector<GeTensorDesc> output_desc_list = {tensor_desc};
  185. // insert node
  186. return InsertOp(compute_graph, VARIABLE, name, input_desc_list, output_desc_list);
  187. }
  188. Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) {
  189. NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT);
  190. if (output_node == nullptr) {
  191. GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID());
  192. return SUCCESS;
  193. }
  194. // Global step just add to main graph's netoutput node.And the main graph must be known shape
  195. if ((compute_graph->GetParentGraph() != nullptr) ||
  196. ((compute_graph->GetParentGraph() == nullptr) && (GraphUtils::IsUnknownShapeGraph(compute_graph)))) {
  197. GELOGD("Subgraph %s no need global step variable.", compute_graph->GetName().c_str());
  198. return SUCCESS;
  199. }
  200. NodePtr exist_node = compute_graph->FindNode(NODE_NAME_GLOBAL_STEP);
  201. if (exist_node != nullptr) {
  202. GELOGD("Node %s already exist, no need add.", NODE_NAME_GLOBAL_STEP.c_str());
  203. return SUCCESS;
  204. }
  205. // set global step tensor desc
  206. GeTensorDesc tensor_desc(GeShape({1}), FORMAT_ND, DT_UINT64);
  207. std::vector<GeTensorDesc> input_desc_list = {};
  208. std::vector<GeTensorDesc> output_desc_list = {tensor_desc};
  209. NodePtr global_step = InsertOp(compute_graph, VARIABLE, NODE_NAME_GLOBAL_STEP,
  210. input_desc_list, output_desc_list);
  211. if (global_step == nullptr) {
  212. GELOGE(FAILED, "Add global_step node failed, global_step is null.");
  213. return FAILED;
  214. }
  215. // add ctrl edges
  216. graphStatus add_ret = GraphUtils::AddEdge(global_step->GetOutControlAnchor(), output_node->GetInControlAnchor());
  217. if (add_ret != GRAPH_SUCCESS) {
  218. GELOGE(FAILED, "Add global_step to netoutput edge failed, add_ret=%u.", add_ret);
  219. return FAILED;
  220. }
  221. GELOGD("Add global_step to netoutput edge in graph %u success", compute_graph->GetGraphID());
  222. return SUCCESS;
  223. }
  224. NodePtr FlowCtrlPass::InsertAssignOp(ge::ComputeGraphPtr &compute_graph, const string &node_type,
  225. const string &node_name, const NodePtr &ref_node, const NodePtr &value_node) {
  226. GE_IF_BOOL_EXEC(ref_node == nullptr || value_node == nullptr ||
  227. ref_node->GetOpDesc() == nullptr || value_node->GetOpDesc() == nullptr,
  228. GELOGE(FAILED, "ref node or value node is null");
  229. return nullptr);
  230. GeTensorDesc ref_tensor_desc = ref_node->GetOpDesc()->GetOutputDesc(0);
  231. GeTensorDesc val_tensor_desc = value_node->GetOpDesc()->GetOutputDesc(0);
  232. std::vector<GeTensorDesc> input_desc_list = {ref_tensor_desc, val_tensor_desc};
  233. std::vector<GeTensorDesc> output_desc_list = {ref_tensor_desc};
  234. NodePtr assign_node = InsertOp(compute_graph, node_type, node_name, input_desc_list, output_desc_list);
  235. if (assign_node == nullptr) {
  236. GELOGE(FAILED, "Insert node %s(%s) failed.", node_name.c_str(), node_type.c_str());
  237. return nullptr;
  238. }
  239. // assign node input 0 = ref_node
  240. graphStatus add_ret = GraphUtils::AddEdge(ref_node->GetOutDataAnchor(0), assign_node->GetInDataAnchor(0));
  241. if (add_ret != GRAPH_SUCCESS) {
  242. GELOGE(FAILED, "Add ref_node to %s edge failed, add_ret=%u.", node_name.c_str(), add_ret);
  243. return nullptr;
  244. }
  245. // assign input 1 = value_node
  246. add_ret = GraphUtils::AddEdge(value_node->GetOutDataAnchor(0), assign_node->GetInDataAnchor(1));
  247. if (add_ret != GRAPH_SUCCESS) {
  248. GELOGE(FAILED, "Add value_node to %s edge failed, add_ret=%u.", node_name.c_str(), add_ret);
  249. return nullptr;
  250. }
  251. (void)ge::AttrUtils::SetBool(assign_node->GetOpDesc(), ATTR_NEED_COMPILE, true);
  252. return assign_node;
  253. }
  254. Status FlowCtrlPass::CreateIterCtrlTrueBranch(ComputeGraphPtr &compute_graph, const NodePtr &loop_cond_node,
  255. const NodePtr &loop_inc_node, NodePtr &switch_node) {
  256. /*
  257. * loopCond
  258. * |
  259. * v
  260. * switch --> AssignAdd --> active
  261. * ^
  262. * |
  263. * loopIncrement
  264. */
  265. // Insert AssignAdd node
  266. assign_add_node_in_fpbp_loop_ =
  267. InsertAssignOp(compute_graph, ASSIGNADD, NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD, loop_cond_node, loop_inc_node);
  268. if (assign_add_node_in_fpbp_loop_ == nullptr || switch_node == nullptr) {
  269. GELOGE(PARAM_INVALID, "assign add node or switch node is null");
  270. return FAILED;
  271. }
  272. string active_name = switch_node->GetName() + "_StreamActive";
  273. // add attr for stream assign model to break branch.
  274. GE_CHK_STATUS_RET(SetStreamLabel(assign_add_node_in_fpbp_loop_, active_name), "set stream label failed");
  275. // used for stream assign to find true branch
  276. GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed");
  277. // 2. Insert active node
  278. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  279. if (active_node == nullptr) {
  280. GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str());
  281. return FAILED;
  282. }
  283. GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed");
  284. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
  285. DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
  286. // add ctrl edges
  287. graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(),
  288. assign_add_node_in_fpbp_loop_->GetInControlAnchor());
  289. if (add_ret != GRAPH_SUCCESS) {
  290. GELOGE(FAILED, "Add switch_node to assign_add_node ctrl edge failed, add_ret=%u.", add_ret);
  291. return FAILED;
  292. }
  293. add_ret = GraphUtils::AddEdge(assign_add_node_in_fpbp_loop_->GetOutControlAnchor(),
  294. active_node->GetInControlAnchor());
  295. if (add_ret != GRAPH_SUCCESS) {
  296. GELOGE(FAILED, "Add assign_add_node to active_node ctrl edge failed, add_ret=%u.", add_ret);
  297. return FAILED;
  298. }
  299. GELOGI("CreateIterCtrlTrueBranch success. StreamActive op:%s.", active_node->GetName().c_str());
  300. return SUCCESS;
  301. }
  302. Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, const NodePtr &loop_cond_node,
  303. const NodePtr &loop_reset_node, NodePtr &switch_node) {
  304. /*
  305. * loopCond
  306. * |
  307. * v
  308. * switch --> Assign --> active --> ModelExit
  309. * ^
  310. * |
  311. * loopReset
  312. */
  313. // Insert Assign node and ctrl edge
  314. NodePtr assign_node =
  315. InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node);
  316. if (assign_node == nullptr || switch_node == nullptr) {
  317. GELOGE(PARAM_INVALID, "assign_node or switch node is null.");
  318. return FAILED;
  319. }
  320. GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed.");
  321. graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor());
  322. if (add_ret != GRAPH_SUCCESS) {
  323. GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret);
  324. return FAILED;
  325. }
  326. if (CheckMultiDataSet(compute_graph)) {
  327. GELOGI("Multi dataSae exist, model_exit node is need.");
  328. // 2. Insert active node and add ctrl edge
  329. string active_name = switch_node->GetName() + "_StreamExitActive";
  330. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  331. if (active_node == nullptr) {
  332. GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str());
  333. return FAILED;
  334. }
  335. GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed");
  336. GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
  337. "set switch branch node label failed.");
  338. string model_exit_name = switch_node->GetName() + "_ModelExit";
  339. GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed");
  340. add_ret = GraphUtils::AddEdge(assign_node->GetOutControlAnchor(), active_node->GetInControlAnchor());
  341. if (add_ret != GRAPH_SUCCESS) {
  342. GELOGE(FAILED, "Add assign_node to active_node ctrl edge failed, add_ret=%u.", add_ret);
  343. return FAILED;
  344. }
  345. // 3. Insert model exit node and add ctrl edge
  346. NodePtr model_exit_node = InsertOp(compute_graph, MODELEXIT, model_exit_name, {}, {});
  347. if (model_exit_node == nullptr) {
  348. GELOGE(FAILED, "Insert model_exit node:%s for IterCtrlTrueStream failed.", model_exit_name.c_str());
  349. return FAILED;
  350. }
  351. GE_CHK_STATUS_RET(SetStreamLabel(model_exit_node, model_exit_name), "set stream label failed");
  352. add_ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(), model_exit_node->GetInControlAnchor());
  353. if (add_ret != GRAPH_SUCCESS) {
  354. GELOGE(FAILED, "Add active_node to model_exit_node ctrl edge failed, add_ret=%u.", add_ret);
  355. return FAILED;
  356. }
  357. }
  358. GELOGI("CreateIterCtrlFalseBranch success.");
  359. return SUCCESS;
  360. }
  361. Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) {
  362. GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr."); return FAILED);
  363. string pre_node_name = pre_node->GetName();
  364. GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str());
  365. // 1. Get or add variables
  366. NodePtr loop_cond_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  367. if (loop_cond_node == nullptr) {
  368. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_COND.c_str());
  369. return FAILED;
  370. }
  371. NodePtr loop_inc_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_INCREMENT);
  372. if (loop_inc_node == nullptr) {
  373. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_INCREMENT.c_str());
  374. return FAILED;
  375. }
  376. NodePtr loop_reset_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_RESETVALUE);
  377. if (loop_reset_node == nullptr) {
  378. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_RESETVALUE.c_str());
  379. return FAILED;
  380. }
  381. NodePtr iter_per_loop_node = AddVariableNode(compute_graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  382. if (iter_per_loop_node == nullptr) {
  383. GELOGE(FAILED, "Add variable:%s failed.", NODE_NAME_FLOWCTRL_LOOP_PER_ITER.c_str());
  384. return FAILED;
  385. }
  386. // 2. Add StreamSwitch
  387. string switch_name = pre_node_name + "_" + NODE_NAME_STREAM_SWITCH;
  388. NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
  389. if (switch_node == nullptr) {
  390. GELOGE(FAILED, "InsertStreamSwitchOp:%s failed.", switch_name.c_str());
  391. return FAILED;
  392. }
  393. GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed");
  394. graphStatus add_ret = GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  395. if (add_ret != GRAPH_SUCCESS) {
  396. GELOGE(FAILED, "Add pre node:%s to switch_node:%s ctrl edge failed, ret = %u.", pre_node_name.c_str(),
  397. switch_name.c_str(), add_ret);
  398. return FAILED;
  399. }
  400. // 3. Create switch false branch: return results and reset the loopCond
  401. Status ret = CreateIterCtrlFalseBranch(compute_graph, loop_cond_node, loop_reset_node, switch_node);
  402. if (ret != SUCCESS) {
  403. GELOGE(ret, "CreateIterCtrlFalseBranch fail, pre node:%s.", pre_node_name.c_str());
  404. return ret;
  405. }
  406. // 4. Create switch true branch:
  407. // active train streams and increase the loopCond
  408. ret = CreateIterCtrlTrueBranch(compute_graph, loop_cond_node, loop_inc_node, switch_node);
  409. if (ret != SUCCESS) {
  410. GELOGE(ret, "CreateIterCtrlTrueBranch fail, pre node:%s.", pre_node_name.c_str());
  411. return ret;
  412. }
  413. return SUCCESS;
  414. }
  415. Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node) {
  416. /*
  417. * before add:
  418. * iterator
  419. * |
  420. * v
  421. * MemcpyAsync
  422. *
  423. * after add:
  424. * iterator ----------┐
  425. * | ┆c
  426. * v c v c
  427. * MemcpyAsync-----> switch -----> active
  428. * ^
  429. * / \
  430. * itersPerLoop loopCond
  431. */
  432. GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr,
  433. DOMI_LOGE("loop after node or compute graph is null."); return FAILED);
  434. InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0);
  435. if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) {
  436. GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str());
  437. return FAILED;
  438. }
  439. NodePtr loop_pre_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode();
  440. // 1. Get variables
  441. NodePtr loop_cond_node = compute_graph->FindNode(NODE_NAME_FLOWCTRL_LOOP_COND);
  442. if (loop_cond_node == nullptr) {
  443. GELOGE(FAILED, "Find node :%s failed.", NODE_NAME_FLOWCTRL_LOOP_COND.c_str());
  444. return FAILED;
  445. }
  446. NodePtr iter_per_loop_node = compute_graph->FindNode(NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  447. if (iter_per_loop_node == nullptr) {
  448. GELOGE(FAILED, "Find node :%s failed.", NODE_NAME_FLOWCTRL_LOOP_PER_ITER.c_str());
  449. return FAILED;
  450. }
  451. // 2. Add StreamSwitch and edges to switch_node.
  452. GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null."); return FAILED);
  453. string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH;
  454. NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
  455. if (switch_node == nullptr) {
  456. GELOGE(FAILED, "InsertStreamSwitchOp:%s failed.", switch_name.c_str());
  457. return FAILED;
  458. }
  459. GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed.");
  460. graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  461. if (add_ret != GRAPH_SUCCESS) {
  462. GELOGE(FAILED, "Add loop_pre_node:%s to switch_node:%s ctrl edge failed, ret = %u.",
  463. loop_pre_node->GetName().c_str(), switch_name.c_str(), add_ret);
  464. return FAILED;
  465. }
  466. add_ret = GraphUtils::AddEdge(loop_after_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
  467. if (add_ret != GRAPH_SUCCESS) {
  468. GELOGE(FAILED, "Add node:%s to switch_node:%s ctrl edge failed, ret = %u.", loop_after_node->GetName().c_str(),
  469. switch_name.c_str(), add_ret);
  470. return FAILED;
  471. }
  472. // 3. Create switch true branch: only active
  473. string active_name = switch_name + "_StreamActive";
  474. NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {});
  475. if (active_node == nullptr) {
  476. GELOGE(FAILED, "Insert stream active node:%s for SpecialNodeIteratorCtrl failed.", active_name.c_str());
  477. return FAILED;
  478. }
  479. GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed.");
  480. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
  481. DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
  482. add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), active_node->GetInControlAnchor());
  483. if (add_ret != GRAPH_SUCCESS) {
  484. GELOGE(FAILED, "Add switch_node:%s to active_node:%s ctrl edge failed, ret = %u.", switch_name.c_str(),
  485. active_name.c_str(), add_ret);
  486. return FAILED;
  487. }
  488. // used for stream assign to find true branch
  489. GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed.");
  490. // used for stream assign to find active stream
  491. GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { loop_pre_node->GetName() }), "set active label list failed");
  492. active_nodes_in_iter_loop_.push_back(active_node);
  493. return SUCCESS;
  494. }
  495. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示