You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

omg_util.cc 8.5 kB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/common/omg_util.h"
  17. #include <algorithm>
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "graph/utils/graph_utils.h"
  21. #include "graph/utils/tensor_utils.h"
  22. #include "common/math/math_util.h"
  23. namespace ge {
  24. ///
  25. /// @brief get the Original Type of FrameworkOp
  26. /// @param [in] node
  27. /// @param [out] type
  28. /// @return Status
  29. ///
  30. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  31. GE_CHECK_NOTNULL(node);
  32. type = node->GetType();
  33. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  34. GE_CHECK_NOTNULL(node->GetOpDesc());
  35. bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  36. if (!ret) {
  37. REPORT_INNER_ERROR("E19999", "Get Attr:%s fail for op:%s(%s)", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(),
  38. node->GetName().c_str(), node->GetType().c_str());
  39. GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
  40. return INTERNAL_ERROR;
  41. }
  42. GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
  43. return SUCCESS;
  44. }
  45. ///
  46. /// @brief set op stream_label
  47. /// @param [in] node
  48. /// @param [in] label
  49. /// @return Status
  50. ///
  51. Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) {
  52. GE_CHECK_NOTNULL(node);
  53. OpDescPtr tmp_desc = node->GetOpDesc();
  54. GE_CHECK_NOTNULL(tmp_desc);
  55. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) {
  56. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(),
  57. node->GetName().c_str(), node->GetType().c_str());
  58. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str());
  59. return FAILED;
  60. }
  61. return SUCCESS;
  62. }
  63. ///
  64. /// @brief set op cycle_event flag
  65. /// @param [in] node
  66. /// @return Status
  67. ///
  68. Status SetCycleEvent(const ge::NodePtr &node) {
  69. GE_CHECK_NOTNULL(node);
  70. OpDescPtr tmp_desc = node->GetOpDesc();
  71. GE_CHECK_NOTNULL(tmp_desc);
  72. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) {
  73. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_CYCLE_EVENT_FLAG.c_str(),
  74. node->GetName().c_str(), node->GetType().c_str());
  75. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_CYCLE_EVENT_FLAG failed", node->GetName().c_str());
  76. return FAILED;
  77. }
  78. return SUCCESS;
  79. }
  80. ///
  81. /// @brief set op active_label_list
  82. /// @param [in] node
  83. /// @param [in] active_label_list
  84. /// @return Status
  85. ///
  86. Status SetActiveLabelList(const ge::NodePtr &node, const std::vector<std::string> &active_label_list) {
  87. GE_CHECK_NOTNULL(node);
  88. OpDescPtr tmp_desc = node->GetOpDesc();
  89. GE_CHECK_NOTNULL(tmp_desc);
  90. if (!AttrUtils::SetListStr(tmp_desc, ge::ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list)) {
  91. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(),
  92. node->GetName().c_str(), node->GetType().c_str());
  93. GELOGE(FAILED, "Op: %s set ATTR_NAME_ACTIVE_LABEL_LIST failed", node->GetName().c_str());
  94. return FAILED;
  95. }
  96. return SUCCESS;
  97. }
  98. ///
  99. /// @brief set op branch_label
  100. /// @param [in] node
  101. /// @param [in] branch_label
  102. /// @return Status
  103. ///
  104. Status SetSwitchBranchNodeLabel(const ge::NodePtr &node, const std::string &branch_label) {
  105. GE_CHECK_NOTNULL(node);
  106. OpDescPtr tmp_desc = node->GetOpDesc();
  107. GE_CHECK_NOTNULL(tmp_desc);
  108. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, branch_label)) {
  109. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_SWITCH_BRANCH_NODE_LABEL.c_str(),
  110. node->GetName().c_str(), node->GetType().c_str());
  111. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_BRANCH_NODE_LABEL failed", node->GetName().c_str());
  112. return FAILED;
  113. }
  114. return SUCCESS;
  115. }
  116. ///
  117. /// @brief set op true_branch flag
  118. /// @param [in] node
  119. /// @param [in] value
  120. /// @return Status
  121. ///
  122. Status SetSwitchTrueBranchFlag(const ge::NodePtr &node, bool value) {
  123. GE_CHECK_NOTNULL(node);
  124. OpDescPtr tmp_desc = node->GetOpDesc();
  125. GE_CHECK_NOTNULL(tmp_desc);
  126. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value)) {
  127. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG.c_str(),
  128. node->GetName().c_str(), node->GetType().c_str());
  129. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG failed", node->GetName().c_str());
  130. return FAILED;
  131. }
  132. return SUCCESS;
  133. }
  134. ///
  135. /// @brief set op original name
  136. /// @param [in] node
  137. /// @param [in] orig_name
  138. /// @return Status
  139. ///
  140. Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name) {
  141. GE_CHECK_NOTNULL(node);
  142. OpDescPtr tmp_desc = node->GetOpDesc();
  143. GE_CHECK_NOTNULL(tmp_desc);
  144. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_ORIG_NODE_NAME, orig_name)) {
  145. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  146. node->GetName().c_str(), node->GetType().c_str());
  147. GELOGE(FAILED, "Op: %s set ATTR_NAME_ORIG_NODE_NAME failed", node->GetName().c_str());
  148. return FAILED;
  149. }
  150. return SUCCESS;
  151. }
  152. ///
  153. /// @brief set op cyclic_dependence flag
  154. /// @param [in] node
  155. /// @return Status
  156. ///
  157. Status SetCyclicDependenceFlag(const ge::NodePtr &node) {
  158. GE_CHECK_NOTNULL(node);
  159. OpDescPtr tmp_desc = node->GetOpDesc();
  160. GE_CHECK_NOTNULL(tmp_desc);
  161. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_CYCLIC_DEPENDENCE_FLAG, true)) {
  162. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CYCLIC_DEPENDENCE_FLAG.c_str(),
  163. node->GetName().c_str(), node->GetType().c_str());
  164. GELOGE(FAILED, "Op: %s set ATTR_NAME_CYCLIC_DEPENDENCE_FLAG failed", node->GetName().c_str());
  165. return FAILED;
  166. }
  167. return SUCCESS;
  168. }
  169. ///
  170. /// @brief set op next_iteration name
  171. /// @param [in] node
  172. /// @param [in] next
  173. /// @return Status
  174. ///
  175. Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {
  176. GE_CHECK_NOTNULL(node);
  177. OpDescPtr tmp_desc = node->GetOpDesc();
  178. GE_CHECK_NOTNULL(tmp_desc);
  179. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) {
  180. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
  181. node->GetName().c_str(), node->GetType().c_str());
  182. GELOGE(FAILED, "Op: %s set ATTR_NAME_NEXT_ITERATION failed", node->GetName().c_str());
  183. return FAILED;
  184. }
  185. return SUCCESS;
  186. }
  187. ///
  188. /// @brief Align the memory
  189. /// @param [in/out] memory size
  190. /// @param [in] alinment
  191. /// @return void
  192. ///
  193. void AlignMemSize(int64_t &mem_size, int64_t align_size) {
  194. if (mem_size <= 0) {
  195. return;
  196. }
  197. mem_size = (mem_size + align_size - 1) / align_size * align_size;
  198. }
  199. ///
  200. /// @brief Get memory size from tensor desc
  201. /// @param [in] node
  202. /// @param [out] memory size
  203. /// @return Status
  204. ///
  205. Status GetMemorySize(const NodePtr &node, int64_t &output_size) {
  206. GE_CHECK_NOTNULL(node->GetOpDesc());
  207. auto output_op_desc = node->GetOpDesc()->GetOutputDescPtr(kBufferPoolNodeOutIndex);
  208. GE_CHECK_NOTNULL(output_op_desc);
  209. int64_t size = 0;
  210. auto ret = ge::TensorUtils::GetSize(*output_op_desc, size);
  211. if (ret != ge::GRAPH_SUCCESS) {
  212. GELOGE(INTERNAL_ERROR, "[Get][Size]Node:%s.", node->GetName().c_str());
  213. REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str());
  214. return INTERNAL_ERROR;
  215. }
  216. FMK_INT64_ADDCHECK(size, kBufferPoolMemAlignSize);
  217. AlignMemSize(size, kBufferPoolMemAlignSize);
  218. // The HCOM operator requires an additional 512 bytes before and after
  219. FMK_INT64_ADDCHECK(size, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize));
  220. output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize;
  221. return SUCCESS;
  222. }
  223. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示