|
@@ -85,13 +85,13 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
EXPECT_EQ(input_data_node_num, 3); |
|
|
EXPECT_EQ(input_data_node_num, 3); |
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); |
|
|
|
|
|
|
|
|
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data"); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); |
|
|
|
|
|
|
|
|
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data"); |
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); |
|
|
|
|
|
|
|
|
auto data2 = MakeNode(sub_graph, 1, 1, "data2", "Data"); |
|
|
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); |
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); |
|
@@ -100,19 +100,28 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { |
|
|
|
|
|
|
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
|
|
|
// after pass, data1 and data2 are fused to data0 |
|
|
|
|
|
|
|
|
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph); |
|
|
|
|
|
size_t sub_graph_num = parent_graph->GetAllSubgraphs().size(); |
|
|
|
|
|
EXPECT_EQ(sub_graph_num, 1); |
|
|
|
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
|
|
|
|
|
|
EXPECT_NE(data1_node, nullptr); |
|
|
auto data2_node = sub_graph->FindNode("data2"); |
|
|
auto data2_node = sub_graph->FindNode("data2"); |
|
|
EXPECT_EQ(data2_node, nullptr); |
|
|
|
|
|
|
|
|
EXPECT_NE(data2_node, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS); |
|
|
|
|
|
|
|
|
|
|
|
// after pass, data1 and data2 are fused to data0 |
|
|
|
|
|
data1_node = sub_graph->FindNode("data1"); |
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
|
|
|
data2_node = sub_graph->FindNode("data2"); |
|
|
|
|
|
EXPECT_EQ(data2_node, nullptr); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// graph with subgraph |
|
|
/// graph with subgraph |
|
|
/// const |
|
|
/// const |
|
|
/// / \ |
|
|
/// / \ |
|
|
/// cast1 cast2 |
|
|
|
|
|
|
|
|
/// cast1 cast1 |
|
|
/// \ / |
|
|
/// \ / |
|
|
/// case |
|
|
/// case |
|
|
/// | |
|
|
/// | |
|
@@ -127,7 +136,6 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { |
|
|
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); |
|
|
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); |
|
|
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); |
|
|
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); |
|
|
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); |
|
|
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); |
|
|
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); |
|
|
|
|
|
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); |
|
|
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); |
|
|
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); |
|
|
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); |
|
|
|
|
|
|
|
@@ -136,23 +144,21 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { |
|
|
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
|
|
|
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
|
|
|
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); |
|
|
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); |
|
|
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); |
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); |
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); |
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); |
|
|
|
|
|
|
|
|
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data"); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); |
|
|
|
|
|
|
|
|
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data"); |
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); |
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); |
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); |
|
@@ -160,9 +166,17 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { |
|
|
|
|
|
|
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentNode(parent_case); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
sub_graph->SetParentGraph(parent_graph); |
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); |
|
|
|
|
|
// after pass, data1 is fused to data0 |
|
|
|
|
|
|
|
|
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph); |
|
|
|
|
|
|
|
|
|
|
|
size_t sub_graph_num = parent_graph->GetAllSubgraphs().size(); |
|
|
|
|
|
EXPECT_EQ(sub_graph_num, 1); |
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
auto data1_node = sub_graph->FindNode("data1"); |
|
|
|
|
|
EXPECT_NE(data1_node, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS); |
|
|
|
|
|
|
|
|
|
|
|
// after pass, data1 is fused to data0 |
|
|
|
|
|
data1_node = sub_graph->FindNode("data1"); |
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
EXPECT_EQ(data1_node, nullptr); |
|
|
} |
|
|
} |
|
|
} // namespace ge |
|
|
} // namespace ge |