Browse Source

Rename need_infer_again_

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
7fbfe1467f
2 changed files with 6 additions and 6 deletions
  1. +2
    -2
      ge/graph/passes/infershape_pass.cc
  2. +4
    -4
      tests/ut/ge/graph/passes/base_pass_unittest.cc

+ 2
- 2
ge/graph/passes/infershape_pass.cc View File

@@ -43,7 +43,7 @@ Status InferShapePass::Run(NodePtr &node) {
return GE_GRAPH_INFERSHAPE_FAILED; return GE_GRAPH_INFERSHAPE_FAILED;
} }
bool need_repass = false; bool need_repass = false;
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass);
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass);
if (has_attr) { if (has_attr) {
if (!OptionExists(kOptimizeAfterSubGraph)) { if (!OptionExists(kOptimizeAfterSubGraph)) {
return SUCCESS; return SUCCESS;
@@ -53,7 +53,7 @@ Status InferShapePass::Run(NodePtr &node) {
GELOGD("Node %s need repass immediately.", node->GetName().c_str()); GELOGD("Node %s need repass immediately.", node->GetName().c_str());
} else { } else {
// clear attr on while // clear attr on while
node->GetOpDesc()->DelAttr("need_infer_again_");
node->GetOpDesc()->DelAttr("_need_infer_again");
} }
} }
return SUCCESS; return SUCCESS;


+ 4
- 4
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass {
// simulate infershape pass // simulate infershape pass
if(node->GetType() == WHILE){ if(node->GetType() == WHILE){
bool need_repass = false; bool need_repass = false;
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass);
AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass);
if(!OptionExists(kOptimizeAfterSubGraph)){ if(!OptionExists(kOptimizeAfterSubGraph)){
return SUCCESS; return SUCCESS;
} }
if(need_repass){ if(need_repass){
AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false);
AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false);
AddImmediateRePassNode(node); AddImmediateRePassNode(node);
} }
else{ else{
// clear attr on while // clear attr on while
node->GetOpDesc()->DelAttr("need_infer_again_");
node->GetOpDesc()->DelAttr("_need_infer_again");
} }
} }
return SUCCESS; return SUCCESS;
@@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() {
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone()); op_desc->AddOutputDesc(tensor_desc->Clone());
} }
AttrUtils::SetBool(op_desc,"need_infer_again_", true);
AttrUtils::SetBool(op_desc,"_need_infer_again", true);
op_desc->AddSubgraphName(sub_graph->GetName()); op_desc->AddSubgraphName(sub_graph->GetName());
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); op_desc->SetSubgraphInstanceName(0,sub_graph->GetName());
auto root_graph = builder.GetGraph(); auto root_graph = builder.GetGraph();


Loading…
Cancel
Save