You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

branch_logical_remove_pass.cc 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/branch_logical_remove_pass.h"
  17. #include "graph/utils/node_utils.h"
  18. namespace ge {
  19. Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) {
  20. BranchExecCondCalculator calculator(graph);
  21. if (calculator.Calculate() != GRAPH_SUCCESS) {
  22. GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed.");
  23. return FAILED;
  24. }
  25. const auto &node_exec_cond = calculator.GetBranchExecCond();
  26. if (node_exec_cond.empty()) {
  27. GELOGI("No branch in graph %s, skip.", graph->GetName().c_str());
  28. return SUCCESS;
  29. }
  30. for (const auto &node : graph->GetDirectNode()) {
  31. const std::string &type = NodeUtils::GetNodeType(node);
  32. if ((type == SWITCH) || (type == REFSWITCH)) {
  33. if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) {
  34. GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str());
  35. return FAILED;
  36. }
  37. } else if ((type == MERGE) || (type == REFMERGE)) {
  38. if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) {
  39. GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str());
  40. return FAILED;
  41. }
  42. }
  43. }
  44. return SUCCESS;
  45. }
  46. Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
  47. const NodePtr &switch_node) {
  48. const auto &iter = node_exec_cond.find(switch_node);
  49. if (iter == node_exec_cond.end()) {
  50. GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str());
  51. return FAILED;
  52. }
  53. if (!iter->second.IsValid()) {
  54. return SUCCESS;
  55. }
  56. for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) {
  57. const auto &out_node = pair.first;
  58. const auto &out_iter = node_exec_cond.find(out_node);
  59. if (out_iter == node_exec_cond.end()) {
  60. GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str());
  61. return FAILED;
  62. }
  63. if (!out_iter->second.IsValid()) {
  64. GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
  65. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  66. if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
  67. GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
  68. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  69. return FAILED;
  70. }
  71. } else if (iter->second.String() == out_iter->second.String()) {
  72. GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
  73. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  74. if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
  75. GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
  76. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  77. return FAILED;
  78. }
  79. const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  80. GE_CHECK_NOTNULL(data_input_anchor);
  81. const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor();
  82. GE_CHECK_NOTNULL(peer_out_anchor);
  83. GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  84. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  85. if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) {
  86. GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  87. pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
  88. return FAILED;
  89. }
  90. }
  91. }
  92. return SUCCESS;
  93. }
  94. Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
  95. const NodePtr &merge_node) {
  96. const auto &iter = node_exec_cond.find(merge_node);
  97. if (iter == node_exec_cond.end()) {
  98. GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str());
  99. return FAILED;
  100. }
  101. if (!iter->second.IsValid()) {
  102. return SUCCESS;
  103. }
  104. for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) {
  105. const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  106. if (peer_out_anchor == nullptr) { continue; }
  107. const auto &in_node = peer_out_anchor->GetOwnerNode();
  108. const auto &in_iter = node_exec_cond.find(in_node);
  109. if (in_iter == node_exec_cond.end()) {
  110. GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str());
  111. return FAILED;
  112. }
  113. if (!in_iter->second.IsValid()) {
  114. GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
  115. merge_node->GetName().c_str(), in_data_anchor->GetIdx());
  116. if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  117. GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
  118. merge_node->GetName().c_str(), in_data_anchor->GetIdx());
  119. return FAILED;
  120. }
  121. }
  122. }
  123. return SUCCESS;
  124. }
  125. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示