|
@@ -57,6 +57,36 @@ ut::GraphBuilder Graph1Builder() { |
|
|
builder.AddDataEdge(cast1, 0, conv2d, 0); |
|
|
builder.AddDataEdge(cast1, 0, conv2d, 0); |
|
|
return builder; |
|
|
return builder; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// data1 const1 |
|
|
|
|
|
/// \ / |
|
|
|
|
|
/// add1 |
|
|
|
|
|
/// | |
|
|
|
|
|
/// data2 -> switch1 (empty) |
|
|
|
|
|
/// | |
|
|
|
|
|
/// conv2d |
|
|
|
|
|
ut::GraphBuilder Graph2Builder() { |
|
|
|
|
|
ut::GraphBuilder builder = ut::GraphBuilder("graph2"); |
|
|
|
|
|
auto data1 = builder.AddNode("data1", "Data", 0, 1); |
|
|
|
|
|
auto data2 = builder.AddNode("data2", "Data", 0, 1); |
|
|
|
|
|
auto const1 = builder.AddNode("const1", "Const", 0, 1); |
|
|
|
|
|
auto add1 = builder.AddNode("add1", "Add", 2, 1); |
|
|
|
|
|
auto switch1 = builder.AddNode("switch1", "Switch", 2, 1); |
|
|
|
|
|
auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0); |
|
|
|
|
|
|
|
|
|
|
|
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); |
|
|
|
|
|
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); |
|
|
|
|
|
add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); |
|
|
|
|
|
GeTensorDesc empty_tensor(GeShape({1, 0, 8, 8}),FORMAT_NCHW); |
|
|
|
|
|
switch1->GetOpDesc()->UpdateOutputDesc(0, empty_tensor); |
|
|
|
|
|
|
|
|
|
|
|
builder.AddDataEdge(data1, 0, add1, 0); |
|
|
|
|
|
builder.AddDataEdge(const1, 0, add1, 1); |
|
|
|
|
|
builder.AddDataEdge(add1, 0, switch1, 0); |
|
|
|
|
|
builder.AddDataEdge(data2, 0, switch1, 1); |
|
|
|
|
|
builder.AddDataEdge(switch1, 0, conv2d, 0); |
|
|
|
|
|
return builder; |
|
|
|
|
|
} |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -85,4 +115,19 @@ TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) { |
|
|
auto conv2d = graph->FindNode("conv2d"); |
|
|
auto conv2d = graph->FindNode("conv2d"); |
|
|
EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const"); |
|
|
EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_switch_skip) { |
|
|
|
|
|
auto builder = Graph2Builder(); |
|
|
|
|
|
auto graph = builder.GetGraph(); |
|
|
|
|
|
graph->SetSessionID(0); |
|
|
|
|
|
ReplaceWithEmptyConstPass replace_with_empty_const_pass; |
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(graph->GetDirectNodesSize(), 6); |
|
|
|
|
|
// run pass on switch1, graph still has 6 nodes |
|
|
|
|
|
auto switch1 = graph->FindNode("switch1"); |
|
|
|
|
|
EXPECT_NE(switch1, nullptr); |
|
|
|
|
|
Status ret = replace_with_empty_const_pass.Run(switch1); |
|
|
|
|
|
EXPECT_EQ(ret, SUCCESS); |
|
|
|
|
|
EXPECT_EQ(graph->GetDirectNodesSize(), 6); |
|
|
|
|
|
} |
|
|
} // namespace ge |
|
|
} // namespace ge |