You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/node_utils.h"
  17. #include "graph/utils/graph_utils.h"
  18. #include "debug/ge_op_types.h"
  19. #include "debug/ge_util.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/anchor.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/types.h"
  24. #include "utils/tensor_utils.h"
  25. #include "utils/type_utils.h"
  26. namespace ge {
  27. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  28. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  29. const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
  30. const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
  31. const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
  32. const std::set<std::string> kCaseOpTypes = {"Case"};
  33. const std::set<std::string> kForOpTypes = {"For"};
  34. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  35. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  36. auto ge_shape = ptr->GetShape();
  37. for (const auto &dim : ge_shape.GetDims()) {
  38. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  39. return true;
  40. }
  41. }
  42. }
  43. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  44. auto ge_shape = ptr->GetShape();
  45. for (const auto &dim : ge_shape.GetDims()) {
  46. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  47. return true;
  48. }
  49. }
  50. }
  51. return false;
  52. }
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  54. const uint32_t &event_id) {
  55. GE_CHECK_NOTNULL(node);
  56. map_send_info_[node].push_back(event_id);
  57. return GRAPH_SUCCESS;
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  60. const uint32_t &event_id) {
  61. GE_CHECK_NOTNULL(node);
  62. map_recv_info_[node].push_back(event_id);
  63. return GRAPH_SUCCESS;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  66. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  67. GE_CHECK_NOTNULL(node);
  68. auto find = map_send_info_.find(node);
  69. if (find == map_send_info_.end()) {
  70. return GRAPH_FAILED;
  71. } else {
  72. vec_send = find->second;
  73. return GRAPH_SUCCESS;
  74. }
  75. }
  76. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  77. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  78. GE_CHECK_NOTNULL(node);
  79. auto find = map_recv_info_.find(node);
  80. if (find == map_recv_info_.end()) {
  81. return GRAPH_FAILED;
  82. } else {
  83. vec_recv = find->second;
  84. return GRAPH_SUCCESS;
  85. }
  86. }
  87. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  88. map_send_info_.clear();
  89. return GRAPH_SUCCESS;
  90. }
  91. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  92. map_recv_info_.clear();
  93. return GRAPH_SUCCESS;
  94. }
  95. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  96. GE_CHECK_NOTNULL(src);
  97. NodePtr cur_ptr;
  98. if (depth < 1) {
  99. return GRAPH_FAILED;
  100. }
  101. for (int i = 0; i < depth; i++) {
  102. if (src->GetOutDataNodes().size() != 1) {
  103. return GRAPH_FAILED;
  104. }
  105. cur_ptr = src->GetOutDataNodes().at(0);
  106. GE_CHECK_NOTNULL(cur_ptr);
  107. }
  108. dst = cur_ptr;
  109. return GRAPH_SUCCESS;
  110. }
  111. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  112. InControlAnchorPtr &in_control) {
  113. GE_CHECK_NOTNULL(node_ptr);
  114. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  115. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  116. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  117. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  118. out_data = p;
  119. in_control = p_in;
  120. return GRAPH_SUCCESS;
  121. }
  122. }
  123. return GRAPH_FAILED;
  124. }
  125. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  126. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  127. "node or in_data_anchor is nullptr");
  128. bool find_flag = false;
  129. uint32_t index = 0;
  130. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  131. for (const auto &tmp : node_ptr->in_data_anchors_) {
  132. if (tmp == in_data_anchor) {
  133. find_flag = true;
  134. auto iter = node_ptr->in_data_anchors_.begin() + index;
  135. if (iter != node_ptr->in_data_anchors_.end()) {
  136. it = node_ptr->in_data_anchors_.erase(iter);
  137. }
  138. break;
  139. }
  140. index++;
  141. }
  142. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  143. (*it)->SetIdx(index);
  144. index++;
  145. }
  146. if (!find_flag) {
  147. return GRAPH_FAILED;
  148. }
  149. return GRAPH_SUCCESS;
  150. }
  151. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  152. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  153. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  154. return GRAPH_SUCCESS;
  155. }
  156. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  157. node.anchor_status_updated_ = true;
  158. return GRAPH_SUCCESS;
  159. }
  160. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  161. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  162. return IsAnchorStatusSet(*node_ptr);
  163. }
  164. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  165. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  166. if ((origin_node == nullptr) || (new_node == nullptr)) {
  167. return GRAPH_FAILED;
  168. }
  169. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  170. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  171. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  172. return GRAPH_FAILED;
  173. }
  174. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  175. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  176. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  177. "unlink peer_anchor failed");
  178. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  179. "linkto peer_anchor failed");
  180. }
  181. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  182. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  183. "unlink peer_anchor failed");
  184. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  185. "linkto peer_anchor failed");
  186. }
  187. }
  188. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  189. GE_CHECK_NOTNULL(origin_out_control_anchor);
  190. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  191. GE_CHECK_NOTNULL(new_out_control_anchor);
  192. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  193. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  194. "linkto peer_anchor failed");
  195. }
  196. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  197. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  198. "linkto peer_anchor failed");
  199. }
  200. origin_out_control_anchor->UnlinkAll();
  201. return GRAPH_SUCCESS;
  202. }
  203. bool NodeUtils::IsConst(const Node &node) {
  204. auto src_node_type = node.GetType();
  205. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  206. return is_const;
  207. }
  208. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  209. if (node_ptr == nullptr) {
  210. GELOGE(GRAPH_FAILED, "node is null");
  211. return;
  212. }
  213. UpdateIsInputConst(*node_ptr);
  214. }
  215. ///
  216. /// update is_input_const
  217. /// @param node
  218. /// @return void
  219. ///
  220. void NodeUtils::UpdateIsInputConst(Node &node) {
  221. std::vector<bool> is_input_const;
  222. size_t anchor_num = node.GetAllInDataAnchors().size();
  223. for (size_t i = 0; i < anchor_num; i++) {
  224. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  225. if (in_anchor == nullptr) {
  226. is_input_const.push_back(false);
  227. continue;
  228. }
  229. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  230. if (peer_out_anchor == nullptr) {
  231. is_input_const.push_back(false);
  232. continue;
  233. }
  234. auto src_node = peer_out_anchor->GetOwnerNode();
  235. if (src_node == nullptr) {
  236. is_input_const.push_back(false);
  237. continue;
  238. }
  239. if (IsConst(*(src_node))) {
  240. is_input_const.push_back(true);
  241. } else {
  242. is_input_const.push_back(false);
  243. }
  244. }
  245. if (node.GetOpDesc() == nullptr) {
  246. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  247. return;
  248. }
  249. node.GetOpDesc()->SetIsInputConst(is_input_const);
  250. }
  251. void NodeUtils::UnlinkAll(const Node &node) {
  252. for (const auto &anchor : node.GetAllOutAnchors()) {
  253. anchor->UnlinkAll();
  254. }
  255. for (const auto &anchor : node.GetAllInAnchors()) {
  256. anchor->UnlinkAll();
  257. }
  258. }
  259. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  260. if (node_ptr == nullptr) {
  261. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  262. return GRAPH_FAILED;
  263. }
  264. auto op_desc = node_ptr->GetOpDesc();
  265. if (op_desc == nullptr) {
  266. return GRAPH_FAILED;
  267. }
  268. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  269. GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx());
  270. ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
  271. output_tensor.SetOriginShape(output_tensor.GetShape());
  272. output_tensor.SetOriginDataType(output_tensor.GetDataType());
  273. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  274. node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(),
  275. TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(),
  276. TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str());
  277. (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
  278. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  279. if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  280. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  281. continue;
  282. }
  283. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  284. if (peer_input_desc == nullptr) {
  285. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  286. continue;
  287. }
  288. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  289. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(),
  290. output_tensor.GetDataType(), output_tensor.GetOriginDataType());
  291. peer_input_desc->SetShape(output_tensor.GetShape());
  292. peer_input_desc->SetOriginShape(output_tensor.GetOriginShape());
  293. peer_input_desc->SetDataType(output_tensor.GetDataType());
  294. peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType());
  295. std::vector<std::pair<int64_t, int64_t>> shape_range;
  296. (void)output_tensor.GetShapeRange(shape_range);
  297. peer_input_desc->SetShapeRange(shape_range);
  298. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  299. static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
  300. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  301. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
  302. peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
  303. }
  304. }
  305. return GRAPH_SUCCESS;
  306. }
  307. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  308. for (const auto &in_anchor : node.in_data_anchors_) {
  309. if (in_anchor != nullptr) {
  310. auto out_anchor = in_anchor->GetPeerOutAnchor();
  311. if (out_anchor != nullptr) {
  312. if (out_anchor->GetOwnerNode() != nullptr) {
  313. return false;
  314. }
  315. }
  316. }
  317. }
  318. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  319. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  320. for (const auto &out_control_anchor : peer_out_control_anchors) {
  321. if (out_control_anchor != nullptr) {
  322. if (out_control_anchor->GetOwnerNode() != nullptr) {
  323. return false;
  324. }
  325. }
  326. }
  327. }
  328. return true;
  329. }
  330. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  331. auto desc = node.GetOpDesc();
  332. if (desc == nullptr) {
  333. return GeTensorDesc();
  334. }
  335. return desc->GetOutputDesc(index);
  336. }
  337. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  338. auto desc = node.GetOpDesc();
  339. if (desc == nullptr) {
  340. return GeTensorDesc();
  341. }
  342. return desc->GetInputDesc(index);
  343. }
  344. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  345. auto desc = node.GetOpDesc();
  346. if (desc == nullptr) {
  347. return GRAPH_PARAM_INVALID;
  348. }
  349. auto output_desc = desc->MutableOutputDesc(index);
  350. if (output_desc == nullptr) {
  351. return GRAPH_PARAM_INVALID;
  352. }
  353. output_desc->SetShape(shape);
  354. return GRAPH_SUCCESS;
  355. }
  356. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  357. auto desc = node.GetOpDesc();
  358. if (desc == nullptr) {
  359. return GRAPH_PARAM_INVALID;
  360. }
  361. auto input_desc = desc->MutableInputDesc(index);
  362. if (input_desc == nullptr) {
  363. return GRAPH_PARAM_INVALID;
  364. }
  365. input_desc->SetShape(shape);
  366. return GRAPH_SUCCESS;
  367. }
  368. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  369. auto desc = node.GetOpDesc();
  370. GE_CHECK_NOTNULL(desc);
  371. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  372. if (sub_graph_names.empty()) {
  373. is_unknow = OpShapeIsUnknown(desc);
  374. return GRAPH_SUCCESS;
  375. } else {
  376. auto owner_graph = node.GetOwnerComputeGraph();
  377. GE_CHECK_NOTNULL(owner_graph);
  378. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  379. if (root_graph == nullptr) {
  380. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  381. return GRAPH_PARAM_INVALID;
  382. }
  383. for (auto &sub_graph_name : sub_graph_names) {
  384. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  385. GE_CHECK_NOTNULL(sub_graph);
  386. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  387. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  388. if (status != GRAPH_SUCCESS) {
  389. GE_LOGE("get node unknown shape status failed!");
  390. return status;
  391. }
  392. if (is_unknow) {
  393. return GRAPH_SUCCESS;
  394. }
  395. }
  396. }
  397. }
  398. return GRAPH_SUCCESS;
  399. }
  400. std::string NodeUtils::GetNodeType(const Node &node) {
  401. if (node.GetType() != FRAMEWORKOP) {
  402. return node.GetType();
  403. }
  404. std::string type;
  405. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  406. return type;
  407. }
  408. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  409. auto op_desc = node.GetOpDesc();
  410. if (op_desc == nullptr) {
  411. return nullptr;
  412. }
  413. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  414. if (root_graph == nullptr) {
  415. return nullptr;
  416. }
  417. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  418. }
  419. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  420. if (subgraph == nullptr) {
  421. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  422. return GRAPH_PARAM_INVALID;
  423. }
  424. auto op_desc = node.GetOpDesc();
  425. if (op_desc == nullptr) {
  426. return GRAPH_PARAM_INVALID;
  427. }
  428. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  429. if (root_graph == nullptr) {
  430. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  431. return GRAPH_PARAM_INVALID;
  432. }
  433. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  434. if (ret != GRAPH_SUCCESS) {
  435. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  436. return ret;
  437. }
  438. subgraph->SetParentNode(node.shared_from_this());
  439. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  440. return root_graph->AddSubgraph(subgraph);
  441. }
  442. ///
  443. /// Check if node is input of subgraph
  444. /// @param [in] node
  445. /// @return bool
  446. ///
  447. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  448. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  449. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  450. return false;
  451. }
  452. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  453. if (parent_op_desc == nullptr) {
  454. return false;
  455. }
  456. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  457. return false;
  458. }
  459. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  460. }
  461. ///
  462. /// Check if node is output of subgraph
  463. /// @param [in] node
  464. /// @return bool
  465. ///
  466. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  467. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  468. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  469. return false;
  470. }
  471. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  472. if (parent_op_desc == nullptr) {
  473. return false;
  474. }
  475. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  476. return false;
  477. }
  478. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  479. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  480. return true;
  481. }
  482. }
  483. return false;
  484. }
  485. ///
  486. /// @brief Get subgraph original input node.
  487. /// @param [in] node
  488. /// @return Node
  489. ///
  490. NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
  491. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  492. uint32_t parent_index = 0;
  493. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  494. return nullptr;
  495. }
  496. // Subgraph Data Node, check for constant input.
  497. const ComputeGraphPtr &graph = node->GetOwnerComputeGraph();
  498. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  499. const NodePtr &parent_node = graph->GetParentNode();
  500. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  501. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  502. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  503. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  504. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  505. return peer_out_anchor->GetOwnerNode();
  506. }
  507. ///
  508. /// @brief Get subgraph input is constant.
  509. /// @param [in] node
  510. /// @param [out] string
  511. /// @return bool
  512. ///
  513. bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) {
  514. GE_CHECK_NOTNULL_EXEC(in_node, return false);
  515. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  516. op_type = in_node->GetType();
  517. return true;
  518. }
  519. if (in_node->GetType() == DATA) {
  520. std::string const_type;
  521. if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) {
  522. return false;
  523. }
  524. if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) {
  525. op_type = const_type;
  526. return true;
  527. }
  528. }
  529. return false;
  530. }
  531. ///
  532. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  533. /// @param [in] node
  534. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  535. ///
  536. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  537. GE_CHECK_NOTNULL(node);
  538. auto op_desc = node->GetOpDesc();
  539. GE_CHECK_NOTNULL(op_desc);
  540. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  541. if (subgraph_names.empty()) {
  542. return GRAPH_SUCCESS;
  543. } else {
  544. auto owner_graph = node->GetOwnerComputeGraph();
  545. GE_CHECK_NOTNULL(owner_graph);
  546. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  547. GE_CHECK_NOTNULL(root_graph);
  548. std::unordered_set<std::string> subgraph_to_remove;
  549. for (auto &subgraph_name : subgraph_names) {
  550. std::deque<std::string> queue;
  551. queue.push_back(subgraph_name);
  552. subgraph_to_remove.insert(subgraph_name);
  553. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  554. while (!queue.empty()) {
  555. auto graph_name = queue.front();
  556. queue.pop_front();
  557. auto subgraph = root_graph->GetSubgraph(graph_name);
  558. GE_CHECK_NOTNULL(subgraph);
  559. for (const auto &sub_node : subgraph->GetDirectNode()) {
  560. auto sub_op_desc = sub_node->GetOpDesc();
  561. GE_CHECK_NOTNULL(sub_op_desc);
  562. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  563. // Subgraph and all nodes in it will be removed later,
  564. // no need to remove 'SubgraphInstanceName' in op desc here.
  565. for (auto &name : sub_names) {
  566. if (subgraph_to_remove.insert(name).second) {
  567. queue.push_back(name);
  568. }
  569. }
  570. }
  571. }
  572. }
  573. // Remove subgraph from root_graph
  574. for (const auto &name : subgraph_to_remove) {
  575. GELOGI("Remove subgraph:%s.", name.c_str());
  576. root_graph->RemoveSubgraph(name);
  577. }
  578. }
  579. return GRAPH_SUCCESS;
  580. }
  581. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.