You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flow_ctrl_pass_unittest.cc 18 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "common/ge_inner_error_codes.h"
  18. #include "common/types.h"
  19. #include "graph/manager/graph_var_manager.h"
  20. #include "graph/utils/attr_utils.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "inc/pass_manager.h"
  23. #define private public
  24. #include "graph/passes/flow_ctrl_pass.h"
  25. #undef private
  26. using namespace domi;
  27. namespace ge {
  28. class UTEST_graph_passes_flow_ctrl_pass : public testing::Test {
  29. protected:
  30. void SetUp() {
  31. uint64_t session_id = 0;
  32. uint32_t device_id = 0;
  33. uint64_t job_id = 0;
  34. uint32_t session_version = 0;
  35. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->Init(session_version, session_id, device_id, job_id));
  36. }
  37. void TearDown() { VarManagerPool::Instance().Destory(); }
  38. public:
  39. /// Set up a graph with the following network structure
  40. /// IteratorGetNext
  41. /// |
  42. /// MemcpyAsync
  43. /// |
  44. /// A
  45. /// |
  46. /// NetOutput
  47. void MakeGraph(ge::ComputeGraphPtr &graph) {
  48. auto desc_ptr = make_shared<ge::GeTensorDesc>();
  49. auto desc = *desc_ptr;
  50. ge::OpDescPtr op_desc_get_next = make_shared<ge::OpDesc>("IteratorGetNext", FRAMEWORKOP);
  51. op_desc_get_next->AddOutputDesc(desc);
  52. ge::OpDescPtr op_desc_memcpy = make_shared<ge::OpDesc>("MemcpyAsync", MEMCPYASYNC);
  53. op_desc_memcpy->AddInputDesc(desc);
  54. op_desc_memcpy->AddOutputDesc(desc);
  55. ge::AttrUtils::SetBool(op_desc_memcpy, ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true);
  56. ge::OpDescPtr op_desc_a = make_shared<ge::OpDesc>("A", RESOURCEAPPLYMOMENTUM);
  57. op_desc_a->AddInputDesc(desc);
  58. op_desc_a->AddOutputDesc(desc);
  59. ge::OpDescPtr op_desc_gatherv2 = make_shared<ge::OpDesc>("GatherV2", GATHERV2);
  60. op_desc_gatherv2->AddInputDesc(desc);
  61. op_desc_gatherv2->AddOutputDesc(desc);
  62. ge::OpDescPtr op_desc_global_step = make_shared<ge::OpDesc>("global_step", VARIABLE);
  63. op_desc_global_step->AddOutputDesc(desc);
  64. ge::OpDescPtr op_desc_netout = make_shared<ge::OpDesc>("NetOutput", NETOUTPUT);
  65. ge::AttrUtils::SetInt(op_desc_netout, ATTR_NAME_TRUE_BRANCH_STREAM, TRUE_STREAM_ID);
  66. op_desc_netout->AddInputDesc(desc);
  67. op_desc_netout->AddInputDesc(desc);
  68. // add node
  69. ge::NodePtr get_next_node = graph->AddNode(op_desc_get_next);
  70. ge::NodePtr memcpy_node = graph->AddNode(op_desc_memcpy);
  71. ge::NodePtr node_a = graph->AddNode(op_desc_a);
  72. ge::NodePtr global_step = graph->AddNode(op_desc_global_step);
  73. ge::NodePtr gatherv2 = graph->AddNode(op_desc_gatherv2);
  74. ge::NodePtr netoutput = graph->AddNode(op_desc_netout);
  75. // add edge
  76. ge::GraphUtils::AddEdge(get_next_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  77. ge::GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), node_a->GetInDataAnchor(0));
  78. ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0));
  79. ge::GraphUtils::AddEdge(gatherv2->GetOutDataAnchor(0), netoutput->GetInDataAnchor(1));
  80. ge::GraphUtils::AddEdge(global_step->GetOutDataAnchor(0), gatherv2->GetInDataAnchor(0));
  81. }
  82. void AddSessionVariables(void) {
  83. static std::set<std::string> varList = {
  84. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  85. NODE_NAME_FLOWCTRL_LOOP_COND,
  86. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  87. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  88. NODE_NAME_GLOBAL_STEP,
  89. };
  90. uint8_t *dev_ptr = nullptr;
  91. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  92. for (std::string var_name : varList) {
  93. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  94. }
  95. }
  96. };
  97. TEST_F(UTEST_graph_passes_flow_ctrl_pass, FlowCtrlPass_Success_Test) {
  98. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("FlowCtrlPassSuccess");
  99. graph->SetNeedIteration(true);
  100. // Create graph
  101. MakeGraph(graph);
  102. graph->TopologicalSorting();
  103. AddSessionVariables();
  104. FlowCtrlPass flow_ctrl_pass;
  105. Status ret = flow_ctrl_pass.Run(graph);
  106. EXPECT_EQ(ret, SUCCESS);
  107. EXPECT_EQ(16, graph->GetDirectNodesSize());
  108. int stream_switch_cnt = 0;
  109. int stream_activeCnt = 0;
  110. for (ge::NodePtr node : graph->GetDirectNode()) {
  111. if (node->GetOpDesc()->GetType() == STREAMSWITCH) {
  112. stream_switch_cnt++;
  113. } else if (node->GetOpDesc()->GetType() == STREAMACTIVE) {
  114. stream_activeCnt++;
  115. }
  116. }
  117. EXPECT_EQ(stream_switch_cnt, 2);
  118. EXPECT_EQ(stream_activeCnt, 2);
  119. }
  120. TEST_F(UTEST_graph_passes_flow_ctrl_pass, FlowCtrlPass_Success_VAR_NODE_ADD_BEFORE) {
  121. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("FlowCtrlPassSuccess");
  122. graph->SetNeedIteration(true);
  123. // Create graph
  124. MakeGraph(graph);
  125. graph->TopologicalSorting();
  126. AddSessionVariables();
  127. FlowCtrlPass flow_ctrl_pass;
  128. NodePtr loop_cond_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  129. EXPECT_NE(loop_cond_node, nullptr);
  130. NodePtr loop_increment_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_INCREMENT);
  131. EXPECT_NE(loop_increment_node, nullptr);
  132. NodePtr loop_reset_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_RESETVALUE);
  133. EXPECT_NE(loop_reset_node, nullptr);
  134. NodePtr iter_per_loop_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  135. EXPECT_NE(iter_per_loop_node, nullptr);
  136. Status ret = flow_ctrl_pass.Run(graph);
  137. EXPECT_EQ(ret, ge::SUCCESS);
  138. }
  139. TEST_F(UTEST_graph_passes_flow_ctrl_pass, FlowCtrlPass_NOT_TRAIN) {
  140. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("TestNotChange");
  141. graph->SetNeedIteration(false);
  142. FlowCtrlPass flow_ctrl_pass;
  143. Status ret = flow_ctrl_pass.Run(graph);
  144. EXPECT_EQ(ret, NOT_CHANGED);
  145. }
  146. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddFpBpIteratorCtrl_WITHOUT_VAR) {
  147. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("TestNotChange");
  148. graph->SetNeedIteration(true);
  149. // Create graph
  150. MakeGraph(graph);
  151. graph->TopologicalSorting();
  152. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  153. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  154. uint8_t *dev_ptr = nullptr;
  155. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(NODE_NAME_FLOWCTRL_LOOP_PER_ITER, tensor_desc,
  156. dev_ptr, RT_MEMORY_HBM));
  157. // not add var
  158. FlowCtrlPass flow_ctrl_pass;
  159. Status ret = flow_ctrl_pass.Run(graph);
  160. EXPECT_NE(ret, ge::SUCCESS);
  161. }
  162. TEST_F(UTEST_graph_passes_flow_ctrl_pass, Run_AddSpecialNodeIteratorCtrl_NO_INANCHOR) {
  163. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  164. graph->SetNeedIteration(true);
  165. // Create graph
  166. MakeGraph(graph);
  167. graph->TopologicalSorting();
  168. AddSessionVariables();
  169. FlowCtrlPass flow_ctrl_pass;
  170. NodePtr getnext_node = graph->FindNode("IteratorGetNext");
  171. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  172. GraphUtils::RemoveEdge(getnext_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  173. Status ret = flow_ctrl_pass.Run(graph);
  174. EXPECT_NE(ret, ge::SUCCESS);
  175. }
  176. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddFpBpIteratorCtrl_WITHOUT_LOOP_COND) {
  177. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_COND");
  178. graph->SetNeedIteration(true);
  179. // Create graph
  180. MakeGraph(graph);
  181. graph->TopologicalSorting();
  182. std::set<std::string> varList = {
  183. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  184. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  185. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  186. NODE_NAME_GLOBAL_STEP,
  187. };
  188. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  189. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  190. uint8_t *dev_ptr = nullptr;
  191. for (std::string var_name : varList) {
  192. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  193. }
  194. // not add var
  195. FlowCtrlPass flow_ctrl_pass;
  196. NodePtr pre_node = graph->FindNode("NetOutput");
  197. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  198. EXPECT_EQ(ret, FAILED);
  199. }
  200. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddFpBpIteratorCtrl_WITHOUT_LOOP_INCREMENT) {
  201. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_INCREMENT");
  202. graph->SetNeedIteration(true);
  203. // Create graph
  204. MakeGraph(graph);
  205. graph->TopologicalSorting();
  206. std::set<std::string> varList = {
  207. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  208. NODE_NAME_FLOWCTRL_LOOP_COND,
  209. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  210. NODE_NAME_GLOBAL_STEP,
  211. };
  212. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  213. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  214. uint8_t *dev_ptr = nullptr;
  215. for (std::string var_name : varList) {
  216. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  217. }
  218. // not add var
  219. FlowCtrlPass flow_ctrl_pass;
  220. NodePtr pre_node = graph->FindNode("NetOutput");
  221. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  222. EXPECT_EQ(ret, FAILED);
  223. }
  224. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddFpBpIteratorCtrl_WITHOUT_LOOP_RESETVALUE) {
  225. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_RESETVALUE");
  226. graph->SetNeedIteration(true);
  227. // Create graph
  228. MakeGraph(graph);
  229. graph->TopologicalSorting();
  230. std::set<std::string> varList = {
  231. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  232. NODE_NAME_FLOWCTRL_LOOP_COND,
  233. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  234. NODE_NAME_GLOBAL_STEP,
  235. };
  236. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  237. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  238. uint8_t *dev_ptr = nullptr;
  239. for (std::string var_name : varList) {
  240. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  241. }
  242. // not add var
  243. FlowCtrlPass flow_ctrl_pass;
  244. NodePtr pre_node = graph->FindNode("NetOutput");
  245. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  246. EXPECT_EQ(ret, FAILED);
  247. }
  248. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddFpBpIteratorCtrl_WITHOUT_LOOP_PER_ITER) {
  249. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  250. graph->SetNeedIteration(true);
  251. // Create graph
  252. MakeGraph(graph);
  253. graph->TopologicalSorting();
  254. std::set<std::string> varList = {
  255. NODE_NAME_FLOWCTRL_LOOP_COND,
  256. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  257. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  258. NODE_NAME_GLOBAL_STEP,
  259. };
  260. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  261. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  262. uint8_t *dev_ptr = nullptr;
  263. for (std::string var_name : varList) {
  264. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  265. }
  266. FlowCtrlPass flow_ctrl_pass;
  267. NodePtr pre_node = graph->FindNode("NetOutput");
  268. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  269. EXPECT_EQ(ret, FAILED);
  270. }
  271. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddSpecialNodeIteratorCtrl_WITHOUT_LOOP_COND) {
  272. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_COND");
  273. graph->SetNeedIteration(true);
  274. // Create graph
  275. MakeGraph(graph);
  276. graph->TopologicalSorting();
  277. std::set<std::string> varList = {
  278. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  279. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  280. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  281. NODE_NAME_GLOBAL_STEP,
  282. };
  283. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  284. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  285. uint8_t *dev_ptr = nullptr;
  286. for (std::string var_name : varList) {
  287. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  288. }
  289. FlowCtrlPass flow_ctrl_pass;
  290. NodePtr iter_per_loop_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  291. EXPECT_NE(iter_per_loop_node, nullptr);
  292. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  293. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  294. EXPECT_EQ(ret, FAILED);
  295. }
  296. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddSpecialNodeIteratorCtrl_WITHOUT_LOOP_PER_ITER) {
  297. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  298. graph->SetNeedIteration(true);
  299. // Create graph
  300. MakeGraph(graph);
  301. graph->TopologicalSorting();
  302. std::set<std::string> varList = {
  303. NODE_NAME_FLOWCTRL_LOOP_COND,
  304. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  305. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  306. NODE_NAME_GLOBAL_STEP,
  307. };
  308. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  309. uint8_t *dev_ptr = nullptr;
  310. for (std::string var_name : varList) {
  311. EXPECT_EQ(domi::SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  312. }
  313. FlowCtrlPass flow_ctrl_pass;
  314. NodePtr loop_cond_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  315. EXPECT_NE(loop_cond_node, nullptr);
  316. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  317. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  318. EXPECT_EQ(ret, FAILED);
  319. }
  320. TEST_F(UTEST_graph_passes_flow_ctrl_pass, AddSpecialNodeIteratorCtrl_NO_INANCHOR) {
  321. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  322. graph->SetNeedIteration(true);
  323. // Create graph
  324. MakeGraph(graph);
  325. graph->TopologicalSorting();
  326. FlowCtrlPass flow_ctrl_pass;
  327. NodePtr getnext_node = graph->FindNode("IteratorGetNext");
  328. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  329. GraphUtils::RemoveEdge(getnext_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  330. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  331. EXPECT_EQ(ret, FAILED);
  332. }
  333. TEST_F(UTEST_graph_passes_flow_ctrl_pass, InsertAssignOp_SUCCESS) {
  334. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  335. FlowCtrlPass flow_ctrl_pass;
  336. GeTensorDesc tmp_geT_tensor_desc;
  337. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  338. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  339. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  340. EXPECT_NE(add_node, nullptr);
  341. }
  342. TEST_F(UTEST_graph_passes_flow_ctrl_pass, InsertAssignOp_REF_NODE_NO_OUTANCHOR) {
  343. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  344. FlowCtrlPass flow_ctrl_pass;
  345. GeTensorDesc tmp_geT_tensor_desc;
  346. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  347. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  348. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  349. EXPECT_EQ(add_node, nullptr);
  350. }
  351. TEST_F(UTEST_graph_passes_flow_ctrl_pass, InsertAssignOp_VALUE_NODE_NO_OUTANCHOR) {
  352. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  353. FlowCtrlPass flow_ctrl_pass;
  354. GeTensorDesc tmp_geT_tensor_desc;
  355. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  356. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  357. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  358. EXPECT_EQ(add_node, nullptr);
  359. }
  360. TEST_F(UTEST_graph_passes_flow_ctrl_pass, CreateIterCtrlFalseBranch_InsertAssignOp_FAILED) {
  361. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_CreateIterCtrlFalseBranch_InsertAssignOp_FAILED");
  362. FlowCtrlPass flow_ctrl_pass;
  363. GeTensorDesc tmp_geT_tensor_desc;
  364. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  365. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  366. NodePtr switch_node = flow_ctrl_pass.InsertOp(graph, STREAMSWITCH, "switch_node", {}, {});
  367. Status ret = flow_ctrl_pass.CreateIterCtrlFalseBranch(graph, ref_node, value_node, switch_node);
  368. EXPECT_EQ(ret, FAILED);
  369. }
  370. TEST_F(UTEST_graph_passes_flow_ctrl_pass, CreateIterCtrlTrueBranch_InsertAssignOp_FAILED) {
  371. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("CreateIterCtrlTrueBranch_InsertAssignOp_FAILED");
  372. FlowCtrlPass flow_ctrl_pass;
  373. GeTensorDesc tmp_geT_tensor_desc;
  374. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  375. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  376. NodePtr switch_node = flow_ctrl_pass.InsertOp(graph, STREAMSWITCH, "switch_node", {}, {});
  377. Status ret = flow_ctrl_pass.CreateIterCtrlTrueBranch(graph, ref_node, value_node, switch_node);
  378. EXPECT_EQ(ret, FAILED);
  379. }
  380. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示