diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index a54a15c1..b9a98f62 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -43,7 +43,7 @@ Status InferShapePass::Run(NodePtr &node) { return GE_GRAPH_INFERSHAPE_FAILED; } bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass); if (has_attr) { if (!OptionExists(kOptimizeAfterSubGraph)) { return SUCCESS; @@ -53,7 +53,7 @@ Status InferShapePass::Run(NodePtr &node) { GELOGD("Node %s need repass immediately.", node->GetName().c_str()); } else { // clear attr on while - node->GetOpDesc()->DelAttr("need_infer_again_"); + node->GetOpDesc()->DelAttr("_need_infer_again"); } } return SUCCESS; diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 129c11d8..9bba5d77 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass { // simulate infershape pass if(node->GetType() == WHILE){ bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); if(!OptionExists(kOptimizeAfterSubGraph)){ return SUCCESS; } if(need_repass){ - AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); + AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); AddImmediateRePassNode(node); } else{ // clear attr on while - node->GetOpDesc()->DelAttr("need_infer_again_"); + node->GetOpDesc()->DelAttr("_need_infer_again"); } } return SUCCESS; @@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() { for (int i = 0; i < 2; ++i) { op_desc->AddOutputDesc(tensor_desc->Clone()); } - AttrUtils::SetBool(op_desc,"need_infer_again_", true); + AttrUtils::SetBool(op_desc,"_need_infer_again", true); op_desc->AddSubgraphName(sub_graph->GetName()); op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); auto root_graph = builder.GetGraph();