You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_without_reshape_fusion_pass.cc 64 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/transop_without_reshape_fusion_pass.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <sstream>
  20. #include <string>
  21. #include <atomic>
  22. #include "common/ge/ge_util.h"
  23. #include "framework/common/ge_inner_error_codes.h"
  24. #include "framework/common/types.h"
  25. #include "common/transop_util.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_tensor.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/op_desc_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. namespace {
  36. const char *const kRemainNode = "node_remain";
  37. const int kInvalidFusionOpCount = -1;
  38. const char *const kAttrNameSrcFormat = "src_format";
  39. const char *const kAttrNameDstFormat = "dst_format";
  40. } // namespace
  41. namespace ge {
  42. void TransOpWithoutReshapeFusionPass::SetRemainNode(
  43. const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
  44. auto iter = nodes_anchor.begin();
  45. while (iter != nodes_anchor.end()) {
  46. auto in_anchor = iter->second;
  47. if (in_anchor == nullptr) {
  48. return;
  49. }
  50. auto in_node = in_anchor->GetOwnerNode();
  51. ++iter;
  52. if (in_node == nullptr) {
  53. return;
  54. }
  55. if (!IsTransOp(in_node)) {
  56. continue;
  57. }
  58. auto op_desc = in_node->GetOpDesc();
  59. if (op_desc == nullptr) {
  60. continue;
  61. }
  62. GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str());
  63. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true),
  64. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kRemainNode,
  65. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  66. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", kRemainNode,
  67. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  68. return);
  69. }
  70. }
  71. bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorPtr &out_anchor,
  72. const InDataAnchorPtr &in_anchor) {
  73. if (out_anchor == nullptr || in_anchor == nullptr || in_anchor->GetOwnerNode() == nullptr ||
  74. out_anchor->GetOwnerNode() == nullptr) {
  75. return false;
  76. }
  77. auto in_node = in_anchor->GetOwnerNode();
  78. GE_IF_BOOL_EXEC(in_node == nullptr,
  79. REPORT_INNER_ERROR("E19999", "Param in_anchor's owner node is nullptr, check invalid");
  80. GELOGE(INTERNAL_ERROR, "[Check][Param]Param in_anchor's owner node is nullptr");
  81. return false);
  82. auto in_op = in_node->GetOpDesc();
  83. auto out_owner_node = out_anchor->GetOwnerNode();
  84. GE_IF_BOOL_EXEC(out_owner_node == nullptr,
  85. REPORT_INNER_ERROR("E19999", "Param out_anchor's owner node is nullptr, check invalid");
  86. GELOGE(INTERNAL_ERROR, "[Check][Param] Param out_anchor's owner node is nullptr");
  87. return false);
  88. auto out_op = out_owner_node->GetOpDesc();
  89. GE_IF_BOOL_EXEC(in_op == nullptr,
  90. REPORT_INNER_ERROR("E19999", "Param in_anchor's owner op_desc is nullptr, check invalid");
  91. GELOGE(INTERNAL_ERROR, "[Check][Param] Param in_anchor's owner op_desc is nullptr");
  92. return false);
  93. GE_IF_BOOL_EXEC(out_op == nullptr,
  94. REPORT_INNER_ERROR("E19999", "Param out_anchor's owner op_desc is nullptr, check invalid");
  95. GELOGE(INTERNAL_ERROR, "[Check][Param] Param out_anchor's owner op_desc is nullptr");
  96. return false);
  97. auto in_op_desc = in_op->GetInputDescPtr(in_anchor->GetIdx());
  98. auto out_op_desc = out_op->GetOutputDescPtr(out_anchor->GetIdx());
  99. GE_IF_BOOL_EXEC(in_op_desc == nullptr,
  100. REPORT_INNER_ERROR("E19999", "Param in_anchor corresponding tensor is nullptr, check invalid");
  101. GELOGE(INTERNAL_ERROR, "[Check][Param] Param in_anchor corresponding tensor is nullptr");
  102. return false);
  103. GE_IF_BOOL_EXEC(out_op_desc == nullptr,
  104. REPORT_INNER_ERROR("E19999", "Param out_anchor corresponding tensor is nullptr, check invalid");
  105. GELOGE(INTERNAL_ERROR, "[Check][Param] Param out_anchor corresponding tensor is nullptr");
  106. return false);
  107. if (!ShapeEqualCheck(in_op_desc->GetShape(), out_op_desc->GetShape())) {
  108. return false;
  109. }
  110. if (in_op->GetType() == CAST || out_op->GetType() == CAST) {
  111. return TransOpUtil::CheckPrecisionLoss(in_node);
  112. }
  113. if (in_op_desc->GetFormat() == FORMAT_ND) {
  114. return false;
  115. }
  116. if (out_op_desc->GetFormat() == FORMAT_ND) {
  117. return false;
  118. }
  119. if (in_op_desc->GetFormat() != out_op_desc->GetFormat()) {
  120. return false;
  121. }
  122. return FusionFormatSupport(in_op_desc->GetFormat());
  123. }
  124. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
  125. vector<bool> sub_graph_has_reshape_node(sub_graph_anchors_.size(), false);
  126. vector<int> transop_num_count(sub_graph_anchors_.size(), 0);
  127. vector<vector<NodePtr>> sub_graph_nodes(sub_graph_anchors_.size());
  128. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  129. auto nodes_anchor = sub_graph_anchors_[i];
  130. vector<NodePtr> nodes_tmp;
  131. auto iter = nodes_anchor.begin();
  132. auto first_out_anchor = iter->first;
  133. if (first_out_anchor == nullptr) {
  134. continue;
  135. }
  136. nodes_tmp.push_back(first_out_anchor->GetOwnerNode());
  137. while (iter != nodes_anchor.end()) {
  138. auto in_anchor = iter->second;
  139. GE_CHECK_NOTNULL(in_anchor);
  140. auto in_node = in_anchor->GetOwnerNode();
  141. GE_CHECK_NOTNULL(in_node);
  142. if (in_node->GetType() == RESHAPE) {
  143. sub_graph_has_reshape_node[i] = true;
  144. break;
  145. }
  146. if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) {
  147. auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat();
  148. auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat();
  149. if (input_format == output_format) {
  150. sub_graph_has_reshape_node[i] = true;
  151. break;
  152. }
  153. }
  154. auto out_anchor = iter->first;
  155. GE_CHECK_NOTNULL(out_anchor);
  156. if (!FormatContinuousCheck(out_anchor, in_anchor)) {
  157. sub_graph_has_reshape_node[i] = true;
  158. break;
  159. }
  160. nodes_tmp.push_back(in_node);
  161. if (IsTransOp(in_node)) {
  162. // count transop num
  163. transop_num_count[i]++;
  164. }
  165. ++iter;
  166. }
  167. sub_graph_nodes[i].swap(nodes_tmp);
  168. if (sub_graph_has_reshape_node[i]) {
  169. SetRemainNode(nodes_anchor);
  170. }
  171. }
  172. sub_graph_has_reshape_node_.swap(sub_graph_has_reshape_node);
  173. transop_num_count_.swap(transop_num_count);
  174. sub_graph_nodes_.swap(sub_graph_nodes);
  175. return GRAPH_SUCCESS;
  176. }
  177. void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
  178. const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
  179. // The caller guarantees that the index is legal.
  180. for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) {
  181. auto nodes_anchor = sub_graph_anchors_[index][j];
  182. auto out_data_anchor = nodes_anchor.first;
  183. GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor);
  184. for (const auto &peer_in_control_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  185. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_control_anchor);
  186. auto peer_node = peer_in_control_anchor->GetOwnerNode();
  187. if (peer_node == nullptr) {
  188. continue;
  189. }
  190. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  191. if (iter == sub_graph_nodes_[index].end()) {
  192. out_data_peer_in_control_anchors[index].push_back(peer_in_control_anchor);
  193. } else {
  194. sub_graph_has_out_data_peer_in_control_edge_[index] = true;
  195. }
  196. }
  197. }
  198. }
  199. void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
  200. const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
  201. // The caller guarantees that the index is legal.
  202. for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) {
  203. auto node = sub_graph_nodes_[index][j];
  204. GE_CHECK_NOTNULL_JUST_RETURN(node);
  205. auto in_control_anchor = node->GetInControlAnchor();
  206. if (in_control_anchor == nullptr) {
  207. continue;
  208. }
  209. for (const auto &peer_out_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  210. GE_CHECK_NOTNULL_JUST_RETURN(peer_out_anchor);
  211. auto peer_node = peer_out_anchor->GetOwnerNode();
  212. if (peer_node == nullptr) {
  213. continue;
  214. }
  215. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  216. if (iter == sub_graph_nodes_[index].end()) {
  217. in_control_peer_out_control_anchors[index].push_back(peer_out_anchor);
  218. } else {
  219. sub_graph_has_control_edge_[index] = true;
  220. }
  221. }
  222. }
  223. }
  224. void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors(
  225. const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
  226. vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
  227. for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) {
  228. auto node = sub_graph_nodes_[index][j];
  229. GE_CHECK_NOTNULL_JUST_RETURN(node);
  230. auto out_control_anchor = node->GetOutControlAnchor();
  231. GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor);
  232. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  233. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  234. auto peer_node = peer_in_anchor->GetOwnerNode();
  235. if (peer_node == nullptr) {
  236. continue;
  237. }
  238. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  239. if (iter == sub_graph_nodes_[index].end()) {
  240. if (j > 0) {
  241. out_control_peer_in_control_anchors[index].push_back(peer_in_anchor);
  242. }
  243. } else {
  244. sub_graph_has_control_edge_[index] = true;
  245. }
  246. }
  247. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  248. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  249. auto peer_node = peer_in_anchor->GetOwnerNode();
  250. if (peer_node == nullptr) {
  251. continue;
  252. }
  253. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  254. if (iter == sub_graph_nodes_[index].end()) {
  255. if (j > 0) {
  256. out_control_peer_in_data_anchors[index].push_back(peer_in_anchor);
  257. }
  258. } else {
  259. sub_graph_has_control_edge_[index] = true;
  260. }
  261. }
  262. }
  263. }
  264. void TransOpWithoutReshapeFusionPass::GetControlAnchors() {
  265. vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors(sub_graph_nodes_.size());
  266. vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors(sub_graph_nodes_.size());
  267. vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors(sub_graph_nodes_.size());
  268. vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors(sub_graph_nodes_.size());
  269. vector<bool> sub_graph_has_control_edge(sub_graph_nodes_.size(), false);
  270. sub_graph_has_control_edge_.swap(sub_graph_has_control_edge);
  271. vector<bool> sub_graph_has_out_data_peer_in_control_edge(sub_graph_nodes_.size(), false);
  272. sub_graph_has_out_data_peer_in_control_edge_.swap(sub_graph_has_out_data_peer_in_control_edge);
  273. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  274. if (sub_graph_has_reshape_node_[i]) {
  275. continue;
  276. }
  277. GetOutDataPeerInControlAnchors(i, out_data_peer_in_control_anchors);
  278. GetInControlPeerOutControlAnchors(i, in_control_peer_out_control_anchors);
  279. GetOutControlPeerAnchors(i, out_control_peer_in_control_anchors, out_control_peer_in_data_anchors);
  280. }
  281. in_control_peer_out_control_anchors_.swap(in_control_peer_out_control_anchors);
  282. out_control_peer_in_control_anchors_.swap(out_control_peer_in_control_anchors);
  283. out_control_peer_in_data_anchors_.swap(out_control_peer_in_data_anchors);
  284. out_data_peer_in_control_anchors_.swap(out_data_peer_in_control_anchors);
  285. }
  286. void TransOpWithoutReshapeFusionPass::EraseInvalidAnchorsPair() {
  287. auto sub_graph_iter = sub_graph_anchors_.begin();
  288. while (sub_graph_iter != sub_graph_anchors_.end()) {
  289. if (sub_graph_iter->size() <= 1) {
  290. sub_graph_iter = sub_graph_anchors_.erase(sub_graph_iter);
  291. } else {
  292. ++sub_graph_iter;
  293. }
  294. }
  295. }
  296. void TransOpWithoutReshapeFusionPass::UpdateOutputName(const OutDataAnchorPtr &out_anchor,
  297. const InDataAnchorPtr &old_peer_in_anchor,
  298. const NodePtr &in_owner_node) {
  299. if (out_anchor == nullptr || old_peer_in_anchor == nullptr || in_owner_node == nullptr) {
  300. GELOGI("out_anchor or old_peer_in_anchor or in_owner_node is nullptr");
  301. return;
  302. }
  303. auto out_owner_node = out_anchor->GetOwnerNode();
  304. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  305. GE_CHECK_NOTNULL_JUST_RETURN(old_peer_in_anchor->GetOwnerNode());
  306. auto old_peer_in_name = old_peer_in_anchor->GetOwnerNode()->GetName();
  307. auto output_op = out_owner_node->GetOpDesc();
  308. GE_CHECK_NOTNULL_JUST_RETURN(output_op);
  309. auto output_names = output_op->GetAllOutputName();
  310. auto old_peer_in_name_iter = output_names.find(old_peer_in_name);
  311. if (old_peer_in_name_iter != output_names.end()) {
  312. output_names.erase(old_peer_in_name_iter);
  313. }
  314. output_names[in_owner_node->GetName()] = out_anchor->GetIdx();
  315. if (!output_op->UpdateOutputName(output_names)) {
  316. GELOGW("output_op UpdateOutputName failed");
  317. }
  318. }
  319. void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor,
  320. const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node) {
  321. if (old_peer_out_anchor == nullptr || in_anchor == nullptr || out_owner_node == nullptr) {
  322. GELOGI("old_peer_out_anchor or in_anchor or out_owner_node is nullptr");
  323. return;
  324. }
  325. auto old_node = old_peer_out_anchor->GetOwnerNode();
  326. GE_CHECK_NOTNULL_JUST_RETURN(old_node);
  327. auto old_peer_out_name = old_node->GetName();
  328. auto in_owner_node = in_anchor->GetOwnerNode();
  329. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  330. auto input_op = in_owner_node->GetOpDesc();
  331. GE_CHECK_NOTNULL_JUST_RETURN(input_op);
  332. auto input_names = input_op->GetAllInputName();
  333. auto old_peer_out_name_iter = input_names.find(old_peer_out_name);
  334. if (old_peer_out_name_iter != input_names.end()) {
  335. input_names.erase(old_peer_out_name_iter);
  336. }
  337. input_names[out_owner_node->GetName()] = in_anchor->GetIdx();
  338. input_op->UpdateInputName(input_names);
  339. }
  340. graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
  341. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  342. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  343. auto out_anchor = begin_anchors_pair.first;
  344. GE_CHECK_NOTNULL(out_anchor);
  345. auto out_owner_node = out_anchor->GetOwnerNode();
  346. GE_CHECK_NOTNULL(out_owner_node);
  347. auto in_anchor = end_anchors_pair.second;
  348. GE_CHECK_NOTNULL(in_anchor);
  349. auto in_owner_node = in_anchor->GetOwnerNode();
  350. GE_CHECK_NOTNULL(in_owner_node);
  351. if (sub_graph_has_control_edge_[index]) {
  352. GELOGI("add control edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str());
  353. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), in_owner_node->GetInControlAnchor()) !=
  354. GRAPH_SUCCESS) {
  355. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  356. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  357. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  358. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  359. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  360. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  361. return GRAPH_FAILED;
  362. }
  363. }
  364. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  365. GELOGI("add out data 2 in contorl edge.src:%s, dst:%s", out_owner_node->GetName().c_str(),
  366. in_owner_node->GetName().c_str());
  367. if (GraphUtils::AddEdge(out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  368. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  369. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  370. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  371. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  372. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  373. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  374. return GRAPH_FAILED;
  375. }
  376. }
  377. return GRAPH_SUCCESS;
  378. }
  379. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged(
  380. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  381. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  382. if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) {
  383. return GRAPH_FAILED;
  384. }
  385. auto out_anchor = begin_anchors_pair.first;
  386. GE_CHECK_NOTNULL(out_anchor);
  387. auto out_owner_node = out_anchor->GetOwnerNode();
  388. GE_CHECK_NOTNULL(out_owner_node);
  389. auto in_anchor = end_anchors_pair.second;
  390. GE_CHECK_NOTNULL(in_anchor);
  391. auto in_owner_node = in_anchor->GetOwnerNode();
  392. GE_CHECK_NOTNULL(in_owner_node);
  393. // can not remove old control edge
  394. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  395. GE_CHECK_NOTNULL(peer_in_anchor);
  396. GELOGI("add control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  397. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  398. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  399. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  400. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  401. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  402. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  403. GELOGE(GRAPH_FAILED, "[Add]ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  404. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  405. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  406. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  407. return GRAPH_FAILED;
  408. }
  409. }
  410. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  411. GE_CHECK_NOTNULL(peer_out_anchor);
  412. GELOGI("add control edge.src:%s, src idx:%d, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  413. peer_out_anchor->GetIdx(), in_owner_node->GetName().c_str());
  414. if (GraphUtils::AddEdge(peer_out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  415. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  416. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  417. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  418. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  419. GELOGE(GRAPH_FAILED, "[Add]ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  420. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  421. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  422. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str());
  423. return GRAPH_FAILED;
  424. }
  425. }
  426. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  427. GE_CHECK_NOTNULL(peer_in_anchor);
  428. GELOGI("add out control 2 in data edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  429. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  430. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  431. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  432. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  433. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  434. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  435. GELOGE(GRAPH_FAILED, "[Add]ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  436. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  437. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  438. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  439. return GRAPH_FAILED;
  440. }
  441. }
  442. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  443. GE_CHECK_NOTNULL(peer_in_anchor);
  444. GELOGI("add out data 2 in control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  445. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  446. if (GraphUtils::AddEdge(out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  447. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  448. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  449. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  450. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  451. GELOGE(GRAPH_FAILED, "[Add]ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  452. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(),
  453. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  454. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  455. return GRAPH_FAILED;
  456. }
  457. }
  458. return GRAPH_SUCCESS;
  459. }
  460. graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged(
  461. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  462. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  463. auto out_anchor = begin_anchors_pair.first;
  464. GE_CHECK_NOTNULL(out_anchor);
  465. auto out_owner_node = out_anchor->GetOwnerNode();
  466. GE_CHECK_NOTNULL(out_owner_node);
  467. auto in_anchor = end_anchors_pair.second;
  468. GE_CHECK_NOTNULL(in_anchor);
  469. auto in_owner_node = in_anchor->GetOwnerNode();
  470. GE_CHECK_NOTNULL(in_owner_node);
  471. GELOGI("remove edge.src %s, src idx:%d, dst:%s, dst idx:%d",
  472. end_anchors_pair.first->GetOwnerNode()->GetName().c_str(), end_anchors_pair.first->GetIdx(),
  473. in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  474. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_anchors_pair.first, in_anchor),
  475. "[Remove][Edge] between %s(%s)(index:%d) and %s(%s)(index:%d) failed",
  476. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  477. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  478. GELOGI("relink node.src node:%s, src idx:%d, dst node:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  479. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  480. if (GraphUtils::AddEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) {
  481. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  482. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  483. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  484. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  485. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  486. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  487. return GRAPH_FAILED;
  488. } else {
  489. auto old_peer_in_anchor = begin_anchors_pair.second;
  490. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  491. auto old_peer_out_anchor = end_anchors_pair.first;
  492. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  493. }
  494. return RelinkControlEdgesWhenDescNotChanged(begin_anchors_pair, end_anchors_pair, index);
  495. }
  496. OpDescPtr TransOpWithoutReshapeFusionPass::GetFormatTransferOp(const GeTensorDesc &format_trans_input_desc,
  497. const GeTensorDesc &format_trans_output_desc) {
  498. static std::atomic_long atomic_fusion_format_transfer_op_count(1);
  499. auto fusion_format_transfer_op_count = atomic_fusion_format_transfer_op_count.fetch_add(1);
  500. std::stringstream format_transfer_op_name;
  501. format_transfer_op_name << "fusion_format_transfer_" << fusion_format_transfer_op_count;
  502. OpDescPtr format_transfer_op = MakeShared<OpDesc>(format_transfer_op_name.str().c_str(), TRANSDATA);
  503. if (format_transfer_op == nullptr) {
  504. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  505. GELOGE(INTERNAL_ERROR, "[New][OpDesc] failed");
  506. return nullptr;
  507. }
  508. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_INPUT_FORMAT,
  509. static_cast<int64_t>(format_trans_input_desc.GetFormat())),
  510. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(),
  511. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  512. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_INPUT_FORMAT.c_str(),
  513. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  514. return nullptr);
  515. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_OUTPUT_FORMAT,
  516. static_cast<int64_t>(format_trans_output_desc.GetFormat())),
  517. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(),
  518. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  519. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_OUTPUT_FORMAT.c_str(),
  520. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  521. return nullptr);
  522. string src_format = TypeUtils::FormatToSerialString(format_trans_input_desc.GetFormat());
  523. string dst_format = TypeUtils::FormatToSerialString(format_trans_output_desc.GetFormat());
  524. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameSrcFormat, src_format),
  525. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kAttrNameSrcFormat,
  526. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  527. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", kAttrNameSrcFormat,
  528. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  529. return nullptr);
  530. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameDstFormat, dst_format),
  531. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kAttrNameDstFormat,
  532. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  533. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", kAttrNameDstFormat,
  534. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  535. return nullptr);
  536. GE_IF_BOOL_EXEC(format_transfer_op->AddInputDesc(format_trans_input_desc) != GRAPH_SUCCESS,
  537. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  538. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  539. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed",
  540. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  541. return nullptr);
  542. GE_IF_BOOL_EXEC(format_transfer_op->AddOutputDesc(format_trans_output_desc) != GRAPH_SUCCESS,
  543. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  544. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  545. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed",
  546. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  547. return nullptr);
  548. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetBool(format_transfer_op, ATTR_NEED_COMPILE, true),
  549. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  550. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  551. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  552. format_transfer_op->GetName().c_str(), format_transfer_op->GetType().c_str());
  553. return nullptr);
  554. return format_transfer_op;
  555. }
  556. OpDescPtr TransOpWithoutReshapeFusionPass::GetCastOp(const GeTensorDesc &cast_input_desc,
  557. const GeTensorDesc &cast_output_desc) {
  558. static std::atomic_long atomic_fusion_cast_op_count(1);
  559. auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1);
  560. std::stringstream cast_op_name;
  561. cast_op_name << "fusion_cast_op_" << fusion_cast_op_count;
  562. auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST);
  563. auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  564. node_op.BreakConnect();
  565. if (cast_op == nullptr) {
  566. REPORT_CALL_ERROR("E19999", "Create operator:%s(%s) failed", cast_op_name.str().c_str(), CAST);
  567. GELOGE(INTERNAL_ERROR, "[Create][Operator] %s(%s) failed", cast_op_name.str().c_str(), CAST);
  568. return nullptr;
  569. }
  570. const int default_input_index = 0;
  571. const int default_output_index = 0;
  572. if (cast_op->GetInputsSize() == 0) {
  573. GE_IF_BOOL_EXEC(cast_op->AddInputDesc(cast_input_desc) != GRAPH_SUCCESS,
  574. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  575. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  576. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed",
  577. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  578. return nullptr);
  579. } else {
  580. GE_IF_BOOL_EXEC(cast_op->UpdateInputDesc(default_input_index, cast_input_desc) != GRAPH_SUCCESS,
  581. REPORT_CALL_ERROR("E19999", "Update input:%d desc of op:%s(%s) failed", default_input_index,
  582. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  583. GELOGE(INTERNAL_ERROR, "[Update][InputDesc] of op:%s(%s) failed, input index:%d",
  584. cast_op->GetName().c_str(), cast_op->GetType().c_str(), default_input_index);
  585. return nullptr);
  586. }
  587. if (cast_op->GetOutputsSize() == 0) {
  588. GE_IF_BOOL_EXEC(cast_op->AddOutputDesc(cast_output_desc) != GRAPH_SUCCESS,
  589. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  590. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  591. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed",
  592. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  593. return nullptr);
  594. } else {
  595. GE_IF_BOOL_EXEC(cast_op->UpdateOutputDesc(default_output_index, cast_output_desc) != GRAPH_SUCCESS,
  596. REPORT_CALL_ERROR("E19999", "Update output:%d desc of op:%s(%s) failed", default_output_index,
  597. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  598. GELOGE(INTERNAL_ERROR, "[Update][OutputDesc] of op:%s(%s) failed, output index:%d",
  599. cast_op->GetName().c_str(), cast_op->GetType().c_str(), default_output_index);
  600. return nullptr);
  601. }
  602. if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast<int64_t>(cast_output_desc.GetDataType()))) {
  603. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", CAST_ATTR_DST_TYPE.c_str(),
  604. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  605. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", CAST_ATTR_DST_TYPE.c_str(),
  606. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  607. return nullptr;
  608. }
  609. if (!AttrUtils::SetBool(cast_op, ATTR_NEED_COMPILE, true)) {
  610. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  611. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  612. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
  613. cast_op->GetName().c_str(), cast_op->GetType().c_str());
  614. return nullptr;
  615. }
  616. return cast_op;
  617. }
  618. bool TransOpWithoutReshapeFusionPass::InsertCastFirstCheck(const GeTensorDesc &out_desc,
  619. const GeTensorDesc &in_desc) const {
  620. return out_desc.GetDataType() != in_desc.GetDataType() && out_desc.GetDataType() != DT_FLOAT16 &&
  621. in_desc.GetDataType() == DT_FLOAT16;
  622. }
  623. void TransOpWithoutReshapeFusionPass::GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  624. GeTensorDesc &format_transfer_input,
  625. GeTensorDesc &format_transfer_output) {
  626. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  627. if (insert_cast_first) {
  628. format_transfer_input = out_desc;
  629. format_transfer_input.SetDataType(in_desc.GetDataType());
  630. format_transfer_output = in_desc;
  631. } else {
  632. format_transfer_input = out_desc;
  633. format_transfer_output = in_desc;
  634. format_transfer_output.SetDataType(out_desc.GetDataType());
  635. }
  636. }
  637. void TransOpWithoutReshapeFusionPass::GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  638. GeTensorDesc &cast_input, GeTensorDesc &cast_output) {
  639. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  640. if (insert_cast_first) {
  641. cast_input = out_desc;
  642. cast_output = out_desc;
  643. cast_output.SetDataType(in_desc.GetDataType());
  644. } else {
  645. cast_input = in_desc;
  646. cast_input.SetDataType(out_desc.GetDataType());
  647. cast_output = in_desc;
  648. }
  649. }
  650. void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc,
  651. GeTensorDesc &in_desc) {
  652. auto nodes_anchor = sub_graph_anchors_[index];
  653. auto out_peer_anchor = nodes_anchor.front().second;
  654. GE_CHECK_NOTNULL_JUST_RETURN(out_peer_anchor);
  655. auto out_owner_node = out_peer_anchor->GetOwnerNode();
  656. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  657. auto out_peer_op_desc = out_owner_node->GetOpDesc();
  658. GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr,
  659. GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, out_peer_op_desc is nullptr"); return);
  660. out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx());
  661. auto in_peer_anchor = nodes_anchor.back().first;
  662. GE_CHECK_NOTNULL_JUST_RETURN(in_peer_anchor);
  663. auto in_owner_node = in_peer_anchor->GetOwnerNode();
  664. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  665. auto in_peer_op_desc = in_owner_node->GetOpDesc();
  666. GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr,
  667. GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, in_peer_op_desc is nullptr"); return);
  668. in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx());
  669. }
  670. graphStatus TransOpWithoutReshapeFusionPass::FormatFusion(const int index, OpDescPtr &format_transfer_op,
  671. int32_t &fusion_op_count, bool &fusion_continue) {
  672. GeTensorDesc out_desc;
  673. GeTensorDesc in_desc;
  674. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  675. GeTensorDesc format_transfer_input;
  676. GeTensorDesc format_transfer_output;
  677. GetFormatTransferDesc(out_desc, in_desc, format_transfer_input, format_transfer_output);
  678. if (out_desc.GetFormat() == in_desc.GetFormat() &&
  679. (!ShapeEqualCheck(out_desc.GetShape(), in_desc.GetShape()) ||
  680. !ShapeEqualCheck(out_desc.GetOriginShape(), in_desc.GetOriginShape()))) {
  681. SetRemainNode(sub_graph_anchors_[index]);
  682. return GRAPH_SUCCESS;
  683. }
  684. if (out_desc.GetFormat() != in_desc.GetFormat() && FusionFormatSupport(out_desc.GetFormat()) &&
  685. FusionFormatSupport(in_desc.GetFormat())) {
  686. // create format transop
  687. format_transfer_op = GetFormatTransferOp(format_transfer_input, format_transfer_output);
  688. if (format_transfer_op == nullptr) {
  689. return GRAPH_FAILED;
  690. }
  691. if (OpAccuracyAbilityCheck(format_transfer_op)) {
  692. ++fusion_op_count;
  693. GELOGI("support format transfer op %s", format_transfer_op->GetName().c_str());
  694. } else {
  695. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  696. format_transfer_input.GetFormat(), format_transfer_input.GetDataType(), format_transfer_output.GetFormat(),
  697. format_transfer_output.GetDataType());
  698. fusion_op_count = kInvalidFusionOpCount;
  699. }
  700. } else if (out_desc.GetFormat() != in_desc.GetFormat()) {
  701. SetRemainNode(sub_graph_anchors_[index]);
  702. return GRAPH_SUCCESS;
  703. }
  704. fusion_continue = true;
  705. return GRAPH_SUCCESS;
  706. }
  707. graphStatus TransOpWithoutReshapeFusionPass::DataTypeFusion(const int index, OpDescPtr &cast_op,
  708. int32_t &fusion_op_count) {
  709. GeTensorDesc out_desc;
  710. GeTensorDesc in_desc;
  711. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  712. GeTensorDesc cast_input;
  713. GeTensorDesc cast_output;
  714. GetCastOpDesc(out_desc, in_desc, cast_input, cast_output);
  715. if (fusion_op_count != kInvalidFusionOpCount && out_desc.GetDataType() != in_desc.GetDataType()) {
  716. // create cast op
  717. cast_op = GetCastOp(cast_input, cast_output);
  718. if (cast_op == nullptr) {
  719. fusion_op_count = kInvalidFusionOpCount;
  720. return GRAPH_FAILED;
  721. }
  722. if (OpAccuracyAbilityCheck(cast_op)) {
  723. ++fusion_op_count;
  724. GELOGI("support cast op %s. src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  725. cast_op->GetName().c_str(), cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(),
  726. cast_output.GetDataType());
  727. } else {
  728. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  729. cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType());
  730. fusion_op_count = kInvalidFusionOpCount;
  731. }
  732. }
  733. return GRAPH_SUCCESS;
  734. }
  735. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuseHandle(const ComputeGraphPtr &graph, const int index) {
  736. bool fusion_continue = false;
  737. OpDescPtr format_transfer_op = nullptr;
  738. int32_t fusion_op_count = 0;
  739. auto fortmat_fusion_ret = FormatFusion(index, format_transfer_op, fusion_op_count, fusion_continue);
  740. if (fortmat_fusion_ret != GRAPH_SUCCESS || !fusion_continue) {
  741. SetRemainNode(sub_graph_anchors_[index]);
  742. return GRAPH_SUCCESS;
  743. }
  744. OpDescPtr cast_op = nullptr;
  745. if (DataTypeFusion(index, cast_op, fusion_op_count) != GRAPH_SUCCESS) {
  746. SetRemainNode(sub_graph_anchors_[index]);
  747. return GRAPH_SUCCESS;
  748. }
  749. if (fusion_op_count != kInvalidFusionOpCount && fusion_op_count < transop_num_count_[index]) {
  750. GeTensorDesc out_desc;
  751. GeTensorDesc in_desc;
  752. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  753. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  754. if (InsertNewTransOp(graph, cast_op, format_transfer_op, index, insert_cast_first) != GRAPH_SUCCESS) {
  755. return GRAPH_FAILED;
  756. }
  757. } else {
  758. // remain all nodes
  759. SetRemainNode(sub_graph_anchors_[index]);
  760. }
  761. return GRAPH_SUCCESS;
  762. }
  763. void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &graph) {
  764. if (graph == nullptr) {
  765. return;
  766. }
  767. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  768. if (sub_graph_has_reshape_node_[i]) {
  769. continue;
  770. }
  771. for (const auto &node : sub_graph_nodes_[i]) {
  772. GE_CHECK_NOTNULL_JUST_RETURN(node);
  773. // remove nodes
  774. if (!IsTransOp(node)) {
  775. continue;
  776. }
  777. auto op_desc = node->GetOpDesc();
  778. GE_CHECK_NOTNULL_JUST_RETURN(op_desc);
  779. bool node_remain_flag = op_desc->TryGetExtAttr(kRemainNode, false);
  780. if (node_remain_flag) {
  781. continue;
  782. }
  783. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true),
  784. GELOGE(INTERNAL_ERROR, "[Set][ExtAttr] for op:%s failed", op_desc->GetName().c_str()); return);
  785. GELOGI("remove node:%s", node->GetName().c_str());
  786. if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) {
  787. GELOGW("Isolate node: %s failed.", node->GetName().c_str());
  788. continue;
  789. }
  790. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  791. GELOGW("Remove node: %s failed.", node->GetName().c_str());
  792. continue;
  793. }
  794. }
  795. }
  796. }
  797. graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) {
  798. GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin.");
  799. if (graph == nullptr) {
  800. return GRAPH_SUCCESS;
  801. }
  802. for (const auto &node : graph->GetDirectNode()) {
  803. GE_CHECK_NOTNULL(node);
  804. if (IsTransOp(node)) {
  805. continue;
  806. }
  807. bool is_unknown = false;
  808. auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
  809. if (ret != GRAPH_SUCCESS) {
  810. GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(),
  811. node->GetType().c_str());
  812. continue;
  813. }
  814. if (is_unknown) {
  815. GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(),
  816. node->GetType().c_str());
  817. continue;
  818. }
  819. GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
  820. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  821. GE_CHECK_NOTNULL(out_anchor);
  822. vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
  823. vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> nodes_list;
  824. if (GetSubGraphsBetweenNormalNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) {
  825. GELOGW("get transops failed!");
  826. continue;
  827. }
  828. sub_graph_anchors_.swap(sub_graph_anchors);
  829. EraseInvalidAnchorsPair();
  830. if (sub_graph_anchors_.empty()) {
  831. continue;
  832. }
  833. // check reshape node
  834. if (GetSubGraphNodesInfo() != GRAPH_SUCCESS) {
  835. continue;
  836. }
  837. // save control edge
  838. GetControlAnchors();
  839. if (TransOpFuse(graph) != GRAPH_SUCCESS) {
  840. return GRAPH_FAILED;
  841. }
  842. }
  843. }
  844. GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end.");
  845. return GRAPH_SUCCESS;
  846. }
  847. bool TransOpWithoutReshapeFusionPass::DescEqualCheck(ConstGeTensorDescPtr &desc_src,
  848. ConstGeTensorDescPtr &desc_dst) const {
  849. if (desc_src == nullptr || desc_dst == nullptr) {
  850. return false;
  851. }
  852. if (desc_src->GetFormat() != desc_dst->GetFormat() || desc_src->GetDataType() != desc_dst->GetDataType()) {
  853. return false;
  854. }
  855. if (!ShapeEqualCheck(desc_src->GetShape(), desc_dst->GetShape())) {
  856. return false;
  857. }
  858. return ShapeEqualCheck(desc_src->GetOriginShape(), desc_dst->GetOriginShape());
  859. }
  860. bool TransOpWithoutReshapeFusionPass::ShapeEqualCheck(const GeShape &src, const GeShape &dst) const {
  861. if (src.GetDims().size() != dst.GetDims().size()) {
  862. return false;
  863. }
  864. for (size_t i = 0; i < src.GetDims().size(); ++i) {
  865. if (src.GetDim(i) != dst.GetDim(i)) {
  866. return false;
  867. }
  868. }
  869. return true;
  870. }
  871. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuse(const ComputeGraphPtr &graph) {
  872. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  873. if (sub_graph_has_reshape_node_[i]) {
  874. continue;
  875. }
  876. auto nodes_anchor = sub_graph_anchors_[i];
  877. auto out_anchor = nodes_anchor.front().first;
  878. GE_CHECK_NOTNULL(out_anchor);
  879. auto out_op_desc = out_anchor->GetOwnerNode()->GetOpDesc();
  880. GE_CHECK_NOTNULL(out_op_desc);
  881. auto out_desc = out_op_desc->GetOutputDescPtr(out_anchor->GetIdx());
  882. GE_CHECK_NOTNULL(out_desc);
  883. auto in_anchor = nodes_anchor.back().second;
  884. GE_CHECK_NOTNULL(in_anchor);
  885. auto in_op_desc = in_anchor->GetOwnerNode()->GetOpDesc();
  886. GE_CHECK_NOTNULL(in_op_desc);
  887. auto in_desc = in_op_desc->GetInputDescPtr(in_anchor->GetIdx());
  888. GE_CHECK_NOTNULL(in_desc);
  889. if (FusionFormatSupport(out_desc->GetFormat()) && DescEqualCheck(out_desc, in_desc)) {
  890. // relink begin_out to end_in
  891. if (RelinkNodesWhenDescNotChanged(nodes_anchor.front(), nodes_anchor.back(), static_cast<int>(i)) !=
  892. GRAPH_SUCCESS) {
  893. return GRAPH_FAILED;
  894. }
  895. } else {
  896. if (TransOpFuseHandle(graph, static_cast<int>(i)) != GRAPH_SUCCESS) {
  897. return GRAPH_FAILED;
  898. }
  899. }
  900. }
  901. RemoveNousedNodes(graph);
  902. return GRAPH_SUCCESS;
  903. }
  904. graphStatus TransOpWithoutReshapeFusionPass::AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop,
  905. NodePtr &trans_node) {
  906. if (graph == nullptr) {
  907. return GRAPH_SUCCESS;
  908. }
  909. if (transop == nullptr) {
  910. return GRAPH_SUCCESS;
  911. }
  912. trans_node = graph->AddNode(transop);
  913. if (trans_node == nullptr) {
  914. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  915. transop->GetName().c_str(), transop->GetType().c_str(), graph->GetName().c_str());
  916. GELOGE(GRAPH_FAILED, "[Add][Node] %s(%s) to graph:%s failed",
  917. transop->GetName().c_str(), transop->GetType().c_str(), graph->GetName().c_str());
  918. return GRAPH_FAILED;
  919. }
  920. return GRAPH_SUCCESS;
  921. }
  922. graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  923. const OpDescPtr &format_transfer_op,
  924. const bool insert_cast_first,
  925. std::vector<NodePtr> &new_trans_nodes) {
  926. NodePtr format_transfer_node;
  927. if (AddTransNode(graph, format_transfer_op, format_transfer_node) != GRAPH_SUCCESS) {
  928. return GRAPH_FAILED;
  929. }
  930. NodePtr cast_node;
  931. if (AddTransNode(graph, cast_op, cast_node) != GRAPH_SUCCESS) {
  932. return GRAPH_FAILED;
  933. }
  934. if (insert_cast_first) {
  935. if (cast_node != nullptr) {
  936. new_trans_nodes.push_back(cast_node);
  937. }
  938. if (format_transfer_node != nullptr) {
  939. new_trans_nodes.push_back(format_transfer_node);
  940. }
  941. } else {
  942. if (format_transfer_node != nullptr) {
  943. new_trans_nodes.push_back(format_transfer_node);
  944. }
  945. if (cast_node != nullptr) {
  946. new_trans_nodes.push_back(cast_node);
  947. }
  948. }
  949. return GRAPH_SUCCESS;
  950. }
  951. graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  952. const OpDescPtr &format_transfer_op, const int index,
  953. const bool insert_cast_first) {
  954. std::vector<NodePtr> new_trans_nodes;
  955. if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) {
  956. return GRAPH_FAILED;
  957. }
  958. if (new_trans_nodes.empty()) {
  959. GELOGI("No new trans node. Do not need insert new transop.");
  960. return GRAPH_SUCCESS;
  961. }
  962. pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front();
  963. pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back();
  964. auto out_anchor = begin_out.first;
  965. GE_CHECK_NOTNULL(out_anchor);
  966. auto out_owner_node = out_anchor->GetOwnerNode();
  967. GE_CHECK_NOTNULL(out_owner_node);
  968. auto in_anchor = end_in.second;
  969. GE_CHECK_NOTNULL(in_anchor);
  970. auto in_owner_node = in_anchor->GetOwnerNode();
  971. GE_CHECK_NOTNULL(in_owner_node);
  972. GELOGI("remove edge.src:%s, src idx:%d, dst:%s, dst idx:%d", end_in.first->GetOwnerNode()->GetName().c_str(),
  973. end_in.first->GetIdx(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  974. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_in.first, in_anchor),
  975. "[Remove][Edge] between %s and %s failed",
  976. out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str());
  977. GELOGI("add edge.src:%s, src idx:%d, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetIdx(),
  978. new_trans_nodes.front()->GetName().c_str());
  979. if (GraphUtils::AddEdge(out_anchor, new_trans_nodes.front()->GetInAnchor(0)) != GRAPH_SUCCESS) {
  980. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  981. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  982. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  983. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  984. out_owner_node->GetName().c_str(), out_owner_node->GetType().c_str(), out_anchor->GetIdx(),
  985. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  986. return GRAPH_FAILED;
  987. } else {
  988. auto old_peer_in_anchor = begin_out.second;
  989. GE_CHECK_NOTNULL(old_peer_in_anchor);
  990. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  991. }
  992. if (new_trans_nodes.size() > 1) {
  993. GELOGI("add edge.src:%s, dst:%s", new_trans_nodes.front()->GetName().c_str(),
  994. new_trans_nodes.back()->GetName().c_str());
  995. if (GraphUtils::AddEdge(new_trans_nodes.front()->GetOutAnchor(0), new_trans_nodes.back()->GetInAnchor(0)) !=
  996. GRAPH_SUCCESS) {
  997. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  998. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  999. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str());
  1000. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  1001. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  1002. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str());
  1003. return GRAPH_FAILED;
  1004. } else {
  1005. auto old_peer_out_anchor = end_in.first;
  1006. GE_CHECK_NOTNULL(old_peer_out_anchor);
  1007. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  1008. }
  1009. }
  1010. GELOGI("add edge.src:%s, dst:%s, dst idx:%d", new_trans_nodes.back()->GetName().c_str(),
  1011. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  1012. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  1013. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1014. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  1015. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  1016. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1017. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str(),
  1018. in_owner_node->GetName().c_str(), in_owner_node->GetType().c_str(), in_anchor->GetIdx());
  1019. return GRAPH_FAILED;
  1020. }
  1021. return RelinkControlEdge(index, out_anchor, new_trans_nodes);
  1022. }
  1023. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
  1024. const vector<NodePtr> &new_trans_nodes) {
  1025. GE_CHECK_NOTNULL(out_anchor);
  1026. if (new_trans_nodes.front() == nullptr || new_trans_nodes.back() == nullptr) {
  1027. REPORT_INNER_ERROR("E19999", "Param new_trans_nodes front or back is nullptr, check invalid");
  1028. GELOGE(GRAPH_FAILED, "[Check][Param] Param new_trans_nodes front or back is nullptr");
  1029. return GRAPH_FAILED;
  1030. }
  1031. if (sub_graph_has_control_edge_[index]) {
  1032. GELOGI("add control edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
  1033. new_trans_nodes.front()->GetName().c_str());
  1034. if (GraphUtils::AddEdge(out_anchor->GetOwnerNode()->GetOutControlAnchor(),
  1035. new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  1036. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  1037. out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
  1038. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  1039. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  1040. out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
  1041. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  1042. return GRAPH_FAILED;
  1043. }
  1044. }
  1045. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  1046. GE_CHECK_NOTNULL(peer_in_anchor);
  1047. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1048. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  1049. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  1050. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  1051. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1052. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1053. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  1054. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  1055. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1056. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetType().c_str());
  1057. return GRAPH_FAILED;
  1058. }
  1059. }
  1060. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  1061. GE_CHECK_NOTNULL(peer_out_anchor);
  1062. GELOGI("add control edge.src:%s, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  1063. new_trans_nodes.front()->GetName().c_str());
  1064. if (GraphUtils::AddEdge(peer_out_anchor, new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  1065. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  1066. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  1067. peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  1068. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  1069. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  1070. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetType().c_str(),
  1071. new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.front()->GetType().c_str());
  1072. return GRAPH_FAILED;
  1073. }
  1074. }
  1075. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  1076. GE_CHECK_NOTNULL(peer_in_anchor);
  1077. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1078. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  1079. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  1080. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  1081. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1082. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1083. peer_in_anchor->GetOwnerNode()->GetType().c_str());
  1084. GELOGE(GRAPH_FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  1085. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1086. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetType().c_str());
  1087. return GRAPH_FAILED;
  1088. }
  1089. }
  1090. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  1091. GE_CHECK_NOTNULL(peer_in_anchor);
  1092. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1093. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  1094. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), peer_in_anchor) != GRAPH_SUCCESS) {
  1095. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1096. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1097. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1098. peer_in_anchor->GetOwnerNode()->GetType().c_str(), peer_in_anchor->GetIdx());
  1099. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1100. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1101. peer_in_anchor->GetOwnerNode()->GetName().c_str(),
  1102. peer_in_anchor->GetOwnerNode()->GetType().c_str(), peer_in_anchor->GetIdx());
  1103. return GRAPH_FAILED;
  1104. }
  1105. }
  1106. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  1107. auto in_anchor = sub_graph_anchors_[index].back().second;
  1108. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  1109. in_anchor->GetOwnerNode()->GetName().c_str());
  1110. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0),
  1111. in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  1112. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s) and op:%s(%s) failed",
  1113. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1114. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetType().c_str());
  1115. GELOGE(GRAPH_FAILED, "[Add][Edge] between op:%s(%s) and op:%s(%s) failed",
  1116. new_trans_nodes.back()->GetName().c_str(), new_trans_nodes.back()->GetType().c_str(),
  1117. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetType().c_str());
  1118. return GRAPH_FAILED;
  1119. }
  1120. }
  1121. return GRAPH_SUCCESS;
  1122. }
  1123. bool TransOpWithoutReshapeFusionPass::OpAccuracyAbilityCheck(const OpDescPtr &op_desc) {
  1124. auto instance = GELib::GetInstance();
  1125. if ((instance == nullptr) || (!instance->InitFlag())) {
  1126. GELOGW("GELib is not initialized!");
  1127. return false;
  1128. }
  1129. if (op_desc == nullptr) {
  1130. return false;
  1131. }
  1132. OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
  1133. vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  1134. if (op_infos.empty()) {
  1135. GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str());
  1136. return false;
  1137. }
  1138. std::string unsupported_reason;
  1139. for (const auto &it : op_infos) {
  1140. auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  1141. auto &kernel_name = it.opKernelLib;
  1142. auto kernel_info_store = kernel_map.find(kernel_name);
  1143. if (kernel_info_store != kernel_map.end()) {
  1144. if (kernel_info_store->second != nullptr &&
  1145. kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) {
  1146. op_desc->SetOpEngineName(it.engine);
  1147. op_desc->SetOpKernelLibName(kernel_name);
  1148. GELOGI("Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(),
  1149. op_desc->GetName().c_str());
  1150. return true;
  1151. }
  1152. }
  1153. }
  1154. GELOGI("op %s CheckAccuracySupported failed!reason:%s", op_desc->GetType().c_str(), unsupported_reason.c_str());
  1155. return false;
  1156. }
  1157. bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) {
  1158. return format == FORMAT_NCHW || format == FORMAT_NHWC || format == FORMAT_FRACTAL_Z || format == FORMAT_NC1HWC0;
  1159. }
  1160. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode(
  1161. const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
  1162. vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
  1163. graphStatus ret = GRAPH_SUCCESS;
  1164. if (out_anchor == nullptr) {
  1165. REPORT_INNER_ERROR("E19999", "Param out_anchor is nullptr, check invalid");
  1166. GELOGE(GRAPH_FAILED, "[Check][Param] param out_anchor is nullptr");
  1167. return GRAPH_FAILED;
  1168. }
  1169. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  1170. if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr ||
  1171. peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  1172. continue;
  1173. }
  1174. nodes_list.emplace_back(out_anchor, peer_in_anchor);
  1175. auto peer_in_node = peer_in_anchor->GetOwnerNode();
  1176. GE_CHECK_NOTNULL(peer_in_node);
  1177. if (!IsTransOp(peer_in_node)) {
  1178. sub_graphs_out.push_back(nodes_list);
  1179. nodes_list.pop_back();
  1180. } else {
  1181. for (const auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
  1182. ret = GetSubGraphsBetweenNormalNode(peer_out_anchor, sub_graphs_out, nodes_list);
  1183. if (ret != GRAPH_SUCCESS) {
  1184. GELOGE(GRAPH_FAILED, "[Get][SubGraphs] Between Normal Node failed! node:%s",
  1185. peer_in_node->GetName().c_str());
  1186. return GRAPH_FAILED;
  1187. }
  1188. }
  1189. nodes_list.pop_back();
  1190. }
  1191. }
  1192. return GRAPH_SUCCESS;
  1193. }
  1194. bool TransOpWithoutReshapeFusionPass::IsTransOp(const NodePtr &node) {
  1195. // The caller guarantees that the pointer is not null.
  1196. return node->GetType() == CAST || node->GetType() == RESHAPE || node->GetType() == TRANSPOSE ||
  1197. node->GetType() == TRANSPOSED || node->GetType() == TRANSDATA;
  1198. }
  1199. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示