You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_without_reshape_fusion_pass.cc 56 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/transop_without_reshape_fusion_pass.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <sstream>
  20. #include <string>
  21. #include <atomic>
  22. #include "common/ge/ge_util.h"
  23. #include "common/ge_inner_error_codes.h"
  24. #include "common/types.h"
  25. #include "graph/common/transop_util.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_tensor.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/op_desc_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. namespace {
  36. const char *const kRemainNode = "node_remain";
  37. const int kInvalidFusionOpCount = -1;
  38. const char *const kAttrNameSrcFormat = "src_format";
  39. const char *const kAttrNameDstFormat = "dst_format";
  40. } // namespace
  41. namespace ge {
  42. void TransOpWithoutReshapeFusionPass::SetRemainNode(
  43. const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
  44. auto iter = nodes_anchor.begin();
  45. while (iter != nodes_anchor.end()) {
  46. auto in_anchor = iter->second;
  47. if (in_anchor == nullptr) {
  48. return;
  49. }
  50. auto in_node = in_anchor->GetOwnerNode();
  51. ++iter;
  52. if (in_node == nullptr) {
  53. return;
  54. }
  55. if (!IsTransOp(in_node)) {
  56. continue;
  57. }
  58. auto op_desc = in_node->GetOpDesc();
  59. if (op_desc == nullptr) {
  60. continue;
  61. }
  62. GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str());
  63. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true),
  64. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kRemainNode,
  65. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  66. GELOGE(INTERNAL_ERROR, "set ext attr failed"); return);
  67. }
  68. }
  69. bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorPtr &out_anchor,
  70. const InDataAnchorPtr &in_anchor) {
  71. if (out_anchor == nullptr || in_anchor == nullptr || in_anchor->GetOwnerNode() == nullptr ||
  72. out_anchor->GetOwnerNode() == nullptr) {
  73. return false;
  74. }
  75. auto in_node = in_anchor->GetOwnerNode();
  76. GE_IF_BOOL_EXEC(in_node == nullptr,
  77. REPORT_INNER_ERROR("E19999", "Param in_anchor's owner node is nullptr, check invalid");
  78. GELOGE(INTERNAL_ERROR, "in_node is null"); return false);
  79. auto in_op = in_node->GetOpDesc();
  80. auto out_owner_node = out_anchor->GetOwnerNode();
  81. GE_IF_BOOL_EXEC(out_owner_node == nullptr,
  82. REPORT_INNER_ERROR("E19999", "Param out_anchor's owner node is nullptr, check invalid");
  83. GELOGE(INTERNAL_ERROR, "out_owner_node is null"); return false);
  84. auto out_op = out_owner_node->GetOpDesc();
  85. GE_IF_BOOL_EXEC(in_op == nullptr,
  86. REPORT_INNER_ERROR("E19999", "Param in_anchor's owner op_desc is nullptr, check invalid");
  87. GELOGE(INTERNAL_ERROR, "in_op is null"); return false);
  88. GE_IF_BOOL_EXEC(out_op == nullptr,
  89. REPORT_INNER_ERROR("E19999", "Param out_anchor's owner op_desc is nullptr, check invalid");
  90. GELOGE(INTERNAL_ERROR, "out_op is null"); return false);
  91. auto in_op_desc = in_op->GetInputDescPtr(in_anchor->GetIdx());
  92. auto out_op_desc = out_op->GetOutputDescPtr(out_anchor->GetIdx());
  93. GE_IF_BOOL_EXEC(in_op_desc == nullptr,
  94. REPORT_INNER_ERROR("E19999", "Param in_anchor corresponding tensor is nullptr, check invalid");
  95. GELOGE(INTERNAL_ERROR, "in_op_desc is null"); return false);
  96. GE_IF_BOOL_EXEC(out_op_desc == nullptr,
  97. REPORT_INNER_ERROR("E19999", "Param out_anchor corresponding tensor is nullptr, check invalid");
  98. GELOGE(INTERNAL_ERROR, "out_op_desc is null"); return false);
  99. if (!ShapeEqualCheck(in_op_desc->GetShape(), out_op_desc->GetShape())) {
  100. return false;
  101. }
  102. if (in_op->GetType() == CAST || out_op->GetType() == CAST) {
  103. return TransOpUtil::CheckPrecisionLoss(in_node);
  104. }
  105. if (in_op_desc->GetFormat() == FORMAT_ND) {
  106. return false;
  107. }
  108. if (out_op_desc->GetFormat() == FORMAT_ND) {
  109. return false;
  110. }
  111. if (in_op_desc->GetFormat() != out_op_desc->GetFormat()) {
  112. return false;
  113. }
  114. return FusionFormatSupport(in_op_desc->GetFormat());
  115. }
  116. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
  117. vector<bool> sub_graph_has_reshape_node(sub_graph_anchors_.size(), false);
  118. vector<int> transop_num_count(sub_graph_anchors_.size(), 0);
  119. vector<vector<NodePtr>> sub_graph_nodes(sub_graph_anchors_.size());
  120. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  121. auto nodes_anchor = sub_graph_anchors_[i];
  122. vector<NodePtr> nodes_tmp;
  123. auto iter = nodes_anchor.begin();
  124. auto first_out_anchor = iter->first;
  125. if (first_out_anchor == nullptr) {
  126. continue;
  127. }
  128. nodes_tmp.push_back(first_out_anchor->GetOwnerNode());
  129. while (iter != nodes_anchor.end()) {
  130. auto in_anchor = iter->second;
  131. GE_CHECK_NOTNULL(in_anchor);
  132. auto in_node = in_anchor->GetOwnerNode();
  133. GE_CHECK_NOTNULL(in_node);
  134. if (in_node->GetType() == RESHAPE) {
  135. sub_graph_has_reshape_node[i] = true;
  136. break;
  137. }
  138. if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) {
  139. auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat();
  140. auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat();
  141. if (input_format == output_format) {
  142. sub_graph_has_reshape_node[i] = true;
  143. break;
  144. }
  145. }
  146. auto out_anchor = iter->first;
  147. GE_CHECK_NOTNULL(out_anchor);
  148. if (!FormatContinuousCheck(out_anchor, in_anchor)) {
  149. sub_graph_has_reshape_node[i] = true;
  150. break;
  151. }
  152. nodes_tmp.push_back(in_node);
  153. if (IsTransOp(in_node)) {
  154. // count transop num
  155. transop_num_count[i]++;
  156. }
  157. ++iter;
  158. }
  159. sub_graph_nodes[i].swap(nodes_tmp);
  160. if (sub_graph_has_reshape_node[i]) {
  161. SetRemainNode(nodes_anchor);
  162. }
  163. }
  164. sub_graph_has_reshape_node_.swap(sub_graph_has_reshape_node);
  165. transop_num_count_.swap(transop_num_count);
  166. sub_graph_nodes_.swap(sub_graph_nodes);
  167. return GRAPH_SUCCESS;
  168. }
  169. void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
  170. const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
  171. // The caller guarantees that the index is legal.
  172. for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) {
  173. auto nodes_anchor = sub_graph_anchors_[index][j];
  174. auto out_data_anchor = nodes_anchor.first;
  175. GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor);
  176. for (const auto &peer_in_control_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  177. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_control_anchor);
  178. auto peer_node = peer_in_control_anchor->GetOwnerNode();
  179. if (peer_node == nullptr) {
  180. continue;
  181. }
  182. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  183. if (iter == sub_graph_nodes_[index].end()) {
  184. out_data_peer_in_control_anchors[index].push_back(peer_in_control_anchor);
  185. } else {
  186. sub_graph_has_out_data_peer_in_control_edge_[index] = true;
  187. }
  188. }
  189. }
  190. }
  191. void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
  192. const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
  193. // The caller guarantees that the index is legal.
  194. for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) {
  195. auto node = sub_graph_nodes_[index][j];
  196. GE_CHECK_NOTNULL_JUST_RETURN(node);
  197. auto in_control_anchor = node->GetInControlAnchor();
  198. if (in_control_anchor == nullptr) {
  199. continue;
  200. }
  201. for (const auto &peer_out_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  202. GE_CHECK_NOTNULL_JUST_RETURN(peer_out_anchor);
  203. auto peer_node = peer_out_anchor->GetOwnerNode();
  204. if (peer_node == nullptr) {
  205. continue;
  206. }
  207. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  208. if (iter == sub_graph_nodes_[index].end()) {
  209. in_control_peer_out_control_anchors[index].push_back(peer_out_anchor);
  210. } else {
  211. sub_graph_has_control_edge_[index] = true;
  212. }
  213. }
  214. }
  215. }
  216. void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors(
  217. const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
  218. vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
  219. for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) {
  220. auto node = sub_graph_nodes_[index][j];
  221. GE_CHECK_NOTNULL_JUST_RETURN(node);
  222. auto out_control_anchor = node->GetOutControlAnchor();
  223. GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor);
  224. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  225. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  226. auto peer_node = peer_in_anchor->GetOwnerNode();
  227. if (peer_node == nullptr) {
  228. continue;
  229. }
  230. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  231. if (iter == sub_graph_nodes_[index].end()) {
  232. if (j > 0) {
  233. out_control_peer_in_control_anchors[index].push_back(peer_in_anchor);
  234. }
  235. } else {
  236. sub_graph_has_control_edge_[index] = true;
  237. }
  238. }
  239. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  240. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  241. auto peer_node = peer_in_anchor->GetOwnerNode();
  242. if (peer_node == nullptr) {
  243. continue;
  244. }
  245. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  246. if (iter == sub_graph_nodes_[index].end()) {
  247. if (j > 0) {
  248. out_control_peer_in_data_anchors[index].push_back(peer_in_anchor);
  249. }
  250. } else {
  251. sub_graph_has_control_edge_[index] = true;
  252. }
  253. }
  254. }
  255. }
  256. void TransOpWithoutReshapeFusionPass::GetControlAnchors() {
  257. vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors(sub_graph_nodes_.size());
  258. vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors(sub_graph_nodes_.size());
  259. vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors(sub_graph_nodes_.size());
  260. vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors(sub_graph_nodes_.size());
  261. vector<bool> sub_graph_has_control_edge(sub_graph_nodes_.size(), false);
  262. sub_graph_has_control_edge_.swap(sub_graph_has_control_edge);
  263. vector<bool> sub_graph_has_out_data_peer_in_control_edge(sub_graph_nodes_.size(), false);
  264. sub_graph_has_out_data_peer_in_control_edge_.swap(sub_graph_has_out_data_peer_in_control_edge);
  265. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  266. if (sub_graph_has_reshape_node_[i]) {
  267. continue;
  268. }
  269. GetOutDataPeerInControlAnchors(i, out_data_peer_in_control_anchors);
  270. GetInControlPeerOutControlAnchors(i, in_control_peer_out_control_anchors);
  271. GetOutControlPeerAnchors(i, out_control_peer_in_control_anchors, out_control_peer_in_data_anchors);
  272. }
  273. in_control_peer_out_control_anchors_.swap(in_control_peer_out_control_anchors);
  274. out_control_peer_in_control_anchors_.swap(out_control_peer_in_control_anchors);
  275. out_control_peer_in_data_anchors_.swap(out_control_peer_in_data_anchors);
  276. out_data_peer_in_control_anchors_.swap(out_data_peer_in_control_anchors);
  277. }
  278. void TransOpWithoutReshapeFusionPass::EraseInvalidAnchorsPair() {
  279. auto sub_graph_iter = sub_graph_anchors_.begin();
  280. while (sub_graph_iter != sub_graph_anchors_.end()) {
  281. if (sub_graph_iter->size() <= 1) {
  282. sub_graph_iter = sub_graph_anchors_.erase(sub_graph_iter);
  283. } else {
  284. ++sub_graph_iter;
  285. }
  286. }
  287. }
  288. void TransOpWithoutReshapeFusionPass::UpdateOutputName(const OutDataAnchorPtr &out_anchor,
  289. const InDataAnchorPtr &old_peer_in_anchor,
  290. const NodePtr &in_owner_node) {
  291. if (out_anchor == nullptr || old_peer_in_anchor == nullptr || in_owner_node == nullptr) {
  292. GELOGI("out_anchor or old_peer_in_anchor or in_owner_node is nullptr");
  293. return;
  294. }
  295. auto out_owner_node = out_anchor->GetOwnerNode();
  296. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  297. GE_CHECK_NOTNULL_JUST_RETURN(old_peer_in_anchor->GetOwnerNode());
  298. auto old_peer_in_name = old_peer_in_anchor->GetOwnerNode()->GetName();
  299. auto output_op = out_owner_node->GetOpDesc();
  300. GE_CHECK_NOTNULL_JUST_RETURN(output_op);
  301. auto output_names = output_op->GetAllOutputName();
  302. auto old_peer_in_name_iter = output_names.find(old_peer_in_name);
  303. if (old_peer_in_name_iter != output_names.end()) {
  304. output_names.erase(old_peer_in_name_iter);
  305. }
  306. output_names[in_owner_node->GetName()] = out_anchor->GetIdx();
  307. if (!output_op->UpdateOutputName(output_names)) {
  308. GELOGW("output_op UpdateOutputName failed");
  309. }
  310. }
  311. void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor,
  312. const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node) {
  313. if (old_peer_out_anchor == nullptr || in_anchor == nullptr || out_owner_node == nullptr) {
  314. GELOGI("old_peer_out_anchor or in_anchor or out_owner_node is nullptr");
  315. return;
  316. }
  317. auto old_node = old_peer_out_anchor->GetOwnerNode();
  318. GE_CHECK_NOTNULL_JUST_RETURN(old_node);
  319. auto old_peer_out_name = old_node->GetName();
  320. auto in_owner_node = in_anchor->GetOwnerNode();
  321. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  322. auto input_op = in_owner_node->GetOpDesc();
  323. GE_CHECK_NOTNULL_JUST_RETURN(input_op);
  324. auto input_names = input_op->GetAllInputName();
  325. auto old_peer_out_name_iter = input_names.find(old_peer_out_name);
  326. if (old_peer_out_name_iter != input_names.end()) {
  327. input_names.erase(old_peer_out_name_iter);
  328. }
  329. input_names[out_owner_node->GetName()] = in_anchor->GetIdx();
  330. input_op->UpdateInputName(input_names);
  331. }
  332. graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
  333. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  334. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  335. auto out_anchor = begin_anchors_pair.first;
  336. GE_CHECK_NOTNULL(out_anchor);
  337. auto out_owner_node = out_anchor->GetOwnerNode();
  338. GE_CHECK_NOTNULL(out_owner_node);
  339. auto in_anchor = end_anchors_pair.second;
  340. GE_CHECK_NOTNULL(in_anchor);
  341. auto in_owner_node = in_anchor->GetOwnerNode();
  342. GE_CHECK_NOTNULL(in_owner_node);
  343. if (sub_graph_has_control_edge_[index]) {
  344. GELOGI("add control edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str());
  345. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), in_owner_node->GetInControlAnchor()) !=
  346. GRAPH_SUCCESS) {
  347. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  348. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  349. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  350. return GRAPH_FAILED;
  351. }
  352. }
  353. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  354. GELOGI("add out data 2 in contorl edge.src:%s, dst:%s", out_owner_node->GetName().c_str(),
  355. in_owner_node->GetName().c_str());
  356. if (GraphUtils::AddEdge(out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  357. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  358. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  359. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  360. return GRAPH_FAILED;
  361. }
  362. }
  363. return GRAPH_SUCCESS;
  364. }
  365. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged(
  366. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  367. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  368. if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) {
  369. return GRAPH_FAILED;
  370. }
  371. auto out_anchor = begin_anchors_pair.first;
  372. GE_CHECK_NOTNULL(out_anchor);
  373. auto out_owner_node = out_anchor->GetOwnerNode();
  374. GE_CHECK_NOTNULL(out_owner_node);
  375. auto in_anchor = end_anchors_pair.second;
  376. GE_CHECK_NOTNULL(in_anchor);
  377. auto in_owner_node = in_anchor->GetOwnerNode();
  378. GE_CHECK_NOTNULL(in_owner_node);
  379. // can not remove old control edge
  380. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  381. GE_CHECK_NOTNULL(peer_in_anchor);
  382. GELOGI("add control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  383. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  384. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  385. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  386. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  387. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  388. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  389. return GRAPH_FAILED;
  390. }
  391. }
  392. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  393. GE_CHECK_NOTNULL(peer_out_anchor);
  394. GELOGI("add control edge.src:%s, src idx:%d, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  395. peer_out_anchor->GetIdx(), in_owner_node->GetName().c_str());
  396. if (GraphUtils::AddEdge(peer_out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  397. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  398. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  399. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  400. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  401. return GRAPH_FAILED;
  402. }
  403. }
  404. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  405. GE_CHECK_NOTNULL(peer_in_anchor);
  406. GELOGI("add out control 2 in data edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  407. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  408. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  409. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  410. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  411. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  412. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  413. return GRAPH_FAILED;
  414. }
  415. }
  416. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  417. GE_CHECK_NOTNULL(peer_in_anchor);
  418. GELOGI("add out data 2 in control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  419. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  420. if (GraphUtils::AddEdge(out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  421. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  422. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  423. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  424. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  425. return GRAPH_FAILED;
  426. }
  427. }
  428. return GRAPH_SUCCESS;
  429. }
  430. graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged(
  431. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  432. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  433. auto out_anchor = begin_anchors_pair.first;
  434. GE_CHECK_NOTNULL(out_anchor);
  435. auto out_owner_node = out_anchor->GetOwnerNode();
  436. GE_CHECK_NOTNULL(out_owner_node);
  437. auto in_anchor = end_anchors_pair.second;
  438. GE_CHECK_NOTNULL(in_anchor);
  439. auto in_owner_node = in_anchor->GetOwnerNode();
  440. GE_CHECK_NOTNULL(in_owner_node);
  441. GELOGI("remove edge.src %s, src idx:%d, dst:%s, dst idx:%d",
  442. end_anchors_pair.first->GetOwnerNode()->GetName().c_str(), end_anchors_pair.first->GetIdx(),
  443. in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  444. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_anchors_pair.first, in_anchor), "remove edge failed");
  445. GELOGI("relink node.src node:%s, src idx:%d, dst node:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  446. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  447. if (GraphUtils::AddEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) {
  448. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  449. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  450. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  451. GELOGE(GRAPH_FAILED, "add edge failed!src:%s, src idx:%d, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  452. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  453. return GRAPH_FAILED;
  454. } else {
  455. auto old_peer_in_anchor = begin_anchors_pair.second;
  456. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  457. auto old_peer_out_anchor = end_anchors_pair.first;
  458. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  459. }
  460. return RelinkControlEdgesWhenDescNotChanged(begin_anchors_pair, end_anchors_pair, index);
  461. }
  462. OpDescPtr TransOpWithoutReshapeFusionPass::GetFormatTransferOp(const GeTensorDesc &format_trans_input_desc,
  463. const GeTensorDesc &format_trans_output_desc) {
  464. static std::atomic_long atomic_fusion_format_transfer_op_count(1);
  465. auto fusion_format_transfer_op_count = atomic_fusion_format_transfer_op_count.fetch_add(1);
  466. std::stringstream format_transfer_op_name;
  467. format_transfer_op_name << "fusion_format_transfer_" << fusion_format_transfer_op_count;
  468. OpDescPtr format_transfer_op = MakeShared<OpDesc>(format_transfer_op_name.str().c_str(), TRANSDATA);
  469. if (format_transfer_op == nullptr) {
  470. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  471. GELOGE(INTERNAL_ERROR, "new format transfer op failed!");
  472. return nullptr;
  473. }
  474. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_INPUT_FORMAT,
  475. static_cast<int64_t>(format_trans_input_desc.GetFormat())),
  476. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(),
  477. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  478. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_INPUT_FORMAT failed");
  479. return nullptr);
  480. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_OUTPUT_FORMAT,
  481. static_cast<int64_t>(format_trans_output_desc.GetFormat())),
  482. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(),
  483. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  484. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_OUTPUT_FORMAT failed");
  485. return nullptr);
  486. string src_format = TypeUtils::FormatToSerialString(format_trans_input_desc.GetFormat());
  487. string dst_format = TypeUtils::FormatToSerialString(format_trans_output_desc.GetFormat());
  488. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameSrcFormat, src_format),
  489. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kAttrNameSrcFormat,
  490. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  491. GELOGE(INTERNAL_ERROR, "set kAttrNameSrcFormat failed");
  492. return nullptr);
  493. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameDstFormat, dst_format),
  494. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kAttrNameDstFormat,
  495. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  496. GELOGE(INTERNAL_ERROR, "set kAttrNameDstFormat failed");
  497. return nullptr);
  498. GE_IF_BOOL_EXEC(format_transfer_op->AddInputDesc(format_trans_input_desc) != GRAPH_SUCCESS,
  499. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  500. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  501. GELOGE(INTERNAL_ERROR, "add input desc failed");
  502. return nullptr);
  503. GE_IF_BOOL_EXEC(format_transfer_op->AddOutputDesc(format_trans_output_desc) != GRAPH_SUCCESS,
  504. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  505. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  506. GELOGE(INTERNAL_ERROR, "add output desc failed");
  507. return nullptr);
  508. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetBool(format_transfer_op, ATTR_NEED_COMPILE, true),
  509. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  510. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  511. GELOGE(INTERNAL_ERROR, "set ext attr failed");
  512. return nullptr);
  513. return format_transfer_op;
  514. }
  515. OpDescPtr TransOpWithoutReshapeFusionPass::GetCastOp(const GeTensorDesc &cast_input_desc,
  516. const GeTensorDesc &cast_output_desc) {
  517. static std::atomic_long atomic_fusion_cast_op_count(1);
  518. auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1);
  519. std::stringstream cast_op_name;
  520. cast_op_name << "fusion_cast_op_" << fusion_cast_op_count;
  521. auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST);
  522. auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  523. node_op.BreakConnect();
  524. if (cast_op == nullptr) {
  525. REPORT_CALL_ERROR("E19999", "Create operator:%s(%s) failed", cast_op_name.str().c_str(), CAST);
  526. GELOGE(INTERNAL_ERROR, "new cast op failed!");
  527. return nullptr;
  528. }
  529. const int default_input_index = 0;
  530. const int default_output_index = 0;
  531. if (cast_op->GetInputsSize() == 0) {
  532. GE_IF_BOOL_EXEC(cast_op->AddInputDesc(cast_input_desc) != GRAPH_SUCCESS,
  533. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  534. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  535. GELOGE(INTERNAL_ERROR, "add input desc failed");
  536. return nullptr);
  537. } else {
  538. GE_IF_BOOL_EXEC(cast_op->UpdateInputDesc(default_input_index, cast_input_desc) != GRAPH_SUCCESS,
  539. REPORT_CALL_ERROR("E19999", "Update input:%d desc of op:%s(%s) failed", default_input_index,
  540. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  541. GELOGE(INTERNAL_ERROR, "update input desc failed");
  542. return nullptr);
  543. }
  544. if (cast_op->GetOutputsSize() == 0) {
  545. GE_IF_BOOL_EXEC(cast_op->AddOutputDesc(cast_output_desc) != GRAPH_SUCCESS,
  546. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  547. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  548. GELOGE(INTERNAL_ERROR, "add output desc failed");
  549. return nullptr);
  550. } else {
  551. GE_IF_BOOL_EXEC(cast_op->UpdateOutputDesc(default_output_index, cast_output_desc) != GRAPH_SUCCESS,
  552. REPORT_CALL_ERROR("E19999", "Update output:%d desc of op:%s(%s) failed", default_output_index,
  553. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  554. GELOGE(INTERNAL_ERROR, "update output desc failed");
  555. return nullptr);
  556. }
  557. if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast<int64_t>(cast_output_desc.GetDataType()))) {
  558. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", CAST_ATTR_DST_TYPE.c_str(),
  559. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  560. GELOGE(INTERNAL_ERROR, "set dst_type attr failed");
  561. return nullptr;
  562. }
  563. if (!AttrUtils::SetBool(cast_op, ATTR_NEED_COMPILE, true)) {
  564. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  565. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  566. GELOGE(INTERNAL_ERROR, "set need_compile attr failed");
  567. return nullptr;
  568. }
  569. return cast_op;
  570. }
  571. bool TransOpWithoutReshapeFusionPass::InsertCastFirstCheck(const GeTensorDesc &out_desc,
  572. const GeTensorDesc &in_desc) const {
  573. return out_desc.GetDataType() != in_desc.GetDataType() && out_desc.GetDataType() != DT_FLOAT16 &&
  574. in_desc.GetDataType() == DT_FLOAT16;
  575. }
  576. void TransOpWithoutReshapeFusionPass::GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  577. GeTensorDesc &format_transfer_input,
  578. GeTensorDesc &format_transfer_output) {
  579. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  580. if (insert_cast_first) {
  581. format_transfer_input = out_desc;
  582. format_transfer_input.SetDataType(in_desc.GetDataType());
  583. format_transfer_output = in_desc;
  584. } else {
  585. format_transfer_input = out_desc;
  586. format_transfer_output = in_desc;
  587. format_transfer_output.SetDataType(out_desc.GetDataType());
  588. }
  589. }
  590. void TransOpWithoutReshapeFusionPass::GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  591. GeTensorDesc &cast_input, GeTensorDesc &cast_output) {
  592. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  593. if (insert_cast_first) {
  594. cast_input = out_desc;
  595. cast_output = out_desc;
  596. cast_output.SetDataType(in_desc.GetDataType());
  597. } else {
  598. cast_input = in_desc;
  599. cast_input.SetDataType(out_desc.GetDataType());
  600. cast_output = in_desc;
  601. }
  602. }
  603. void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc,
  604. GeTensorDesc &in_desc) {
  605. auto nodes_anchor = sub_graph_anchors_[index];
  606. auto out_peer_anchor = nodes_anchor.front().second;
  607. GE_CHECK_NOTNULL_JUST_RETURN(out_peer_anchor);
  608. auto out_owner_node = out_peer_anchor->GetOwnerNode();
  609. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  610. auto out_peer_op_desc = out_owner_node->GetOpDesc();
  611. GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return);
  612. out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx());
  613. auto in_peer_anchor = nodes_anchor.back().first;
  614. GE_CHECK_NOTNULL_JUST_RETURN(in_peer_anchor);
  615. auto in_owner_node = in_peer_anchor->GetOwnerNode();
  616. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  617. auto in_peer_op_desc = in_owner_node->GetOpDesc();
  618. GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return);
  619. in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx());
  620. }
  621. graphStatus TransOpWithoutReshapeFusionPass::FormatFusion(const int index, OpDescPtr &format_transfer_op,
  622. int32_t &fusion_op_count, bool &fusion_continue) {
  623. GeTensorDesc out_desc;
  624. GeTensorDesc in_desc;
  625. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  626. GeTensorDesc format_transfer_input;
  627. GeTensorDesc format_transfer_output;
  628. GetFormatTransferDesc(out_desc, in_desc, format_transfer_input, format_transfer_output);
  629. if (out_desc.GetFormat() == in_desc.GetFormat() &&
  630. (!ShapeEqualCheck(out_desc.GetShape(), in_desc.GetShape()) ||
  631. !ShapeEqualCheck(out_desc.GetOriginShape(), in_desc.GetOriginShape()))) {
  632. SetRemainNode(sub_graph_anchors_[index]);
  633. return GRAPH_SUCCESS;
  634. }
  635. if (out_desc.GetFormat() != in_desc.GetFormat() && FusionFormatSupport(out_desc.GetFormat()) &&
  636. FusionFormatSupport(in_desc.GetFormat())) {
  637. // create format transop
  638. format_transfer_op = GetFormatTransferOp(format_transfer_input, format_transfer_output);
  639. if (format_transfer_op == nullptr) {
  640. return GRAPH_FAILED;
  641. }
  642. if (OpAccuracyAbilityCheck(format_transfer_op)) {
  643. ++fusion_op_count;
  644. GELOGI("support format transfer op %s", format_transfer_op->GetName().c_str());
  645. } else {
  646. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  647. format_transfer_input.GetFormat(), format_transfer_input.GetDataType(), format_transfer_output.GetFormat(),
  648. format_transfer_output.GetDataType());
  649. fusion_op_count = kInvalidFusionOpCount;
  650. }
  651. } else if (out_desc.GetFormat() != in_desc.GetFormat()) {
  652. SetRemainNode(sub_graph_anchors_[index]);
  653. return GRAPH_SUCCESS;
  654. }
  655. fusion_continue = true;
  656. return GRAPH_SUCCESS;
  657. }
  658. graphStatus TransOpWithoutReshapeFusionPass::DataTypeFusion(const int index, OpDescPtr &cast_op,
  659. int32_t &fusion_op_count) {
  660. GeTensorDesc out_desc;
  661. GeTensorDesc in_desc;
  662. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  663. GeTensorDesc cast_input;
  664. GeTensorDesc cast_output;
  665. GetCastOpDesc(out_desc, in_desc, cast_input, cast_output);
  666. if (fusion_op_count != kInvalidFusionOpCount && out_desc.GetDataType() != in_desc.GetDataType()) {
  667. // create cast op
  668. cast_op = GetCastOp(cast_input, cast_output);
  669. if (cast_op == nullptr) {
  670. fusion_op_count = kInvalidFusionOpCount;
  671. return GRAPH_FAILED;
  672. }
  673. if (OpAccuracyAbilityCheck(cast_op)) {
  674. ++fusion_op_count;
  675. GELOGI("support cast op %s. src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  676. cast_op->GetName().c_str(), cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(),
  677. cast_output.GetDataType());
  678. } else {
  679. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  680. cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType());
  681. fusion_op_count = kInvalidFusionOpCount;
  682. }
  683. }
  684. return GRAPH_SUCCESS;
  685. }
  686. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuseHandle(const ComputeGraphPtr &graph, const int index) {
  687. bool fusion_continue = false;
  688. OpDescPtr format_transfer_op = nullptr;
  689. int32_t fusion_op_count = 0;
  690. auto fortmat_fusion_ret = FormatFusion(index, format_transfer_op, fusion_op_count, fusion_continue);
  691. if (fortmat_fusion_ret != GRAPH_SUCCESS || !fusion_continue) {
  692. SetRemainNode(sub_graph_anchors_[index]);
  693. return GRAPH_SUCCESS;
  694. }
  695. OpDescPtr cast_op = nullptr;
  696. if (DataTypeFusion(index, cast_op, fusion_op_count) != GRAPH_SUCCESS) {
  697. SetRemainNode(sub_graph_anchors_[index]);
  698. return GRAPH_SUCCESS;
  699. }
  700. if (fusion_op_count != kInvalidFusionOpCount && fusion_op_count < transop_num_count_[index]) {
  701. GeTensorDesc out_desc;
  702. GeTensorDesc in_desc;
  703. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  704. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  705. if (InsertNewTransOp(graph, cast_op, format_transfer_op, index, insert_cast_first) != GRAPH_SUCCESS) {
  706. return GRAPH_FAILED;
  707. }
  708. } else {
  709. // remain all nodes
  710. SetRemainNode(sub_graph_anchors_[index]);
  711. }
  712. return GRAPH_SUCCESS;
  713. }
  714. void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &graph) {
  715. if (graph == nullptr) {
  716. return;
  717. }
  718. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  719. if (sub_graph_has_reshape_node_[i]) {
  720. continue;
  721. }
  722. for (const auto &node : sub_graph_nodes_[i]) {
  723. GE_CHECK_NOTNULL_JUST_RETURN(node);
  724. // remove nodes
  725. if (!IsTransOp(node)) {
  726. continue;
  727. }
  728. auto op_desc = node->GetOpDesc();
  729. GE_CHECK_NOTNULL_JUST_RETURN(op_desc);
  730. bool node_remain_flag = op_desc->TryGetExtAttr(kRemainNode, false);
  731. if (node_remain_flag) {
  732. continue;
  733. }
  734. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return);
  735. GELOGI("remove node:%s", node->GetName().c_str());
  736. if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) {
  737. GELOGW("Isolate node: %s failed.", node->GetName().c_str());
  738. continue;
  739. }
  740. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  741. GELOGW("Remove node: %s failed.", node->GetName().c_str());
  742. continue;
  743. }
  744. }
  745. }
  746. }
  747. graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) {
  748. GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin.");
  749. if (graph == nullptr) {
  750. return GRAPH_SUCCESS;
  751. }
  752. for (const auto &node : graph->GetDirectNode()) {
  753. GE_CHECK_NOTNULL(node);
  754. if (IsTransOp(node)) {
  755. continue;
  756. }
  757. bool is_unknown = false;
  758. auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
  759. if (ret != GRAPH_SUCCESS) {
  760. GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(),
  761. node->GetType().c_str());
  762. continue;
  763. }
  764. if (is_unknown) {
  765. GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(),
  766. node->GetType().c_str());
  767. continue;
  768. }
  769. GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
  770. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  771. GE_CHECK_NOTNULL(out_anchor);
  772. vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
  773. vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> nodes_list;
  774. if (GetSubGraphsBetweenNormalNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) {
  775. GELOGW("get transops failed!");
  776. continue;
  777. }
  778. sub_graph_anchors_.swap(sub_graph_anchors);
  779. EraseInvalidAnchorsPair();
  780. if (sub_graph_anchors_.empty()) {
  781. continue;
  782. }
  783. // check reshape node
  784. if (GetSubGraphNodesInfo() != GRAPH_SUCCESS) {
  785. continue;
  786. }
  787. // save control edge
  788. GetControlAnchors();
  789. if (TransOpFuse(graph) != GRAPH_SUCCESS) {
  790. return GRAPH_FAILED;
  791. }
  792. }
  793. }
  794. GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end.");
  795. return GRAPH_SUCCESS;
  796. }
  797. bool TransOpWithoutReshapeFusionPass::DescEqualCheck(ConstGeTensorDescPtr &desc_src,
  798. ConstGeTensorDescPtr &desc_dst) const {
  799. if (desc_src == nullptr || desc_dst == nullptr) {
  800. return false;
  801. }
  802. if (desc_src->GetFormat() != desc_dst->GetFormat() || desc_src->GetDataType() != desc_dst->GetDataType()) {
  803. return false;
  804. }
  805. if (!ShapeEqualCheck(desc_src->GetShape(), desc_dst->GetShape())) {
  806. return false;
  807. }
  808. return ShapeEqualCheck(desc_src->GetOriginShape(), desc_dst->GetOriginShape());
  809. }
  810. bool TransOpWithoutReshapeFusionPass::ShapeEqualCheck(const GeShape &src, const GeShape &dst) const {
  811. if (src.GetDims().size() != dst.GetDims().size()) {
  812. return false;
  813. }
  814. for (size_t i = 0; i < src.GetDims().size(); ++i) {
  815. if (src.GetDim(i) != dst.GetDim(i)) {
  816. return false;
  817. }
  818. }
  819. return true;
  820. }
  821. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuse(const ComputeGraphPtr &graph) {
  822. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  823. if (sub_graph_has_reshape_node_[i]) {
  824. continue;
  825. }
  826. auto nodes_anchor = sub_graph_anchors_[i];
  827. auto out_anchor = nodes_anchor.front().first;
  828. GE_CHECK_NOTNULL(out_anchor);
  829. auto out_op_desc = out_anchor->GetOwnerNode()->GetOpDesc();
  830. GE_CHECK_NOTNULL(out_op_desc);
  831. auto out_desc = out_op_desc->GetOutputDescPtr(out_anchor->GetIdx());
  832. GE_CHECK_NOTNULL(out_desc);
  833. auto in_anchor = nodes_anchor.back().second;
  834. GE_CHECK_NOTNULL(in_anchor);
  835. auto in_op_desc = in_anchor->GetOwnerNode()->GetOpDesc();
  836. GE_CHECK_NOTNULL(in_op_desc);
  837. auto in_desc = in_op_desc->GetInputDescPtr(in_anchor->GetIdx());
  838. GE_CHECK_NOTNULL(in_desc);
  839. if (FusionFormatSupport(out_desc->GetFormat()) && DescEqualCheck(out_desc, in_desc)) {
  840. // relink begin_out to end_in
  841. if (RelinkNodesWhenDescNotChanged(nodes_anchor.front(), nodes_anchor.back(), static_cast<int>(i)) !=
  842. GRAPH_SUCCESS) {
  843. return GRAPH_FAILED;
  844. }
  845. } else {
  846. if (TransOpFuseHandle(graph, static_cast<int>(i)) != GRAPH_SUCCESS) {
  847. return GRAPH_FAILED;
  848. }
  849. }
  850. }
  851. RemoveNousedNodes(graph);
  852. return GRAPH_SUCCESS;
  853. }
  854. graphStatus TransOpWithoutReshapeFusionPass::AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop,
  855. NodePtr &trans_node) {
  856. if (graph == nullptr) {
  857. return GRAPH_SUCCESS;
  858. }
  859. if (transop == nullptr) {
  860. return GRAPH_SUCCESS;
  861. }
  862. trans_node = graph->AddNode(transop);
  863. if (trans_node == nullptr) {
  864. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  865. transop->GetName().c_str(), transop->GetType().c_str(), graph->GetName().c_str());
  866. GELOGE(GRAPH_FAILED, "add node failed!");
  867. return GRAPH_FAILED;
  868. }
  869. return GRAPH_SUCCESS;
  870. }
  871. graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  872. const OpDescPtr &format_transfer_op,
  873. const bool insert_cast_first,
  874. std::vector<NodePtr> &new_trans_nodes) {
  875. NodePtr format_transfer_node;
  876. if (AddTransNode(graph, format_transfer_op, format_transfer_node) != GRAPH_SUCCESS) {
  877. return GRAPH_FAILED;
  878. }
  879. NodePtr cast_node;
  880. if (AddTransNode(graph, cast_op, cast_node) != GRAPH_SUCCESS) {
  881. return GRAPH_FAILED;
  882. }
  883. if (insert_cast_first) {
  884. if (cast_node != nullptr) {
  885. new_trans_nodes.push_back(cast_node);
  886. }
  887. if (format_transfer_node != nullptr) {
  888. new_trans_nodes.push_back(format_transfer_node);
  889. }
  890. } else {
  891. if (format_transfer_node != nullptr) {
  892. new_trans_nodes.push_back(format_transfer_node);
  893. }
  894. if (cast_node != nullptr) {
  895. new_trans_nodes.push_back(cast_node);
  896. }
  897. }
  898. return GRAPH_SUCCESS;
  899. }
  900. graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  901. const OpDescPtr &format_transfer_op, const int index,
  902. const bool insert_cast_first) {
  903. std::vector<NodePtr> new_trans_nodes;
  904. if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) {
  905. return GRAPH_FAILED;
  906. }
  907. if (new_trans_nodes.empty()) {
  908. GELOGI("No new trans node. Do not need insert new transop.");
  909. return GRAPH_SUCCESS;
  910. }
  911. pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front();
  912. pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back();
  913. auto out_anchor = begin_out.first;
  914. GE_CHECK_NOTNULL(out_anchor);
  915. auto out_owner_node = out_anchor->GetOwnerNode();
  916. GE_CHECK_NOTNULL(out_owner_node);
  917. auto in_anchor = end_in.second;
  918. GE_CHECK_NOTNULL(in_anchor);
  919. auto in_owner_node = in_anchor->GetOwnerNode();
  920. GE_CHECK_NOTNULL(in_owner_node);
  921. GELOGI("remove edge.src:%s, src idx:%d, dst:%s, dst idx:%d", end_in.first->GetOwnerNode()->GetName().c_str(),
  922. end_in.first->GetIdx(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  923. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_in.first, in_anchor), "remove edge failed");
  924. GELOGI("add edge.src:%s, src idx:%d, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetIdx(),
  925. new_trans_nodes.front()->GetName().c_str());
  926. if (GraphUtils::AddEdge(out_anchor, new_trans_nodes.front()->GetInAnchor(0)) != GRAPH_SUCCESS) {
  927. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  928. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  929. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  930. return GRAPH_FAILED;
  931. } else {
  932. auto old_peer_in_anchor = begin_out.second;
  933. GE_CHECK_NOTNULL(old_peer_in_anchor);
  934. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  935. }
  936. if (new_trans_nodes.size() > 1) {
  937. GELOGI("add edge.src:%s, dst:%s", new_trans_nodes.front()->GetName().c_str(),
  938. new_trans_nodes.back()->GetName().c_str());
  939. if (GraphUtils::AddEdge(new_trans_nodes.front()->GetOutAnchor(0), new_trans_nodes.back()->GetInAnchor(0)) !=
  940. GRAPH_SUCCESS) {
  941. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  942. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  943. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str());
  944. return GRAPH_FAILED;
  945. } else {
  946. auto old_peer_out_anchor = end_in.first;
  947. GE_CHECK_NOTNULL(old_peer_out_anchor);
  948. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  949. }
  950. }
  951. GELOGI("add edge.src:%s, dst:%s, dst idx:%d", new_trans_nodes.back()->GetName().c_str(),
  952. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  953. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  954. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  955. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  956. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  957. return GRAPH_FAILED;
  958. }
  959. return RelinkControlEdge(index, out_anchor, new_trans_nodes);
  960. }
  961. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
  962. const vector<NodePtr> &new_trans_nodes) {
  963. GE_CHECK_NOTNULL(out_anchor);
  964. if (new_trans_nodes.front() == nullptr || new_trans_nodes.back() == nullptr) {
  965. REPORT_INNER_ERROR("E19999", "Param new_trans_nodes front or back is nullptr, check invalid");
  966. return GRAPH_FAILED;
  967. }
  968. if (sub_graph_has_control_edge_[index]) {
  969. GELOGI("add control edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
  970. new_trans_nodes.front()->GetName().c_str());
  971. if (GraphUtils::AddEdge(out_anchor->GetOwnerNode()->GetOutControlAnchor(),
  972. new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  973. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  974. out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
  975. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  976. return GRAPH_FAILED;
  977. }
  978. }
  979. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  980. GE_CHECK_NOTNULL(peer_in_anchor);
  981. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  982. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  983. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  984. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  985. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  986. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  987. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  988. return GRAPH_FAILED;
  989. }
  990. }
  991. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  992. GE_CHECK_NOTNULL(peer_out_anchor);
  993. GELOGI("add control edge.src:%s, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  994. new_trans_nodes.front()->GetName().c_str());
  995. if (GraphUtils::AddEdge(peer_out_anchor, new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  996. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  997. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  998. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  999. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  1000. return GRAPH_FAILED;
  1001. }
  1002. }
  1003. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  1004. GE_CHECK_NOTNULL(peer_in_anchor);
  1005. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1006. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  1007. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  1008. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  1009. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1010. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1011. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  1012. return GRAPH_FAILED;
  1013. }
  1014. }
  1015. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  1016. GE_CHECK_NOTNULL(peer_in_anchor);
  1017. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1018. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  1019. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), peer_in_anchor) != GRAPH_SUCCESS) {
  1020. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1021. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1022. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1023. peer_in_anchor->GetOwnerNode()->GetType().c_str(), peer_in_anchor->GetIdx());
  1024. return GRAPH_FAILED;
  1025. }
  1026. }
  1027. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  1028. auto in_anchor = sub_graph_anchors_[index].back().second;
  1029. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1030. in_anchor->GetOwnerNode()->GetName().c_str());
  1031. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0),
  1032. in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  1033. return GRAPH_FAILED;
  1034. }
  1035. }
  1036. return GRAPH_SUCCESS;
  1037. }
  1038. bool TransOpWithoutReshapeFusionPass::OpAccuracyAbilityCheck(const OpDescPtr &op_desc) {
  1039. auto instance = GELib::GetInstance();
  1040. if ((instance == nullptr) || (!instance->InitFlag())) {
  1041. GELOGW("GELib is not initialized!");
  1042. return false;
  1043. }
  1044. if (op_desc == nullptr) {
  1045. return false;
  1046. }
  1047. OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
  1048. vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  1049. if (op_infos.empty()) {
  1050. GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str());
  1051. return false;
  1052. }
  1053. std::string unsupported_reason;
  1054. for (const auto &it : op_infos) {
  1055. auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  1056. auto &kernel_name = it.opKernelLib;
  1057. auto kernel_info_store = kernel_map.find(kernel_name);
  1058. if (kernel_info_store != kernel_map.end()) {
  1059. if (kernel_info_store->second != nullptr &&
  1060. kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) {
  1061. op_desc->SetOpEngineName(it.engine);
  1062. op_desc->SetOpKernelLibName(kernel_name);
  1063. GELOGI("Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(),
  1064. op_desc->GetName().c_str());
  1065. return true;
  1066. }
  1067. }
  1068. }
  1069. GELOGI("op %s CheckAccuracySupported failed!reason:%s", op_desc->GetType().c_str(), unsupported_reason.c_str());
  1070. return false;
  1071. }
  1072. bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) {
  1073. return format == FORMAT_NCHW || format == FORMAT_NHWC || format == FORMAT_FRACTAL_Z || format == FORMAT_NC1HWC0;
  1074. }
  1075. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode(
  1076. const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
  1077. vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
  1078. graphStatus ret = GRAPH_SUCCESS;
  1079. if (out_anchor == nullptr) {
  1080. REPORT_INNER_ERROR("E19999", "Param out_anchor is nullptr, check invalid");
  1081. return GRAPH_FAILED;
  1082. }
  1083. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  1084. if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr ||
  1085. peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  1086. continue;
  1087. }
  1088. nodes_list.emplace_back(out_anchor, peer_in_anchor);
  1089. auto peer_in_node = peer_in_anchor->GetOwnerNode();
  1090. GE_CHECK_NOTNULL(peer_in_node);
  1091. if (!IsTransOp(peer_in_node)) {
  1092. sub_graphs_out.push_back(nodes_list);
  1093. nodes_list.pop_back();
  1094. } else {
  1095. for (const auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
  1096. ret = GetSubGraphsBetweenNormalNode(peer_out_anchor, sub_graphs_out, nodes_list);
  1097. if (ret != GRAPH_SUCCESS) {
  1098. GELOGE(GRAPH_FAILED, "get all transops between normal node failed!node:%s", peer_in_node->GetName().c_str());
  1099. return GRAPH_FAILED;
  1100. }
  1101. }
  1102. nodes_list.pop_back();
  1103. }
  1104. }
  1105. return GRAPH_SUCCESS;
  1106. }
  1107. bool TransOpWithoutReshapeFusionPass::IsTransOp(const NodePtr &node) {
  1108. // The caller guarantees that the pointer is not null.
  1109. return node->GetType() == CAST || node->GetType() == RESHAPE || node->GetType() == TRANSPOSE ||
  1110. node->GetType() == TRANSPOSED || node->GetType() == TRANSDATA;
  1111. }
  1112. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示