You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_output_pass.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/net_output_pass.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include <vector>
  22. #include "common/ge/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/omg/omg_inner_types.h"
  26. #include "graph/debug/ge_attr_define.h"
  27. #include "graph/common/local_context.h"
  28. #include "graph/passes/pass_utils.h"
  29. #include "graph/utils/tensor_utils.h"
  30. #include "graph/utils/type_utils.h"
  31. namespace ge {
  32. static std::map<std::string, ge::DataType> output_type_str_to_datatype = {
  33. {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"INT8", ge::DT_INT8}, {"INT16", ge::DT_INT16},
  34. {"UINT16", ge::DT_UINT16}, {"UINT8", ge::DT_UINT8}, {"INT32", ge::DT_INT32}, {"INT64", ge::DT_INT64},
  35. {"UINT32", ge::DT_UINT32}, {"UINT64", ge::DT_UINT64}, {"DOUBLE", ge::DT_DOUBLE}};
  36. // the size of user defined output datatype or format string after split by ":".
  37. const size_t kUserDefinedElementCount = 2;
  38. const size_t kNodesCount = 2;
  39. Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node,
  40. std::map<int32_t, RetvalInfo> &retval_node_index_map) {
  41. GE_CHECK_NOTNULL(node);
  42. GE_CHECK_NOTNULL(node->GetOpDesc());
  43. int64_t output_index = 0;
  44. if (!AttrUtils::GetInt(node->GetOpDesc(), RETVAL_ATTR_NAME_INDEX, output_index)) {
  45. GELOGE(PARAM_INVALID, "Get output index failed.");
  46. return PARAM_INVALID;
  47. }
  48. if (retval_node_index_map.count(output_index) > 0) {
  49. GELOGE(PARAM_INVALID, "Retval has duplicate index.");
  50. return PARAM_INVALID;
  51. }
  52. int parent_node_index = -1;
  53. (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_node_index);
  54. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(0);
  55. GE_CHECK_NOTNULL(in_data_anchor);
  56. GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor());
  57. int32_t src_node_index = in_data_anchor->GetPeerOutAnchor()->GetIdx();
  58. NodePtr src_node_ptr = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode();
  59. retval_node_index_map[output_index] = {src_node_ptr, src_node_index, parent_node_index};
  60. // if user targets include retval node,delete it from set and insert its input node instead
  61. // better to GetInNodes here
  62. auto iter = targets_.find(node);
  63. if (iter != targets_.end()) {
  64. targets_.erase(iter);
  65. targets_.insert(src_node_ptr);
  66. GELOGI("node [%s] is in user def targets, do not output result to user!", node->GetName().c_str());
  67. }
  68. is_include_special_node_ = true;
  69. return SUCCESS;
  70. }
  71. Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector<RetvalInfo> &output_nodes_info) {
  72. std::map<int32_t, RetvalInfo> retval_node_index_map;
  73. for (NodePtr &node : graph->GetDirectNode()) {
  74. Status ret = SUCCESS;
  75. if ((node->GetOpDesc() != nullptr) && (node->GetOpDesc()->HasAttr(RETVAL_ATTR_NAME_INDEX))) {
  76. /// Set the output according to the Retval operator,
  77. /// identify by whether there is an index parameter
  78. ret = GetRetvalOutputInfo(node, retval_node_index_map);
  79. }
  80. if (ret != SUCCESS) {
  81. GELOGE(ret, "GetRetvalOutputInfo failed");
  82. return ret;
  83. }
  84. }
  85. GELOGI("Get retval node size:%zu.", retval_node_index_map.size());
  86. std::vector<RetvalInfo> out_nodes_tmp;
  87. /// The Netoutput output is determined by Retval, and the input order
  88. /// of Netoutput is sorted according to the index value of Retval.
  89. for (auto &it : retval_node_index_map) {
  90. out_nodes_tmp.push_back(it.second);
  91. }
  92. // when user set targets, mean that no output result
  93. for (auto &ele : graph->GetGraphOutNodesInfo()) {
  94. auto iter = targets_.find(ele.first);
  95. if (iter != targets_.end()) {
  96. GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str());
  97. targets_.erase(iter);
  98. }
  99. auto op_desc = ele.first->GetOpDesc();
  100. GE_CHECK_NOTNULL(op_desc);
  101. if (op_desc->HasAttr(ATTR_ATC_USER_DEFINE_OUTPUT_NODES)) {
  102. is_user_define_ouput_nodes = true;
  103. }
  104. int parent_index = -1;
  105. auto output_desc = op_desc->MutableOutputDesc(ele.second);
  106. if (output_desc == nullptr) {
  107. GELOGE(FAILED, "[Get][OutputDesc]Can not find output tensor desc from node:%s, index %d",
  108. op_desc->GetName().c_str(), ele.second);
  109. return FAILED;
  110. }
  111. (void)ge::AttrUtils::GetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  112. output_nodes_info.push_back({ele.first, ele.second, parent_index});
  113. }
  114. GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size());
  115. for (auto &ele : out_nodes_tmp) {
  116. // add member, no need to remove duplicated because we need to keep all edges
  117. output_nodes_info.push_back(ele);
  118. }
  119. GELOGI("Get output node, size:%zu.", output_nodes_info.size());
  120. Status check_ret = CheckOutputNodeInfo(graph, output_nodes_info);
  121. if (check_ret != SUCCESS) {
  122. return check_ret;
  123. }
  124. return SUCCESS;
  125. }
  126. Status NetOutputPass::CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector<RetvalInfo> &outputs) {
  127. for (auto &item : outputs) {
  128. NodePtr node = item.output_node;
  129. if (node == nullptr) {
  130. GELOGE(PARAM_INVALID, "Node in outputs is null.");
  131. return PARAM_INVALID;
  132. } else {
  133. if (graph->FindNode(node->GetName()) == nullptr) {
  134. GELOGE(INTERNAL_ERROR, "Out node (%s) is not in graph.", node->GetName().c_str());
  135. return INTERNAL_ERROR;
  136. }
  137. GE_CHECK_NOTNULL(node->GetOpDesc());
  138. int32_t out_size = node->GetOpDesc()->GetOutputsSize();
  139. int32_t index = item.node_output_index;
  140. if (index < 0 || index >= out_size) {
  141. GELOGE(PARAM_INVALID,
  142. "User declared out node (%s) output index:%d must be smaller "
  143. "than node ouput size:%d and cann't be negative!",
  144. node->GetName().c_str(), index, out_size);
  145. return PARAM_INVALID;
  146. }
  147. }
  148. }
  149. return SUCCESS;
  150. }
  151. Status NetOutputPass::RemoveUnusedNode(const ge::ComputeGraphPtr &graph) {
  152. std::vector<ge::NodePtr> node_to_delete;
  153. // Delete _Retval operator.
  154. for (auto &node : graph->GetDirectNode()) {
  155. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, GELOGW("Node OpDesc is nullptr"); continue);
  156. bool need_be_deleted = node->GetInDataNodes().size() != 0 && node->GetOutDataNodesSize() == 0 &&
  157. (node->GetOpDesc()->HasAttr(RETVAL_ATTR_NAME_INDEX));
  158. if (need_be_deleted) {
  159. node_to_delete.push_back(node);
  160. }
  161. }
  162. for (NodePtr &node : node_to_delete) {
  163. auto iter = targets_.find(node);
  164. if (iter != targets_.end()) {
  165. GELOGI("[Net output pass] node[%s] is in user set targets.so do not remove!", node->GetName().c_str());
  166. continue;
  167. }
  168. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  169. GELOGE(INTERNAL_ERROR, "Remove node failed, node name:%s.", node->GetName().c_str());
  170. return INTERNAL_ERROR;
  171. }
  172. }
  173. return SUCCESS;
  174. }
  175. Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) {
  176. OpDescPtr net_output_desc = net_output->GetOpDesc();
  177. if (net_output_desc == nullptr) {
  178. GELOGE(INTERNAL_ERROR, "Opdesc of net output node is nullptr.");
  179. return INTERNAL_ERROR;
  180. }
  181. if (net_output_desc->GetInputsSize() == 0) {
  182. GELOGE(INTERNAL_ERROR, "Net output node input is empty.");
  183. return INTERNAL_ERROR;
  184. }
  185. std::vector<bool> is_input_const;
  186. for (const auto &in_anchor : net_output->GetAllInDataAnchors()) {
  187. GE_CHECK_NOTNULL(in_anchor);
  188. uint32_t index = static_cast<uint32_t>(in_anchor->GetIdx());
  189. if (index >= net_output_desc->GetAllInputsDesc().size()) {
  190. GELOGE(INTERNAL_ERROR, "Index is invalid, index:%u, size:%zu.", index,
  191. net_output_desc->GetAllInputsDesc().size());
  192. return INTERNAL_ERROR;
  193. }
  194. GE_CHECK_NOTNULL(in_anchor->GetPeerOutAnchor());
  195. is_input_const.push_back(PassUtils::IsConstant(in_anchor->GetPeerOutAnchor()->GetOwnerNode()));
  196. OpDescPtr src_op_desc = in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc();
  197. GE_CHECK_NOTNULL(src_op_desc);
  198. uint32_t peer_index = static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx());
  199. ge::GeTensorDesc output_in_desc = src_op_desc->GetOutputDesc(peer_index);
  200. if (net_output_desc->UpdateInputDesc(index, output_in_desc) != GRAPH_SUCCESS) {
  201. GELOGE(INTERNAL_ERROR, "Update input desc failed, index:%u.", index);
  202. return INTERNAL_ERROR;
  203. }
  204. GELOGD("Update desc, format:%s, data type:%s, index:%u.",
  205. TypeUtils::FormatToSerialString(output_in_desc.GetFormat()).c_str(),
  206. TypeUtils::DataTypeToSerialString(output_in_desc.GetDataType()).c_str(), index);
  207. }
  208. net_output_desc->SetIsInputConst(is_input_const);
  209. return SUCCESS;
  210. }
  211. Status NetOutputPass::AddCtrlEdgeForTargets(const ge::NodePtr &net_out_node) {
  212. if (net_out_node == nullptr) {
  213. GELOGE(PARAM_INVALID, "net out node is null.");
  214. return PARAM_INVALID;
  215. }
  216. // Add ctrl edge for targets
  217. for (auto &node : targets_) {
  218. if (node == nullptr) {
  219. continue;
  220. }
  221. // no need to check null because have handled it in run SaveAndRemoveTargets function
  222. graphStatus status = GraphUtils::AddEdge(node->GetOutControlAnchor(), net_out_node->GetInControlAnchor());
  223. if (status != GRAPH_SUCCESS) {
  224. GELOGE(INTERNAL_ERROR, "Add ctrl edge to netoutput node[%s] for target node [%s] failed!",
  225. net_out_node->GetName().c_str(), node->GetName().c_str());
  226. return INTERNAL_ERROR;
  227. }
  228. GELOGD("Add ctrl edge to netoutput node[%s] for target node [%s] success!", net_out_node->GetName().c_str(),
  229. node->GetName().c_str());
  230. }
  231. return SUCCESS;
  232. }
  233. void NetOutputPass::SaveAndRemoveTargets(const ge::ComputeGraphPtr &graph) {
  234. // save user targets node
  235. for (auto &node : graph->GetGraphTargetNodesInfo()) {
  236. if (node == nullptr) {
  237. GELOGW("User pointed targets contains null node.ignore it !");
  238. continue;
  239. }
  240. targets_.insert(node);
  241. }
  242. GELOGI("User pointed targets size is %zu !", targets_.size());
  243. }
  244. Status NetOutputPass::AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node,
  245. const std::vector<RetvalInfo> &output_nodes_info) {
  246. int32_t net_input_index = 0;
  247. for (auto &item : output_nodes_info) {
  248. NodePtr src_node = item.output_node;
  249. GE_CHECK_NOTNULL(src_node);
  250. graphStatus status = GraphUtils::AddEdge(src_node->GetOutDataAnchor(item.node_output_index),
  251. net_out_node->GetInDataAnchor(net_input_index));
  252. if (status != GRAPH_SUCCESS) {
  253. GELOGE(INTERNAL_ERROR, "AddEdge failed, src name:%s, src index:%d, dst index:%d.", src_node->GetName().c_str(),
  254. item.node_output_index, net_input_index);
  255. return INTERNAL_ERROR;
  256. }
  257. GELOGD("AddEdge to output node, src name:%s, src index:%d, dst index:%d.", src_node->GetName().c_str(),
  258. item.node_output_index, net_input_index);
  259. if (item.parent_node_index >= 0) {
  260. GELOGI("Add parent node index %d for the netoutput input %d on graph %s", item.parent_node_index, net_input_index,
  261. graph->GetName().c_str());
  262. auto input_desc = net_out_node->GetOpDesc()->MutableInputDesc(net_input_index);
  263. if (input_desc == nullptr) {
  264. GELOGE(INTERNAL_ERROR, "Can not find intput tensor desc from NetOutput, index %d", net_input_index);
  265. return INTERNAL_ERROR;
  266. }
  267. if (!AttrUtils::SetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, item.parent_node_index)) {
  268. GELOGE(INTERNAL_ERROR, "Failed to add parent index to NetOutput, index %d", net_input_index);
  269. return INTERNAL_ERROR;
  270. }
  271. }
  272. net_input_index++;
  273. }
  274. if (RemoveUnusedNode(graph) != SUCCESS) {
  275. GELOGE(INTERNAL_ERROR, "Remove unused nodes failed.");
  276. return INTERNAL_ERROR;
  277. }
  278. if (AddCtrlEdgeForTargets(net_out_node) != SUCCESS) {
  279. GELOGE(INTERNAL_ERROR, "Add ctrl edge for targets failed.");
  280. return INTERNAL_ERROR;
  281. }
  282. // Add true stream, netoutput is 0
  283. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetInt(net_out_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0),
  284. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed");
  285. return INTERNAL_ERROR);
  286. return SUCCESS;
  287. }
  288. bool NetOutputPass::CheckNodeIsInOutputNodes(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node) {
  289. for (auto &ele : graph->GetGraphOutNodesInfo()) {
  290. auto out_node = ele.first;
  291. if (node == out_node) {
  292. return true;
  293. }
  294. }
  295. return false;
  296. }
  297. Status NetOutputPass::UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node) {
  298. if (net_out_node == nullptr) {
  299. GELOGE(PARAM_INVALID, "net out node is null.");
  300. return PARAM_INVALID;
  301. }
  302. Status ret = SUCCESS;
  303. // unlink all anchor to data anchor of netoutput
  304. for (auto &in_data_anchor : net_out_node->GetAllInDataAnchors()) {
  305. if (in_data_anchor == nullptr) {
  306. continue;
  307. }
  308. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  309. if (peer_out_anchor == nullptr) {
  310. GELOGI("PeerOutAnchor is null!");
  311. continue;
  312. }
  313. auto node = peer_out_anchor->GetOwnerNode();
  314. auto iter = targets_.find(node);
  315. if (iter != targets_.end()) {
  316. if (!CheckNodeIsInOutputNodes(graph, node)) {
  317. ret = in_data_anchor->Unlink(peer_out_anchor);
  318. if (ret != SUCCESS) {
  319. GELOGE(INTERNAL_ERROR, "Unlink peer_out_anchor fail!");
  320. return ret;
  321. }
  322. } else {
  323. targets_.erase(iter);
  324. }
  325. }
  326. }
  327. return ret;
  328. }
  329. Status NetOutputPass::UnLinkControlAnchorOfNetoutput(const ge::ComputeGraphPtr &graph,
  330. const ge::NodePtr &net_out_node) {
  331. if (net_out_node == nullptr) {
  332. GELOGE(PARAM_INVALID, "net out node is null.");
  333. return PARAM_INVALID;
  334. }
  335. Status ret = SUCCESS;
  336. auto in_control_anchor = net_out_node->GetInControlAnchor();
  337. if (in_control_anchor == nullptr) {
  338. GELOGE(PARAM_INVALID, "in control anchor is null.");
  339. return PARAM_INVALID;
  340. }
  341. // unlink all data anchor to control anchor of netoutput
  342. for (auto &peer_out_data_anchor : in_control_anchor->GetPeerOutDataAnchors()) {
  343. if (peer_out_data_anchor == nullptr) {
  344. GELOGI("PeerOutControlAnchor is null!");
  345. } else {
  346. auto node = peer_out_data_anchor->GetOwnerNode();
  347. auto iter = targets_.find(node);
  348. if (iter != targets_.end()) {
  349. if (CheckNodeIsInOutputNodes(graph, node) == false) {
  350. ret = in_control_anchor->Unlink(peer_out_data_anchor);
  351. if (ret != SUCCESS) {
  352. GELOGE(INTERNAL_ERROR, "Unlink peer_out_anchor fail!");
  353. return ret;
  354. }
  355. } else {
  356. targets_.erase(iter);
  357. }
  358. }
  359. }
  360. }
  361. /// check all control anchor to control anchor of netoutput and delete it from targets
  362. /// to avoid duplicated add control edge;
  363. for (auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  364. if (peer_out_control_anchor == nullptr) {
  365. GELOGI("PeerOutControlAnchor is null");
  366. } else {
  367. auto node = peer_out_control_anchor->GetOwnerNode();
  368. auto iter = targets_.find(node);
  369. if (iter != targets_.end()) {
  370. targets_.erase(iter);
  371. }
  372. }
  373. }
  374. return ret;
  375. }
  376. Status NetOutputPass::UnLink(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node) {
  377. GELOGI("[NetOutputPass] Enter Unlink process.");
  378. Status ret = UnLinkDataAnchorOfNetoutput(graph, net_out_node);
  379. if (ret != SUCCESS) {
  380. GELOGI("[NetOutputPass] UnLinkDataAnchorOfNetoutput process fail.");
  381. return ret;
  382. }
  383. ret = UnLinkControlAnchorOfNetoutput(graph, net_out_node);
  384. if (ret != SUCCESS) {
  385. GELOGI("[NetOutputPass] UnLinkControlAnchorOfNetoutput process fail.");
  386. return ret;
  387. }
  388. return ret;
  389. }
  390. Status NetOutputPass::ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &output_node) {
  391. if (UpdateNetOutputDesc(output_node) != SUCCESS) {
  392. GELOGE(INTERNAL_ERROR, "Update net output desc failed.");
  393. return INTERNAL_ERROR;
  394. }
  395. if (UnLink(graph, output_node) != SUCCESS) {
  396. GELOGE(INTERNAL_ERROR, "UnLink connection between netoutput node and user set target node");
  397. return INTERNAL_ERROR;
  398. }
  399. if (AddCtrlEdgeForTargets(output_node) != SUCCESS) {
  400. GELOGE(INTERNAL_ERROR, "Add ctrl edge for targets failed.");
  401. return INTERNAL_ERROR;
  402. }
  403. return SUCCESS;
  404. }
  405. Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph,
  406. const ge::NodePtr &net_out_node) {
  407. GE_CHECK_NOTNULL(net_out_node);
  408. if (!GetLocalOmgContext().user_out_nodes.empty() || is_user_define_ouput_nodes) {
  409. GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set.");
  410. return SUCCESS;
  411. }
  412. bool graph_has_only_one_node_except_netoutput = (graph->GetDirectNodesSize() == kNodesCount);
  413. for (const auto &node : graph->GetDirectNode()) {
  414. if (node == nullptr || node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() == NETOUTPUT) {
  415. continue;
  416. }
  417. if ((node->GetInControlNodes().size() != 0 || node->GetInDataNodes().size() != 0 ||
  418. graph_has_only_one_node_except_netoutput) &&
  419. node->GetOutDataNodesSize() == 0 && node->GetOutControlNodes().size() == 0) {
  420. GE_CHK_STATUS_RET(GraphUtils::AddEdge(node->GetOutControlAnchor(), net_out_node->GetInControlAnchor()),
  421. "add edge failed");
  422. GELOGD("Add ctrl edge success. src name :%s, dst name :%s", node->GetName().c_str(),
  423. net_out_node->GetName().c_str());
  424. }
  425. }
  426. return SUCCESS;
  427. }
  428. Status NetOutputPass::CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph) {
  429. // Only flush subgraph name
  430. string node_name =
  431. (graph->GetParentGraph() != nullptr) ? (graph->GetName() + "_" + NODE_NAME_NET_OUTPUT) : NODE_NAME_NET_OUTPUT;
  432. net_output_desc = MakeShared<OpDesc>(node_name, NETOUTPUT);
  433. if (net_output_desc == nullptr) {
  434. GELOGE(MEMALLOC_FAILED, "Make shared net output op failed.");
  435. return MEMALLOC_FAILED;
  436. }
  437. (void)AttrUtils::SetListStr(net_output_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES,
  438. std::move(std::vector<std::string>()));
  439. return SUCCESS;
  440. }
  441. Status NetOutputPass::Run(ge::ComputeGraphPtr graph) {
  442. if (graph == nullptr) {
  443. GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null.");
  444. return GE_GRAPH_PARAM_NULLPTR;
  445. }
  446. GELOGI("NetOutputPass Run.graph is [%s]", graph->GetName().c_str());
  447. NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT);
  448. // save user targets node
  449. SaveAndRemoveTargets(graph);
  450. // If graph already has a netoutput node, doesn't need to create it again.
  451. if (output_node != nullptr) {
  452. (void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES,
  453. std::move(std::vector<std::string>()));
  454. if (ProcessWithNetoutput(graph, output_node) != SUCCESS) {
  455. GELOGE(INTERNAL_ERROR, "Process with netoutput node failed.");
  456. return INTERNAL_ERROR;
  457. }
  458. } else {
  459. if (AddNetOutputNodeToGraph(graph, output_node) != SUCCESS) {
  460. GELOGE(INTERNAL_ERROR, "Set user define dtype and format for netoutput failed.");
  461. return INTERNAL_ERROR;
  462. }
  463. }
  464. // Add userdef attrs to netoutput node
  465. return SetUserDefDTypeAndFormatFromAtcParams(output_node);
  466. }
  467. Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node) {
  468. OpDescPtr net_output_desc = nullptr;
  469. if (CreateNetOutputNode(net_output_desc, graph) != SUCCESS) {
  470. GELOGE(INTERNAL_ERROR, "Get net output nodes failed.");
  471. return INTERNAL_ERROR;
  472. }
  473. std::vector<RetvalInfo> output_nodes_info;
  474. if (GetOutputNode(graph, output_nodes_info) != SUCCESS) {
  475. GELOGE(INTERNAL_ERROR, "Get net output nodes failed.");
  476. return INTERNAL_ERROR;
  477. }
  478. GELOGI("[NETOUTPUT PASS] OutNodesInfo size:%zu, Targets Size:%zu, is_include_special_node_:%d",
  479. graph->GetGraphOutNodesInfo().size(), graph->GetGraphTargetNodesInfo().size(), is_include_special_node_);
  480. // If user does not set out nodes and targets and no retval node, also add netoutput node
  481. if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) &&
  482. !is_include_special_node_) {
  483. GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!Add netoutput!");
  484. output_node = graph->AddNode(net_output_desc);
  485. GE_CHK_STATUS_RET(AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node),
  486. "add ctrl edge between leaf and netoutput failed");
  487. return SUCCESS;
  488. }
  489. GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size());
  490. if (output_nodes_info.empty()) {
  491. // because retval node is contained by output_nodes_info, here means targets is non-empty
  492. output_node = graph->AddNode(net_output_desc);
  493. if (output_node == nullptr) {
  494. GELOGE(INTERNAL_ERROR, "Add output node failed.");
  495. return INTERNAL_ERROR;
  496. }
  497. GE_CHK_STATUS_RET(AddCtrlEdgeForTargets(output_node), "add ctrl edge for targets failed");
  498. // Add true stream, netoutput is 0
  499. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetInt(output_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0),
  500. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed");
  501. return INTERNAL_ERROR);
  502. return SUCCESS;
  503. }
  504. AddInOutForNetOutputOp(graph, net_output_desc, output_nodes_info);
  505. output_node = graph->AddNode(net_output_desc);
  506. if (output_node == nullptr) {
  507. GELOGE(INTERNAL_ERROR, "Add output node failed.");
  508. return INTERNAL_ERROR;
  509. }
  510. if (AddEdgesForNetOutput(graph, output_node, output_nodes_info) != SUCCESS) {
  511. GELOGE(INTERNAL_ERROR, "Add edges for net output node failed.");
  512. return INTERNAL_ERROR;
  513. }
  514. if (AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node) != SUCCESS) {
  515. GELOGE(INTERNAL_ERROR, "Add control edges between leaf and netoutput failed.");
  516. return INTERNAL_ERROR;
  517. }
  518. GELOGI("Add NetOutput node success.");
  519. return SUCCESS;
  520. }
  521. void NetOutputPass::AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc,
  522. vector<RetvalInfo> &output_nodes_info) {
  523. std::vector<bool> is_input_const;
  524. for (auto iter = output_nodes_info.begin(); iter != output_nodes_info.end();) {
  525. NodePtr src_node = iter->output_node;
  526. if (src_node == nullptr) {
  527. continue;
  528. }
  529. int32_t src_index = iter->node_output_index;
  530. // if src_node is in targets_, no need to Add in and out for netoutput
  531. auto it = targets_.find(src_node);
  532. if (it != targets_.end()) {
  533. iter = output_nodes_info.erase(iter);
  534. GELOGD("node [%s] is in processed targets, do not add inout for netoutput!", src_node->GetName().c_str());
  535. continue;
  536. }
  537. /// Get the output attribute of src_node,
  538. /// and set to the input/output of net_out_node.
  539. if (src_node == nullptr || src_node->GetOpDesc() == nullptr || net_output_desc == nullptr) {
  540. GELOGE(INTERNAL_ERROR, "src node or net output desc is null.");
  541. return;
  542. }
  543. ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index);
  544. out_desc.SetFormat(FORMAT_ND);
  545. out_desc.SetOriginFormat(FORMAT_ND);
  546. GE_IF_BOOL_EXEC(net_output_desc->AddInputDesc(out_desc) != SUCCESS, GELOGW("add input desc failed"); return );
  547. is_input_const.push_back(PassUtils::IsConstant(src_node));
  548. ++iter;
  549. }
  550. net_output_desc->SetIsInputConst(is_input_const);
  551. }
  552. bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, OpDescPtr &op_desc, uint32_t &src_index,
  553. ge::DataType &dt) {
  554. if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) {
  555. dt = output_type_str_to_datatype[output_type];
  556. return true;
  557. }
  558. vector<string> output_dt_str;
  559. if (ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_data_type", output_dt_str)) {
  560. for (const auto &dt_str : output_dt_str) {
  561. vector<string> dt_str_split = StringUtils::Split(dt_str, ':');
  562. if (dt_str_split.size() == kUserDefinedElementCount) {
  563. if (dt_str_split[0] == to_string(src_index)) {
  564. dt = TypeUtils::SerialStringToDataType(dt_str_split[1]);
  565. return true;
  566. }
  567. } else {
  568. GELOGW("The size of [%s] is not 2 after split.", dt_str.c_str());
  569. continue;
  570. }
  571. }
  572. }
  573. return false;
  574. }
  575. bool NeedUpdateOutputFp16Nc1hwc0(OpDescPtr &op_desc, uint32_t &src_index) {
  576. vector<string> output_dt_str;
  577. if (ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_dt_str)) {
  578. for (const auto &dt_str : output_dt_str) {
  579. vector<string> dt_str_split = StringUtils::Split(dt_str, ':');
  580. if (dt_str_split.size() == kUserDefinedElementCount) {
  581. if (dt_str_split[0] == to_string(src_index)) {
  582. return true;
  583. }
  584. } else {
  585. GELOGW("The size of [%s] is not 2 after split.", dt_str.c_str());
  586. continue;
  587. }
  588. }
  589. }
  590. return false;
  591. }
  592. Status NetOutputPass::SetUserDefDTypeAndFormatFromAtcParams(const NodePtr &output_node) {
  593. if (output_node == nullptr) {
  594. GELOGI("[NETOUTPUT PASS] The graph no need netoutput node!");
  595. return SUCCESS;
  596. }
  597. auto output_type = GetLocalOmgContext().output_type;
  598. auto op_desc = output_node->GetOpDesc();
  599. GE_CHECK_NOTNULL(op_desc);
  600. std::vector<std::string> userdef_dtypes;
  601. std::vector<std::string> userdef_formats;
  602. ge::DataType output_data_type = ge::DT_FLOAT;
  603. for (const auto &in_anchor : output_node->GetAllInDataAnchors()) {
  604. auto index = static_cast<uint32_t>(in_anchor->GetIdx());
  605. auto peer_out = in_anchor->GetPeerOutAnchor();
  606. if (peer_out == nullptr) {
  607. // If user set target, peer_out anchor will be unlinked.
  608. continue;
  609. }
  610. auto src_index = static_cast<uint32_t>(peer_out->GetIdx());
  611. auto src_node = peer_out->GetOwnerNode();
  612. GE_CHECK_NOTNULL(src_node);
  613. OpDescPtr src_op_desc = src_node->GetOpDesc();
  614. GE_CHECK_NOTNULL(src_op_desc);
  615. // Update datatype
  616. if (NeedUpdateOutputByOutputTypeParm(output_type, src_op_desc, src_index, output_data_type)) {
  617. GELOGD("Add user-define datatype:%s to netoutput node.",
  618. TypeUtils::DataTypeToSerialString(output_data_type).c_str());
  619. userdef_dtypes.push_back(
  620. std::to_string(index).append(":").append(TypeUtils::DataTypeToSerialString(output_data_type)));
  621. continue;
  622. }
  623. // Output_node is not set,check if is_output_adjust_hw_layout is set
  624. bool set_fp16_nc1hwc0 = NeedUpdateOutputFp16Nc1hwc0(src_op_desc, src_index);
  625. if (set_fp16_nc1hwc0) {
  626. // Set DT_FLOAT16 & FORMAT_NC1HWC0
  627. userdef_dtypes.push_back(std::to_string(index).append(":").append(TypeUtils::DataTypeToSerialString(DT_FLOAT16)));
  628. userdef_formats.push_back(
  629. std::to_string(index).append(":").append(TypeUtils::FormatToSerialString(FORMAT_NC1HWC0)));
  630. }
  631. }
  632. if (!userdef_dtypes.empty() && !ge::AttrUtils::SetListStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, userdef_dtypes)) {
  633. GELOGE(INTERNAL_ERROR, "Set user_define_dtype attr list for netoutput failed.");
  634. return INTERNAL_ERROR;
  635. }
  636. if (!userdef_formats.empty() && !ge::AttrUtils::SetListStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, userdef_formats)) {
  637. GELOGE(INTERNAL_ERROR, "Set user_define_format attr list for netoutput failed.");
  638. return INTERNAL_ERROR;
  639. }
  640. return SUCCESS;
  641. }
  642. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.