|
@@ -58,8 +58,6 @@ public: |
|
|
/// netoutput |
|
|
/// netoutput |
|
|
/// ... |
|
|
/// ... |
|
|
/// data0 data1 data2 |
|
|
/// data0 data1 data2 |
|
|
/// | \ / |
|
|
|
|
|
/// conv add |
|
|
|
|
|
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
PassManager pass_manager; |
|
|
PassManager pass_manager; |
|
|
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); |
|
|
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); |
|
@@ -81,6 +79,11 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); |
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); |
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
|
|
|
|
auto case_node = parent_graph->FindNode("parent_case"); |
|
|
|
|
|
EXPECT_NE(case_node, nullptr); |
|
|
|
|
|
size_t input_data_node_num = case_node->GetInDataNodes().size(); |
|
|
|
|
|
EXPECT_EQ(input_data_node_num, 3); |
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); |
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
@@ -98,6 +101,12 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
|
|
|
// after pass, data1 and data2 are fused to data0 |
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
|
|
|
auto data2_node = sub_graph->FindNode("data2"); |
|
|
|
|
|
EXPECT_EQ(data2_node, nullptr); |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// graph with subgraph |
|
|
/// graph with subgraph |
|
@@ -152,5 +161,8 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
|
|
|
// after pass, data1 is fused to data0 |
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
} |
|
|
} |
|
|
} // namespace ge |
|
|
} // namespace ge |