You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_output_pass_unittest.cc 36 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/net_output_pass.h"
  17. #include <gtest/gtest.h>
  18. #include "common/ge_inner_error_codes.h"
  19. #include "common/types.h"
  20. #include "ge/ge_api.h"
  21. #include "graph/compute_graph.h"
  22. #include "graph/debug/graph_debug.h"
  23. #include "graph/manager/graph_manager.h"
  24. #include "graph/manager/graph_manager_utils.h"
  25. #include "graph/operator_reg.h"
  26. #include "graph/utils/op_desc_utils.h"
  27. #include "inc/pass_manager.h"
  28. #include "init/gelib.h"
  29. #include "opskernel_manager/ops_kernel_manager.h"
  30. using namespace std;
  31. using namespace testing;
  32. using namespace ge;
  33. using namespace domi;
  34. class UTEST_graph_passes_net_output_pass : public testing::Test {
  35. protected:
  36. void SetUp() {}
  37. void TearDown() {}
  38. };
  39. ge::ComputeGraphPtr BuildClearWeightGraph(void) {
  40. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  41. ge::OpDescPtr cast_op = std::make_shared<ge::OpDesc>();
  42. cast_op->SetType(CAST);
  43. cast_op->SetName("Cast1");
  44. cast_op->AddInputDesc(ge::GeTensorDesc());
  45. cast_op->AddOutputDesc(ge::GeTensorDesc());
  46. ge::NodePtr cast_node = graph->AddNode(cast_op);
  47. ge::OpDescPtr const_op = std::make_shared<ge::OpDesc>();
  48. const_op->SetType(CONSTANT);
  49. const_op->SetName("Const1");
  50. const_op->AddOutputDesc(ge::GeTensorDesc());
  51. ge::NodePtr const_node = graph->AddNode(const_op);
  52. ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), cast_node->GetInDataAnchor(0));
  53. return graph;
  54. }
  55. ge::ComputeGraphPtr build_graph(bool with_leaf_node = false) {
  56. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  57. ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
  58. data_op->SetType(DATA);
  59. data_op->SetName("Data1");
  60. data_op->AddInputDesc(ge::GeTensorDesc());
  61. data_op->AddOutputDesc(ge::GeTensorDesc());
  62. ge::NodePtr data1 = graph->AddNode(data_op);
  63. ge::OpDescPtr relu_op1 = std::make_shared<ge::OpDesc>();
  64. relu_op1->SetType(ACTIVATION);
  65. relu_op1->SetName("Relu1");
  66. relu_op1->AddInputDesc(ge::GeTensorDesc());
  67. relu_op1->AddOutputDesc(ge::GeTensorDesc());
  68. ge::NodePtr relu1 = graph->AddNode(relu_op1);
  69. ge::OpDescPtr relu_op2 = std::make_shared<ge::OpDesc>();
  70. relu_op2->SetType(RELU);
  71. relu_op2->SetName("Relu2");
  72. relu_op2->AddInputDesc(ge::GeTensorDesc());
  73. relu_op2->AddOutputDesc(ge::GeTensorDesc());
  74. relu_op2->AddOutputDesc(ge::GeTensorDesc());
  75. ge::NodePtr relu2 = graph->AddNode(relu_op2);
  76. ge::OpDescPtr relu_op3 = std::make_shared<ge::OpDesc>();
  77. relu_op3->SetType(ACTIVATION);
  78. relu_op3->SetName("Relu3");
  79. relu_op3->AddInputDesc(ge::GeTensorDesc());
  80. relu_op3->AddOutputDesc(ge::GeTensorDesc());
  81. ge::NodePtr relu3;
  82. if (with_leaf_node == true) {
  83. relu3 = graph->AddNode(relu_op3);
  84. }
  85. ge::OpDescPtr mul_op = std::make_shared<ge::OpDesc>();
  86. mul_op->SetType(MUL);
  87. mul_op->SetName("Mul");
  88. mul_op->AddInputDesc(ge::GeTensorDesc());
  89. mul_op->AddInputDesc(ge::GeTensorDesc());
  90. mul_op->AddOutputDesc(ge::GeTensorDesc());
  91. mul_op->AddOutputDesc(ge::GeTensorDesc());
  92. mul_op->AddOutputDesc(ge::GeTensorDesc());
  93. mul_op->AddOutputDesc(ge::GeTensorDesc());
  94. ge::NodePtr mul = graph->AddNode(mul_op);
  95. ge::OpDescPtr mul_op1 = std::make_shared<ge::OpDesc>();
  96. mul_op1->SetType(MUL);
  97. mul_op1->SetName("Mul1");
  98. mul_op1->AddInputDesc(ge::GeTensorDesc());
  99. mul_op1->AddInputDesc(ge::GeTensorDesc());
  100. mul_op1->AddOutputDesc(ge::GeTensorDesc());
  101. ge::NodePtr mul1 = graph->AddNode(mul_op1);
  102. ge::OpDescPtr mul_op2 = std::make_shared<ge::OpDesc>();
  103. mul_op2->SetType(MUL);
  104. mul_op2->SetName("Mul2");
  105. mul_op2->AddInputDesc(ge::GeTensorDesc());
  106. mul_op2->AddInputDesc(ge::GeTensorDesc());
  107. mul_op2->AddOutputDesc(ge::GeTensorDesc());
  108. ge::NodePtr mul2 = graph->AddNode(mul_op2);
  109. ge::OpDescPtr fc_op = std::make_shared<ge::OpDesc>();
  110. fc_op->SetType(FULL_CONNECTION);
  111. fc_op->SetName("FullConnection");
  112. fc_op->AddInputDesc(ge::GeTensorDesc());
  113. fc_op->AddOutputDesc(ge::GeTensorDesc());
  114. fc_op->AddOutputDesc(ge::GeTensorDesc());
  115. ge::NodePtr fc = graph->AddNode(fc_op);
  116. ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), relu1->GetInDataAnchor(0));
  117. ge::GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), fc->GetInDataAnchor(0));
  118. ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(0), relu2->GetInDataAnchor(0));
  119. if (with_leaf_node == true) {
  120. ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(1), relu3->GetInDataAnchor(0));
  121. }
  122. ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(0), mul->GetInDataAnchor(0));
  123. ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(1), mul->GetInDataAnchor(1));
  124. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(0), mul1->GetInDataAnchor(0));
  125. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(1), mul1->GetInDataAnchor(1));
  126. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(2), mul2->GetInDataAnchor(0));
  127. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(3), mul2->GetInDataAnchor(1));
  128. return graph;
  129. }
  130. TEST_F(UTEST_graph_passes_net_output_pass, add_ctrl_edge_for_netout_from_leaf_success) {
  131. ge::ComputeGraphPtr compute_graph = build_graph(true);
  132. // construct targets
  133. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  134. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  135. ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
  136. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
  137. compute_graph->SetGraphOutNodesInfo(output_nodes);
  138. ge::PassManager pass_managers;
  139. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  140. Status status = pass_managers.Run(compute_graph);
  141. EXPECT_EQ(status, ge::SUCCESS);
  142. // check contain netoutput
  143. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  144. EXPECT_NE(net_out_node, nullptr);
  145. /// check input data node of netoutput
  146. /// when output and targets set conflicts each other , output set is prio
  147. /// Check data input
  148. int input_data_node_num = net_out_node->GetInDataNodes().size();
  149. EXPECT_EQ(input_data_node_num, 1);
  150. std::vector<string> expect_input_data_result{"Relu3"};
  151. for (auto node : net_out_node->GetInDataNodes()) {
  152. auto name = node->GetName();
  153. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  154. if (iter != expect_input_data_result.end()) {
  155. expect_input_data_result.erase(iter);
  156. }
  157. }
  158. input_data_node_num = expect_input_data_result.size();
  159. EXPECT_EQ(input_data_node_num, 0);
  160. // Check control input
  161. int control_node_num = net_out_node->GetInControlNodes().size();
  162. EXPECT_EQ(control_node_num, 2);
  163. std::vector<string> expect_result{"Mul1", "Mul2"};
  164. for (auto node : net_out_node->GetInControlNodes()) {
  165. auto name = node->GetName();
  166. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  167. if (iter != expect_result.end()) {
  168. expect_result.erase(iter);
  169. }
  170. }
  171. control_node_num = expect_result.size();
  172. EXPECT_EQ(control_node_num, 0);
  173. }
  174. TEST_F(UTEST_graph_passes_net_output_pass, only_target_node_success) {
  175. ge::ComputeGraphPtr compute_graph = build_graph();
  176. // construct targets
  177. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  178. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  179. std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
  180. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  181. ge::PassManager pass_managers;
  182. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  183. Status status = pass_managers.Run(compute_graph);
  184. EXPECT_EQ(status, ge::SUCCESS);
  185. // check contain netoutput
  186. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  187. EXPECT_NE(net_out_node, nullptr);
  188. /// check input data node of netoutput
  189. /// Check data input
  190. int input_data_node_num = net_out_node->GetInDataNodes().size();
  191. EXPECT_EQ(input_data_node_num, 0);
  192. // Check control input
  193. int control_node_num = net_out_node->GetInControlNodes().size();
  194. EXPECT_EQ(control_node_num, 2);
  195. std::vector<string> expect_result{"Mul1", "Mul2"};
  196. for (auto node : net_out_node->GetInControlNodes()) {
  197. auto name = node->GetName();
  198. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  199. if (iter != expect_result.end()) {
  200. expect_result.erase(iter);
  201. }
  202. }
  203. control_node_num = expect_result.size();
  204. EXPECT_EQ(control_node_num, 0);
  205. }
  206. TEST_F(UTEST_graph_passes_net_output_pass, targets_with_retval_success) {
  207. ge::ComputeGraphPtr compute_graph = build_graph();
  208. // Imitate the output node of _Retval issued
  209. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  210. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  211. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  212. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  213. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  214. EXPECT_NE(retval_node1, nullptr);
  215. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  216. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  217. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  218. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
  219. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  220. EXPECT_NE(retval_node2, nullptr);
  221. // construct targets
  222. std::vector<ge::NodePtr> target_nodes = {retval_node1, retval_node2};
  223. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  224. for (NodePtr node : compute_graph->GetDirectNode()) {
  225. if (node->GetName() == "Mul1") {
  226. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  227. } else if (node->GetName() == "Mul2") {
  228. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  229. }
  230. }
  231. ge::PassManager pass_managers;
  232. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  233. Status status = pass_managers.Run(compute_graph);
  234. EXPECT_EQ(status, ge::SUCCESS);
  235. // check contain netoutput
  236. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  237. EXPECT_NE(net_out_node, nullptr);
  238. /// check input data node of netoutput
  239. /// Check data input
  240. int input_data_node_num = net_out_node->GetInDataNodes().size();
  241. EXPECT_EQ(input_data_node_num, 0);
  242. // Check control input
  243. int control_node_num = net_out_node->GetInControlNodes().size();
  244. EXPECT_EQ(control_node_num, 2);
  245. std::vector<string> expect_result{"Mul1", "Mul2"};
  246. for (auto node : net_out_node->GetInControlNodes()) {
  247. auto name = node->GetName();
  248. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  249. if (iter != expect_result.end()) {
  250. expect_result.erase(iter);
  251. }
  252. }
  253. control_node_num = expect_result.size();
  254. EXPECT_EQ(control_node_num, 0);
  255. // Check the deletion of _Retval node
  256. retval_node1 = compute_graph->FindNode("reval_node1");
  257. EXPECT_EQ(retval_node1, nullptr);
  258. retval_node2 = compute_graph->FindNode("reval_node2");
  259. EXPECT_EQ(retval_node2, nullptr);
  260. }
  261. TEST_F(UTEST_graph_passes_net_output_pass, output_node_and_target_node_no_duplicate_success) {
  262. ge::ComputeGraphPtr compute_graph = build_graph(true);
  263. // construct targets
  264. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  265. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  266. std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
  267. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  268. ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
  269. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
  270. compute_graph->SetGraphOutNodesInfo(output_nodes);
  271. ge::PassManager pass_managers;
  272. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  273. Status status = pass_managers.Run(compute_graph);
  274. EXPECT_EQ(status, ge::SUCCESS);
  275. // check contain netoutput
  276. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  277. EXPECT_NE(net_out_node, nullptr);
  278. /// check input data node of netoutput
  279. /// when output and targets set conflicts each other , output set is prio
  280. /// Check data input
  281. int input_data_node_num = net_out_node->GetInDataNodes().size();
  282. EXPECT_EQ(input_data_node_num, 1);
  283. std::vector<string> expect_input_data_result{"Relu3"};
  284. for (auto node : net_out_node->GetInDataNodes()) {
  285. auto name = node->GetName();
  286. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  287. if (iter != expect_input_data_result.end()) {
  288. expect_input_data_result.erase(iter);
  289. }
  290. }
  291. input_data_node_num = expect_input_data_result.size();
  292. EXPECT_EQ(input_data_node_num, 0);
  293. // Check control input
  294. int control_node_num = net_out_node->GetInControlNodes().size();
  295. EXPECT_EQ(control_node_num, 2);
  296. std::vector<string> expect_result{"Mul1", "Mul2"};
  297. for (auto node : net_out_node->GetInControlNodes()) {
  298. auto name = node->GetName();
  299. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  300. if (iter != expect_result.end()) {
  301. expect_result.erase(iter);
  302. }
  303. }
  304. control_node_num = expect_result.size();
  305. EXPECT_EQ(control_node_num, 0);
  306. }
  307. TEST_F(UTEST_graph_passes_net_output_pass, output_node_and_target_node_duplicate_success) {
  308. ge::ComputeGraphPtr compute_graph = build_graph();
  309. // construct targets
  310. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  311. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  312. std::vector<ge::NodePtr> target_nodes = {mul2};
  313. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  314. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  315. compute_graph->SetGraphOutNodesInfo(output_nodes);
  316. ge::PassManager pass_managers;
  317. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  318. Status status = pass_managers.Run(compute_graph);
  319. EXPECT_EQ(status, ge::SUCCESS);
  320. // check contain netoutput
  321. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  322. EXPECT_NE(net_out_node, nullptr);
  323. /// check input data node of netoutput
  324. /// Check data input
  325. int input_data_node_num = net_out_node->GetInDataNodes().size();
  326. EXPECT_EQ(input_data_node_num, 2);
  327. std::vector<string> expect_input_data_result{"Mul1"};
  328. for (auto node : net_out_node->GetInDataNodes()) {
  329. auto name = node->GetName();
  330. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  331. if (iter != expect_input_data_result.end()) {
  332. expect_input_data_result.erase(iter);
  333. }
  334. }
  335. input_data_node_num = expect_input_data_result.size();
  336. EXPECT_EQ(input_data_node_num, 0);
  337. // Check control input
  338. int control_node_num = net_out_node->GetInControlNodes().size();
  339. EXPECT_EQ(control_node_num, 0);
  340. }
  341. TEST_F(UTEST_graph_passes_net_output_pass, netoutput_node_and_target_node_success) {
  342. ge::ComputeGraphPtr compute_graph = build_graph();
  343. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  344. netout->AddInputDesc(ge::GeTensorDesc());
  345. netout->AddInputDesc(ge::GeTensorDesc());
  346. netout->AddOutputDesc(ge::GeTensorDesc());
  347. netout->AddOutputDesc(ge::GeTensorDesc());
  348. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  349. EXPECT_NE(netout_node, nullptr);
  350. for (NodePtr node : compute_graph->GetDirectNode()) {
  351. if (node->GetName() == "Mul1") {
  352. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  353. } else if (node->GetName() == "Mul2") {
  354. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
  355. }
  356. }
  357. // construct targets
  358. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  359. std::vector<ge::NodePtr> target_nodes = {mul2};
  360. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  361. ge::PassManager pass_managers;
  362. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  363. Status status = pass_managers.Run(compute_graph);
  364. EXPECT_EQ(status, ge::SUCCESS);
  365. // check contain netoutput
  366. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  367. EXPECT_NE(net_out_node, nullptr);
  368. /// check input data node of netoutput
  369. /// Check data input
  370. int input_data_node_num = net_out_node->GetInDataNodes().size();
  371. EXPECT_EQ(input_data_node_num, 1);
  372. std::vector<string> expect_input_data_result{"Mul1"};
  373. for (auto node : net_out_node->GetInDataNodes()) {
  374. auto name = node->GetName();
  375. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  376. if (iter != expect_input_data_result.end()) {
  377. expect_input_data_result.erase(iter);
  378. }
  379. }
  380. input_data_node_num = expect_input_data_result.size();
  381. EXPECT_EQ(input_data_node_num, 0);
  382. // Check control input
  383. int control_node_num = net_out_node->GetInControlNodes().size();
  384. EXPECT_EQ(control_node_num, 1);
  385. std::vector<string> expect_control_data_result{"Mul2"};
  386. for (auto node : net_out_node->GetInControlNodes()) {
  387. auto name = node->GetName();
  388. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  389. if (iter != expect_control_data_result.end()) {
  390. expect_control_data_result.erase(iter);
  391. }
  392. }
  393. control_node_num = expect_control_data_result.size();
  394. EXPECT_EQ(control_node_num, 0);
  395. }
  396. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  397. /// include one common node with target nodes.
  398. /// Notice: output nodes set is more prio
  399. TEST_F(UTEST_graph_passes_net_output_pass, netoutput_node_and_output_nodes_and_target_node_success_1) {
  400. ge::ComputeGraphPtr compute_graph = build_graph();
  401. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  402. netout->AddInputDesc(ge::GeTensorDesc());
  403. netout->AddInputDesc(ge::GeTensorDesc());
  404. netout->AddOutputDesc(ge::GeTensorDesc());
  405. netout->AddOutputDesc(ge::GeTensorDesc());
  406. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  407. EXPECT_NE(netout_node, nullptr);
  408. for (NodePtr node : compute_graph->GetDirectNode()) {
  409. if (node->GetName() == "Mul1") {
  410. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  411. } else if (node->GetName() == "Mul2") {
  412. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
  413. }
  414. }
  415. // construct targets
  416. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  417. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  418. std::vector<ge::NodePtr> target_nodes = {mul2};
  419. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  420. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  421. compute_graph->SetGraphOutNodesInfo(output_nodes);
  422. ge::PassManager pass_managers;
  423. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  424. Status status = pass_managers.Run(compute_graph);
  425. EXPECT_EQ(status, ge::SUCCESS);
  426. // check contain netoutput
  427. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  428. EXPECT_NE(net_out_node, nullptr);
  429. /// check input data node of netoutput
  430. /// Check data input
  431. int input_data_node_num = net_out_node->GetInDataNodes().size();
  432. EXPECT_EQ(input_data_node_num, 2);
  433. std::vector<string> expect_input_data_result{"Mul1", "Mul2"};
  434. for (auto node : net_out_node->GetInDataNodes()) {
  435. auto name = node->GetName();
  436. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  437. if (iter != expect_input_data_result.end()) {
  438. expect_input_data_result.erase(iter);
  439. }
  440. }
  441. input_data_node_num = expect_input_data_result.size();
  442. EXPECT_EQ(input_data_node_num, 0);
  443. // Check control input
  444. int control_node_num = net_out_node->GetInControlNodes().size();
  445. EXPECT_EQ(control_node_num, 0);
  446. }
  447. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  448. /// include one common node with target nodes.
  449. /// Notice: output nodes set is more prio
  450. TEST_F(UTEST_graph_passes_net_output_pass, netoutput_node_and_output_nodes_and_target_node_success_2) {
  451. ge::ComputeGraphPtr compute_graph = build_graph(true);
  452. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  453. netout->AddInputDesc(ge::GeTensorDesc());
  454. netout->AddOutputDesc(ge::GeTensorDesc());
  455. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  456. EXPECT_NE(netout_node, nullptr);
  457. for (NodePtr node : compute_graph->GetDirectNode()) {
  458. if (node->GetName() == "Mul1") {
  459. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  460. }
  461. if (node->GetName() == "Mul2") {
  462. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  463. }
  464. if (node->GetName() == "Relu3") {
  465. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  466. }
  467. }
  468. // construct targets
  469. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  470. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  471. std::vector<ge::NodePtr> target_nodes = {mul2};
  472. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  473. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
  474. compute_graph->SetGraphOutNodesInfo(output_nodes);
  475. ge::PassManager pass_managers;
  476. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  477. Status status = pass_managers.Run(compute_graph);
  478. EXPECT_EQ(status, ge::SUCCESS);
  479. // check contain netoutput
  480. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  481. EXPECT_NE(net_out_node, nullptr);
  482. /// check input data node of netoutput
  483. /// Check data input
  484. int input_data_node_num = net_out_node->GetInDataNodes().size();
  485. EXPECT_EQ(input_data_node_num, 1);
  486. std::vector<string> expect_input_data_result{"Mul1"};
  487. for (auto node : net_out_node->GetInDataNodes()) {
  488. auto name = node->GetName();
  489. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  490. if (iter != expect_input_data_result.end()) {
  491. expect_input_data_result.erase(iter);
  492. }
  493. }
  494. input_data_node_num = expect_input_data_result.size();
  495. EXPECT_EQ(input_data_node_num, 0);
  496. // Check control input
  497. int control_node_num = net_out_node->GetInControlNodes().size();
  498. EXPECT_EQ(control_node_num, 2);
  499. std::vector<string> expect_control_data_result{"Mul2", "Relu3"};
  500. for (auto node : net_out_node->GetInControlNodes()) {
  501. auto name = node->GetName();
  502. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  503. if (iter != expect_control_data_result.end()) {
  504. expect_control_data_result.erase(iter);
  505. }
  506. }
  507. control_node_num = expect_control_data_result.size();
  508. EXPECT_EQ(control_node_num, 0);
  509. }
  510. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  511. /// include one common node with target nodes.
  512. /// Notice: output nodes set is more prio
  513. TEST_F(UTEST_graph_passes_net_output_pass, netoutput_node_and_output_nodes_and_target_node_success_3) {
  514. ge::ComputeGraphPtr compute_graph = build_graph();
  515. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  516. netout->AddInputDesc(ge::GeTensorDesc());
  517. netout->AddOutputDesc(ge::GeTensorDesc());
  518. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  519. EXPECT_NE(netout_node, nullptr);
  520. for (NodePtr node : compute_graph->GetDirectNode()) {
  521. if (node->GetName() == "Mul1") {
  522. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  523. }
  524. if (node->GetName() == "Mul2") {
  525. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  526. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInControlAnchor());
  527. }
  528. }
  529. // construct targets
  530. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  531. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  532. std::vector<ge::NodePtr> target_nodes = {mul2};
  533. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  534. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
  535. compute_graph->SetGraphOutNodesInfo(output_nodes);
  536. ge::PassManager pass_managers;
  537. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  538. Status status = pass_managers.Run(compute_graph);
  539. EXPECT_EQ(status, ge::SUCCESS);
  540. // check contain netoutput
  541. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  542. EXPECT_NE(net_out_node, nullptr);
  543. /// check input data node of netoutput
  544. /// Check data input
  545. int input_data_node_num = net_out_node->GetInDataNodes().size();
  546. EXPECT_EQ(input_data_node_num, 1);
  547. std::vector<string> expect_input_data_result{"Mul1"};
  548. for (auto node : net_out_node->GetInDataNodes()) {
  549. auto name = node->GetName();
  550. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  551. if (iter != expect_input_data_result.end()) {
  552. expect_input_data_result.erase(iter);
  553. }
  554. }
  555. input_data_node_num = expect_input_data_result.size();
  556. EXPECT_EQ(input_data_node_num, 0);
  557. // Check control input
  558. int control_node_num = net_out_node->GetInControlNodes().size();
  559. EXPECT_EQ(control_node_num, 1);
  560. std::vector<string> expect_control_data_result{"Mul2"};
  561. for (auto node : net_out_node->GetInControlNodes()) {
  562. auto name = node->GetName();
  563. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  564. if (iter != expect_control_data_result.end()) {
  565. expect_control_data_result.erase(iter);
  566. }
  567. }
  568. control_node_num = expect_control_data_result.size();
  569. EXPECT_EQ(control_node_num, 0);
  570. }
  571. TEST_F(UTEST_graph_passes_net_output_pass, no_output_no_target_no_retval_success) {
  572. ge::ComputeGraphPtr compute_graph = build_graph();
  573. // Construct specified output
  574. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  575. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  576. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  577. compute_graph->SetGraphOutNodesInfo(output_nodes);
  578. ge::PassManager pass_managers;
  579. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  580. Status status = pass_managers.Run(compute_graph);
  581. EXPECT_EQ(status, ge::SUCCESS);
  582. }
  583. TEST_F(UTEST_graph_passes_net_output_pass, user_out_node_success) {
  584. ge::ComputeGraphPtr compute_graph = build_graph();
  585. // Construct specified output
  586. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  587. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  588. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  589. compute_graph->SetGraphOutNodesInfo(output_nodes);
  590. ge::PassManager pass_managers;
  591. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  592. Status status = pass_managers.Run(compute_graph);
  593. EXPECT_EQ(status, ge::SUCCESS);
  594. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  595. EXPECT_NE(net_out_node, nullptr);
  596. // Check data input
  597. string str;
  598. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  599. str += input_data_node->GetName() + ";";
  600. }
  601. EXPECT_EQ(str, "Mul1;Mul2;");
  602. // Check control input
  603. int control_node_num = net_out_node->GetInControlNodes().size();
  604. EXPECT_EQ(control_node_num, 0);
  605. }
  606. TEST_F(UTEST_graph_passes_net_output_pass, retval_node_for_out_success) {
  607. ge::ComputeGraphPtr compute_graph = build_graph();
  608. // Imitate the output node of _Retval issued
  609. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  610. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  611. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  612. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  613. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  614. EXPECT_NE(retval_node1, nullptr);
  615. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  616. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  617. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  618. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
  619. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  620. EXPECT_NE(retval_node2, nullptr);
  621. for (NodePtr node : compute_graph->GetDirectNode()) {
  622. if (node->GetName() == "Mul1") {
  623. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  624. } else if (node->GetName() == "Mul2") {
  625. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  626. }
  627. }
  628. ge::PassManager pass_managers;
  629. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  630. Status status = pass_managers.Run(compute_graph);
  631. EXPECT_EQ(status, ge::SUCCESS);
  632. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  633. EXPECT_NE(net_out_node, nullptr);
  634. // Check data input
  635. string str;
  636. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  637. str += input_data_node->GetName() + ";";
  638. }
  639. EXPECT_EQ(str, "Mul1;Mul2;");
  640. // Check control input
  641. int control_node_num = net_out_node->GetInControlNodes().size();
  642. EXPECT_EQ(control_node_num, 0);
  643. // Check the deletion of _Retval node
  644. retval_node1 = compute_graph->FindNode("reval_node1");
  645. EXPECT_EQ(retval_node1, nullptr);
  646. retval_node2 = compute_graph->FindNode("reval_node2");
  647. EXPECT_EQ(retval_node2, nullptr);
  648. }
  649. TEST_F(UTEST_graph_passes_net_output_pass, check_order_and_const_flag_success) {
  650. ge::ComputeGraphPtr compute_graph = build_graph();
  651. ge::OpDescPtr const_node_desc = std::make_shared<ge::OpDesc>("const_output", CONSTANT);
  652. const_node_desc->AddOutputDesc(ge::GeTensorDesc());
  653. ge::NodePtr const_node = compute_graph->AddNode(const_node_desc);
  654. EXPECT_NE(const_node, nullptr);
  655. NodePtr mul1 = compute_graph->FindNode("Mul1");
  656. EXPECT_NE(mul1, nullptr);
  657. GraphUtils::AddEdge(mul1->GetOutControlAnchor(), const_node->GetInControlAnchor());
  658. // Construct specified output
  659. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{const_node, 0}};
  660. compute_graph->SetGraphOutNodesInfo(output_nodes);
  661. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  662. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  663. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  664. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
  665. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  666. EXPECT_NE(retval_node2, nullptr);
  667. NodePtr mul2 = compute_graph->FindNode("Mul2");
  668. EXPECT_NE(mul2, nullptr);
  669. GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  670. ge::PassManager pass_managers;
  671. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  672. Status status = pass_managers.Run(compute_graph);
  673. EXPECT_EQ(status, ge::SUCCESS);
  674. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  675. EXPECT_NE(net_out_node, nullptr);
  676. // Check data input
  677. string str;
  678. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  679. str += input_data_node->GetName() + ";";
  680. }
  681. EXPECT_EQ(str, "const_output;Mul2;");
  682. // Check control input
  683. int control_node_num = net_out_node->GetInControlNodes().size();
  684. EXPECT_EQ(control_node_num, 0);
  685. // Check is_input_const flag
  686. std::vector<bool> is_input_const = net_out_node->GetOpDesc()->GetIsInputConst();
  687. EXPECT_EQ(is_input_const.size(), 2);
  688. EXPECT_EQ(is_input_const[0], true);
  689. EXPECT_EQ(is_input_const[1], false);
  690. // Check the deletion of _Retval node
  691. retval_node2 = compute_graph->FindNode("reval_node2");
  692. EXPECT_EQ(retval_node2, nullptr);
  693. }
  694. TEST_F(UTEST_graph_passes_net_output_pass, out_node_check_fail) {
  695. ge::ComputeGraphPtr compute_graph = build_graph();
  696. // Construct specified output
  697. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  698. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  699. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_name = {{nullptr, 0}, {mul2, 0}};
  700. compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name);
  701. ge::PassManager pass_managers;
  702. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  703. Status status = pass_managers.Run(compute_graph);
  704. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  705. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  706. EXPECT_EQ(net_out_node, nullptr);
  707. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_index = {{mul1, 0}, {mul2, 100}};
  708. compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_index);
  709. status = pass_managers.Run(compute_graph);
  710. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  711. net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  712. EXPECT_EQ(net_out_node, nullptr);
  713. }
  714. TEST_F(UTEST_graph_passes_net_output_pass, retval_node_check_fail) {
  715. ge::ComputeGraphPtr compute_graph = build_graph();
  716. // Imitate the output node of _Retval issued
  717. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  718. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  719. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  720. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  721. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  722. EXPECT_NE(retval_node1, nullptr);
  723. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  724. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  725. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  726. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
  727. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  728. EXPECT_NE(retval_node2, nullptr);
  729. for (NodePtr node : compute_graph->GetDirectNode()) {
  730. if (node->GetName() == "Mul1") {
  731. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  732. } else if (node->GetName() == "Mul2") {
  733. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  734. }
  735. }
  736. ge::PassManager pass_managers;
  737. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  738. Status status = pass_managers.Run(compute_graph);
  739. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  740. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  741. EXPECT_EQ(net_out_node, nullptr);
  742. }
  743. TEST_F(UTEST_graph_passes_net_output_pass, out_node_update_desc_check_fail) {
  744. ge::ComputeGraphPtr compute_graph = build_graph();
  745. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  746. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  747. EXPECT_NE(netout_node, nullptr);
  748. ge::PassManager pass_managers;
  749. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  750. Status status = pass_managers.Run(compute_graph);
  751. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  752. }
  753. TEST_F(UTEST_graph_passes_net_output_pass, out_node_remove_check_fail) {
  754. ge::ComputeGraphPtr compute_graph = build_graph();
  755. // Construct specified output
  756. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  757. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  758. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  759. compute_graph->SetGraphOutNodesInfo(output_nodes);
  760. // compute_graph->RemoveNode(mul1);
  761. mul1->GetInDataAnchor(0)->UnlinkAll();
  762. mul1->GetInDataAnchor(1)->UnlinkAll();
  763. GraphUtils::RemoveNodeWithoutRelink(compute_graph, mul1);
  764. mul1 = compute_graph->FindNode("Mul1");
  765. EXPECT_EQ(mul1, nullptr);
  766. ge::PassManager pass_managers;
  767. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  768. Status status = pass_managers.Run(compute_graph);
  769. EXPECT_EQ(status, ge::SUCCESS);
  770. }
  771. TEST_F(UTEST_graph_passes_net_output_pass, clear_weight) {
  772. ge::ComputeGraphPtr compute_graph = BuildClearWeightGraph();
  773. auto cast = compute_graph->FindNode("Cast1");
  774. Status ret = ge::OpDescUtils::ClearWeights(cast);
  775. EXPECT_EQ(ge::SUCCESS, ret);
  776. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)