Browse Source

add check of ut

tags/v1.2.0
zhou_lili 3 years ago
parent
commit
0db227b67f
2 changed files with 34 additions and 17 deletions
  1. +31
    -17
      tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc
  2. +3
    -0
      tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc

+ 31
- 17
tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc View File

@@ -85,13 +85,13 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
EXPECT_EQ(input_data_node_num, 3); EXPECT_EQ(input_data_node_num, 3);


ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data");
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data");
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data");
auto data2 = MakeNode(sub_graph, 1, 1, "data2", "Data");
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
@@ -100,19 +100,28 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {


sub_graph->SetParentNode(parent_case); sub_graph->SetParentNode(parent_case);
sub_graph->SetParentGraph(parent_graph); sub_graph->SetParentGraph(parent_graph);
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
// after pass, data1 and data2 are fused to data0
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph);
size_t sub_graph_num = parent_graph->GetAllSubgraphs().size();
EXPECT_EQ(sub_graph_num, 1);

auto data1_node = sub_graph->FindNode("data1"); auto data1_node = sub_graph->FindNode("data1");
EXPECT_EQ(data1_node, nullptr);
EXPECT_NE(data1_node, nullptr);
auto data2_node = sub_graph->FindNode("data2"); auto data2_node = sub_graph->FindNode("data2");
EXPECT_EQ(data2_node, nullptr);
EXPECT_NE(data2_node, nullptr);


EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS);

// after pass, data1 and data2 are fused to data0
data1_node = sub_graph->FindNode("data1");
EXPECT_EQ(data1_node, nullptr);
data2_node = sub_graph->FindNode("data2");
EXPECT_EQ(data2_node, nullptr);
} }


/// graph with subgraph /// graph with subgraph
/// const /// const
/// / \ /// / \
/// cast1 cast2
/// cast1 cast1
/// \ / /// \ /
/// case /// case
/// | /// |
@@ -127,7 +136,6 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast");
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast");
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case");
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");


@@ -136,23 +144,21 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);


GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));


ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data");
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data");
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
@@ -160,9 +166,17 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {


sub_graph->SetParentNode(parent_case); sub_graph->SetParentNode(parent_case);
sub_graph->SetParentGraph(parent_graph); sub_graph->SetParentGraph(parent_graph);
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
// after pass, data1 is fused to data0
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph);

size_t sub_graph_num = parent_graph->GetAllSubgraphs().size();
EXPECT_EQ(sub_graph_num, 1);
auto data1_node = sub_graph->FindNode("data1"); auto data1_node = sub_graph->FindNode("data1");
EXPECT_NE(data1_node, nullptr);

EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS);

// after pass, data1 is fused to data0
data1_node = sub_graph->FindNode("data1");
EXPECT_EQ(data1_node, nullptr); EXPECT_EQ(data1_node, nullptr);
} }
} // namespace ge } // namespace ge

+ 3
- 0
tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc View File

@@ -194,6 +194,9 @@ TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) {
auto func_node = MakeNode(owner, 3, 1, "test_if", "If"); auto func_node = MakeNode(owner, 3, 1, "test_if", "If");
graph->SetParentNode(func_node); graph->SetParentNode(func_node);
graph->SetParentGraph(owner); graph->SetParentGraph(owner);
owner->AddSubgraph(graph->GetName(), graph);
size_t sub_graph_num = owner->GetAllSubgraphs().size();
EXPECT_EQ(sub_graph_num, 1);
EXPECT_EQ(pass_manager.Run(graph), SUCCESS); EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
} }




Loading…
Cancel
Save