You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_without_reshape_fusion_pass.cc 46 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/transop_without_reshape_fusion_pass.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <sstream>
  20. #include <string>
  21. #include <atomic>
  22. #include "common/ge/ge_util.h"
  23. #include "common/ge_inner_error_codes.h"
  24. #include "common/types.h"
  25. #include "graph/common/transop_util.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_tensor.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/op_desc_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. namespace {
  36. const char *const kRemainNode = "node_remain";
  37. const int kInvalidFusionOpCount = -1;
  38. const char *const kAttrNameSrcFormat = "src_format";
  39. const char *const kAttrNameDstFormat = "dst_format";
  40. } // namespace
  41. namespace ge {
  42. void TransOpWithoutReshapeFusionPass::SetRemainNode(
  43. const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
  44. auto iter = nodes_anchor.begin();
  45. while (iter != nodes_anchor.end()) {
  46. auto in_anchor = iter->second;
  47. if (in_anchor == nullptr) {
  48. return;
  49. }
  50. auto in_node = in_anchor->GetOwnerNode();
  51. ++iter;
  52. if (in_node == nullptr) {
  53. return;
  54. }
  55. if (!IsTransOp(in_node)) {
  56. continue;
  57. }
  58. auto op_desc = in_node->GetOpDesc();
  59. if (op_desc == nullptr) {
  60. continue;
  61. }
  62. GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str());
  63. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
  64. }
  65. }
  66. bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorPtr &out_anchor,
  67. const InDataAnchorPtr &in_anchor) {
  68. if (out_anchor == nullptr || in_anchor == nullptr || in_anchor->GetOwnerNode() == nullptr ||
  69. out_anchor->GetOwnerNode() == nullptr) {
  70. return false;
  71. }
  72. auto in_node = in_anchor->GetOwnerNode();
  73. GE_IF_BOOL_EXEC(in_node == nullptr, GELOGE(INTERNAL_ERROR, "in_node is null"); return false);
  74. auto in_op = in_node->GetOpDesc();
  75. auto out_owner_node = out_anchor->GetOwnerNode();
  76. GE_IF_BOOL_EXEC(out_owner_node == nullptr, GELOGE(INTERNAL_ERROR, "out_owner_node is null"); return false);
  77. auto out_op = out_owner_node->GetOpDesc();
  78. GE_IF_BOOL_EXEC(in_op == nullptr, GELOGE(INTERNAL_ERROR, "in_op is null"); return false);
  79. GE_IF_BOOL_EXEC(out_op == nullptr, GELOGE(INTERNAL_ERROR, "out_op is null"); return false);
  80. auto in_op_desc = in_op->GetInputDescPtr(in_anchor->GetIdx());
  81. auto out_op_desc = out_op->GetOutputDescPtr(out_anchor->GetIdx());
  82. GE_IF_BOOL_EXEC(in_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_op_desc is null"); return false);
  83. GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_op_desc is null"); return false);
  84. if (!ShapeEqualCheck(in_op_desc->GetShape(), out_op_desc->GetShape())) {
  85. return false;
  86. }
  87. if (in_op->GetType() == CAST || out_op->GetType() == CAST) {
  88. return TransOpUtil::CheckPrecisionLoss(in_node);
  89. }
  90. if (in_op_desc->GetFormat() == FORMAT_ND) {
  91. return false;
  92. }
  93. if (out_op_desc->GetFormat() == FORMAT_ND) {
  94. return false;
  95. }
  96. if (in_op_desc->GetFormat() != out_op_desc->GetFormat()) {
  97. return false;
  98. }
  99. return FusionFormatSupport(in_op_desc->GetFormat());
  100. }
  101. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
  102. vector<bool> sub_graph_has_reshape_node(sub_graph_anchors_.size(), false);
  103. vector<int> transop_num_count(sub_graph_anchors_.size(), 0);
  104. vector<vector<NodePtr>> sub_graph_nodes(sub_graph_anchors_.size());
  105. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  106. auto nodes_anchor = sub_graph_anchors_[i];
  107. vector<NodePtr> nodes_tmp;
  108. auto iter = nodes_anchor.begin();
  109. auto first_out_anchor = iter->first;
  110. if (first_out_anchor == nullptr) {
  111. continue;
  112. }
  113. nodes_tmp.push_back(first_out_anchor->GetOwnerNode());
  114. while (iter != nodes_anchor.end()) {
  115. auto in_anchor = iter->second;
  116. GE_CHECK_NOTNULL(in_anchor);
  117. auto in_node = in_anchor->GetOwnerNode();
  118. GE_CHECK_NOTNULL(in_node);
  119. if (in_node->GetType() == RESHAPE) {
  120. sub_graph_has_reshape_node[i] = true;
  121. break;
  122. }
  123. auto out_anchor = iter->first;
  124. GE_CHECK_NOTNULL(out_anchor);
  125. if (!FormatContinuousCheck(out_anchor, in_anchor)) {
  126. sub_graph_has_reshape_node[i] = true;
  127. break;
  128. }
  129. nodes_tmp.push_back(in_node);
  130. if (IsTransOp(in_node)) {
  131. // count transop num
  132. transop_num_count[i]++;
  133. }
  134. ++iter;
  135. }
  136. sub_graph_nodes[i].swap(nodes_tmp);
  137. if (sub_graph_has_reshape_node[i]) {
  138. SetRemainNode(nodes_anchor);
  139. }
  140. }
  141. sub_graph_has_reshape_node_.swap(sub_graph_has_reshape_node);
  142. transop_num_count_.swap(transop_num_count);
  143. sub_graph_nodes_.swap(sub_graph_nodes);
  144. return GRAPH_SUCCESS;
  145. }
  146. void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
  147. const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
  148. // The caller guarantees that the index is legal.
  149. for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) {
  150. auto nodes_anchor = sub_graph_anchors_[index][j];
  151. auto out_data_anchor = nodes_anchor.first;
  152. GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor);
  153. for (const auto &peer_in_control_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  154. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_control_anchor);
  155. auto peer_node = peer_in_control_anchor->GetOwnerNode();
  156. if (peer_node == nullptr) {
  157. continue;
  158. }
  159. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  160. if (iter == sub_graph_nodes_[index].end()) {
  161. out_data_peer_in_control_anchors[index].push_back(peer_in_control_anchor);
  162. } else {
  163. sub_graph_has_out_data_peer_in_control_edge_[index] = true;
  164. }
  165. }
  166. }
  167. }
  168. void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
  169. const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
  170. // The caller guarantees that the index is legal.
  171. for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) {
  172. auto node = sub_graph_nodes_[index][j];
  173. GE_CHECK_NOTNULL_JUST_RETURN(node);
  174. auto in_control_anchor = node->GetInControlAnchor();
  175. if (in_control_anchor == nullptr) {
  176. continue;
  177. }
  178. for (const auto &peer_out_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  179. GE_CHECK_NOTNULL_JUST_RETURN(peer_out_anchor);
  180. auto peer_node = peer_out_anchor->GetOwnerNode();
  181. if (peer_node == nullptr) {
  182. continue;
  183. }
  184. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  185. if (iter == sub_graph_nodes_[index].end()) {
  186. in_control_peer_out_control_anchors[index].push_back(peer_out_anchor);
  187. } else {
  188. sub_graph_has_control_edge_[index] = true;
  189. }
  190. }
  191. }
  192. }
  193. void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors(
  194. const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
  195. vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
  196. for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) {
  197. auto node = sub_graph_nodes_[index][j];
  198. GE_CHECK_NOTNULL_JUST_RETURN(node);
  199. auto out_control_anchor = node->GetOutControlAnchor();
  200. GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor);
  201. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  202. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  203. auto peer_node = peer_in_anchor->GetOwnerNode();
  204. if (peer_node == nullptr) {
  205. continue;
  206. }
  207. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  208. if (iter == sub_graph_nodes_[index].end()) {
  209. if (j > 0) {
  210. out_control_peer_in_control_anchors[index].push_back(peer_in_anchor);
  211. }
  212. } else {
  213. sub_graph_has_control_edge_[index] = true;
  214. }
  215. }
  216. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  217. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  218. auto peer_node = peer_in_anchor->GetOwnerNode();
  219. if (peer_node == nullptr) {
  220. continue;
  221. }
  222. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  223. if (iter == sub_graph_nodes_[index].end()) {
  224. if (j > 0) {
  225. out_control_peer_in_data_anchors[index].push_back(peer_in_anchor);
  226. }
  227. } else {
  228. sub_graph_has_control_edge_[index] = true;
  229. }
  230. }
  231. }
  232. }
  233. void TransOpWithoutReshapeFusionPass::GetControlAnchors() {
  234. vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors(sub_graph_nodes_.size());
  235. vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors(sub_graph_nodes_.size());
  236. vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors(sub_graph_nodes_.size());
  237. vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors(sub_graph_nodes_.size());
  238. vector<bool> sub_graph_has_control_edge(sub_graph_nodes_.size(), false);
  239. sub_graph_has_control_edge_.swap(sub_graph_has_control_edge);
  240. vector<bool> sub_graph_has_out_data_peer_in_control_edge(sub_graph_nodes_.size(), false);
  241. sub_graph_has_out_data_peer_in_control_edge_.swap(sub_graph_has_out_data_peer_in_control_edge);
  242. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  243. if (sub_graph_has_reshape_node_[i]) {
  244. continue;
  245. }
  246. GetOutDataPeerInControlAnchors(i, out_data_peer_in_control_anchors);
  247. GetInControlPeerOutControlAnchors(i, in_control_peer_out_control_anchors);
  248. GetOutControlPeerAnchors(i, out_control_peer_in_control_anchors, out_control_peer_in_data_anchors);
  249. }
  250. in_control_peer_out_control_anchors_.swap(in_control_peer_out_control_anchors);
  251. out_control_peer_in_control_anchors_.swap(out_control_peer_in_control_anchors);
  252. out_control_peer_in_data_anchors_.swap(out_control_peer_in_data_anchors);
  253. out_data_peer_in_control_anchors_.swap(out_data_peer_in_control_anchors);
  254. }
  255. void TransOpWithoutReshapeFusionPass::EraseInvalidAnchorsPair() {
  256. auto sub_graph_iter = sub_graph_anchors_.begin();
  257. while (sub_graph_iter != sub_graph_anchors_.end()) {
  258. if (sub_graph_iter->size() <= 1) {
  259. sub_graph_iter = sub_graph_anchors_.erase(sub_graph_iter);
  260. } else {
  261. ++sub_graph_iter;
  262. }
  263. }
  264. }
  265. void TransOpWithoutReshapeFusionPass::UpdateOutputName(const OutDataAnchorPtr &out_anchor,
  266. const InDataAnchorPtr &old_peer_in_anchor,
  267. const NodePtr &in_owner_node) {
  268. if (out_anchor == nullptr || old_peer_in_anchor == nullptr || in_owner_node == nullptr) {
  269. GELOGI("out_anchor or old_peer_in_anchor or in_owner_node is nullptr");
  270. return;
  271. }
  272. auto out_owner_node = out_anchor->GetOwnerNode();
  273. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  274. GE_CHECK_NOTNULL_JUST_RETURN(old_peer_in_anchor->GetOwnerNode());
  275. auto old_peer_in_name = old_peer_in_anchor->GetOwnerNode()->GetName();
  276. auto output_op = out_owner_node->GetOpDesc();
  277. GE_CHECK_NOTNULL_JUST_RETURN(output_op);
  278. auto output_names = output_op->GetAllOutputName();
  279. auto old_peer_in_name_iter = output_names.find(old_peer_in_name);
  280. if (old_peer_in_name_iter != output_names.end()) {
  281. output_names.erase(old_peer_in_name_iter);
  282. }
  283. output_names[in_owner_node->GetName()] = out_anchor->GetIdx();
  284. if (!output_op->UpdateOutputName(output_names)) {
  285. GELOGW("output_op UpdateOutputName failed");
  286. }
  287. }
  288. void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor,
  289. const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node) {
  290. if (old_peer_out_anchor == nullptr || in_anchor == nullptr || out_owner_node == nullptr) {
  291. GELOGI("old_peer_out_anchor or in_anchor or out_owner_node is nullptr");
  292. return;
  293. }
  294. auto old_node = old_peer_out_anchor->GetOwnerNode();
  295. GE_CHECK_NOTNULL_JUST_RETURN(old_node);
  296. auto old_peer_out_name = old_node->GetName();
  297. auto in_owner_node = in_anchor->GetOwnerNode();
  298. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  299. auto input_op = in_owner_node->GetOpDesc();
  300. GE_CHECK_NOTNULL_JUST_RETURN(input_op);
  301. auto input_names = input_op->GetAllInputName();
  302. auto old_peer_out_name_iter = input_names.find(old_peer_out_name);
  303. if (old_peer_out_name_iter != input_names.end()) {
  304. input_names.erase(old_peer_out_name_iter);
  305. }
  306. input_names[out_owner_node->GetName()] = in_anchor->GetIdx();
  307. input_op->UpdateInputName(input_names);
  308. }
  309. graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
  310. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  311. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  312. auto out_anchor = begin_anchors_pair.first;
  313. GE_CHECK_NOTNULL(out_anchor);
  314. auto out_owner_node = out_anchor->GetOwnerNode();
  315. GE_CHECK_NOTNULL(out_owner_node);
  316. auto in_anchor = end_anchors_pair.second;
  317. GE_CHECK_NOTNULL(in_anchor);
  318. auto in_owner_node = in_anchor->GetOwnerNode();
  319. GE_CHECK_NOTNULL(in_owner_node);
  320. if (sub_graph_has_control_edge_[index]) {
  321. GELOGI("add control edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str());
  322. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), in_owner_node->GetInControlAnchor()) !=
  323. GRAPH_SUCCESS) {
  324. return GRAPH_FAILED;
  325. }
  326. }
  327. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  328. GELOGI("add out data 2 in contorl edge.src:%s, dst:%s", out_owner_node->GetName().c_str(),
  329. in_owner_node->GetName().c_str());
  330. if (GraphUtils::AddEdge(out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  331. return GRAPH_FAILED;
  332. }
  333. }
  334. return GRAPH_SUCCESS;
  335. }
  336. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged(
  337. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  338. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  339. if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) {
  340. return GRAPH_FAILED;
  341. }
  342. auto out_anchor = begin_anchors_pair.first;
  343. GE_CHECK_NOTNULL(out_anchor);
  344. auto out_owner_node = out_anchor->GetOwnerNode();
  345. GE_CHECK_NOTNULL(out_owner_node);
  346. auto in_anchor = end_anchors_pair.second;
  347. GE_CHECK_NOTNULL(in_anchor);
  348. auto in_owner_node = in_anchor->GetOwnerNode();
  349. GE_CHECK_NOTNULL(in_owner_node);
  350. // can not remove old control edge
  351. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  352. GE_CHECK_NOTNULL(peer_in_anchor);
  353. GELOGI("add control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  354. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  355. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  356. return GRAPH_FAILED;
  357. }
  358. }
  359. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  360. GE_CHECK_NOTNULL(peer_out_anchor);
  361. GELOGI("add control edge.src:%s, src idx:%d, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  362. peer_out_anchor->GetIdx(), in_owner_node->GetName().c_str());
  363. if (GraphUtils::AddEdge(peer_out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  364. return GRAPH_FAILED;
  365. }
  366. }
  367. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  368. GE_CHECK_NOTNULL(peer_in_anchor);
  369. GELOGI("add out control 2 in data edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  370. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  371. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  372. return GRAPH_FAILED;
  373. }
  374. }
  375. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  376. GE_CHECK_NOTNULL(peer_in_anchor);
  377. GELOGI("add out data 2 in control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  378. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  379. if (GraphUtils::AddEdge(out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  380. return GRAPH_FAILED;
  381. }
  382. }
  383. return GRAPH_SUCCESS;
  384. }
  385. graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged(
  386. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  387. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  388. auto out_anchor = begin_anchors_pair.first;
  389. GE_CHECK_NOTNULL(out_anchor);
  390. auto out_owner_node = out_anchor->GetOwnerNode();
  391. GE_CHECK_NOTNULL(out_owner_node);
  392. auto in_anchor = end_anchors_pair.second;
  393. GE_CHECK_NOTNULL(in_anchor);
  394. auto in_owner_node = in_anchor->GetOwnerNode();
  395. GE_CHECK_NOTNULL(in_owner_node);
  396. GELOGI("remove edge.src %s, src idx:%d, dst:%s, dst idx:%d",
  397. end_anchors_pair.first->GetOwnerNode()->GetName().c_str(), end_anchors_pair.first->GetIdx(),
  398. in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  399. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_anchors_pair.first, in_anchor), "remove edge failed");
  400. GELOGI("relink node.src node:%s, src idx:%d, dst node:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  401. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  402. if (GraphUtils::AddEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) {
  403. GELOGE(GRAPH_FAILED, "add edge failed!src:%s, src idx:%d, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  404. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  405. return GRAPH_FAILED;
  406. } else {
  407. auto old_peer_in_anchor = begin_anchors_pair.second;
  408. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  409. auto old_peer_out_anchor = end_anchors_pair.first;
  410. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  411. }
  412. return RelinkControlEdgesWhenDescNotChanged(begin_anchors_pair, end_anchors_pair, index);
  413. }
  414. OpDescPtr TransOpWithoutReshapeFusionPass::GetFormatTransferOp(const GeTensorDesc &format_trans_input_desc,
  415. const GeTensorDesc &format_trans_output_desc) {
  416. static std::atomic_long atomic_fusion_format_transfer_op_count(1);
  417. auto fusion_format_transfer_op_count = atomic_fusion_format_transfer_op_count.fetch_add(1);
  418. std::stringstream format_transfer_op_name;
  419. format_transfer_op_name << "fusion_format_transfer_" << fusion_format_transfer_op_count;
  420. OpDescPtr format_transfer_op = MakeShared<OpDesc>(format_transfer_op_name.str().c_str(), TRANSDATA);
  421. if (format_transfer_op == nullptr) {
  422. GELOGE(INTERNAL_ERROR, "new format transfer op failed!");
  423. return nullptr;
  424. }
  425. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_INPUT_FORMAT,
  426. static_cast<int64_t>(format_trans_input_desc.GetFormat())),
  427. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_INPUT_FORMAT failed");
  428. return nullptr);
  429. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_OUTPUT_FORMAT,
  430. static_cast<int64_t>(format_trans_output_desc.GetFormat())),
  431. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_OUTPUT_FORMAT failed");
  432. return nullptr);
  433. string src_format = TypeUtils::FormatToSerialString(format_trans_input_desc.GetFormat());
  434. string dst_format = TypeUtils::FormatToSerialString(format_trans_output_desc.GetFormat());
  435. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameSrcFormat, src_format),
  436. GELOGE(INTERNAL_ERROR, "set kAttrNameSrcFormat failed");
  437. return nullptr);
  438. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameDstFormat, dst_format),
  439. GELOGE(INTERNAL_ERROR, "set kAttrNameDstFormat failed");
  440. return nullptr);
  441. GE_IF_BOOL_EXEC(format_transfer_op->AddInputDesc(format_trans_input_desc) != GRAPH_SUCCESS,
  442. GELOGE(INTERNAL_ERROR, "add input desc failed");
  443. return nullptr);
  444. GE_IF_BOOL_EXEC(format_transfer_op->AddOutputDesc(format_trans_output_desc) != GRAPH_SUCCESS,
  445. GELOGE(INTERNAL_ERROR, "add output desc failed");
  446. return nullptr);
  447. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetBool(format_transfer_op, ATTR_NEED_COMPILE, true),
  448. GELOGE(INTERNAL_ERROR, "set ext attr failed");
  449. return nullptr);
  450. return format_transfer_op;
  451. }
  452. OpDescPtr TransOpWithoutReshapeFusionPass::GetCastOp(const GeTensorDesc &cast_input_desc,
  453. const GeTensorDesc &cast_output_desc) {
  454. static std::atomic_long atomic_fusion_cast_op_count(1);
  455. auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1);
  456. std::stringstream cast_op_name;
  457. cast_op_name << "fusion_cast_op_" << fusion_cast_op_count;
  458. auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST);
  459. auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  460. node_op.BreakConnect();
  461. if (cast_op == nullptr) {
  462. GELOGE(INTERNAL_ERROR, "new cast op failed!");
  463. return nullptr;
  464. }
  465. const int default_input_index = 0;
  466. const int default_output_index = 0;
  467. if (cast_op->GetInputsSize() == 0) {
  468. GE_IF_BOOL_EXEC(cast_op->AddInputDesc(cast_input_desc) != GRAPH_SUCCESS,
  469. GELOGE(INTERNAL_ERROR, "add input desc failed");
  470. return nullptr);
  471. } else {
  472. GE_IF_BOOL_EXEC(cast_op->UpdateInputDesc(default_input_index, cast_input_desc) != GRAPH_SUCCESS,
  473. GELOGE(INTERNAL_ERROR, "update input desc failed");
  474. return nullptr);
  475. }
  476. if (cast_op->GetOutputsSize() == 0) {
  477. GE_IF_BOOL_EXEC(cast_op->AddOutputDesc(cast_output_desc) != GRAPH_SUCCESS,
  478. GELOGE(INTERNAL_ERROR, "add output desc failed");
  479. return nullptr);
  480. } else {
  481. GE_IF_BOOL_EXEC(cast_op->UpdateOutputDesc(default_output_index, cast_output_desc) != GRAPH_SUCCESS,
  482. GELOGE(INTERNAL_ERROR, "update output desc failed");
  483. return nullptr);
  484. }
  485. if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast<int64_t>(cast_output_desc.GetDataType()))) {
  486. GELOGE(INTERNAL_ERROR, "set dst_type attr failed");
  487. return nullptr;
  488. }
  489. if (!AttrUtils::SetBool(cast_op, ATTR_NEED_COMPILE, true)) {
  490. GELOGE(INTERNAL_ERROR, "set need_compile attr failed");
  491. return nullptr;
  492. }
  493. return cast_op;
  494. }
  495. bool TransOpWithoutReshapeFusionPass::InsertCastFirstCheck(const GeTensorDesc &out_desc,
  496. const GeTensorDesc &in_desc) const {
  497. return out_desc.GetDataType() != in_desc.GetDataType() && out_desc.GetDataType() != DT_FLOAT16 &&
  498. in_desc.GetDataType() == DT_FLOAT16;
  499. }
  500. void TransOpWithoutReshapeFusionPass::GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  501. GeTensorDesc &format_transfer_input,
  502. GeTensorDesc &format_transfer_output) {
  503. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  504. if (insert_cast_first) {
  505. format_transfer_input = out_desc;
  506. format_transfer_input.SetDataType(in_desc.GetDataType());
  507. format_transfer_output = in_desc;
  508. } else {
  509. format_transfer_input = out_desc;
  510. format_transfer_output = in_desc;
  511. format_transfer_output.SetDataType(out_desc.GetDataType());
  512. }
  513. }
  514. void TransOpWithoutReshapeFusionPass::GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  515. GeTensorDesc &cast_input, GeTensorDesc &cast_output) {
  516. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  517. if (insert_cast_first) {
  518. cast_input = out_desc;
  519. cast_output = out_desc;
  520. cast_output.SetDataType(in_desc.GetDataType());
  521. } else {
  522. cast_input = in_desc;
  523. cast_input.SetDataType(out_desc.GetDataType());
  524. cast_output = in_desc;
  525. }
  526. }
  527. void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc,
  528. GeTensorDesc &in_desc) {
  529. auto nodes_anchor = sub_graph_anchors_[index];
  530. auto out_peer_anchor = nodes_anchor.front().second;
  531. GE_CHECK_NOTNULL_JUST_RETURN(out_peer_anchor);
  532. auto out_owner_node = out_peer_anchor->GetOwnerNode();
  533. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  534. auto out_peer_op_desc = out_owner_node->GetOpDesc();
  535. GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return );
  536. out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx());
  537. auto in_peer_anchor = nodes_anchor.back().first;
  538. GE_CHECK_NOTNULL_JUST_RETURN(in_peer_anchor);
  539. auto in_owner_node = in_peer_anchor->GetOwnerNode();
  540. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  541. auto in_peer_op_desc = in_owner_node->GetOpDesc();
  542. GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return );
  543. in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx());
  544. }
  545. graphStatus TransOpWithoutReshapeFusionPass::FormatFusion(const int index, OpDescPtr &format_transfer_op,
  546. int32_t &fusion_op_count, bool &fusion_continue) {
  547. GeTensorDesc out_desc;
  548. GeTensorDesc in_desc;
  549. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  550. GeTensorDesc format_transfer_input;
  551. GeTensorDesc format_transfer_output;
  552. GetFormatTransferDesc(out_desc, in_desc, format_transfer_input, format_transfer_output);
  553. if (out_desc.GetFormat() == in_desc.GetFormat() &&
  554. (!ShapeEqualCheck(out_desc.GetShape(), in_desc.GetShape()) ||
  555. !ShapeEqualCheck(out_desc.GetOriginShape(), in_desc.GetOriginShape()))) {
  556. SetRemainNode(sub_graph_anchors_[index]);
  557. return GRAPH_SUCCESS;
  558. }
  559. if (out_desc.GetFormat() != in_desc.GetFormat() && FusionFormatSupport(out_desc.GetFormat()) &&
  560. FusionFormatSupport(in_desc.GetFormat())) {
  561. // create format transop
  562. format_transfer_op = GetFormatTransferOp(format_transfer_input, format_transfer_output);
  563. if (format_transfer_op == nullptr) {
  564. return GRAPH_FAILED;
  565. }
  566. if (OpAccuracyAbilityCheck(format_transfer_op)) {
  567. ++fusion_op_count;
  568. GELOGI("support format transfer op %s", format_transfer_op->GetName().c_str());
  569. } else {
  570. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  571. format_transfer_input.GetFormat(), format_transfer_input.GetDataType(), format_transfer_output.GetFormat(),
  572. format_transfer_output.GetDataType());
  573. fusion_op_count = kInvalidFusionOpCount;
  574. }
  575. } else if (out_desc.GetFormat() != in_desc.GetFormat()) {
  576. SetRemainNode(sub_graph_anchors_[index]);
  577. return GRAPH_SUCCESS;
  578. }
  579. fusion_continue = true;
  580. return GRAPH_SUCCESS;
  581. }
  582. graphStatus TransOpWithoutReshapeFusionPass::DataTypeFusion(const int index, OpDescPtr &cast_op,
  583. int32_t &fusion_op_count) {
  584. GeTensorDesc out_desc;
  585. GeTensorDesc in_desc;
  586. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  587. GeTensorDesc cast_input;
  588. GeTensorDesc cast_output;
  589. GetCastOpDesc(out_desc, in_desc, cast_input, cast_output);
  590. if (fusion_op_count != kInvalidFusionOpCount && out_desc.GetDataType() != in_desc.GetDataType()) {
  591. // create cast op
  592. cast_op = GetCastOp(cast_input, cast_output);
  593. if (cast_op == nullptr) {
  594. fusion_op_count = kInvalidFusionOpCount;
  595. return GRAPH_FAILED;
  596. }
  597. if (OpAccuracyAbilityCheck(cast_op)) {
  598. ++fusion_op_count;
  599. GELOGI("support cast op %s. src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  600. cast_op->GetName().c_str(), cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(),
  601. cast_output.GetDataType());
  602. } else {
  603. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  604. cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType());
  605. fusion_op_count = kInvalidFusionOpCount;
  606. }
  607. }
  608. return GRAPH_SUCCESS;
  609. }
  610. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuseHandle(const ComputeGraphPtr &graph, const int index) {
  611. bool fusion_continue = false;
  612. OpDescPtr format_transfer_op = nullptr;
  613. int32_t fusion_op_count = 0;
  614. auto fortmat_fusion_ret = FormatFusion(index, format_transfer_op, fusion_op_count, fusion_continue);
  615. if (fortmat_fusion_ret != GRAPH_SUCCESS || !fusion_continue) {
  616. SetRemainNode(sub_graph_anchors_[index]);
  617. return GRAPH_SUCCESS;
  618. }
  619. OpDescPtr cast_op = nullptr;
  620. if (DataTypeFusion(index, cast_op, fusion_op_count) != GRAPH_SUCCESS) {
  621. SetRemainNode(sub_graph_anchors_[index]);
  622. return GRAPH_SUCCESS;
  623. }
  624. if (fusion_op_count != kInvalidFusionOpCount && fusion_op_count < transop_num_count_[index]) {
  625. GeTensorDesc out_desc;
  626. GeTensorDesc in_desc;
  627. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  628. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  629. if (InsertNewTransOp(graph, cast_op, format_transfer_op, index, insert_cast_first) != GRAPH_SUCCESS) {
  630. return GRAPH_FAILED;
  631. }
  632. } else {
  633. // remain all nodes
  634. SetRemainNode(sub_graph_anchors_[index]);
  635. }
  636. return GRAPH_SUCCESS;
  637. }
  638. void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &graph) {
  639. if (graph == nullptr) {
  640. return;
  641. }
  642. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  643. if (sub_graph_has_reshape_node_[i]) {
  644. continue;
  645. }
  646. for (const auto &node : sub_graph_nodes_[i]) {
  647. GE_CHECK_NOTNULL_JUST_RETURN(node);
  648. // remove nodes
  649. if (!IsTransOp(node)) {
  650. continue;
  651. }
  652. auto op_desc = node->GetOpDesc();
  653. GE_CHECK_NOTNULL_JUST_RETURN(op_desc);
  654. bool node_remain_flag = op_desc->TryGetExtAttr(kRemainNode, false);
  655. if (node_remain_flag) {
  656. continue;
  657. }
  658. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
  659. GELOGI("remove node:%s", node->GetName().c_str());
  660. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  661. GELOGW("remove node failed!node:%s", node->GetName().c_str());
  662. continue;
  663. }
  664. }
  665. }
  666. }
  667. graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) {
  668. GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin.");
  669. if (graph == nullptr) {
  670. return GRAPH_SUCCESS;
  671. }
  672. for (const auto &node : graph->GetDirectNode()) {
  673. GE_CHECK_NOTNULL(node);
  674. if (IsTransOp(node)) {
  675. continue;
  676. }
  677. bool is_unknown = false;
  678. auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
  679. if (ret != GRAPH_SUCCESS) {
  680. GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(),
  681. node->GetType().c_str());
  682. continue;
  683. }
  684. if (is_unknown) {
  685. GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(),
  686. node->GetType().c_str());
  687. continue;
  688. }
  689. GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
  690. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  691. GE_CHECK_NOTNULL(out_anchor);
  692. vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
  693. vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> nodes_list;
  694. if (GetSubGraphsBetweenNormalNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) {
  695. GELOGW("get transops failed!");
  696. continue;
  697. }
  698. sub_graph_anchors_.swap(sub_graph_anchors);
  699. EraseInvalidAnchorsPair();
  700. if (sub_graph_anchors_.empty()) {
  701. continue;
  702. }
  703. // check reshape node
  704. if (GetSubGraphNodesInfo() != GRAPH_SUCCESS) {
  705. continue;
  706. }
  707. // save control edge
  708. GetControlAnchors();
  709. if (TransOpFuse(graph) != GRAPH_SUCCESS) {
  710. return GRAPH_FAILED;
  711. }
  712. }
  713. }
  714. GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end.");
  715. return GRAPH_SUCCESS;
  716. }
  717. bool TransOpWithoutReshapeFusionPass::DescEqualCheck(ConstGeTensorDescPtr &desc_src,
  718. ConstGeTensorDescPtr &desc_dst) const {
  719. if (desc_src == nullptr || desc_dst == nullptr) {
  720. return false;
  721. }
  722. if (desc_src->GetFormat() != desc_dst->GetFormat() || desc_src->GetDataType() != desc_dst->GetDataType()) {
  723. return false;
  724. }
  725. if (!ShapeEqualCheck(desc_src->GetShape(), desc_dst->GetShape())) {
  726. return false;
  727. }
  728. return ShapeEqualCheck(desc_src->GetOriginShape(), desc_dst->GetOriginShape());
  729. }
  730. bool TransOpWithoutReshapeFusionPass::ShapeEqualCheck(const GeShape &src, const GeShape &dst) const {
  731. if (src.GetDims().size() != dst.GetDims().size()) {
  732. return false;
  733. }
  734. for (size_t i = 0; i < src.GetDims().size(); ++i) {
  735. if (src.GetDim(i) != dst.GetDim(i)) {
  736. return false;
  737. }
  738. }
  739. return true;
  740. }
  741. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuse(const ComputeGraphPtr &graph) {
  742. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  743. if (sub_graph_has_reshape_node_[i]) {
  744. continue;
  745. }
  746. auto nodes_anchor = sub_graph_anchors_[i];
  747. auto out_anchor = nodes_anchor.front().first;
  748. GE_CHECK_NOTNULL(out_anchor);
  749. auto out_op_desc = out_anchor->GetOwnerNode()->GetOpDesc();
  750. GE_CHECK_NOTNULL(out_op_desc);
  751. auto out_desc = out_op_desc->GetOutputDescPtr(out_anchor->GetIdx());
  752. GE_CHECK_NOTNULL(out_desc);
  753. auto in_anchor = nodes_anchor.back().second;
  754. GE_CHECK_NOTNULL(in_anchor);
  755. auto in_op_desc = in_anchor->GetOwnerNode()->GetOpDesc();
  756. GE_CHECK_NOTNULL(in_op_desc);
  757. auto in_desc = in_op_desc->GetInputDescPtr(in_anchor->GetIdx());
  758. GE_CHECK_NOTNULL(in_desc);
  759. if (FusionFormatSupport(out_desc->GetFormat()) && DescEqualCheck(out_desc, in_desc)) {
  760. // relink begin_out to end_in
  761. if (RelinkNodesWhenDescNotChanged(nodes_anchor.front(), nodes_anchor.back(), static_cast<int>(i)) !=
  762. GRAPH_SUCCESS) {
  763. return GRAPH_FAILED;
  764. }
  765. } else {
  766. if (TransOpFuseHandle(graph, static_cast<int>(i)) != GRAPH_SUCCESS) {
  767. return GRAPH_FAILED;
  768. }
  769. }
  770. }
  771. RemoveNousedNodes(graph);
  772. return GRAPH_SUCCESS;
  773. }
  774. graphStatus TransOpWithoutReshapeFusionPass::AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop,
  775. NodePtr &trans_node) {
  776. if (graph == nullptr) {
  777. return GRAPH_SUCCESS;
  778. }
  779. if (transop == nullptr) {
  780. return GRAPH_SUCCESS;
  781. }
  782. trans_node = graph->AddNode(transop);
  783. if (trans_node == nullptr) {
  784. GELOGE(GRAPH_FAILED, "add node failed!");
  785. return GRAPH_FAILED;
  786. }
  787. return GRAPH_SUCCESS;
  788. }
  789. graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  790. const OpDescPtr &format_transfer_op,
  791. const bool insert_cast_first,
  792. std::vector<NodePtr> &new_trans_nodes) {
  793. NodePtr format_transfer_node;
  794. if (AddTransNode(graph, format_transfer_op, format_transfer_node) != GRAPH_SUCCESS) {
  795. return GRAPH_FAILED;
  796. }
  797. NodePtr cast_node;
  798. if (AddTransNode(graph, cast_op, cast_node) != GRAPH_SUCCESS) {
  799. return GRAPH_FAILED;
  800. }
  801. if (insert_cast_first) {
  802. if (cast_node != nullptr) {
  803. new_trans_nodes.push_back(cast_node);
  804. }
  805. if (format_transfer_node != nullptr) {
  806. new_trans_nodes.push_back(format_transfer_node);
  807. }
  808. } else {
  809. if (format_transfer_node != nullptr) {
  810. new_trans_nodes.push_back(format_transfer_node);
  811. }
  812. if (cast_node != nullptr) {
  813. new_trans_nodes.push_back(cast_node);
  814. }
  815. }
  816. return GRAPH_SUCCESS;
  817. }
  818. graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  819. const OpDescPtr &format_transfer_op, const int index,
  820. const bool insert_cast_first) {
  821. std::vector<NodePtr> new_trans_nodes;
  822. if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) {
  823. return GRAPH_FAILED;
  824. }
  825. if (new_trans_nodes.empty()) {
  826. GELOGI("No new trans node. Do not need insert new transop.");
  827. return GRAPH_SUCCESS;
  828. }
  829. pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front();
  830. pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back();
  831. auto out_anchor = begin_out.first;
  832. GE_CHECK_NOTNULL(out_anchor);
  833. auto out_owner_node = out_anchor->GetOwnerNode();
  834. GE_CHECK_NOTNULL(out_owner_node);
  835. auto in_anchor = end_in.second;
  836. GE_CHECK_NOTNULL(in_anchor);
  837. auto in_owner_node = in_anchor->GetOwnerNode();
  838. GE_CHECK_NOTNULL(in_owner_node);
  839. GELOGI("remove edge.src:%s, src idx:%d, dst:%s, dst idx:%d", end_in.first->GetOwnerNode()->GetName().c_str(),
  840. end_in.first->GetIdx(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  841. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_in.first, in_anchor), "remove edge failed");
  842. GELOGI("add edge.src:%s, src idx:%d, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetIdx(),
  843. new_trans_nodes.front()->GetName().c_str());
  844. if (GraphUtils::AddEdge(out_anchor, new_trans_nodes.front()->GetInAnchor(0)) != GRAPH_SUCCESS) {
  845. return GRAPH_FAILED;
  846. } else {
  847. auto old_peer_in_anchor = begin_out.second;
  848. GE_CHECK_NOTNULL(old_peer_in_anchor);
  849. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  850. }
  851. if (new_trans_nodes.size() > 1) {
  852. GELOGI("add edge.src:%s, dst:%s", new_trans_nodes.front()->GetName().c_str(),
  853. new_trans_nodes.back()->GetName().c_str());
  854. if (GraphUtils::AddEdge(new_trans_nodes.front()->GetOutAnchor(0), new_trans_nodes.back()->GetInAnchor(0)) !=
  855. GRAPH_SUCCESS) {
  856. return GRAPH_FAILED;
  857. } else {
  858. auto old_peer_out_anchor = end_in.first;
  859. GE_CHECK_NOTNULL(old_peer_out_anchor);
  860. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  861. }
  862. }
  863. GELOGI("add edge.src:%s, dst:%s, dst idx:%d", new_trans_nodes.back()->GetName().c_str(),
  864. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  865. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  866. return GRAPH_FAILED;
  867. }
  868. return RelinkControlEdge(index, out_anchor, new_trans_nodes);
  869. }
  870. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
  871. const vector<NodePtr> &new_trans_nodes) {
  872. GE_CHECK_NOTNULL(out_anchor);
  873. if (new_trans_nodes.front() == nullptr || new_trans_nodes.back() == nullptr) {
  874. return GRAPH_FAILED;
  875. }
  876. if (sub_graph_has_control_edge_[index]) {
  877. GELOGI("add control edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
  878. new_trans_nodes.front()->GetName().c_str());
  879. if (GraphUtils::AddEdge(out_anchor->GetOwnerNode()->GetOutControlAnchor(),
  880. new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  881. return GRAPH_FAILED;
  882. }
  883. }
  884. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  885. GE_CHECK_NOTNULL(peer_in_anchor);
  886. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  887. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  888. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  889. return GRAPH_FAILED;
  890. }
  891. }
  892. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  893. GE_CHECK_NOTNULL(peer_out_anchor);
  894. GELOGI("add control edge.src:%s, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  895. new_trans_nodes.front()->GetName().c_str());
  896. if (GraphUtils::AddEdge(peer_out_anchor, new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  897. return GRAPH_FAILED;
  898. }
  899. }
  900. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  901. GE_CHECK_NOTNULL(peer_in_anchor);
  902. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  903. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  904. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  905. return GRAPH_FAILED;
  906. }
  907. }
  908. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  909. GE_CHECK_NOTNULL(peer_in_anchor);
  910. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  911. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  912. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), peer_in_anchor) != GRAPH_SUCCESS) {
  913. return GRAPH_FAILED;
  914. }
  915. }
  916. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  917. auto in_anchor = sub_graph_anchors_[index].back().second;
  918. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  919. in_anchor->GetOwnerNode()->GetName().c_str());
  920. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0),
  921. in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  922. return GRAPH_FAILED;
  923. }
  924. }
  925. return GRAPH_SUCCESS;
  926. }
  927. bool TransOpWithoutReshapeFusionPass::OpAccuracyAbilityCheck(const OpDescPtr &op_desc) {
  928. auto instance = GELib::GetInstance();
  929. if ((instance == nullptr) || (!instance->InitFlag())) {
  930. GELOGW("GELib is not initialized!");
  931. return false;
  932. }
  933. if (op_desc == nullptr) {
  934. return false;
  935. }
  936. OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
  937. vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  938. if (op_infos.empty()) {
  939. GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str());
  940. return false;
  941. }
  942. std::string unsupported_reason;
  943. for (const auto &it : op_infos) {
  944. auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  945. auto &kernel_name = it.opKernelLib;
  946. auto kernel_info_store = kernel_map.find(kernel_name);
  947. if (kernel_info_store != kernel_map.end()) {
  948. if (kernel_info_store->second != nullptr &&
  949. kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) {
  950. op_desc->SetOpEngineName(it.engine);
  951. op_desc->SetOpKernelLibName(kernel_name);
  952. GELOGI("Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(),
  953. op_desc->GetName().c_str());
  954. return true;
  955. }
  956. }
  957. }
  958. GELOGI("op %s CheckAccuracySupported failed!reason:%s", op_desc->GetType().c_str(), unsupported_reason.c_str());
  959. return false;
  960. }
  961. bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) {
  962. return format == FORMAT_NCHW || format == FORMAT_NHWC || format == FORMAT_FRACTAL_Z || format == FORMAT_NC1HWC0;
  963. }
  964. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode(
  965. const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
  966. vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
  967. graphStatus ret = GRAPH_SUCCESS;
  968. if (out_anchor == nullptr) {
  969. return GRAPH_FAILED;
  970. }
  971. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  972. if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr ||
  973. peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  974. continue;
  975. }
  976. nodes_list.emplace_back(out_anchor, peer_in_anchor);
  977. auto peer_in_node = peer_in_anchor->GetOwnerNode();
  978. GE_CHECK_NOTNULL(peer_in_node);
  979. if (!IsTransOp(peer_in_node)) {
  980. sub_graphs_out.push_back(nodes_list);
  981. nodes_list.pop_back();
  982. } else {
  983. for (const auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
  984. ret = GetSubGraphsBetweenNormalNode(peer_out_anchor, sub_graphs_out, nodes_list);
  985. if (ret != GRAPH_SUCCESS) {
  986. GELOGE(GRAPH_FAILED, "get all transops between normal node failed!node:%s", peer_in_node->GetName().c_str());
  987. return GRAPH_FAILED;
  988. }
  989. }
  990. nodes_list.pop_back();
  991. }
  992. }
  993. return GRAPH_SUCCESS;
  994. }
  995. bool TransOpWithoutReshapeFusionPass::IsTransOp(const NodePtr &node) {
  996. // The caller guarantees that the pointer is not null.
  997. return node->GetType() == CAST || node->GetType() == RESHAPE || node->GetType() == TRANSPOSE ||
  998. node->GetType() == TRANSPOSED || node->GetType() == TRANSDATA;
  999. }
  1000. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示