diff --git a/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc index aa69f6a3..8c3469c8 100644 --- a/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc @@ -85,13 +85,13 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { EXPECT_EQ(input_data_node_num, 3); ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); + auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data"); data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); + auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data"); data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); + auto data2 = MakeNode(sub_graph, 1, 1, "data2", "Data"); data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); @@ -100,19 +100,28 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { sub_graph->SetParentNode(parent_case); sub_graph->SetParentGraph(parent_graph); - EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); - // after pass, data1 and data2 are fused to data0 + parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph); + size_t sub_graph_num = parent_graph->GetAllSubgraphs().size(); + EXPECT_EQ(sub_graph_num, 1); + auto data1_node = sub_graph->FindNode("data1"); - EXPECT_EQ(data1_node, nullptr); + EXPECT_NE(data1_node, nullptr); auto data2_node = sub_graph->FindNode("data2"); - EXPECT_EQ(data2_node, nullptr); + EXPECT_NE(data2_node, nullptr); + EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS); + + // after pass, data1 and data2 are fused to data0 + data1_node = sub_graph->FindNode("data1"); + EXPECT_EQ(data1_node, nullptr); + data2_node = sub_graph->FindNode("data2"); + EXPECT_EQ(data2_node, nullptr); } /// graph with subgraph /// const /// / \ -/// cast1 cast2 +/// cast1 cast1 /// \ / /// case /// | @@ -127,7 +136,6 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { ComputeGraphPtr parent_graph = std::make_shared("parent_graph"); auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); - auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); @@ -136,23 +144,21 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); - parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); - GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); + GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); + auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data"); data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); - auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); + auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data"); data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); @@ -160,9 +166,17 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { sub_graph->SetParentNode(parent_case); sub_graph->SetParentGraph(parent_graph); - EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); - // after pass, data1 is fused to data0 + parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph); + + size_t sub_graph_num = parent_graph->GetAllSubgraphs().size(); + EXPECT_EQ(sub_graph_num, 1); auto data1_node = sub_graph->FindNode("data1"); + EXPECT_NE(data1_node, nullptr); + + EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS); + + // after pass, data1 is fused to data0 + data1_node = sub_graph->FindNode("data1"); EXPECT_EQ(data1_node, nullptr); } } // namespace ge diff --git a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc index b1cd6d4d..1b75a613 100644 --- a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc @@ -194,6 +194,9 @@ TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) { auto func_node = MakeNode(owner, 3, 1, "test_if", "If"); graph->SetParentNode(func_node); graph->SetParentGraph(owner); + owner->AddSubgraph(graph->GetName(), graph); + size_t sub_graph_num = owner->GetAllSubgraphs().size(); + EXPECT_EQ(sub_graph_num, 1); EXPECT_EQ(pass_manager.Run(graph), SUCCESS); }