You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_original_format_pass.cc 9.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/get_original_format_pass.h"
  17. #include <vector>
  18. #include "common/debug/log.h"
  19. #include "common/types.h"
  20. #include "common/util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "framework/omg/omg_inner_types.h"
  23. #include "graph/utils/attr_utils.h"
  24. #include "graph/utils/op_desc_utils.h"
  25. #include "graph/common/local_context.h"
  26. using domi::DOMI_TENSOR_NCHW;
  27. using domi::DOMI_TENSOR_NHWC;
  28. using domi::DOMI_TENSOR_RESERVED;
  29. using domi::FAILED;
  30. using domi::PARAM_INVALID;
  31. using domi::SUCCESS;
  32. namespace ge {
  33. Status GetOriginalFormatPass::Run(ge::ComputeGraphPtr graph) {
  34. GE_CHECK_NOTNULL(graph);
  35. GE_RETURN_WITH_LOG_IF_ERROR(SetOriginalFormat(graph), "SetOriginalFormat failed");
  36. return SUCCESS;
  37. }
  38. Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph) {
  39. GE_CHECK_NOTNULL(graph);
  40. int64_t ori_format = 0;
  41. int64_t tmp_format = 0;
  42. for (auto &node_ptr : graph->GetDirectNode()) {
  43. GE_CHECK_NOTNULL(node_ptr);
  44. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_NAME_INFERRED_FORMAT, DOMI_TENSOR_RESERVED),
  45. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  46. ATTR_NAME_INFERRED_FORMAT.c_str(),
  47. node_ptr->GetName().c_str(), node_ptr->GetType().c_str());
  48. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  49. return FAILED);
  50. }
  51. for (auto &node_ptr : graph->GetDirectNode()) {
  52. GE_CHECK_NOTNULL(node_ptr);
  53. OpDescPtr desc_ptr = node_ptr->GetOpDesc();
  54. GE_CHECK_NOTNULL(desc_ptr);
  55. auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE);
  56. if (is_data) {
  57. GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetLocalOmgContext().format);
  58. ori_format = static_cast<int64_t>(GetLocalOmgContext().format);
  59. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format),
  60. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  61. ATTR_NAME_FORMAT.c_str(),
  62. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  63. GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
  64. return FAILED);
  65. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  66. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  67. ATTR_NAME_INFERRED_FORMAT.c_str(),
  68. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  69. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  70. return FAILED);
  71. continue;
  72. }
  73. int32_t i = 0;
  74. bool continue_flag = false;
  75. bool ignore_pred_format = false;
  76. for (auto &bias_node_ptr : node_ptr->GetInDataNodes()) {
  77. GE_CHECK_NOTNULL(bias_node_ptr);
  78. OpDescPtr bias_op_ptr = bias_node_ptr->GetOpDesc();
  79. GE_CHECK_NOTNULL(bias_op_ptr);
  80. if (bias_op_ptr->GetType() == BIASADD) {
  81. ignore_pred_format = true;
  82. std::size_t tmp_size = ge::OpDescUtils::GetNonConstInputsSize(bias_node_ptr);
  83. GE_IF_BOOL_EXEC(tmp_size > 2 || tmp_size == 0,
  84. GELOGW("bias_node is node followed by %zu nodes, should be 1 or 2", tmp_size);
  85. continue_flag = true; break);
  86. OpDescPtr tmp_first_op_ptr = bias_node_ptr->GetInDataNodes().at(0)->GetOpDesc();
  87. GE_CHECK_NOTNULL(tmp_first_op_ptr);
  88. bias_op_ptr = tmp_first_op_ptr;
  89. // if biasadd have 2 input edges, format should be same
  90. if (tmp_size == 2) {
  91. int64_t first_input_format = 0;
  92. int64_t second_input_format = 0;
  93. OpDescPtr tmpSecondOpPtr = bias_node_ptr->GetInDataNodes().at(1)->GetOpDesc();
  94. GE_CHECK_NOTNULL(tmpSecondOpPtr);
  95. GE_IF_BOOL_EXEC(
  96. !AttrUtils::GetInt(tmp_first_op_ptr, ATTR_NAME_FORMAT, first_input_format), continue_flag = true; break);
  97. GE_IF_BOOL_EXEC(
  98. !AttrUtils::GetInt(tmpSecondOpPtr, ATTR_NAME_FORMAT, second_input_format), continue_flag = true; break);
  99. if (first_input_format != second_input_format) {
  100. GELOGW("biasadd node is followed two nodes with different format, get original format failed");
  101. continue_flag = true;
  102. break;
  103. }
  104. }
  105. }
  106. GE_IF_BOOL_EXEC(!AttrUtils::GetInt(bias_op_ptr, ATTR_NAME_FORMAT, tmp_format), continue_flag = true; break;);
  107. if (i == 0) {
  108. ori_format = tmp_format;
  109. }
  110. GE_IF_BOOL_EXEC(tmp_format != ori_format,
  111. GELOGW("node: %s , original format of src nodes must be same!", bias_node_ptr->GetName().c_str());
  112. continue_flag = true; break;);
  113. i++;
  114. }
  115. GE_IF_BOOL_EXEC(continue_flag, continue);
  116. OpDescPtr tmp_op_ptr = node_ptr->GetOpDesc();
  117. GE_CHECK_NOTNULL(tmp_op_ptr);
  118. if (IsFormatTranspose(tmp_op_ptr, static_cast<int32_t>(ori_format))) {
  119. ori_format = (ori_format == DOMI_TENSOR_NCHW) ? DOMI_TENSOR_NHWC : DOMI_TENSOR_NCHW;
  120. }
  121. if (ignore_pred_format) {
  122. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(tmp_op_ptr, ATTR_NAME_IGNORE_PRED_FORMAT, true),
  123. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  124. ATTR_NAME_IGNORE_PRED_FORMAT.c_str(),
  125. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  126. GELOGE(FAILED, "remove edge failed");
  127. return FAILED);
  128. }
  129. // Do not reset ATTR_NAME_FORMAT if it is set in the OpParser.
  130. if (!tmp_op_ptr->HasAttr(ATTR_NAME_FORMAT)) {
  131. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_FORMAT, ori_format),
  132. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  133. ATTR_NAME_FORMAT.c_str(),
  134. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  135. GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
  136. return FAILED);
  137. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  138. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  139. ATTR_NAME_INFERRED_FORMAT.c_str(),
  140. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  141. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  142. return FAILED);
  143. } else {
  144. int64_t existingFormat = 0;
  145. GE_RETURN_WITH_LOG_IF_FALSE(AttrUtils::GetInt(tmp_op_ptr, ATTR_NAME_FORMAT, existingFormat),
  146. "Get existing_format attr failed");
  147. if (!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, existingFormat)) {
  148. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  149. ATTR_NAME_INFERRED_FORMAT.c_str(),
  150. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  151. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  152. return FAILED;
  153. }
  154. }
  155. }
  156. return SUCCESS;
  157. }
  158. bool GetOriginalFormatPass::IsFormatTranspose(const ge::OpDescPtr op_ptr, int32_t ori_format) {
  159. GE_CHK_BOOL_EXEC(op_ptr != nullptr, return false, "opdef is nullptr");
  160. if (op_ptr->GetType() == PERMUTE) {
  161. vector<int32_t> index_list;
  162. GE_IF_BOOL_EXEC(!AttrUtils::GetListInt(op_ptr, PERMUTE_ATTR_ORDER, index_list), return false);
  163. auto index_size = index_list.size();
  164. GE_IF_BOOL_EXEC(static_cast<int32_t>(index_size) != PERMUTE_ORDER_NUM, return false);
  165. int32_t perm_nchw[4] = {0, 2, 3, 1}; // 4 format nums, {0,2,3,1} NCHW -> NHWC
  166. int32_t perm_nhwc[4] = {0, 3, 1, 2}; // 4 format nums, {0,3,1,2} NHWC -> NCHW
  167. bool is_nchw = true;
  168. bool is_nhwc = true;
  169. for (size_t i = 0; i < index_size; ++i) {
  170. is_nchw = (perm_nchw[i] != index_list[i]) ? false : is_nchw;
  171. is_nhwc = (perm_nhwc[i] != index_list[i]) ? false : is_nhwc;
  172. }
  173. bool ret = (is_nchw && ori_format == DOMI_TENSOR_NCHW && !is_nhwc) ||
  174. (is_nhwc && ori_format == DOMI_TENSOR_NHWC && !is_nchw);
  175. return ret;
  176. }
  177. return false;
  178. }
  179. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示