You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/node_utils.h"
  17. #include "graph/utils/op_desc_utils.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/types.h"
  25. #include "external/graph/operator.h"
  26. #include "graph/ge_context.h"
  27. #include "graph/runtime_inference_context.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/utils/tensor_utils.h"
  30. #include "graph/utils/tensor_adapter.h"
  31. #include "graph/utils/type_utils.h"
  32. namespace ge {
  33. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  34. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  35. const std::set<std::string> kConstOpTypes = { "Const", "Constant" };
  36. const std::set<std::string> kIfOpTypes = { "If", "_If", "StatelessIf" };
  37. const std::set<std::string> kWhileOpTypes = { "While", "_While", "StatelessWhile" };
  38. const std::set<std::string> kCaseOpTypes = { "Case" };
  39. const std::set<std::string> kForOpTypes = { "For" };
  40. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  41. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  42. auto ge_shape = ptr->GetShape();
  43. for (const auto &dim : ge_shape.GetDims()) {
  44. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  45. return true;
  46. }
  47. }
  48. }
  49. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  50. auto ge_shape = ptr->GetShape();
  51. for (const auto &dim : ge_shape.GetDims()) {
  52. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  53. return true;
  54. }
  55. }
  56. }
  57. return false;
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  60. const uint32_t &event_id) {
  61. GE_CHECK_NOTNULL(node);
  62. map_send_info_[node].push_back(event_id);
  63. return GRAPH_SUCCESS;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  66. const uint32_t &event_id) {
  67. GE_CHECK_NOTNULL(node);
  68. map_recv_info_[node].push_back(event_id);
  69. return GRAPH_SUCCESS;
  70. }
  71. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  72. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  73. GE_CHECK_NOTNULL(node);
  74. auto find = map_send_info_.find(node);
  75. if (find == map_send_info_.end()) {
  76. return GRAPH_FAILED;
  77. } else {
  78. vec_send = find->second;
  79. return GRAPH_SUCCESS;
  80. }
  81. }
  82. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  83. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  84. GE_CHECK_NOTNULL(node);
  85. auto find = map_recv_info_.find(node);
  86. if (find == map_recv_info_.end()) {
  87. return GRAPH_FAILED;
  88. } else {
  89. vec_recv = find->second;
  90. return GRAPH_SUCCESS;
  91. }
  92. }
  93. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  94. map_send_info_.clear();
  95. return GRAPH_SUCCESS;
  96. }
  97. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  98. map_recv_info_.clear();
  99. return GRAPH_SUCCESS;
  100. }
  101. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  102. GE_CHECK_NOTNULL(src);
  103. NodePtr cur_ptr;
  104. if (depth < 1) {
  105. return GRAPH_FAILED;
  106. }
  107. for (int i = 0; i < depth; i++) {
  108. if (src->GetOutDataNodes().size() != 1) {
  109. return GRAPH_FAILED;
  110. }
  111. cur_ptr = src->GetOutDataNodes().at(0);
  112. GE_CHECK_NOTNULL(cur_ptr);
  113. }
  114. dst = cur_ptr;
  115. return GRAPH_SUCCESS;
  116. }
  117. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  118. InControlAnchorPtr &in_control) {
  119. GE_CHECK_NOTNULL(node_ptr);
  120. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  121. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  122. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  123. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  124. out_data = p;
  125. in_control = p_in;
  126. return GRAPH_SUCCESS;
  127. }
  128. }
  129. return GRAPH_FAILED;
  130. }
  131. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  132. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  133. "node or in_data_anchor is nullptr");
  134. bool find_flag = false;
  135. uint32_t index = 0;
  136. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  137. for (const auto &tmp : node_ptr->in_data_anchors_) {
  138. if (tmp == in_data_anchor) {
  139. find_flag = true;
  140. auto iter = node_ptr->in_data_anchors_.begin() + index;
  141. if (iter != node_ptr->in_data_anchors_.end()) {
  142. it = node_ptr->in_data_anchors_.erase(iter);
  143. }
  144. break;
  145. }
  146. index++;
  147. }
  148. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  149. (*it)->SetIdx(index);
  150. index++;
  151. }
  152. if (!find_flag) {
  153. return GRAPH_FAILED;
  154. }
  155. return GRAPH_SUCCESS;
  156. }
  157. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  158. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  159. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  160. return GRAPH_SUCCESS;
  161. }
  162. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  163. node.anchor_status_updated_ = true;
  164. return GRAPH_SUCCESS;
  165. }
  166. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  167. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  168. return IsAnchorStatusSet(*node_ptr);
  169. }
  170. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  171. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  172. if ((origin_node == nullptr) || (new_node == nullptr)) {
  173. return GRAPH_FAILED;
  174. }
  175. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  176. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  177. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  178. return GRAPH_FAILED;
  179. }
  180. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  181. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  182. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  183. "unlink peer_anchor failed");
  184. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  185. "linkto peer_anchor failed");
  186. }
  187. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  188. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  189. "unlink peer_anchor failed");
  190. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  191. "linkto peer_anchor failed");
  192. }
  193. }
  194. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  195. GE_CHECK_NOTNULL(origin_out_control_anchor);
  196. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  197. GE_CHECK_NOTNULL(new_out_control_anchor);
  198. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  199. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  200. "linkto peer_anchor failed");
  201. }
  202. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  203. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  204. "linkto peer_anchor failed");
  205. }
  206. origin_out_control_anchor->UnlinkAll();
  207. return GRAPH_SUCCESS;
  208. }
  209. bool NodeUtils::IsConst(const Node &node) {
  210. auto src_node_type = node.GetType();
  211. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  212. return is_const;
  213. }
  214. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  215. if (node_ptr == nullptr) {
  216. GELOGE(GRAPH_FAILED, "node is null");
  217. return;
  218. }
  219. UpdateIsInputConst(*node_ptr);
  220. }
  221. ///
  222. /// update is_input_const
  223. /// @param node
  224. /// @return void
  225. ///
  226. void NodeUtils::UpdateIsInputConst(Node &node) {
  227. std::vector<bool> is_input_const;
  228. size_t anchor_num = node.GetAllInDataAnchors().size();
  229. for (size_t i = 0; i < anchor_num; i++) {
  230. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  231. if (in_anchor == nullptr) {
  232. is_input_const.push_back(false);
  233. continue;
  234. }
  235. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  236. if (peer_out_anchor == nullptr) {
  237. is_input_const.push_back(false);
  238. continue;
  239. }
  240. auto src_node = peer_out_anchor->GetOwnerNode();
  241. if (src_node == nullptr) {
  242. is_input_const.push_back(false);
  243. continue;
  244. }
  245. if (IsConst(*(src_node))) {
  246. is_input_const.push_back(true);
  247. } else {
  248. is_input_const.push_back(false);
  249. }
  250. }
  251. if (node.GetOpDesc() == nullptr) {
  252. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  253. return;
  254. }
  255. node.GetOpDesc()->SetIsInputConst(is_input_const);
  256. }
  257. void NodeUtils::UnlinkAll(const Node &node) {
  258. for (const auto &anchor : node.GetAllOutAnchors()) {
  259. anchor->UnlinkAll();
  260. }
  261. for (const auto &anchor : node.GetAllInAnchors()) {
  262. anchor->UnlinkAll();
  263. }
  264. }
  265. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  266. if (node_ptr == nullptr) {
  267. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  268. return GRAPH_FAILED;
  269. }
  270. auto op_desc = node_ptr->GetOpDesc();
  271. if (op_desc == nullptr) {
  272. return GRAPH_FAILED;
  273. }
  274. bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  275. if (is_unknown_graph) {
  276. return GRAPH_SUCCESS;
  277. }
  278. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  279. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  280. auto out_dims = output_tensor->GetShape().GetDims();
  281. auto out_dtype = output_tensor->GetDataType();
  282. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  283. output_tensor->SetOriginShape(output_tensor->GetShape());
  284. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  285. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  286. node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  287. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  288. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  289. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  290. if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  291. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  292. continue;
  293. }
  294. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  295. if (peer_input_desc == nullptr) {
  296. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  297. continue;
  298. }
  299. // check shape and dtype continuity. do not stop process
  300. auto peer_input_dims = peer_input_desc->GetShape().GetDims();
  301. auto peer_input_dtype = peer_input_desc->GetDataType();
  302. if (out_dtype != peer_input_dtype) {
  303. GELOGW("current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
  304. "input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  305. node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
  306. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
  307. TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
  308. } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
  309. string out_shape_str, peer_in_shape_str;
  310. out_shape_str += "[";
  311. for (int64_t dim : out_dims) {
  312. out_shape_str += std::to_string(dim) + " ";
  313. }
  314. out_shape_str += "]";
  315. peer_in_shape_str += "[";
  316. for (int64_t dim : peer_input_dims) {
  317. peer_in_shape_str += std::to_string(dim) + " ";
  318. }
  319. peer_in_shape_str += "]";
  320. GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
  321. "input_shape is [%s].The two shape should be same! Please check graph and fix it",
  322. node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(),
  323. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str());
  324. }
  325. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  326. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
  327. output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(),
  328. output_tensor->GetOriginDataType());
  329. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  330. peer_input_desc->SetShape(output_tensor->GetShape());
  331. peer_input_desc->SetDataType(output_tensor->GetDataType());
  332. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  333. std::vector<std::pair<int64_t, int64_t>> shape_range;
  334. (void) output_tensor->GetShapeRange(shape_range);
  335. peer_input_desc->SetShapeRange(shape_range);
  336. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  337. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  338. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  339. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
  340. peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(),
  341. peer_input_desc->GetOriginDataType());
  342. }
  343. }
  344. return GRAPH_SUCCESS;
  345. }
  346. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  347. graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, uint32_t num) {
  348. if (node == nullptr) {
  349. GELOGE(GRAPH_FAILED, "Input node is null");
  350. return GRAPH_FAILED;
  351. }
  352. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  353. const auto &op_desc = node->GetOpDesc();
  354. for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
  355. if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
  356. GELOGE(GRAPH_FAILED, "Add input desc failed");
  357. return GRAPH_FAILED;
  358. }
  359. }
  360. for (size_t i = node->in_data_anchors_.size(); i < num; ++i) {
  361. auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
  362. if (anchor == nullptr) {
  363. GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
  364. return GRAPH_FAILED;
  365. }
  366. node->in_data_anchors_.push_back(anchor);
  367. }
  368. return GRAPH_SUCCESS;
  369. }
  370. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  371. graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, uint32_t num) {
  372. if (node == nullptr) {
  373. GELOGE(GRAPH_FAILED, "Input node is null");
  374. return GRAPH_FAILED;
  375. }
  376. const auto &op_desc = node->GetOpDesc();
  377. while (op_desc->GetInputsSize() > num) {
  378. if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
  379. return GRAPH_FAILED;
  380. }
  381. }
  382. auto input_names = op_desc->GetAllInputName();
  383. (void)op_desc->UpdateInputName(input_names);
  384. auto is_input_const = op_desc->GetIsInputConst();
  385. is_input_const.resize(num);
  386. op_desc->SetIsInputConst(is_input_const);
  387. while (node->in_data_anchors_.size() > num) {
  388. node->in_data_anchors_.pop_back();
  389. }
  390. return GRAPH_SUCCESS;
  391. }
  392. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  393. graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, uint32_t num) {
  394. if (node == nullptr) {
  395. GELOGE(GRAPH_FAILED, "Input node is null");
  396. return GRAPH_FAILED;
  397. }
  398. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  399. const OpDescPtr &op_desc = node->GetOpDesc();
  400. for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
  401. if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
  402. GELOGE(GRAPH_FAILED, "Add output desc failed");
  403. return GRAPH_FAILED;
  404. }
  405. }
  406. for (size_t i = node->out_data_anchors_.size(); i < num; ++i) {
  407. auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
  408. if (anchor == nullptr) {
  409. GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
  410. return GRAPH_FAILED;
  411. }
  412. node->out_data_anchors_.push_back(anchor);
  413. }
  414. return GRAPH_SUCCESS;
  415. }
  416. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  417. graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, uint32_t num) {
  418. if (node == nullptr) {
  419. GELOGE(GRAPH_FAILED, "Input node is null");
  420. return GRAPH_FAILED;
  421. }
  422. const auto &op_desc = node->GetOpDesc();
  423. auto output_names = op_desc->GetAllOutputName();
  424. while (op_desc->GetOutputsSize() > num) {
  425. if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
  426. return GRAPH_FAILED;
  427. }
  428. }
  429. (void)op_desc->UpdateOutputName(output_names);
  430. while (node->out_data_anchors_.size() > num) {
  431. node->out_data_anchors_.pop_back();
  432. }
  433. return GRAPH_SUCCESS;
  434. }
  435. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  436. for (const auto &in_anchor : node.in_data_anchors_) {
  437. if (in_anchor != nullptr) {
  438. auto out_anchor = in_anchor->GetPeerOutAnchor();
  439. if (out_anchor != nullptr) {
  440. if (out_anchor->GetOwnerNode() != nullptr) {
  441. return false;
  442. }
  443. }
  444. }
  445. }
  446. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  447. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  448. for (const auto &out_control_anchor : peer_out_control_anchors) {
  449. if (out_control_anchor != nullptr) {
  450. if (out_control_anchor->GetOwnerNode() != nullptr) {
  451. return false;
  452. }
  453. }
  454. }
  455. }
  456. return true;
  457. }
  458. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  459. auto desc = node.GetOpDesc();
  460. if (desc == nullptr) {
  461. return GeTensorDesc();
  462. }
  463. return desc->GetOutputDesc(index);
  464. }
  465. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  466. auto desc = node.GetOpDesc();
  467. if (desc == nullptr) {
  468. return GeTensorDesc();
  469. }
  470. return desc->GetInputDesc(index);
  471. }
  472. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  473. auto desc = node.GetOpDesc();
  474. if (desc == nullptr) {
  475. return GRAPH_PARAM_INVALID;
  476. }
  477. auto output_desc = desc->MutableOutputDesc(index);
  478. if (output_desc == nullptr) {
  479. return GRAPH_PARAM_INVALID;
  480. }
  481. output_desc->SetShape(shape);
  482. return GRAPH_SUCCESS;
  483. }
  484. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  485. auto desc = node.GetOpDesc();
  486. if (desc == nullptr) {
  487. return GRAPH_PARAM_INVALID;
  488. }
  489. auto input_desc = desc->MutableInputDesc(index);
  490. if (input_desc == nullptr) {
  491. return GRAPH_PARAM_INVALID;
  492. }
  493. input_desc->SetShape(shape);
  494. return GRAPH_SUCCESS;
  495. }
  496. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  497. auto desc = node.GetOpDesc();
  498. GE_CHECK_NOTNULL(desc);
  499. // check self
  500. is_unknow = OpShapeIsUnknown(desc);
  501. if (is_unknow) {
  502. return GRAPH_SUCCESS;
  503. }
  504. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  505. if (sub_graph_names.empty()) {
  506. return GRAPH_SUCCESS;
  507. } else {
  508. auto owner_graph = node.GetOwnerComputeGraph();
  509. GE_CHECK_NOTNULL(owner_graph);
  510. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  511. if (root_graph == nullptr) {
  512. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  513. return GRAPH_PARAM_INVALID;
  514. }
  515. for (auto &sub_graph_name : sub_graph_names) {
  516. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  517. GE_CHECK_NOTNULL(sub_graph);
  518. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  519. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  520. if (status != GRAPH_SUCCESS) {
  521. GE_LOGE("get node unknown shape status failed!");
  522. return status;
  523. }
  524. if (is_unknow) {
  525. return GRAPH_SUCCESS;
  526. }
  527. }
  528. }
  529. }
  530. return GRAPH_SUCCESS;
  531. }
  532. graphStatus NodeUtils::GetInputConstData(const ConstNodePtr& node_ptr,
  533. const string &dst_name,
  534. GeTensorPtr &ge_tensor) {
  535. GE_CHECK_NOTNULL(node_ptr);
  536. return NodeUtils::GetInputConstData(*node_ptr, dst_name, ge_tensor);
  537. }
  538. graphStatus NodeUtils::GetInputConstData(const Node &node,
  539. const string &dst_name,
  540. GeTensorPtr &ge_tensor) {
  541. // For inner compute graph
  542. auto op_desc = node.GetOpDesc();
  543. GE_CHECK_NOTNULL(op_desc);
  544. auto index = op_desc->GetInputIndexByName(dst_name);
  545. auto in_data_anchor = node.GetInDataAnchor(index);
  546. GE_CHECK_NOTNULL(in_data_anchor);
  547. auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  548. GE_CHECK_NOTNULL(out_data_anchor);
  549. auto peer_node = out_data_anchor->GetOwnerNode();
  550. if (peer_node->GetType() == ENTER || peer_node->GetType() == REFENTER) {
  551. auto enter_in_data_anchor = peer_node->GetInDataAnchor(0);
  552. GE_CHECK_NOTNULL(enter_in_data_anchor);
  553. auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor();
  554. GE_CHECK_NOTNULL(enter_peer_out_data_anchor);
  555. peer_node = enter_peer_out_data_anchor->GetOwnerNode();
  556. }
  557. auto peer_op_desc = peer_node->GetOpDesc();
  558. GE_CHECK_NOTNULL(peer_op_desc);
  559. auto peer_op_type = peer_op_desc->GetType();
  560. if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
  561. if (!AttrUtils::MutableTensor(peer_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) {
  562. GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str());
  563. return GRAPH_FAILED;
  564. }
  565. return GRAPH_SUCCESS;
  566. } else if (peer_op_type == DATA) {
  567. auto parent_node = NodeUtils::GetParentInput(peer_node);
  568. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  569. parent_node = NodeUtils::GetParentInput(parent_node);
  570. }
  571. if ((parent_node != nullptr)
  572. && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  573. if (!AttrUtils::MutableTensor(parent_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) {
  574. GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str());
  575. return GRAPH_FAILED;
  576. }
  577. return GRAPH_SUCCESS;
  578. }
  579. }
  580. // Try get from runtime inference context
  581. auto session_id = std::to_string(GetContext().SessionId());
  582. RuntimeInferenceContext *runtime_infer_ctx = nullptr;
  583. if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
  584. GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
  585. auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(),
  586. out_data_anchor->GetIdx(), ge_tensor);
  587. if (ret == GRAPH_SUCCESS) {
  588. return GRAPH_SUCCESS;
  589. }
  590. }
  591. GELOGW("node[%s]'s input[%s]'s peer node is not const", node.GetName().c_str(), dst_name.c_str());
  592. return GRAPH_FAILED;
  593. }
  594. std::string NodeUtils::GetNodeType(const Node &node) {
  595. if (node.GetType() != FRAMEWORKOP) {
  596. return node.GetType();
  597. }
  598. std::string type;
  599. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  600. return type;
  601. }
  602. std::string NodeUtils::GetNodeType(const NodePtr &node) {
  603. return node == nullptr ? "" : GetNodeType(*node);
  604. }
  605. std::vector<ComputeGraphPtr> NodeUtils::GetAllSubgraphs(const Node &node) {
  606. auto op_desc = node.GetOpDesc();
  607. if (op_desc == nullptr) {
  608. GELOGE(GRAPH_FAILED, "Failed to get op desc from node %s ", node.GetName().c_str());
  609. return {};
  610. }
  611. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  612. if (root_graph == nullptr) {
  613. GELOGE(GRAPH_FAILED, "Failed to find root graph from node %s ", node.GetName().c_str());
  614. return {};
  615. }
  616. return root_graph->GetAllSubgraphs();
  617. }
  618. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  619. auto op_desc = node.GetOpDesc();
  620. if (op_desc == nullptr) {
  621. return nullptr;
  622. }
  623. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  624. if (root_graph == nullptr) {
  625. return nullptr;
  626. }
  627. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  628. }
  629. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  630. if (subgraph == nullptr) {
  631. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  632. return GRAPH_PARAM_INVALID;
  633. }
  634. auto op_desc = node.GetOpDesc();
  635. if (op_desc == nullptr) {
  636. return GRAPH_PARAM_INVALID;
  637. }
  638. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  639. if (root_graph == nullptr) {
  640. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  641. return GRAPH_PARAM_INVALID;
  642. }
  643. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  644. if (ret != GRAPH_SUCCESS) {
  645. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  646. return ret;
  647. }
  648. subgraph->SetParentNode(node.shared_from_this());
  649. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  650. return root_graph->AddSubgraph(subgraph);
  651. }
  652. ///
  653. /// Check if node is input of subgraph
  654. /// @param [in] node
  655. /// @return bool
  656. ///
  657. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  658. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  659. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  660. return false;
  661. }
  662. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  663. if (parent_op_desc == nullptr) {
  664. return false;
  665. }
  666. // dynamic shape unknown graph false
  667. // dynamic shape known graph with functional subgraph maybe true
  668. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  669. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  670. return false;
  671. } else {
  672. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  673. return false;
  674. }
  675. }
  676. }
  677. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  678. }
  679. ///
  680. /// Check if node is output of subgraph
  681. /// @param [in] node
  682. /// @return bool
  683. ///
  684. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  685. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  686. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  687. return false;
  688. }
  689. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  690. if (parent_op_desc == nullptr) {
  691. return false;
  692. }
  693. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  694. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  695. return false;
  696. } else {
  697. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  698. return false;
  699. }
  700. }
  701. }
  702. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  703. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  704. return true;
  705. }
  706. }
  707. return false;
  708. }
  709. ///
  710. /// @brief Get subgraph original input node.
  711. /// @param [in] node
  712. /// @return Node
  713. ///
  714. NodePtr NodeUtils::GetParentInput(const Node &node) {
  715. uint32_t parent_index = 0;
  716. if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  717. return nullptr;
  718. }
  719. // Subgraph Data Node, check for constant input.
  720. const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
  721. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  722. const NodePtr &parent_node = graph->GetParentNode();
  723. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  724. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  725. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  726. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  727. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  728. return peer_out_anchor->GetOwnerNode();
  729. }
  730. NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
  731. return node == nullptr ? node : GetParentInput(*node);
  732. }
  733. ///
  734. /// @brief Get is dynamic shape graph from node.
  735. /// @param [in] node
  736. /// @return bool
  737. ///
  738. bool NodeUtils::IsDynamicShape(const Node &node) {
  739. const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  740. if (graph == nullptr) {
  741. return false;
  742. }
  743. bool is_dynamic_shape = false;
  744. (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
  745. return is_dynamic_shape;
  746. }
  747. bool NodeUtils::IsDynamicShape(const NodePtr &node) {
  748. return node == nullptr ? false : IsDynamicShape(*node);
  749. }
  750. ///
  751. /// @brief Check is varying_input for while node
  752. /// @param [in] node: Data node for subgraph
  753. /// @return bool
  754. ///
  755. bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
  756. if (node == nullptr) {
  757. return false;
  758. }
  759. if (node->GetType() != DATA) {
  760. return false; // not input_node for subgraph
  761. }
  762. const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  763. if (parent_node == nullptr) {
  764. return false; // root graph
  765. }
  766. if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
  767. return false; // not input_node for while subgraph
  768. }
  769. uint32_t index_i = 0;
  770. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
  771. GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
  772. return false;
  773. }
  774. bool varying_flag = true;
  775. for (const auto &item : node->GetOutDataNodesAndAnchors()) {
  776. if (item.first->GetType() != NETOUTPUT) {
  777. continue;
  778. }
  779. OpDescPtr op_desc = item.first->GetOpDesc();
  780. uint32_t index_o = 0;
  781. if ((op_desc == nullptr) ||
  782. !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
  783. continue; // input for while-cond subgraph
  784. }
  785. if (index_i != index_o) {
  786. continue; // varying input for while-body subgraph
  787. }
  788. varying_flag = false;
  789. break;
  790. }
  791. return varying_flag;
  792. }
  793. ///
  794. /// @brief Get subgraph input is constant.
  795. /// @param [in] node
  796. /// @param [out] string
  797. /// @return bool
  798. ///
  799. bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
  800. if (node == nullptr) {
  801. return false;
  802. }
  803. if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
  804. type = node->GetType();
  805. return true;
  806. }
  807. if (node->GetType() != DATA) {
  808. return false; // not subgraph input node
  809. }
  810. const auto &parent = GetParentInput(node);
  811. return GetConstOpType(parent, type);
  812. }
  813. ///
  814. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  815. /// @param [in] node
  816. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  817. ///
  818. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  819. GE_CHECK_NOTNULL(node);
  820. auto op_desc = node->GetOpDesc();
  821. GE_CHECK_NOTNULL(op_desc);
  822. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  823. if (subgraph_names.empty()) {
  824. return GRAPH_SUCCESS;
  825. } else {
  826. auto owner_graph = node->GetOwnerComputeGraph();
  827. GE_CHECK_NOTNULL(owner_graph);
  828. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  829. GE_CHECK_NOTNULL(root_graph);
  830. std::unordered_set<std::string> subgraph_to_remove;
  831. for (auto &subgraph_name : subgraph_names) {
  832. std::deque<std::string> queue;
  833. queue.push_back(subgraph_name);
  834. subgraph_to_remove.insert(subgraph_name);
  835. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  836. while (!queue.empty()) {
  837. auto graph_name = queue.front();
  838. queue.pop_front();
  839. auto subgraph = root_graph->GetSubgraph(graph_name);
  840. GE_CHECK_NOTNULL(subgraph);
  841. for (const auto &sub_node : subgraph->GetDirectNode()) {
  842. auto sub_op_desc = sub_node->GetOpDesc();
  843. GE_CHECK_NOTNULL(sub_op_desc);
  844. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  845. // Subgraph and all nodes in it will be removed later,
  846. // no need to remove 'SubgraphInstanceName' in op desc here.
  847. for (auto &name : sub_names) {
  848. if (subgraph_to_remove.insert(name).second) {
  849. queue.push_back(name);
  850. }
  851. }
  852. }
  853. }
  854. }
  855. // Remove subgraph from root_graph
  856. for (const auto &name : subgraph_to_remove) {
  857. GELOGI("Remove subgraph:%s.", name.c_str());
  858. root_graph->RemoveSubgraph(name);
  859. }
  860. }
  861. return GRAPH_SUCCESS;
  862. }
  863. ///
  864. /// @brief Get subgraph input data node by index.
  865. /// @param [in] node
  866. /// @return Node
  867. ///
  868. vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
  869. vector<NodePtr> in_data_node_vec;
  870. auto op_desc = node.GetOpDesc();
  871. GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
  872. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  873. if (subgraph_names.empty()) {
  874. GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
  875. return in_data_node_vec;
  876. }
  877. auto compute_graph = node.GetOwnerComputeGraph();
  878. for (const std::string &instance_name : subgraph_names) {
  879. auto subgraph = compute_graph->GetSubgraph(instance_name);
  880. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  881. int parent_index = -1;
  882. if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
  883. (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  884. if (parent_index == index) {
  885. in_data_node_vec.emplace_back(node_in_subgraph);
  886. }
  887. }
  888. }
  889. }
  890. return in_data_node_vec;
  891. }
  892. ///
  893. /// @brief Get subgraph input data node by index.
  894. /// @param [in] node
  895. /// @return Node
  896. ///
  897. vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
  898. vector<NodePtr> out_data_node_vec;
  899. auto op_desc = node.GetOpDesc();
  900. GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
  901. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  902. if (subgraph_names.empty()) {
  903. GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
  904. return out_data_node_vec;
  905. }
  906. auto compute_graph = node.GetOwnerComputeGraph();
  907. for (const std::string &instance_name : subgraph_names) {
  908. auto subgraph = compute_graph->GetSubgraph(instance_name);
  909. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  910. if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
  911. out_data_node_vec.emplace_back(node_in_subgraph);
  912. }
  913. }
  914. }
  915. return out_data_node_vec;
  916. }
  917. NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
  918. if (node.GetInDataAnchor(index) == nullptr) {
  919. return nullptr;
  920. }
  921. if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
  922. return nullptr;
  923. }
  924. return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
  925. }
  926. vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
  927. vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
  928. auto out_data_anchor = node.GetOutDataAnchor(index);
  929. if (out_data_anchor == nullptr) {
  930. return out_data_nodes;
  931. }
  932. for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  933. if (peer_in_anchor == nullptr) {
  934. continue;
  935. }
  936. if (peer_in_anchor->GetOwnerNode() == nullptr) {
  937. continue;
  938. }
  939. out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
  940. }
  941. return out_data_nodes;
  942. }
  943. ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) {
  944. return oprt.GetNode();
  945. }
  946. std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) {
  947. NodePtr input_node = node;
  948. while (input_node != nullptr) {
  949. if (input_node->GetType() != DATA) {
  950. return input_node->GetType();
  951. }
  952. auto owner_graph = input_node->GetOwnerComputeGraph();
  953. auto parent_node = owner_graph->GetParentNode();
  954. if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
  955. return node->GetType(); // not in subgraph or while subgraph.
  956. }
  957. input_node = GetParentInput(input_node);
  958. }
  959. return "";
  960. }
  961. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示