diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index a54a15c1..b9a98f62 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -43,7 +43,7 @@ Status InferShapePass::Run(NodePtr &node) { return GE_GRAPH_INFERSHAPE_FAILED; } bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass); if (has_attr) { if (!OptionExists(kOptimizeAfterSubGraph)) { return SUCCESS; @@ -53,7 +53,7 @@ Status InferShapePass::Run(NodePtr &node) { GELOGD("Node %s need repass immediately.", node->GetName().c_str()); } else { // clear attr on while - node->GetOpDesc()->DelAttr("need_infer_again_"); + node->GetOpDesc()->DelAttr("_need_infer_again"); } } return SUCCESS; diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 129c11d8..9bba5d77 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass { // simulate infershape pass if(node->GetType() == WHILE){ bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); if(!OptionExists(kOptimizeAfterSubGraph)){ return SUCCESS; } if(need_repass){ - AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); + AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); AddImmediateRePassNode(node); } else{ // clear attr on while - node->GetOpDesc()->DelAttr("need_infer_again_"); + node->GetOpDesc()->DelAttr("_need_infer_again"); } } return SUCCESS; @@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() { for (int i = 0; i < 2; ++i) { op_desc->AddOutputDesc(tensor_desc->Clone()); } - AttrUtils::SetBool(op_desc,"need_infer_again_", true); + AttrUtils::SetBool(op_desc,"_need_infer_again", true); op_desc->AddSubgraphName(sub_graph->GetName()); op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); auto root_graph = builder.GetGraph(); diff --git a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc index 8fa5b34e..a7628b2e 100644 --- a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc @@ -26,12 +26,9 @@ #include "graph/operator_factory.h" #include "graph/operator_reg.h" #include "graph_builder_utils.h" -#undef protected -#undef private using namespace std; using namespace testing; -using namespace ge; namespace ge { class UtestGraphInfershapePass : public testing::Test { protected: @@ -52,4 +49,17 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { InferShapePass infershape_pass; EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); } + +TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { + auto graph = std::make_shared("test"); + + auto no_op_desc = std::make_shared("No", "NoOp"); + auto no_op_node = graph->AddNode(no_op_desc); + AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); + + InferShapePass infershape_pass; + infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; + EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); +} + } // namespace ge