|
@@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass { |
|
|
// simulate infershape pass |
|
|
// simulate infershape pass |
|
|
if(node->GetType() == WHILE){ |
|
|
if(node->GetType() == WHILE){ |
|
|
bool need_repass = false; |
|
|
bool need_repass = false; |
|
|
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); |
|
|
|
|
|
|
|
|
AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); |
|
|
if(!OptionExists(kOptimizeAfterSubGraph)){ |
|
|
if(!OptionExists(kOptimizeAfterSubGraph)){ |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
if(need_repass){ |
|
|
if(need_repass){ |
|
|
AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); |
|
|
|
|
|
|
|
|
AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); |
|
|
AddImmediateRePassNode(node); |
|
|
AddImmediateRePassNode(node); |
|
|
} |
|
|
} |
|
|
else{ |
|
|
else{ |
|
|
// clear attr on while |
|
|
// clear attr on while |
|
|
node->GetOpDesc()->DelAttr("need_infer_again_"); |
|
|
|
|
|
|
|
|
node->GetOpDesc()->DelAttr("_need_infer_again"); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
@@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() { |
|
|
for (int i = 0; i < 2; ++i) { |
|
|
for (int i = 0; i < 2; ++i) { |
|
|
op_desc->AddOutputDesc(tensor_desc->Clone()); |
|
|
op_desc->AddOutputDesc(tensor_desc->Clone()); |
|
|
} |
|
|
} |
|
|
AttrUtils::SetBool(op_desc,"need_infer_again_", true); |
|
|
|
|
|
|
|
|
AttrUtils::SetBool(op_desc,"_need_infer_again", true); |
|
|
op_desc->AddSubgraphName(sub_graph->GetName()); |
|
|
op_desc->AddSubgraphName(sub_graph->GetName()); |
|
|
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); |
|
|
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); |
|
|
auto root_graph = builder.GetGraph(); |
|
|
auto root_graph = builder.GetGraph(); |
|
|