From: @zhangxiaokun9 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -43,7 +43,7 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
| return GE_GRAPH_INFERSHAPE_FAILED; | return GE_GRAPH_INFERSHAPE_FAILED; | ||||
| } | } | ||||
| bool need_repass = false; | bool need_repass = false; | ||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); | |||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass); | |||||
| if (has_attr) { | if (has_attr) { | ||||
| if (!OptionExists(kOptimizeAfterSubGraph)) { | if (!OptionExists(kOptimizeAfterSubGraph)) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -53,7 +53,7 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
| GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | ||||
| } else { | } else { | ||||
| // clear attr on while | // clear attr on while | ||||
| node->GetOpDesc()->DelAttr("need_infer_again_"); | |||||
| node->GetOpDesc()->DelAttr("_need_infer_again"); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass { | |||||
| // simulate infershape pass | // simulate infershape pass | ||||
| if(node->GetType() == WHILE){ | if(node->GetType() == WHILE){ | ||||
| bool need_repass = false; | bool need_repass = false; | ||||
| AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); | |||||
| AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); | |||||
| if(!OptionExists(kOptimizeAfterSubGraph)){ | if(!OptionExists(kOptimizeAfterSubGraph)){ | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if(need_repass){ | if(need_repass){ | ||||
| AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); | |||||
| AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); | |||||
| AddImmediateRePassNode(node); | AddImmediateRePassNode(node); | ||||
| } | } | ||||
| else{ | else{ | ||||
| // clear attr on while | // clear attr on while | ||||
| node->GetOpDesc()->DelAttr("need_infer_again_"); | |||||
| node->GetOpDesc()->DelAttr("_need_infer_again"); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() { | |||||
| for (int i = 0; i < 2; ++i) { | for (int i = 0; i < 2; ++i) { | ||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | op_desc->AddOutputDesc(tensor_desc->Clone()); | ||||
| } | } | ||||
| AttrUtils::SetBool(op_desc,"need_infer_again_", true); | |||||
| AttrUtils::SetBool(op_desc,"_need_infer_again", true); | |||||
| op_desc->AddSubgraphName(sub_graph->GetName()); | op_desc->AddSubgraphName(sub_graph->GetName()); | ||||
| op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | ||||
| auto root_graph = builder.GetGraph(); | auto root_graph = builder.GetGraph(); | ||||
| @@ -26,12 +26,9 @@ | |||||
| #include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
| #include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
| #include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | using namespace std; | ||||
| using namespace testing; | using namespace testing; | ||||
| using namespace ge; | |||||
| namespace ge { | namespace ge { | ||||
| class UtestGraphInfershapePass : public testing::Test { | class UtestGraphInfershapePass : public testing::Test { | ||||
| protected: | protected: | ||||
| @@ -52,4 +49,17 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | |||||
| InferShapePass infershape_pass; | InferShapePass infershape_pass; | ||||
| EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | ||||
| } | } | ||||
| TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | |||||
| auto graph = std::make_shared<ComputeGraph>("test"); | |||||
| auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp"); | |||||
| auto no_op_node = graph->AddNode(no_op_desc); | |||||
| AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); | |||||
| InferShapePass infershape_pass; | |||||
| infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
| EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||