From: @zhangxiaokun9 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -43,7 +43,7 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
return GE_GRAPH_INFERSHAPE_FAILED; | return GE_GRAPH_INFERSHAPE_FAILED; | ||||
} | } | ||||
bool need_repass = false; | bool need_repass = false; | ||||
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); | |||||
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass); | |||||
if (has_attr) { | if (has_attr) { | ||||
if (!OptionExists(kOptimizeAfterSubGraph)) { | if (!OptionExists(kOptimizeAfterSubGraph)) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -53,7 +53,7 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | ||||
} else { | } else { | ||||
// clear attr on while | // clear attr on while | ||||
node->GetOpDesc()->DelAttr("need_infer_again_"); | |||||
node->GetOpDesc()->DelAttr("_need_infer_again"); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -70,17 +70,17 @@ class UtestTestPass : public BaseNodePass { | |||||
// simulate infershape pass | // simulate infershape pass | ||||
if(node->GetType() == WHILE){ | if(node->GetType() == WHILE){ | ||||
bool need_repass = false; | bool need_repass = false; | ||||
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); | |||||
AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); | |||||
if(!OptionExists(kOptimizeAfterSubGraph)){ | if(!OptionExists(kOptimizeAfterSubGraph)){ | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if(need_repass){ | if(need_repass){ | ||||
AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); | |||||
AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); | |||||
AddImmediateRePassNode(node); | AddImmediateRePassNode(node); | ||||
} | } | ||||
else{ | else{ | ||||
// clear attr on while | // clear attr on while | ||||
node->GetOpDesc()->DelAttr("need_infer_again_"); | |||||
node->GetOpDesc()->DelAttr("_need_infer_again"); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -492,7 +492,7 @@ ComputeGraphPtr BuildWhileGraph1() { | |||||
for (int i = 0; i < 2; ++i) { | for (int i = 0; i < 2; ++i) { | ||||
op_desc->AddOutputDesc(tensor_desc->Clone()); | op_desc->AddOutputDesc(tensor_desc->Clone()); | ||||
} | } | ||||
AttrUtils::SetBool(op_desc,"need_infer_again_", true); | |||||
AttrUtils::SetBool(op_desc,"_need_infer_again", true); | |||||
op_desc->AddSubgraphName(sub_graph->GetName()); | op_desc->AddSubgraphName(sub_graph->GetName()); | ||||
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | ||||
auto root_graph = builder.GetGraph(); | auto root_graph = builder.GetGraph(); | ||||
@@ -26,12 +26,9 @@ | |||||
#include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
#include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
#include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
#undef protected | |||||
#undef private | |||||
using namespace std; | using namespace std; | ||||
using namespace testing; | using namespace testing; | ||||
using namespace ge; | |||||
namespace ge { | namespace ge { | ||||
class UtestGraphInfershapePass : public testing::Test { | class UtestGraphInfershapePass : public testing::Test { | ||||
protected: | protected: | ||||
@@ -52,4 +49,17 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | |||||
InferShapePass infershape_pass; | InferShapePass infershape_pass; | ||||
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | ||||
} | } | ||||
TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | |||||
auto graph = std::make_shared<ComputeGraph>("test"); | |||||
auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp"); | |||||
auto no_op_node = graph->AddNode(no_op_desc); | |||||
AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); | |||||
InferShapePass infershape_pass; | |||||
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |