You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

next_iteration_pass.cc 16 kB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/next_iteration_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "graph/common/omg_util.h"
  19. using std::string;
  20. namespace ge {
  21. namespace {
  22. const int64_t kLoopType = 1;
  23. }
  24. Status NextIterationPass::Run(ComputeGraphPtr graph) {
  25. GELOGD("NextIterationPass Enter");
  26. /// Enter-----------+
  27. /// +-> Merge -> Switch <- LoopCond <- Cond
  28. /// NextIteration---+
  29. for (auto &node : graph->GetDirectNode()) {
  30. const std::string type = node->GetType();
  31. if ((type != ENTER) && (type != REFENTER)) {
  32. continue;
  33. }
  34. if (GroupEnterNode(node) != SUCCESS) {
  35. GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str());
  36. return INTERNAL_ERROR;
  37. }
  38. }
  39. if (FindWhileGroups() != SUCCESS) {
  40. GELOGE(INTERNAL_ERROR, "Find while groups failed.");
  41. return INTERNAL_ERROR;
  42. }
  43. if (!VerifyWhileGroup()) {
  44. GELOGE(INTERNAL_ERROR, "Verify while groups failed.");
  45. return INTERNAL_ERROR;
  46. }
  47. if (HandleWhileGroup(graph) != SUCCESS) {
  48. GELOGE(FAILED, "Handle while groups failed.");
  49. return FAILED;
  50. }
  51. GELOGD("NextIterationPass Leave");
  52. return SUCCESS;
  53. }
  54. ///
  55. /// @brief Group Enter node
  56. /// @param [in] enter_node
  57. /// @return Status
  58. ///
  59. Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
  60. OpDescPtr enter_desc = enter_node->GetOpDesc();
  61. GE_CHECK_NOTNULL(enter_desc);
  62. std::string frame_name;
  63. if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) {
  64. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ENTER_ATTR_FRAME_NAME.c_str(),
  65. enter_desc->GetName().c_str(), enter_desc->GetType().c_str());
  66. GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str());
  67. return FAILED;
  68. }
  69. string batch_label;
  70. if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  71. frame_name += batch_label;
  72. }
  73. auto iter = loop_group_map_.find(frame_name);
  74. if (iter == loop_group_map_.end()) {
  75. LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
  76. if (loop_group == nullptr) {
  77. REPORT_CALL_ERROR("E19999", "New LoopCondGroup failed");
  78. GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
  79. return FAILED;
  80. }
  81. loop_group->enter_nodes.emplace_back(enter_node);
  82. loop_group_map_[frame_name] = loop_group;
  83. } else {
  84. iter->second->enter_nodes.emplace_back(enter_node);
  85. }
  86. return SUCCESS;
  87. }
  88. ///
  89. /// @brief Find while groups
  90. /// @return Status
  91. ///
  92. Status NextIterationPass::FindWhileGroups() {
  93. for (const auto &loop_group_iter : loop_group_map_) {
  94. const std::string &frame_name = loop_group_iter.first;
  95. for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
  96. for (const auto &out_node : enter_node->GetOutAllNodes()) {
  97. std::string type;
  98. GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type failed.");
  99. if ((type != MERGE) && (type != REFMERGE)) {
  100. continue;
  101. }
  102. NodePtr next_node = nullptr;
  103. if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
  104. GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str());
  105. return INTERNAL_ERROR;
  106. }
  107. loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));
  108. NodePtr switch_node = nullptr;
  109. if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
  110. GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str());
  111. return INTERNAL_ERROR;
  112. }
  113. if (switch_node == nullptr) {
  114. continue;
  115. }
  116. if (!AttrUtils::SetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, kLoopType)) {
  117. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_STREAM_SWITCH_TYPE.c_str(),
  118. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  119. GELOGE(INTERNAL_ERROR, "set int failed");
  120. return INTERNAL_ERROR;
  121. }
  122. NodePtr loop_cond = nullptr;
  123. if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
  124. GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
  125. return INTERNAL_ERROR;
  126. }
  127. loop_group_iter.second->switch_nodes.emplace_back(switch_node);
  128. if (loop_group_iter.second->loop_cond == nullptr) {
  129. loop_group_iter.second->loop_cond = loop_cond;
  130. } else if (loop_group_iter.second->loop_cond != loop_cond) {
  131. REPORT_INNER_ERROR("E19999", "Multi LoopCond nodes exist, frame_name:%s, check invalid", frame_name.c_str());
  132. GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str());
  133. return FAILED;
  134. }
  135. }
  136. }
  137. }
  138. return SUCCESS;
  139. }
  140. ///
  141. /// @brief Verify if valid
  142. /// @return bool
  143. ///
  144. bool NextIterationPass::VerifyWhileGroup() {
  145. // map<frame_name, LoopCondGroup>
  146. for (const auto &loop_group_iter : loop_group_map_) {
  147. const std::string &frame_name = loop_group_iter.first;
  148. if (frame_name.empty()) {
  149. REPORT_INNER_ERROR("E19999", "Verify while group failed, frame_name is empty");
  150. GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
  151. return false;
  152. }
  153. if (loop_group_iter.second->loop_cond == nullptr) {
  154. REPORT_INNER_ERROR("E19999", "Verify while group failed, LoopCond is null, frame_name:%s.", frame_name.c_str());
  155. GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
  156. return false;
  157. }
  158. for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
  159. if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
  160. REPORT_INNER_ERROR("E19999", "Verify while group failed, merge_node/next_node is null, frame_name:%s.",
  161. frame_name.c_str());
  162. GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
  163. frame_name.c_str());
  164. return false;
  165. }
  166. // Mark loop as unknown shape If any merge has unknown shape output.
  167. const auto &op_desc = pair_iter.first->GetOpDesc();
  168. if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) {
  169. loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break.
  170. }
  171. }
  172. }
  173. return true;
  174. }
  175. ///
  176. /// @brief Handle while group
  177. /// @param [in] graph
  178. /// @return Status
  179. ///
  180. Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
  181. for (const auto &loop_cond_iter : loop_group_map_) {
  182. const LoopCondGroup &loop_group = *loop_cond_iter.second;
  183. const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
  184. GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());
  185. // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
  186. NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
  187. NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
  188. if ((enter_active == nullptr) || (next_active == nullptr)) {
  189. GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
  190. return INTERNAL_ERROR;
  191. }
  192. for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
  193. // Enter --> Active
  194. if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
  195. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  196. enter_node->GetName().c_str(), enter_node->GetType().c_str(),
  197. enter_active->GetName().c_str(), enter_active->GetType().c_str());
  198. GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(),
  199. enter_active->GetName().c_str());
  200. return INTERNAL_ERROR;
  201. }
  202. MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape);
  203. }
  204. for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
  205. NodePtr merge_node = pair.first;
  206. NodePtr next_node = pair.second;
  207. // Active --> Merge
  208. if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  209. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  210. enter_active->GetName().c_str(), enter_active->GetType().c_str(),
  211. merge_node->GetName().c_str(), merge_node->GetType().c_str());
  212. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  213. return INTERNAL_ERROR;
  214. }
  215. // NextIteration --> Active
  216. if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
  217. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  218. next_node->GetName().c_str(), next_node->GetType().c_str(),
  219. next_active->GetName().c_str(), next_active->GetType().c_str());
  220. GELOGE(INTERNAL_ERROR, "Add control edge failed.");
  221. return INTERNAL_ERROR;
  222. }
  223. // break link between NextIteration and Merge
  224. if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
  225. GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
  226. return INTERNAL_ERROR;
  227. }
  228. MarkForceUnknownShape(next_node, loop_group.is_unknown_shape);
  229. MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape);
  230. }
  231. if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
  232. (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
  233. GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
  234. return INTERNAL_ERROR;
  235. }
  236. MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape);
  237. MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape);
  238. MarkForceUnknownShape(next_active, loop_group.is_unknown_shape);
  239. HandleSwitchExitNodes(loop_group);
  240. }
  241. return SUCCESS;
  242. }
  243. ///
  244. /// @brief Mark force unknown for Exit node
  245. /// @param [in] group of LoopCond
  246. /// @return void
  247. ///
  248. void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) {
  249. if (!loop_group.is_unknown_shape) {
  250. return;
  251. }
  252. for (const auto &switch_node : loop_group.switch_nodes) {
  253. MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape);
  254. for (const auto &node : switch_node->GetOutDataNodes()) {
  255. std::string node_type;
  256. (void)GetOriginalType(node, node_type);
  257. if (node_type == EXIT || node_type == REFEXIT) {
  258. MarkForceUnknownShape(node, loop_group.is_unknown_shape);
  259. }
  260. }
  261. }
  262. }
  263. ///
  264. /// @brief Create Active Node
  265. /// @param [in] graph
  266. /// @param [in] name
  267. /// @return ge::NodePtr
  268. ///
  269. NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::string &name) {
  270. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
  271. if (op_desc == nullptr) {
  272. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  273. return nullptr;
  274. }
  275. GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str());
  276. NodePtr active_node = graph->AddNode(op_desc);
  277. if (active_node == nullptr) {
  278. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  279. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  280. GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str());
  281. return nullptr;
  282. }
  283. if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) {
  284. REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed",
  285. name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
  286. GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str());
  287. return nullptr;
  288. }
  289. return active_node;
  290. }
  291. ///
  292. /// @brief Break NextIteration Link & add name to merge attr
  293. /// @param [in] next_node
  294. /// @param [in] merge_node
  295. /// @return Status
  296. ///
  297. Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node) {
  298. if ((merge_node == nullptr) || (next_node == nullptr)) {
  299. GELOGE(PARAM_INVALID, "merge node or next node is null.");
  300. return PARAM_INVALID;
  301. }
  302. for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) {
  303. OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor();
  304. if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) {
  305. continue;
  306. }
  307. if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) {
  308. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  309. out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
  310. out_anchor->GetIdx(),
  311. merge_node->GetName().c_str(), merge_node->GetType().c_str(), in_anchor->GetIdx());
  312. GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(),
  313. merge_node->GetName().c_str());
  314. return INTERNAL_ERROR;
  315. }
  316. if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) {
  317. REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed",
  318. next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str());
  319. GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str());
  320. return INTERNAL_ERROR;
  321. }
  322. }
  323. return SUCCESS;
  324. }
  325. ///
  326. /// @brief find target node
  327. /// @param [in] node
  328. /// @param [in] target_type
  329. /// @param [in] is_input
  330. /// @param [out] target_node
  331. /// @return Status
  332. ///
  333. Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
  334. NodePtr &target_node) {
  335. if (node == nullptr) {
  336. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
  337. GELOGE(PARAM_INVALID, "node is null.");
  338. return PARAM_INVALID;
  339. }
  340. std::vector<NodePtr> nodes;
  341. if (is_input) {
  342. for (const auto &tmp_node : node->GetInDataNodes()) {
  343. nodes.emplace_back(tmp_node);
  344. }
  345. } else {
  346. for (const auto &tmp_node : node->GetOutDataNodes()) {
  347. nodes.emplace_back(tmp_node);
  348. }
  349. }
  350. for (const auto &tmp_node : nodes) {
  351. std::string type;
  352. GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed.");
  353. if ((target_type == LOOPCOND) && (type == target_type)) {
  354. target_node = tmp_node;
  355. break;
  356. } else if ((type == target_type) || (type == "Ref" + target_type)) {
  357. target_node = tmp_node;
  358. break;
  359. }
  360. }
  361. if ((target_type != SWITCH) && (target_node == nullptr)) {
  362. REPORT_INNER_ERROR("E19999", "Find target_type:%s node around node:%s(%s) failed",
  363. target_type.c_str(), node->GetName().c_str(), node->GetType().c_str());
  364. GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str());
  365. return INTERNAL_ERROR;
  366. }
  367. return SUCCESS;
  368. }
  369. ///
  370. /// @brief Clear Status, used for subgraph pass
  371. /// @return SUCCESS
  372. ///
  373. Status NextIterationPass::ClearStatus() {
  374. loop_group_map_.clear();
  375. return SUCCESS;
  376. }
  377. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示