You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

assign_pass.cc 10 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/assign_pass.h"
  17. #include "framework/common/debug/log.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. namespace {
  21. const uint32_t kValidInputNodeOutputNum = 1;
  22. const int32_t kAssignRefInputIndex = 0;
  23. const int32_t kAssignValueInputIndex = 1;
  24. }
  25. namespace ge {
  26. #if (ENABLE_OPEN_SRC != True)
  27. Status AssignPass::Run(NodePtr &node) {
  28. GELOGD("AssignPass running");
  29. if (TransformAttr(node) != SUCCESS) {
  30. GELOGE(FAILED, "Transform assign_var_name attr failed, node=%s", node->GetName().c_str());
  31. return FAILED;
  32. }
  33. if (node->GetType() == ASSIGN) {
  34. if (OptimizedAssignNode(node) != SUCCESS) {
  35. GELOGE(FAILED, "Optimize for assign_node %s failed", node->GetName().c_str());
  36. return FAILED;
  37. }
  38. }
  39. GELOGD("AssignPass success");
  40. return SUCCESS;
  41. }
  42. ///
  43. /// @brief Optimize for assign_node
  44. /// @param [in] assign_node
  45. /// @return Status
  46. ///
  47. Status AssignPass::OptimizedAssignNode(NodePtr &assign_node) {
  48. const auto &ref_in_anchor = assign_node->GetInDataAnchor(kAssignRefInputIndex);
  49. const auto &value_in_anchor = assign_node->GetInDataAnchor(kAssignValueInputIndex);
  50. if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) {
  51. GELOGE(FAILED, "In data anchor is null, node:%s", assign_node->GetName().c_str());
  52. return FAILED;
  53. }
  54. const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor();
  55. const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor();
  56. if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) {
  57. GELOGE(FAILED, "Peer data anchor is null, node:%s", assign_node->GetName().c_str());
  58. return FAILED;
  59. }
  60. if (IsCondMatch(assign_node, ref_peer_anchor, value_peer_anchor)) {
  61. ///
  62. /// variable not-const not-const
  63. /// \ / |
  64. /// \ / |
  65. /// Assign ----> variable
  66. /// | |
  67. /// | |
  68. /// node node
  69. ///
  70. GELOGD("Optimization for assign_node %s start", assign_node->GetName().c_str());
  71. if (IsolateAndDeleteNode(assign_node, {kAssignRefInputIndex}) != SUCCESS) {
  72. GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str());
  73. return FAILED;
  74. }
  75. AddNodeDeleted(assign_node);
  76. const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc();
  77. const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc();
  78. if ((ref_input == nullptr) || (value_input == nullptr)) {
  79. GELOGE(FAILED, "value input is null");
  80. return FAILED;
  81. }
  82. // variable has and only has one input
  83. if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) {
  84. GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str());
  85. return FAILED;
  86. }
  87. if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  88. GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str());
  89. return FAILED;
  90. }
  91. GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s",
  92. value_input->GetName().c_str(), ref_input->GetName().c_str());
  93. if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME,
  94. ref_input->GetName())) {
  95. GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed.");
  96. return FAILED;
  97. }
  98. auto value_node = value_peer_anchor->GetOwnerNode();
  99. AddRePassNode(value_node);
  100. }
  101. return SUCCESS;
  102. }
  103. ///
  104. /// @brief Transform assign_var_name attr
  105. /// @param [in] node
  106. /// @return Status
  107. ///
  108. Status AssignPass::TransformAttr(NodePtr &node) {
  109. GE_CHECK_NOTNULL(node->GetOpDesc());
  110. for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDesc()) {
  111. int32_t inplace_input_idx = -1;
  112. std::string assign_var_name;
  113. if (AttrUtils::GetInt(output_desc, INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx) &&
  114. AttrUtils::GetStr(output_desc, ASSIGN_VAR_NAME, assign_var_name)) {
  115. GELOGD("Transform attr ASSIGN_VAR_NAME on node %s, assign_var_name=%s, inplace_input_idx=%d, ",
  116. node->GetName().c_str(), assign_var_name.c_str(), inplace_input_idx);
  117. const auto &in_data_anchor = node->GetInDataAnchor(inplace_input_idx);
  118. GE_CHECK_NOTNULL(in_data_anchor);
  119. const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  120. GE_CHECK_NOTNULL(peer_data_anchor);
  121. auto in_node = peer_data_anchor->GetOwnerNode();
  122. GE_CHECK_NOTNULL(in_node->GetOpDesc());
  123. GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", in_node->GetName().c_str(), assign_var_name.c_str());
  124. if (!AttrUtils::SetStr(in_node->GetOpDesc()->MutableOutputDesc(peer_data_anchor->GetIdx()),
  125. ASSIGN_VAR_NAME, assign_var_name)) {
  126. GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed.");
  127. return FAILED;
  128. }
  129. AddRePassNode(in_node);
  130. }
  131. }
  132. return SUCCESS;
  133. }
  134. #else
  135. Status AssignPass::Run(NodePtr &node) {
  136. GELOGD("AssignPass running");
  137. if (node->GetType() != ASSIGN) {
  138. GELOGD("No need run AssignPass on [%s, %s].", node->GetName().c_str(), node->GetType().c_str());
  139. return SUCCESS;
  140. }
  141. const auto &ref_in_anchor = node->GetInDataAnchor(kAssignRefInputIndex);
  142. const auto &value_in_anchor = node->GetInDataAnchor(kAssignValueInputIndex);
  143. if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) {
  144. GELOGE(FAILED, "In data anchor is null, node:%s", node->GetName().c_str());
  145. return FAILED;
  146. }
  147. const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor();
  148. const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor();
  149. if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) {
  150. GELOGE(FAILED, "Peer data anchor is null, node:%s", node->GetName().c_str());
  151. return FAILED;
  152. }
  153. if (IsCondMatch(node, ref_peer_anchor, value_peer_anchor)) {
  154. ///
  155. /// variable not-const not-const
  156. /// \ / |
  157. /// \ / |
  158. /// Assign ----> variable
  159. /// | |
  160. /// | |
  161. /// node node
  162. ///
  163. GELOGI("Optimization for assign_node %s start", node->GetName().c_str());
  164. if (IsolateAndDeleteNode(node, {kAssignRefInputIndex}) != SUCCESS) {
  165. GELOGE(FAILED, "Isolate and delete assign_node %s failed.", node->GetName().c_str());
  166. return FAILED;
  167. }
  168. AddNodeDeleted(node);
  169. const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc();
  170. const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc();
  171. if ((ref_input == nullptr) || (value_input == nullptr)) {
  172. GELOGE(FAILED, "value input is null");
  173. return FAILED;
  174. }
  175. if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME,
  176. ref_input->GetName())) {
  177. GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed.");
  178. return FAILED;
  179. }
  180. // variable has and only has one input
  181. if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) {
  182. GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str());
  183. return FAILED;
  184. }
  185. if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  186. GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str());
  187. return FAILED;
  188. }
  189. }
  190. GELOGD("AssignPass success");
  191. return SUCCESS;
  192. }
  193. #endif
  194. ///
  195. /// @brief Check if need optimize for assign_node
  196. /// @param [in] assign_node
  197. /// @param [in] peer_data_anchor for ref_input of assign_node
  198. /// @param [in] peer_data_anchor for value_input of assign_node
  199. /// @return Status
  200. ///
  201. bool AssignPass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor,
  202. const OutDataAnchorPtr &value_peer_anchor) {
  203. GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s",
  204. node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(),
  205. value_peer_anchor->GetOwnerNode()->GetName().c_str());
  206. const std::string &value_type = value_peer_anchor->GetOwnerNode()->GetType();
  207. if ((value_type == CONSTANTOP) || (value_type == CONSTANT)) {
  208. GELOGD("value input is const");
  209. return false;
  210. }
  211. const std::string &ref_type = ref_peer_anchor->GetOwnerNode()->GetType();
  212. if ((ref_type != VARIABLE) && (ref_type != VARIABLEV2)) {
  213. GELOGD("ref input is not var");
  214. return false;
  215. }
  216. if (!ref_peer_anchor->GetOwnerNode()->GetInDataNodes().empty()) {
  217. GELOGD("ref input has data input");
  218. return false;
  219. }
  220. if ((ref_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum) ||
  221. (value_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum)) {
  222. GELOGD("ref / value input has other output(s)");
  223. return false;
  224. }
  225. GELOGD("Optimization condition matches, assign_node: %s", node->GetName().c_str());
  226. return true;
  227. }
  228. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示