You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/node_utils.h"
  17. #include "graph/utils/graph_utils.h"
  18. #include "debug/ge_op_types.h"
  19. #include "debug/ge_util.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/anchor.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/types.h"
  24. #include "utils/tensor_utils.h"
  25. #include "utils/type_utils.h"
  26. namespace ge {
  27. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  28. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  29. const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
  30. const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
  31. const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
  32. const std::set<std::string> kCaseOpTypes = {"Case"};
  33. const std::set<std::string> kForOpTypes = {"For"};
  34. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  35. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  36. auto ge_shape = ptr->GetShape();
  37. for (const auto &dim : ge_shape.GetDims()) {
  38. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  39. return true;
  40. }
  41. }
  42. }
  43. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  44. auto ge_shape = ptr->GetShape();
  45. for (const auto &dim : ge_shape.GetDims()) {
  46. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  47. return true;
  48. }
  49. }
  50. }
  51. return false;
  52. }
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  54. const uint32_t &event_id) {
  55. GE_CHECK_NOTNULL(node);
  56. map_send_info_[node].push_back(event_id);
  57. return GRAPH_SUCCESS;
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  60. const uint32_t &event_id) {
  61. GE_CHECK_NOTNULL(node);
  62. map_recv_info_[node].push_back(event_id);
  63. return GRAPH_SUCCESS;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  66. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  67. GE_CHECK_NOTNULL(node);
  68. auto find = map_send_info_.find(node);
  69. if (find == map_send_info_.end()) {
  70. return GRAPH_FAILED;
  71. } else {
  72. vec_send = find->second;
  73. return GRAPH_SUCCESS;
  74. }
  75. }
  76. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  77. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  78. GE_CHECK_NOTNULL(node);
  79. auto find = map_recv_info_.find(node);
  80. if (find == map_recv_info_.end()) {
  81. return GRAPH_FAILED;
  82. } else {
  83. vec_recv = find->second;
  84. return GRAPH_SUCCESS;
  85. }
  86. }
  87. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  88. map_send_info_.clear();
  89. return GRAPH_SUCCESS;
  90. }
  91. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  92. map_recv_info_.clear();
  93. return GRAPH_SUCCESS;
  94. }
  95. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  96. GE_CHECK_NOTNULL(src);
  97. NodePtr cur_ptr;
  98. if (depth < 1) {
  99. return GRAPH_FAILED;
  100. }
  101. for (int i = 0; i < depth; i++) {
  102. if (src->GetOutDataNodes().size() != 1) {
  103. return GRAPH_FAILED;
  104. }
  105. cur_ptr = src->GetOutDataNodes().at(0);
  106. GE_CHECK_NOTNULL(cur_ptr);
  107. }
  108. dst = cur_ptr;
  109. return GRAPH_SUCCESS;
  110. }
  111. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  112. InControlAnchorPtr &in_control) {
  113. GE_CHECK_NOTNULL(node_ptr);
  114. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  115. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  116. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  117. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  118. out_data = p;
  119. in_control = p_in;
  120. return GRAPH_SUCCESS;
  121. }
  122. }
  123. return GRAPH_FAILED;
  124. }
  125. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  126. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  127. "node or in_data_anchor is nullptr");
  128. bool find_flag = false;
  129. uint32_t index = 0;
  130. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  131. for (const auto &tmp : node_ptr->in_data_anchors_) {
  132. if (tmp == in_data_anchor) {
  133. find_flag = true;
  134. auto iter = node_ptr->in_data_anchors_.begin() + index;
  135. if (iter != node_ptr->in_data_anchors_.end()) {
  136. it = node_ptr->in_data_anchors_.erase(iter);
  137. }
  138. break;
  139. }
  140. index++;
  141. }
  142. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  143. (*it)->SetIdx(index);
  144. index++;
  145. }
  146. if (!find_flag) {
  147. return GRAPH_FAILED;
  148. }
  149. return GRAPH_SUCCESS;
  150. }
  151. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  152. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  153. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  154. return GRAPH_SUCCESS;
  155. }
  156. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  157. node.anchor_status_updated_ = true;
  158. return GRAPH_SUCCESS;
  159. }
  160. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  161. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  162. return IsAnchorStatusSet(*node_ptr);
  163. }
  164. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  165. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  166. if ((origin_node == nullptr) || (new_node == nullptr)) {
  167. return GRAPH_FAILED;
  168. }
  169. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  170. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  171. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  172. return GRAPH_FAILED;
  173. }
  174. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  175. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  176. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  177. "unlink peer_anchor failed");
  178. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  179. "linkto peer_anchor failed");
  180. }
  181. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  182. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  183. "unlink peer_anchor failed");
  184. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  185. "linkto peer_anchor failed");
  186. }
  187. }
  188. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  189. GE_CHECK_NOTNULL(origin_out_control_anchor);
  190. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  191. GE_CHECK_NOTNULL(new_out_control_anchor);
  192. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  193. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  194. "linkto peer_anchor failed");
  195. }
  196. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  197. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  198. "linkto peer_anchor failed");
  199. }
  200. origin_out_control_anchor->UnlinkAll();
  201. return GRAPH_SUCCESS;
  202. }
  203. bool NodeUtils::IsConst(const Node &node) {
  204. auto src_node_type = node.GetType();
  205. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  206. return is_const;
  207. }
  208. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  209. if (node_ptr == nullptr) {
  210. GELOGE(GRAPH_FAILED, "node is null");
  211. return;
  212. }
  213. UpdateIsInputConst(*node_ptr);
  214. }
  215. ///
  216. /// update is_input_const
  217. /// @param node
  218. /// @return void
  219. ///
  220. void NodeUtils::UpdateIsInputConst(Node &node) {
  221. std::vector<bool> is_input_const;
  222. size_t anchor_num = node.GetAllInDataAnchors().size();
  223. for (size_t i = 0; i < anchor_num; i++) {
  224. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  225. if (in_anchor == nullptr) {
  226. is_input_const.push_back(false);
  227. continue;
  228. }
  229. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  230. if (peer_out_anchor == nullptr) {
  231. is_input_const.push_back(false);
  232. continue;
  233. }
  234. auto src_node = peer_out_anchor->GetOwnerNode();
  235. if (src_node == nullptr) {
  236. is_input_const.push_back(false);
  237. continue;
  238. }
  239. if (IsConst(*(src_node))) {
  240. is_input_const.push_back(true);
  241. } else {
  242. is_input_const.push_back(false);
  243. }
  244. }
  245. if (node.GetOpDesc() == nullptr) {
  246. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  247. return;
  248. }
  249. node.GetOpDesc()->SetIsInputConst(is_input_const);
  250. }
  251. void NodeUtils::UnlinkAll(const Node &node) {
  252. for (const auto &anchor : node.GetAllOutAnchors()) {
  253. anchor->UnlinkAll();
  254. }
  255. for (const auto &anchor : node.GetAllInAnchors()) {
  256. anchor->UnlinkAll();
  257. }
  258. }
  259. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  260. if (node_ptr == nullptr) {
  261. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  262. return GRAPH_FAILED;
  263. }
  264. auto op_desc = node_ptr->GetOpDesc();
  265. if (op_desc == nullptr) {
  266. return GRAPH_FAILED;
  267. }
  268. bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  269. if (is_unknown_graph) {
  270. return GRAPH_SUCCESS;
  271. }
  272. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  273. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  274. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  275. output_tensor->SetOriginShape(output_tensor->GetShape());
  276. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  277. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  278. node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  279. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  280. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  281. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  282. if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  283. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  284. continue;
  285. }
  286. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  287. if (peer_input_desc == nullptr) {
  288. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  289. continue;
  290. }
  291. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  292. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
  293. output_tensor->GetDataType(), output_tensor->GetOriginDataType());
  294. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  295. peer_input_desc->SetShape(output_tensor->GetShape());
  296. peer_input_desc->SetDataType(output_tensor->GetDataType());
  297. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  298. std::vector<std::pair<int64_t, int64_t>> shape_range;
  299. (void)output_tensor->GetShapeRange(shape_range);
  300. peer_input_desc->SetShapeRange(shape_range);
  301. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  302. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  303. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  304. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
  305. peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
  306. }
  307. }
  308. return GRAPH_SUCCESS;
  309. }
  310. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  311. for (const auto &in_anchor : node.in_data_anchors_) {
  312. if (in_anchor != nullptr) {
  313. auto out_anchor = in_anchor->GetPeerOutAnchor();
  314. if (out_anchor != nullptr) {
  315. if (out_anchor->GetOwnerNode() != nullptr) {
  316. return false;
  317. }
  318. }
  319. }
  320. }
  321. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  322. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  323. for (const auto &out_control_anchor : peer_out_control_anchors) {
  324. if (out_control_anchor != nullptr) {
  325. if (out_control_anchor->GetOwnerNode() != nullptr) {
  326. return false;
  327. }
  328. }
  329. }
  330. }
  331. return true;
  332. }
  333. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  334. auto desc = node.GetOpDesc();
  335. if (desc == nullptr) {
  336. return GeTensorDesc();
  337. }
  338. return desc->GetOutputDesc(index);
  339. }
  340. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  341. auto desc = node.GetOpDesc();
  342. if (desc == nullptr) {
  343. return GeTensorDesc();
  344. }
  345. return desc->GetInputDesc(index);
  346. }
  347. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  348. auto desc = node.GetOpDesc();
  349. if (desc == nullptr) {
  350. return GRAPH_PARAM_INVALID;
  351. }
  352. auto output_desc = desc->MutableOutputDesc(index);
  353. if (output_desc == nullptr) {
  354. return GRAPH_PARAM_INVALID;
  355. }
  356. output_desc->SetShape(shape);
  357. return GRAPH_SUCCESS;
  358. }
  359. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  360. auto desc = node.GetOpDesc();
  361. if (desc == nullptr) {
  362. return GRAPH_PARAM_INVALID;
  363. }
  364. auto input_desc = desc->MutableInputDesc(index);
  365. if (input_desc == nullptr) {
  366. return GRAPH_PARAM_INVALID;
  367. }
  368. input_desc->SetShape(shape);
  369. return GRAPH_SUCCESS;
  370. }
  371. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  372. auto desc = node.GetOpDesc();
  373. GE_CHECK_NOTNULL(desc);
  374. // check self
  375. is_unknow = OpShapeIsUnknown(desc);
  376. if (is_unknow) {
  377. return GRAPH_SUCCESS;
  378. }
  379. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  380. if (sub_graph_names.empty()) {
  381. return GRAPH_SUCCESS;
  382. } else {
  383. auto owner_graph = node.GetOwnerComputeGraph();
  384. GE_CHECK_NOTNULL(owner_graph);
  385. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  386. if (root_graph == nullptr) {
  387. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  388. return GRAPH_PARAM_INVALID;
  389. }
  390. for (auto &sub_graph_name : sub_graph_names) {
  391. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  392. GE_CHECK_NOTNULL(sub_graph);
  393. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  394. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  395. if (status != GRAPH_SUCCESS) {
  396. GE_LOGE("get node unknown shape status failed!");
  397. return status;
  398. }
  399. if (is_unknow) {
  400. return GRAPH_SUCCESS;
  401. }
  402. }
  403. }
  404. }
  405. return GRAPH_SUCCESS;
  406. }
  407. std::string NodeUtils::GetNodeType(const Node &node) {
  408. if (node.GetType() != FRAMEWORKOP) {
  409. return node.GetType();
  410. }
  411. std::string type;
  412. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  413. return type;
  414. }
  415. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  416. auto op_desc = node.GetOpDesc();
  417. if (op_desc == nullptr) {
  418. return nullptr;
  419. }
  420. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  421. if (root_graph == nullptr) {
  422. return nullptr;
  423. }
  424. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  425. }
  426. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  427. if (subgraph == nullptr) {
  428. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  429. return GRAPH_PARAM_INVALID;
  430. }
  431. auto op_desc = node.GetOpDesc();
  432. if (op_desc == nullptr) {
  433. return GRAPH_PARAM_INVALID;
  434. }
  435. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  436. if (root_graph == nullptr) {
  437. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  438. return GRAPH_PARAM_INVALID;
  439. }
  440. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  441. if (ret != GRAPH_SUCCESS) {
  442. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  443. return ret;
  444. }
  445. subgraph->SetParentNode(node.shared_from_this());
  446. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  447. return root_graph->AddSubgraph(subgraph);
  448. }
  449. ///
  450. /// Check if node is input of subgraph
  451. /// @param [in] node
  452. /// @return bool
  453. ///
  454. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  455. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  456. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  457. return false;
  458. }
  459. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  460. if (parent_op_desc == nullptr) {
  461. return false;
  462. }
  463. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  464. bool is_unknown_shape = false;
  465. (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
  466. if (is_unknown_shape) return false;
  467. }
  468. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
  469. kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
  470. kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
  471. return false;
  472. }
  473. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  474. }
  475. ///
  476. /// Check if node is output of subgraph
  477. /// @param [in] node
  478. /// @return bool
  479. ///
  480. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  481. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  482. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  483. return false;
  484. }
  485. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  486. if (parent_op_desc == nullptr) {
  487. return false;
  488. }
  489. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  490. bool is_unknown_shape = false;
  491. (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape);
  492. if (is_unknown_shape) return false;
  493. }
  494. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) &&
  495. kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 &&
  496. kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) {
  497. return false;
  498. }
  499. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  500. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  501. return true;
  502. }
  503. }
  504. return false;
  505. }
  506. ///
  507. /// @brief Get subgraph original input node.
  508. /// @param [in] node
  509. /// @return Node
  510. ///
  511. NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
  512. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  513. uint32_t parent_index = 0;
  514. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  515. return nullptr;
  516. }
  517. // Subgraph Data Node, check for constant input.
  518. const ComputeGraphPtr &graph = node->GetOwnerComputeGraph();
  519. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  520. const NodePtr &parent_node = graph->GetParentNode();
  521. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  522. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  523. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  524. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  525. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  526. return peer_out_anchor->GetOwnerNode();
  527. }
  528. ///
  529. /// @brief Check is varying_input for while node
  530. /// @param [in] node: Data node for subgraph
  531. /// @return bool
  532. ///
  533. bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
  534. if (node == nullptr) {
  535. return false;
  536. }
  537. if (node->GetType() != DATA) {
  538. return false; // not input_node for subgraph
  539. }
  540. const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  541. if (parent_node == nullptr) {
  542. return false; // root graph
  543. }
  544. if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
  545. return false; // not input_node for while subgraph
  546. }
  547. uint32_t index_i = 0;
  548. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
  549. GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
  550. return false;
  551. }
  552. bool varying_flag = true;
  553. for (const auto &item : node->GetOutDataNodesAndAnchors()) {
  554. if (item.first->GetType() != NETOUTPUT) {
  555. continue;
  556. }
  557. OpDescPtr op_desc = item.first->GetOpDesc();
  558. uint32_t index_o = 0;
  559. if ((op_desc == nullptr) ||
  560. !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
  561. continue; // input for while-cond subgraph
  562. }
  563. if (index_i != index_o) {
  564. continue; // varying input for while-body subgraph
  565. }
  566. varying_flag = false;
  567. break;
  568. }
  569. return varying_flag;
  570. }
  571. ///
  572. /// @brief Get subgraph input is constant.
  573. /// @param [in] node
  574. /// @param [out] string
  575. /// @return bool
  576. ///
  577. bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) {
  578. GE_CHECK_NOTNULL_EXEC(in_node, return false);
  579. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  580. op_type = in_node->GetType();
  581. return true;
  582. }
  583. if (in_node->GetType() == DATA) {
  584. std::string const_type;
  585. if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) {
  586. return false;
  587. }
  588. if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) {
  589. op_type = const_type;
  590. return true;
  591. }
  592. }
  593. return false;
  594. }
  595. ///
  596. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  597. /// @param [in] node
  598. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  599. ///
  600. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  601. GE_CHECK_NOTNULL(node);
  602. auto op_desc = node->GetOpDesc();
  603. GE_CHECK_NOTNULL(op_desc);
  604. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  605. if (subgraph_names.empty()) {
  606. return GRAPH_SUCCESS;
  607. } else {
  608. auto owner_graph = node->GetOwnerComputeGraph();
  609. GE_CHECK_NOTNULL(owner_graph);
  610. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  611. GE_CHECK_NOTNULL(root_graph);
  612. std::unordered_set<std::string> subgraph_to_remove;
  613. for (auto &subgraph_name : subgraph_names) {
  614. std::deque<std::string> queue;
  615. queue.push_back(subgraph_name);
  616. subgraph_to_remove.insert(subgraph_name);
  617. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  618. while (!queue.empty()) {
  619. auto graph_name = queue.front();
  620. queue.pop_front();
  621. auto subgraph = root_graph->GetSubgraph(graph_name);
  622. GE_CHECK_NOTNULL(subgraph);
  623. for (const auto &sub_node : subgraph->GetDirectNode()) {
  624. auto sub_op_desc = sub_node->GetOpDesc();
  625. GE_CHECK_NOTNULL(sub_op_desc);
  626. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  627. // Subgraph and all nodes in it will be removed later,
  628. // no need to remove 'SubgraphInstanceName' in op desc here.
  629. for (auto &name : sub_names) {
  630. if (subgraph_to_remove.insert(name).second) {
  631. queue.push_back(name);
  632. }
  633. }
  634. }
  635. }
  636. }
  637. // Remove subgraph from root_graph
  638. for (const auto &name : subgraph_to_remove) {
  639. GELOGI("Remove subgraph:%s.", name.c_str());
  640. root_graph->RemoveSubgraph(name);
  641. }
  642. }
  643. return GRAPH_SUCCESS;
  644. }
  645. ///
  646. /// @brief Get subgraph input data node by index.
  647. /// @param [in] node
  648. /// @return Node
  649. ///
  650. vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
  651. vector<NodePtr> in_data_node_vec;
  652. auto op_desc = node.GetOpDesc();
  653. GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
  654. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  655. if (subgraph_names.empty()) {
  656. GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
  657. return in_data_node_vec;
  658. }
  659. auto compute_graph = node.GetOwnerComputeGraph();
  660. for (const std::string &instance_name : subgraph_names) {
  661. auto subgraph = compute_graph->GetSubgraph(instance_name);
  662. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  663. int parent_index = 0;
  664. if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
  665. (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  666. if (parent_index == index) {
  667. in_data_node_vec.emplace_back(node_in_subgraph);
  668. }
  669. }
  670. }
  671. }
  672. return in_data_node_vec;
  673. }
  674. ///
  675. /// @brief Get subgraph input data node by index.
  676. /// @param [in] node
  677. /// @return Node
  678. ///
  679. vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
  680. vector<NodePtr> out_data_node_vec;
  681. auto op_desc = node.GetOpDesc();
  682. GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
  683. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  684. if (subgraph_names.empty()) {
  685. GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
  686. return out_data_node_vec;
  687. }
  688. auto compute_graph = node.GetOwnerComputeGraph();
  689. for (const std::string &instance_name : subgraph_names) {
  690. auto subgraph = compute_graph->GetSubgraph(instance_name);
  691. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  692. if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
  693. out_data_node_vec.emplace_back(node_in_subgraph);
  694. }
  695. }
  696. }
  697. return out_data_node_vec;
  698. }
  699. NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) {
  700. if (node.GetInDataAnchor(index) == nullptr) {
  701. return nullptr;
  702. }
  703. if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
  704. return nullptr;
  705. }
  706. return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
  707. }
  708. vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) {
  709. vector<NodePtr> out_data_nodes;
  710. auto out_data_anchor = node.GetOutDataAnchor(index);
  711. if (out_data_anchor == nullptr) {
  712. return out_data_nodes;
  713. }
  714. for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  715. if (peer_in_anchor == nullptr) {
  716. continue;
  717. }
  718. if (peer_in_anchor->GetOwnerNode() == nullptr) {
  719. continue;
  720. }
  721. out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  722. }
  723. return out_data_nodes;
  724. }
  725. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.