You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_utils.cc 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/node_utils.h"
  17. #include "utils/op_desc_utils.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/types.h"
  25. #include "utils/tensor_utils.h"
  26. #include "utils/type_utils.h"
  27. namespace ge {
  28. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
  29. std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
  30. const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
  31. const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
  32. const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
  33. const std::set<std::string> kCaseOpTypes = {"Case"};
  34. const std::set<std::string> kForOpTypes = {"For"};
  35. bool OpShapeIsUnknown(const OpDescPtr &desc) {
  36. for (const auto &ptr : desc->GetAllInputsDescPtr()) {
  37. auto ge_shape = ptr->GetShape();
  38. for (const auto &dim : ge_shape.GetDims()) {
  39. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  40. return true;
  41. }
  42. }
  43. }
  44. for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
  45. auto ge_shape = ptr->GetShape();
  46. for (const auto &dim : ge_shape.GetDims()) {
  47. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  48. return true;
  49. }
  50. }
  51. }
  52. return false;
  53. }
  54. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
  55. const uint32_t &event_id) {
  56. GE_CHECK_NOTNULL(node);
  57. map_send_info_[node].push_back(event_id);
  58. return GRAPH_SUCCESS;
  59. }
  60. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
  61. const uint32_t &event_id) {
  62. GE_CHECK_NOTNULL(node);
  63. map_recv_info_[node].push_back(event_id);
  64. return GRAPH_SUCCESS;
  65. }
  66. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  67. NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
  68. GE_CHECK_NOTNULL(node);
  69. auto find = map_send_info_.find(node);
  70. if (find == map_send_info_.end()) {
  71. return GRAPH_FAILED;
  72. } else {
  73. vec_send = find->second;
  74. return GRAPH_SUCCESS;
  75. }
  76. }
  77. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  78. NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
  79. GE_CHECK_NOTNULL(node);
  80. auto find = map_recv_info_.find(node);
  81. if (find == map_recv_info_.end()) {
  82. return GRAPH_FAILED;
  83. } else {
  84. vec_recv = find->second;
  85. return GRAPH_SUCCESS;
  86. }
  87. }
  88. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
  89. map_send_info_.clear();
  90. return GRAPH_SUCCESS;
  91. }
  92. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
  93. map_recv_info_.clear();
  94. return GRAPH_SUCCESS;
  95. }
  96. graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
  97. GE_CHECK_NOTNULL(src);
  98. NodePtr cur_ptr;
  99. if (depth < 1) {
  100. return GRAPH_FAILED;
  101. }
  102. for (int i = 0; i < depth; i++) {
  103. if (src->GetOutDataNodes().size() != 1) {
  104. return GRAPH_FAILED;
  105. }
  106. cur_ptr = src->GetOutDataNodes().at(0);
  107. GE_CHECK_NOTNULL(cur_ptr);
  108. }
  109. dst = cur_ptr;
  110. return GRAPH_SUCCESS;
  111. }
  112. graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
  113. InControlAnchorPtr &in_control) {
  114. GE_CHECK_NOTNULL(node_ptr);
  115. for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
  116. GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
  117. for (const auto &p_in : p->GetPeerInControlAnchors()) {
  118. GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
  119. out_data = p;
  120. in_control = p_in;
  121. return GRAPH_SUCCESS;
  122. }
  123. }
  124. return GRAPH_FAILED;
  125. }
  126. graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
  127. GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
  128. "node or in_data_anchor is nullptr");
  129. bool find_flag = false;
  130. uint32_t index = 0;
  131. vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
  132. for (const auto &tmp : node_ptr->in_data_anchors_) {
  133. if (tmp == in_data_anchor) {
  134. find_flag = true;
  135. auto iter = node_ptr->in_data_anchors_.begin() + index;
  136. if (iter != node_ptr->in_data_anchors_.end()) {
  137. it = node_ptr->in_data_anchors_.erase(iter);
  138. }
  139. break;
  140. }
  141. index++;
  142. }
  143. for (; it != node_ptr->in_data_anchors_.end(); ++it) {
  144. (*it)->SetIdx(index);
  145. index++;
  146. }
  147. if (!find_flag) {
  148. return GRAPH_FAILED;
  149. }
  150. return GRAPH_SUCCESS;
  151. }
  152. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
  153. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
  154. GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
  155. return GRAPH_SUCCESS;
  156. }
  157. graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
  158. node.anchor_status_updated_ = true;
  159. return GRAPH_SUCCESS;
  160. }
  161. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
  162. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
  163. return IsAnchorStatusSet(*node_ptr);
  164. }
  165. bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
  166. graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
  167. if ((origin_node == nullptr) || (new_node == nullptr)) {
  168. return GRAPH_FAILED;
  169. }
  170. auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
  171. auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
  172. if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
  173. return GRAPH_FAILED;
  174. }
  175. for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
  176. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
  177. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  178. "unlink peer_anchor failed");
  179. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  180. "linkto peer_anchor failed");
  181. }
  182. for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
  183. GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
  184. "unlink peer_anchor failed");
  185. GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  186. "linkto peer_anchor failed");
  187. }
  188. }
  189. auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
  190. GE_CHECK_NOTNULL(origin_out_control_anchor);
  191. auto new_out_control_anchor = new_node->GetOutControlAnchor();
  192. GE_CHECK_NOTNULL(new_out_control_anchor);
  193. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
  194. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  195. "linkto peer_anchor failed");
  196. }
  197. for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
  198. GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
  199. "linkto peer_anchor failed");
  200. }
  201. origin_out_control_anchor->UnlinkAll();
  202. return GRAPH_SUCCESS;
  203. }
  204. bool NodeUtils::IsConst(const Node &node) {
  205. auto src_node_type = node.GetType();
  206. bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
  207. return is_const;
  208. }
  209. void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
  210. if (node_ptr == nullptr) {
  211. GELOGE(GRAPH_FAILED, "node is null");
  212. return;
  213. }
  214. UpdateIsInputConst(*node_ptr);
  215. }
  216. ///
  217. /// update is_input_const
  218. /// @param node
  219. /// @return void
  220. ///
  221. void NodeUtils::UpdateIsInputConst(Node &node) {
  222. std::vector<bool> is_input_const;
  223. size_t anchor_num = node.GetAllInDataAnchors().size();
  224. for (size_t i = 0; i < anchor_num; i++) {
  225. auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
  226. if (in_anchor == nullptr) {
  227. is_input_const.push_back(false);
  228. continue;
  229. }
  230. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  231. if (peer_out_anchor == nullptr) {
  232. is_input_const.push_back(false);
  233. continue;
  234. }
  235. auto src_node = peer_out_anchor->GetOwnerNode();
  236. if (src_node == nullptr) {
  237. is_input_const.push_back(false);
  238. continue;
  239. }
  240. if (IsConst(*(src_node))) {
  241. is_input_const.push_back(true);
  242. } else {
  243. is_input_const.push_back(false);
  244. }
  245. }
  246. if (node.GetOpDesc() == nullptr) {
  247. GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
  248. return;
  249. }
  250. node.GetOpDesc()->SetIsInputConst(is_input_const);
  251. }
  252. void NodeUtils::UnlinkAll(const Node &node) {
  253. for (const auto &anchor : node.GetAllOutAnchors()) {
  254. anchor->UnlinkAll();
  255. }
  256. for (const auto &anchor : node.GetAllInAnchors()) {
  257. anchor->UnlinkAll();
  258. }
  259. }
  260. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
  261. if (node_ptr == nullptr) {
  262. GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
  263. return GRAPH_FAILED;
  264. }
  265. auto op_desc = node_ptr->GetOpDesc();
  266. if (op_desc == nullptr) {
  267. return GRAPH_FAILED;
  268. }
  269. bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  270. if (is_unknown_graph) {
  271. return GRAPH_SUCCESS;
  272. }
  273. for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
  274. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  275. auto out_dims = output_tensor->GetShape().GetDims();
  276. auto out_dtype = output_tensor->GetDataType();
  277. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  278. output_tensor->SetOriginShape(output_tensor->GetShape());
  279. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  280. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  281. node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  282. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  283. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  284. for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
  285. if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  286. GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
  287. continue;
  288. }
  289. auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
  290. if (peer_input_desc == nullptr) {
  291. GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
  292. continue;
  293. }
  294. // check shape and dtype continuity. do not stop process
  295. auto peer_input_dims = peer_input_desc->GetShape().GetDims();
  296. auto peer_input_dtype = peer_input_desc->GetDataType();
  297. if (out_dtype != peer_input_dtype) {
  298. GELOGW(
  299. "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
  300. "input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  301. node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
  302. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
  303. TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
  304. } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
  305. string out_shape_str, peer_in_shape_str;
  306. out_shape_str += "[";
  307. for (int64_t dim : out_dims) {
  308. out_shape_str += std::to_string(dim) + " ";
  309. }
  310. out_shape_str += "]";
  311. peer_in_shape_str += "[";
  312. for (int64_t dim : peer_input_dims) {
  313. peer_in_shape_str += std::to_string(dim) + " ";
  314. }
  315. peer_in_shape_str += "]";
  316. GELOGW(
  317. "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
  318. "input_shape is [%s].The two shape should be same! Please check graph and fix it",
  319. node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(),
  320. peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str());
  321. }
  322. GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
  323. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
  324. output_tensor->GetDataType(), output_tensor->GetOriginDataType());
  325. peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
  326. peer_input_desc->SetShape(output_tensor->GetShape());
  327. peer_input_desc->SetDataType(output_tensor->GetDataType());
  328. peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
  329. std::vector<std::pair<int64_t, int64_t>> shape_range;
  330. (void)output_tensor->GetShapeRange(shape_range);
  331. peer_input_desc->SetShapeRange(shape_range);
  332. ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
  333. static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  334. GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
  335. peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
  336. peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
  337. }
  338. }
  339. return GRAPH_SUCCESS;
  340. }
  341. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
  342. uint32_t num) {
  343. if (node == nullptr) {
  344. GELOGE(GRAPH_FAILED, "Input node is null");
  345. return GRAPH_FAILED;
  346. }
  347. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  348. const auto &op_desc = node->GetOpDesc();
  349. for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
  350. if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
  351. GELOGE(GRAPH_FAILED, "Add input desc failed");
  352. return GRAPH_FAILED;
  353. }
  354. auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
  355. if (anchor == nullptr) {
  356. GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
  357. return GRAPH_FAILED;
  358. }
  359. node->in_data_anchors_.push_back(anchor);
  360. }
  361. return GRAPH_SUCCESS;
  362. }
  363. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
  364. uint32_t num) {
  365. if (node == nullptr) {
  366. GELOGE(GRAPH_FAILED, "Input node is null");
  367. return GRAPH_FAILED;
  368. }
  369. const auto &op_desc = node->GetOpDesc();
  370. while (op_desc->GetInputsSize() > num) {
  371. if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
  372. return GRAPH_FAILED;
  373. }
  374. }
  375. auto input_names = op_desc->GetAllInputName();
  376. (void)op_desc->UpdateInputName(input_names);
  377. auto is_input_const = op_desc->GetIsInputConst();
  378. is_input_const.resize(num);
  379. op_desc->SetIsInputConst(is_input_const);
  380. while (node->in_data_anchors_.size() > num) {
  381. node->in_data_anchors_.pop_back();
  382. }
  383. return GRAPH_SUCCESS;
  384. }
  385. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node,
  386. uint32_t num) {
  387. if (node == nullptr) {
  388. GELOGE(GRAPH_FAILED, "Input node is null");
  389. return GRAPH_FAILED;
  390. }
  391. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
  392. const OpDescPtr &op_desc = node->GetOpDesc();
  393. for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
  394. if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
  395. GELOGE(GRAPH_FAILED, "Add output desc failed");
  396. return GRAPH_FAILED;
  397. }
  398. auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
  399. if (anchor == nullptr) {
  400. GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
  401. return GRAPH_FAILED;
  402. }
  403. node->out_data_anchors_.push_back(anchor);
  404. }
  405. return GRAPH_SUCCESS;
  406. }
  407. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node,
  408. uint32_t num) {
  409. if (node == nullptr) {
  410. GELOGE(GRAPH_FAILED, "Input node is null");
  411. return GRAPH_FAILED;
  412. }
  413. const auto &op_desc = node->GetOpDesc();
  414. auto output_names = op_desc->GetAllOutputName();
  415. while (op_desc->GetOutputsSize() > num) {
  416. if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
  417. return GRAPH_FAILED;
  418. }
  419. }
  420. (void)op_desc->UpdateOutputName(output_names);
  421. while (node->out_data_anchors_.size() > num) {
  422. node->out_data_anchors_.pop_back();
  423. }
  424. return GRAPH_SUCCESS;
  425. }
  426. bool NodeUtils::IsInNodesEmpty(const Node &node) {
  427. for (const auto &in_anchor : node.in_data_anchors_) {
  428. if (in_anchor != nullptr) {
  429. auto out_anchor = in_anchor->GetPeerOutAnchor();
  430. if (out_anchor != nullptr) {
  431. if (out_anchor->GetOwnerNode() != nullptr) {
  432. return false;
  433. }
  434. }
  435. }
  436. }
  437. if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
  438. auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
  439. for (const auto &out_control_anchor : peer_out_control_anchors) {
  440. if (out_control_anchor != nullptr) {
  441. if (out_control_anchor->GetOwnerNode() != nullptr) {
  442. return false;
  443. }
  444. }
  445. }
  446. }
  447. return true;
  448. }
  449. GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
  450. auto desc = node.GetOpDesc();
  451. if (desc == nullptr) {
  452. return GeTensorDesc();
  453. }
  454. return desc->GetOutputDesc(index);
  455. }
  456. GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
  457. auto desc = node.GetOpDesc();
  458. if (desc == nullptr) {
  459. return GeTensorDesc();
  460. }
  461. return desc->GetInputDesc(index);
  462. }
  463. graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
  464. auto desc = node.GetOpDesc();
  465. if (desc == nullptr) {
  466. return GRAPH_PARAM_INVALID;
  467. }
  468. auto output_desc = desc->MutableOutputDesc(index);
  469. if (output_desc == nullptr) {
  470. return GRAPH_PARAM_INVALID;
  471. }
  472. output_desc->SetShape(shape);
  473. return GRAPH_SUCCESS;
  474. }
  475. graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
  476. auto desc = node.GetOpDesc();
  477. if (desc == nullptr) {
  478. return GRAPH_PARAM_INVALID;
  479. }
  480. auto input_desc = desc->MutableInputDesc(index);
  481. if (input_desc == nullptr) {
  482. return GRAPH_PARAM_INVALID;
  483. }
  484. input_desc->SetShape(shape);
  485. return GRAPH_SUCCESS;
  486. }
  487. graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
  488. auto desc = node.GetOpDesc();
  489. GE_CHECK_NOTNULL(desc);
  490. // check self
  491. is_unknow = OpShapeIsUnknown(desc);
  492. if (is_unknow) {
  493. return GRAPH_SUCCESS;
  494. }
  495. auto sub_graph_names = desc->GetSubgraphInstanceNames();
  496. if (sub_graph_names.empty()) {
  497. return GRAPH_SUCCESS;
  498. } else {
  499. auto owner_graph = node.GetOwnerComputeGraph();
  500. GE_CHECK_NOTNULL(owner_graph);
  501. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  502. if (root_graph == nullptr) {
  503. GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
  504. return GRAPH_PARAM_INVALID;
  505. }
  506. for (auto &sub_graph_name : sub_graph_names) {
  507. auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
  508. GE_CHECK_NOTNULL(sub_graph);
  509. for (const auto &node_ptr : sub_graph->GetDirectNode()) {
  510. auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
  511. if (status != GRAPH_SUCCESS) {
  512. GE_LOGE("get node unknown shape status failed!");
  513. return status;
  514. }
  515. if (is_unknow) {
  516. return GRAPH_SUCCESS;
  517. }
  518. }
  519. }
  520. }
  521. return GRAPH_SUCCESS;
  522. }
  523. std::string NodeUtils::GetNodeType(const Node &node) {
  524. if (node.GetType() != FRAMEWORKOP) {
  525. return node.GetType();
  526. }
  527. std::string type;
  528. (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  529. return type;
  530. }
  531. std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); }
  532. graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) {
  533. return GRAPH_SUCCESS;
  534. }
  535. graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) {
  536. return GRAPH_SUCCESS;
  537. }
  538. ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
  539. auto op_desc = node.GetOpDesc();
  540. if (op_desc == nullptr) {
  541. return nullptr;
  542. }
  543. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  544. if (root_graph == nullptr) {
  545. return nullptr;
  546. }
  547. return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
  548. }
  549. graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
  550. if (subgraph == nullptr) {
  551. GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
  552. return GRAPH_PARAM_INVALID;
  553. }
  554. auto op_desc = node.GetOpDesc();
  555. if (op_desc == nullptr) {
  556. return GRAPH_PARAM_INVALID;
  557. }
  558. auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  559. if (root_graph == nullptr) {
  560. GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
  561. return GRAPH_PARAM_INVALID;
  562. }
  563. auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
  564. if (ret != GRAPH_SUCCESS) {
  565. GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
  566. return ret;
  567. }
  568. subgraph->SetParentNode(node.shared_from_this());
  569. subgraph->SetParentGraph(node.GetOwnerComputeGraph());
  570. return root_graph->AddSubgraph(subgraph);
  571. }
  572. ///
  573. /// Check if node is input of subgraph
  574. /// @param [in] node
  575. /// @return bool
  576. ///
  577. bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
  578. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  579. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
  580. return false;
  581. }
  582. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  583. if (parent_op_desc == nullptr) {
  584. return false;
  585. }
  586. // dynamic shape unknown graph false
  587. // dynamic shape known graph with functional subgraph maybe true
  588. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  589. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  590. return false;
  591. } else {
  592. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  593. return false;
  594. }
  595. }
  596. }
  597. return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
  598. }
  599. ///
  600. /// Check if node is output of subgraph
  601. /// @param [in] node
  602. /// @return bool
  603. ///
  604. bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
  605. if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
  606. (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
  607. return false;
  608. }
  609. auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
  610. if (parent_op_desc == nullptr) {
  611. return false;
  612. }
  613. if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
  614. if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
  615. return false;
  616. } else {
  617. if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  618. return false;
  619. }
  620. }
  621. }
  622. for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
  623. if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
  624. return true;
  625. }
  626. }
  627. return false;
  628. }
  629. ///
  630. /// @brief Get subgraph original input node.
  631. /// @param [in] node
  632. /// @return Node
  633. ///
  634. NodePtr NodeUtils::GetParentInput(const Node &node) {
  635. uint32_t parent_index = 0;
  636. if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  637. return nullptr;
  638. }
  639. // Subgraph Data Node, check for constant input.
  640. const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
  641. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  642. const NodePtr &parent_node = graph->GetParentNode();
  643. GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
  644. const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
  645. GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
  646. const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
  647. GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
  648. return peer_out_anchor->GetOwnerNode();
  649. }
  650. NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); }
  651. ///
  652. /// @brief Get is dynamic shape graph from node.
  653. /// @param [in] node
  654. /// @return bool
  655. ///
  656. bool NodeUtils::IsDynamicShape(const Node &node) {
  657. const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
  658. if (graph == nullptr) {
  659. return false;
  660. }
  661. bool is_dynamic_shape = false;
  662. (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
  663. return is_dynamic_shape;
  664. }
  665. bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); }
  666. ///
  667. /// @brief Check is varying_input for while node
  668. /// @param [in] node: Data node for subgraph
  669. /// @return bool
  670. ///
  671. bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
  672. if (node == nullptr) {
  673. return false;
  674. }
  675. if (node->GetType() != DATA) {
  676. return false; // not input_node for subgraph
  677. }
  678. const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  679. if (parent_node == nullptr) {
  680. return false; // root graph
  681. }
  682. if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
  683. return false; // not input_node for while subgraph
  684. }
  685. uint32_t index_i = 0;
  686. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
  687. GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
  688. return false;
  689. }
  690. bool varying_flag = true;
  691. for (const auto &item : node->GetOutDataNodesAndAnchors()) {
  692. if (item.first->GetType() != NETOUTPUT) {
  693. continue;
  694. }
  695. OpDescPtr op_desc = item.first->GetOpDesc();
  696. uint32_t index_o = 0;
  697. if ((op_desc == nullptr) ||
  698. !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
  699. continue; // input for while-cond subgraph
  700. }
  701. if (index_i != index_o) {
  702. continue; // varying input for while-body subgraph
  703. }
  704. varying_flag = false;
  705. break;
  706. }
  707. return varying_flag;
  708. }
  709. ///
  710. /// @brief Get subgraph input is constant.
  711. /// @param [in] node
  712. /// @param [out] string
  713. /// @return bool
  714. ///
  715. bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
  716. if (node == nullptr) {
  717. return false;
  718. }
  719. if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
  720. type = node->GetType();
  721. return true;
  722. }
  723. if (node->GetType() != DATA) {
  724. return false; // not subgraph input node
  725. }
  726. const auto &parent = GetParentInput(node);
  727. return GetConstOpType(parent, type);
  728. }
  729. ///
  730. /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
  731. /// @param [in] node
  732. /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
  733. ///
  734. Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
  735. GE_CHECK_NOTNULL(node);
  736. auto op_desc = node->GetOpDesc();
  737. GE_CHECK_NOTNULL(op_desc);
  738. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  739. if (subgraph_names.empty()) {
  740. return GRAPH_SUCCESS;
  741. } else {
  742. auto owner_graph = node->GetOwnerComputeGraph();
  743. GE_CHECK_NOTNULL(owner_graph);
  744. auto root_graph = GraphUtils::FindRootGraph(owner_graph);
  745. GE_CHECK_NOTNULL(root_graph);
  746. std::unordered_set<std::string> subgraph_to_remove;
  747. for (auto &subgraph_name : subgraph_names) {
  748. std::deque<std::string> queue;
  749. queue.push_back(subgraph_name);
  750. subgraph_to_remove.insert(subgraph_name);
  751. op_desc->RemoveSubgraphInstanceName(subgraph_name);
  752. while (!queue.empty()) {
  753. auto graph_name = queue.front();
  754. queue.pop_front();
  755. auto subgraph = root_graph->GetSubgraph(graph_name);
  756. GE_CHECK_NOTNULL(subgraph);
  757. for (const auto &sub_node : subgraph->GetDirectNode()) {
  758. auto sub_op_desc = sub_node->GetOpDesc();
  759. GE_CHECK_NOTNULL(sub_op_desc);
  760. auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
  761. // Subgraph and all nodes in it will be removed later,
  762. // no need to remove 'SubgraphInstanceName' in op desc here.
  763. for (auto &name : sub_names) {
  764. if (subgraph_to_remove.insert(name).second) {
  765. queue.push_back(name);
  766. }
  767. }
  768. }
  769. }
  770. }
  771. // Remove subgraph from root_graph
  772. for (const auto &name : subgraph_to_remove) {
  773. GELOGI("Remove subgraph:%s.", name.c_str());
  774. root_graph->RemoveSubgraph(name);
  775. }
  776. }
  777. return GRAPH_SUCCESS;
  778. }
  779. ///
  780. /// @brief Get subgraph input data node by index.
  781. /// @param [in] node
  782. /// @return Node
  783. ///
  784. vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
  785. vector<NodePtr> in_data_node_vec;
  786. auto op_desc = node.GetOpDesc();
  787. GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
  788. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  789. if (subgraph_names.empty()) {
  790. GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
  791. return in_data_node_vec;
  792. }
  793. auto compute_graph = node.GetOwnerComputeGraph();
  794. for (const std::string &instance_name : subgraph_names) {
  795. auto subgraph = compute_graph->GetSubgraph(instance_name);
  796. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  797. int parent_index = -1;
  798. if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
  799. (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  800. if (parent_index == index) {
  801. in_data_node_vec.emplace_back(node_in_subgraph);
  802. }
  803. }
  804. }
  805. }
  806. return in_data_node_vec;
  807. }
  808. ///
  809. /// @brief Get subgraph input data node by index.
  810. /// @param [in] node
  811. /// @return Node
  812. ///
  813. vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
  814. vector<NodePtr> out_data_node_vec;
  815. auto op_desc = node.GetOpDesc();
  816. GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
  817. auto subgraph_names = op_desc->GetSubgraphInstanceNames();
  818. if (subgraph_names.empty()) {
  819. GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
  820. return out_data_node_vec;
  821. }
  822. auto compute_graph = node.GetOwnerComputeGraph();
  823. for (const std::string &instance_name : subgraph_names) {
  824. auto subgraph = compute_graph->GetSubgraph(instance_name);
  825. for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
  826. if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
  827. out_data_node_vec.emplace_back(node_in_subgraph);
  828. }
  829. }
  830. }
  831. return out_data_node_vec;
  832. }
  833. NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
  834. if (node.GetInDataAnchor(index) == nullptr) {
  835. return nullptr;
  836. }
  837. if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
  838. return nullptr;
  839. }
  840. return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
  841. }
  842. vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
  843. vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
  844. auto out_data_anchor = node.GetOutDataAnchor(index);
  845. if (out_data_anchor == nullptr) {
  846. return out_data_nodes;
  847. }
  848. for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  849. if (peer_in_anchor == nullptr) {
  850. continue;
  851. }
  852. if (peer_in_anchor->GetOwnerNode() == nullptr) {
  853. continue;
  854. }
  855. out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
  856. }
  857. return out_data_nodes;
  858. }
  859. ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); }
  860. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示