diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc index 9459c852..3176d1ee 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/ge/graph/passes/replace_with_empty_const_pass.cc @@ -21,7 +21,23 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +namespace { +const std::unordered_set kControlFlowOps = { + ge::SWITCH, + ge::REFSWITCH, + ge::MERGE, + ge::REFMERGE, + ge::ENTER, + ge::REFENTER, + ge::NEXTITERATION, + ge::REFNEXTITERATION, + ge::EXIT, + ge::REFEXIT, + ge::LOOPCOND +}; +} namespace ge { Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGD("ReplaceWithEmptyConstPass in."); @@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); return SUCCESS; } + if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) { + GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str()); + return SUCCESS; + } // Node like no op, it has no output if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); diff --git a/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc index 6711b0d3..d353498c 100644 --- a/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc @@ -57,6 +57,36 @@ ut::GraphBuilder Graph1Builder() { builder.AddDataEdge(cast1, 0, conv2d, 0); return builder; } + +/// data1 const1 +/// \ / +/// add1 +/// | +/// data2 -> switch1 (empty) +/// | +/// conv2d +ut::GraphBuilder Graph2Builder() { + ut::GraphBuilder builder = ut::GraphBuilder("graph2"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto data2 = builder.AddNode("data2", "Data", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto add1 = builder.AddNode("add1", "Add", 2, 1); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 1); + auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0); + + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + GeTensorDesc empty_tensor(GeShape({1, 0, 8, 8}),FORMAT_NCHW); + switch1->GetOpDesc()->UpdateOutputDesc(0, empty_tensor); + + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(const1, 0, add1, 1); + builder.AddDataEdge(add1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, conv2d, 0); + return builder; +} } // namespace @@ -85,4 +115,19 @@ TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) { auto conv2d = graph->FindNode("conv2d"); EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const"); } + +TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_switch_skip) { + auto builder = Graph2Builder(); + auto graph = builder.GetGraph(); + graph->SetSessionID(0); + ReplaceWithEmptyConstPass replace_with_empty_const_pass; + + EXPECT_EQ(graph->GetDirectNodesSize(), 6); + // run pass on switch1, graph still has 6 nodes + auto switch1 = graph->FindNode("switch1"); + EXPECT_NE(switch1, nullptr); + Status ret = replace_with_empty_const_pass.Run(switch1); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(graph->GetDirectNodesSize(), 6); +} } // namespace ge