| @@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass { | |||||
| names_to_add_repass_.erase(iter); | names_to_add_repass_.erase(iter); | ||||
| } | } | ||||
| } | } | ||||
| // simulate infershape pass | |||||
| if(node->GetType() == WHILE){ | |||||
| bool need_repass = false; | |||||
| AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); | |||||
| if(!OptionExists(kOptimizeAfterSubGraph)){ | |||||
| return SUCCESS; | |||||
| } | |||||
| if(need_repass){ | |||||
| AddImmediateRePassNode(node); | |||||
| } | |||||
| else{ | |||||
| // clear attr on while | |||||
| node->GetOpDesc()->DelAttr("need_infer_again_"); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void clear() { iter_nodes_.clear(); } | void clear() { iter_nodes_.clear(); } | ||||
| @@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) { | |||||
| EXPECT_EQ(test_pass.GetRunTimes(), 1007); | EXPECT_EQ(test_pass.GetRunTimes(), 1007); | ||||
| } | } | ||||
| */ | */ | ||||
| TEST_F(UTESTGraphPassesBasePass, while_loop) { | TEST_F(UTESTGraphPassesBasePass, while_loop) { | ||||
| NamesToPass names_to_pass; | NamesToPass names_to_pass; | ||||
| auto test_pass = UtestTestPass(true); | auto test_pass = UtestTestPass(true); | ||||
| @@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||||
| auto ge_pass = GEPass(graph); | auto ge_pass = GEPass(graph); | ||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | ||||
| } | } | ||||
| /// data1 const | |||||
| /// \ / | |||||
| /// while | |||||
| /// / \ | |||||
| /// | | | |||||
| /// cast1 cast2 | |||||
| ComputeGraphPtr BuildWhileGraph1() { | |||||
| // build sub graph | |||||
| auto builder_sub = ut::GraphBuilder("sub"); | |||||
| auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); | |||||
| auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); | |||||
| auto add = builder_sub.AddNode("add", ADD, 2, 1); | |||||
| builder_sub.AddDataEdge(data_1, 0, add, 0); | |||||
| builder_sub.AddDataEdge(data_2, 0, add, 1); | |||||
| auto sub_graph = builder_sub.GetGraph(); | |||||
| sub_graph->SetName("while_sub"); | |||||
| // build root graph | |||||
| auto builder = ut::GraphBuilder("g1"); | |||||
| auto data = builder.AddNode("data1", DATA, 0, 1); | |||||
| auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); | |||||
| auto c1 = builder.AddNode("cast1", CAST, 1, 1); | |||||
| auto c2 = builder.AddNode("cast2", CAST, 1, 1); | |||||
| // add while op | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape({1,1,1,1})); | |||||
| tensor_desc->SetFormat(FORMAT_ND); | |||||
| tensor_desc->SetDataType(DT_INT32); | |||||
| auto op_desc = std::make_shared<OpDesc>("while", WHILE); | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| AttrUtils::SetBool(op_desc,"need_infer_again_", true); | |||||
| op_desc->AddSubgraphName(sub_graph->GetName()); | |||||
| op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | |||||
| auto root_graph = builder.GetGraph(); | |||||
| auto while_op = root_graph->AddNode(op_desc); | |||||
| builder.AddDataEdge(data, 0, while_op, 0); | |||||
| builder.AddDataEdge(const_op, 0, while_op, 1); | |||||
| builder.AddDataEdge(while_op, 0, c1, 0); | |||||
| builder.AddDataEdge(while_op, 1, c2, 0); | |||||
| sub_graph->SetParentGraph(root_graph); | |||||
| sub_graph->SetParentNode(while_op); | |||||
| root_graph->AddSubgraph(sub_graph); | |||||
| return root_graph; | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, while_infershape) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| auto graph = BuildWhileGraph1(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto while_node = graph->FindNode("while"); | |||||
| EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||