You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_symmetry_elimination_pass.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/transop_symmetry_elimination_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "framework/common/util.h"
  20. #include "common/transop_util.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. #include "graph/utils/node_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "framework/common/types.h"
  26. namespace {
  27. const std::set<std::string> white_list_op{ge::TRANSPOSED, ge::RESHAPE, ge::REFORMAT, ge::CAST, ge::TRANSDATA};
  28. } // namespace
  29. namespace ge {
  30. Status TransOpSymmetryEliminationPass::Run(NodePtr &node) {
  31. GE_CHECK_NOTNULL(node);
  32. GE_CHECK_NOTNULL(node->GetOpDesc());
  33. if (white_list_op.find(node->GetType()) == white_list_op.end()) { return SUCCESS; }
  34. GELOGD("Symmetry Elimination Pass in.");
  35. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  36. GE_CHECK_NOTNULL(out_anchor);
  37. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  38. GE_CHECK_NOTNULL(peer_in_anchor);
  39. GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode());
  40. GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()->GetOpDesc());
  41. if (!CheckCanBeEliminated(node, peer_in_anchor)) { continue; }
  42. auto dst_node = peer_in_anchor->GetOwnerNode();
  43. Status ret = EliminateTransOp(node, out_anchor, dst_node, peer_in_anchor);
  44. if (ret != SUCCESS) {
  45. // if eliminate failed ,it should't break precess, so give a warning here
  46. GELOGW("Eliminate %s and %s failed, ignore current pass.", node->GetName().c_str(),
  47. dst_node->GetName().c_str());
  48. return ret;
  49. }
  50. }
  51. }
  52. GELOGD("Symmetry Elimination Pass end.");
  53. return SUCCESS;
  54. }
  55. bool TransOpSymmetryEliminationPass::CheckCanBeEliminated(const ge::NodePtr &src_node,
  56. const InDataAnchorPtr &dst_in_anchor) {
  57. auto dst_node = dst_in_anchor->GetOwnerNode();
  58. if (src_node->GetType() != dst_node->GetType()) {
  59. GELOGD("Pre node %s type %s is not equal with node %s type %s. Ignore pass.", src_node->GetName().c_str(),
  60. src_node->GetType().c_str(), dst_node->GetName().c_str(), dst_node->GetType().c_str());
  61. return false;
  62. }
  63. if (dst_in_anchor->GetIdx() != TransOpUtil::GetTransOpDataIndex(src_node)) {
  64. GELOGD("Next node %s type %s input %d is not for transform. Ignore pass.", dst_node->GetName().c_str(),
  65. dst_node->GetType().c_str(), dst_in_anchor->GetIdx());
  66. return false;
  67. }
  68. if (src_node->GetType() == ge::RESHAPE) {
  69. GE_CHECK_NOTNULL(src_node->GetOpDesc());
  70. auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0));
  71. if (unknown_dims_num != 0 && (unknown_dims_num == UNKNOWN_DIM_NUM || unknown_dims_num > 1)) {
  72. GELOGD("Pre node %s is reshape op which input is dynamic shape and has more than one unknown dimension. "
  73. "Ignore pass.",
  74. src_node->GetName().c_str());
  75. return false;
  76. }
  77. } else if (src_node->GetType() == ge::TRANSPOSED) {
  78. if (!JudgeTransposeDBack2Raw(src_node, dst_node)) {
  79. GELOGD("Two Transpose op src node %s dst node %s will change the raw data. Ignore pass.",
  80. src_node->GetName().c_str(), dst_node->GetName().c_str());
  81. return false;
  82. }
  83. } else if (src_node->GetType() == ge::TRANSDATA) {
  84. auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0));
  85. if (unknown_dims_num == UNKNOWN_DIM_NUM) {
  86. GELOGD("Pre node %s is transdata op which input is dynamic shape and all dimension are unknown(-2). Ignore pass.",
  87. src_node->GetName().c_str());
  88. return false;
  89. }
  90. }
  91. return TransOpUtil::CheckPrecisionLoss(src_node) && DescAreSymmetry(src_node, dst_node);
  92. }
  93. bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) {
  94. const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0);
  95. const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0);
  96. GE_CHECK_NOTNULL(src_input_desc);
  97. GE_CHECK_NOTNULL(dst_output_desc);
  98. const auto &src_input_dtype = src_input_desc->GetDataType();
  99. const auto &src_input_format = src_input_desc->GetFormat();
  100. const auto &src_input_shape = src_input_desc->GetShape().GetDims();
  101. const auto &dst_output_dtype = dst_output_desc->GetDataType();
  102. const auto &dst_output_format = dst_output_desc->GetFormat();
  103. const auto &dst_output_shape = dst_output_desc->GetShape().GetDims();
  104. bool is_symmetry = true;
  105. if (src_node->GetType() == CAST && dst_node->GetType() == CAST) {
  106. bool is_format_symmetry =
  107. (src_input_format == dst_output_format) || (dst_output_format == FORMAT_ND) || (src_input_format == FORMAT_ND);
  108. is_symmetry = (src_input_dtype == dst_output_dtype) && is_format_symmetry;
  109. } else {
  110. is_symmetry = (src_input_dtype == dst_output_dtype) && (src_input_shape == dst_output_shape)
  111. && (src_input_format == dst_output_format);
  112. }
  113. if (!is_symmetry) {
  114. GELOGD("Not satisfied symmetry. ignore pass.\n"
  115. "Src node %s input type: %s format: %s shape: %s, "
  116. "dst node %s output type: %s format: %s shape: %s. ",
  117. src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_input_dtype).c_str(),
  118. TypeUtils::FormatToSerialString(src_input_format).c_str(), formats::ShapeToString(src_input_shape).c_str(),
  119. dst_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(dst_output_dtype).c_str(),
  120. TypeUtils::FormatToSerialString(dst_output_format).c_str(),
  121. formats::ShapeToString(dst_output_shape).c_str());
  122. }
  123. return is_symmetry;
  124. }
  125. int TransOpSymmetryEliminationPass::GetUnknownDimsNum(const GeTensorDesc& node_desc){
  126. //
  127. // unknown_dims_num != 0 , is dynamic shape
  128. // unknown_dims_num = UNKNOWN_DIM_NUM , all dims are unknown
  129. // unknown_dims_num = n , n > 0 , has n dims unknown
  130. //
  131. int unknown_dims_num = 0;
  132. auto ge_shape = node_desc.GetShape();
  133. for (const auto dim : ge_shape.GetDims()) {
  134. if (dim == UNKNOWN_DIM_NUM) { return UNKNOWN_DIM_NUM; }
  135. if (dim == UNKNOWN_DIM) { ++unknown_dims_num; }
  136. }
  137. return unknown_dims_num;
  138. }
  139. bool TransOpSymmetryEliminationPass::JudgeTransposeDBack2Raw(const NodePtr &src_node, const NodePtr &dst_node) {
  140. //
  141. // A transpose to C : A---->(perm_1)---->B---->(perm_2)---->C
  142. // we want to judge A is equal with C or not
  143. // suppose A = C then:
  144. // 1. B[i] = A[perm_1[i]]
  145. // 2. C[i] = B[perm_2[i]]
  146. // 3. combine 1 and 2 then: C[i] = A[perm_1[perm_2[i]]]
  147. // which we get through 3: i = perm_1[perm_2[i]]
  148. //
  149. vector<int64_t> src_node_perm;
  150. (void)AttrUtils::GetListInt(src_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, src_node_perm);
  151. vector<int64_t> dst_node_perm;
  152. (void)AttrUtils::GetListInt(dst_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, dst_node_perm);
  153. if (src_node_perm.size() != dst_node_perm.size()) { return false; }
  154. for (size_t src_index = 0; src_index < src_node_perm.size(); ++src_index) {
  155. if (dst_node_perm[src_index] >= static_cast<int64_t>(src_node_perm.size())) { return false; }
  156. if (static_cast<int64_t>(src_index) != src_node_perm[dst_node_perm[src_index]]) { return false; }
  157. }
  158. return true;
  159. }
  160. Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor,
  161. NodePtr &dst_node, const InDataAnchorPtr &dst_in_anchor) {
  162. // Two transform nodes can be offset like A->T1->T2->B
  163. // 1.Unlink T1->T2
  164. auto ret = src_out_anchor->Unlink(dst_in_anchor);
  165. if (ret != GRAPH_SUCCESS) {
  166. REPORT_CALL_ERROR("E19999",
  167. "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed",
  168. src_out_anchor->GetOwnerNode()->GetName().c_str(),
  169. src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
  170. dst_in_anchor->GetOwnerNode()->GetName().c_str(),
  171. dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
  172. GELOGE(FAILED, "[Unlink][DataAnchor] from %s(%s)(index:%d) to %s(%s)(index:%d) failed.",
  173. src_out_anchor->GetOwnerNode()->GetName().c_str(),
  174. src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
  175. dst_in_anchor->GetOwnerNode()->GetName().c_str(),
  176. dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
  177. return ret;
  178. }
  179. // 2.Link A->T2
  180. auto data_idx = TransOpUtil::GetTransOpDataIndex(src_node);
  181. auto in_anchor = src_node->GetInDataAnchor(data_idx);
  182. GE_CHECK_NOTNULL(in_anchor);
  183. GE_CHECK_NOTNULL(in_anchor->GetPeerOutAnchor());
  184. auto pre_normal_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode();
  185. ret = GraphUtils::AddEdge(in_anchor->GetPeerOutAnchor(), dst_in_anchor);
  186. if (ret != GRAPH_SUCCESS) {
  187. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  188. pre_normal_node->GetName().c_str(), pre_normal_node->GetType().c_str(),
  189. in_anchor->GetPeerOutAnchor()->GetIdx(),
  190. dst_in_anchor->GetOwnerNode()->GetName().c_str(),
  191. dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
  192. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  193. pre_normal_node->GetName().c_str(), pre_normal_node->GetType().c_str(),
  194. in_anchor->GetPeerOutAnchor()->GetIdx(),
  195. dst_in_anchor->GetOwnerNode()->GetName().c_str(),
  196. dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
  197. return ret;
  198. }
  199. // 3.Copy in-control/data-in-control from T1->T2
  200. ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node);
  201. if (ret != GRAPH_SUCCESS) {
  202. REPORT_CALL_ERROR("E19999", "Copy in control edge from node:%s(%s) to node:%s(%s) failed",
  203. src_node->GetName().c_str(), src_node->GetType().c_str(),
  204. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  205. GELOGE(FAILED, "[Copy][InCtrlEdges] from node:%s(%s) to node:%s(%s) failed",
  206. src_node->GetName().c_str(), src_node->GetType().c_str(),
  207. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  208. return ret;
  209. }
  210. // 4.Add control edge from T1 other input to T2, like reshape second input
  211. for (const auto &in_node : src_node->GetInDataNodes()) {
  212. if (in_node->GetName() == pre_normal_node->GetName()) { continue; }
  213. ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
  214. if (ret != GRAPH_SUCCESS) {
  215. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  216. in_node->GetName().c_str(), in_node->GetType().c_str(),
  217. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  218. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  219. in_node->GetName().c_str(), in_node->GetType().c_str(),
  220. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  221. return ret;
  222. }
  223. }
  224. // 5.IsolateAndDelete T2, A will link to B automatically, and all control edge will also relink.
  225. ret = IsolateAndDeleteNode(dst_node, {0});
  226. if (ret != GRAPH_SUCCESS) {
  227. REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
  228. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  229. GELOGE(INTERNAL_ERROR, "[IsolateAndDelete][Node] failed, node name:%s, node type:%s ",
  230. dst_node->GetName().c_str(), dst_node->GetType().c_str());
  231. return ret;
  232. }
  233. GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", dst_node->GetName().c_str());
  234. // 6.If T1 has no data out, isolate and deleted it.
  235. ret = RemoveTransOpWithoutOutput(pre_normal_node, src_node);
  236. if (ret != GRAPH_SUCCESS) {
  237. GELOGE(ret, "[Call][RemoveTransOpWithoutOutput] for node:%s(%s) failed",
  238. src_node->GetName().c_str(), src_node->GetType().c_str());
  239. return ret;
  240. }
  241. return SUCCESS;
  242. }
  243. Status TransOpSymmetryEliminationPass::RemoveTransOpWithoutOutput(NodePtr &pre_node, NodePtr &trans_node) {
  244. if (trans_node->GetOutDataNodesSize() == 0) {
  245. // 6.1 Copy out control to pre normal node
  246. Status ret = GraphUtils::CopyOutCtrlEdges(trans_node, pre_node);
  247. if (ret != GRAPH_SUCCESS) {
  248. REPORT_CALL_ERROR("E19999", "Copy out control edge from node:%s(%s) to node:%s(%s) failed",
  249. trans_node->GetName().c_str(), trans_node->GetType().c_str(),
  250. pre_node->GetName().c_str(), pre_node->GetType().c_str());
  251. GELOGE(FAILED, "[Copy][OutCtrlEdges] from %s to %s failed.", trans_node->GetName().c_str(),
  252. pre_node->GetName().c_str());
  253. return ret;
  254. }
  255. // 6.2 Isolate and delete T1
  256. ret = IsolateAndDeleteNode(trans_node, {});
  257. if (ret != GRAPH_SUCCESS) {
  258. REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
  259. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  260. GELOGE(INTERNAL_ERROR, "[IsolateAndDelete][Node] %s(%s) failed", trans_node->GetName().c_str(),
  261. trans_node->GetType().c_str());
  262. return ret;
  263. }
  264. GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", trans_node->GetName().c_str());
  265. }
  266. return SUCCESS;
  267. }
  268. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示