| @@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass { | |||
| names_to_add_repass_.erase(iter); | |||
| } | |||
| } | |||
| // simulate infershape pass | |||
| if(node->GetType() == WHILE){ | |||
| bool need_repass = false; | |||
| AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); | |||
| if(!OptionExists(kOptimizeAfterSubGraph)){ | |||
| return SUCCESS; | |||
| } | |||
| if(need_repass){ | |||
| AddImmediateRePassNode(node); | |||
| } | |||
| else{ | |||
| // clear attr on while | |||
| node->GetOpDesc()->DelAttr("need_infer_again_"); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void clear() { iter_nodes_.clear(); } | |||
| @@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) { | |||
| EXPECT_EQ(test_pass.GetRunTimes(), 1007); | |||
| } | |||
| */ | |||
| TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||
| NamesToPass names_to_pass; | |||
| auto test_pass = UtestTestPass(true); | |||
| @@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||
| auto ge_pass = GEPass(graph); | |||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||
| } | |||
| /// data1 const | |||
| /// \ / | |||
| /// while | |||
| /// / \ | |||
| /// | | | |||
| /// cast1 cast2 | |||
| ComputeGraphPtr BuildWhileGraph1() { | |||
| // build sub graph | |||
| auto builder_sub = ut::GraphBuilder("sub"); | |||
| auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); | |||
| auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); | |||
| auto add = builder_sub.AddNode("add", ADD, 2, 1); | |||
| builder_sub.AddDataEdge(data_1, 0, add, 0); | |||
| builder_sub.AddDataEdge(data_2, 0, add, 1); | |||
| auto sub_graph = builder_sub.GetGraph(); | |||
| sub_graph->SetName("while_sub"); | |||
| // build root graph | |||
| auto builder = ut::GraphBuilder("g1"); | |||
| auto data = builder.AddNode("data1", DATA, 0, 1); | |||
| auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); | |||
| auto c1 = builder.AddNode("cast1", CAST, 1, 1); | |||
| auto c2 = builder.AddNode("cast2", CAST, 1, 1); | |||
| // add while op | |||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
| tensor_desc->SetShape(GeShape({1,1,1,1})); | |||
| tensor_desc->SetFormat(FORMAT_ND); | |||
| tensor_desc->SetDataType(DT_INT32); | |||
| auto op_desc = std::make_shared<OpDesc>("while", WHILE); | |||
| for (int i = 0; i < 2; ++i) { | |||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||
| } | |||
| for (int i = 0; i < 2; ++i) { | |||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
| } | |||
| AttrUtils::SetBool(op_desc,"need_infer_again_", true); | |||
| op_desc->AddSubgraphName(sub_graph->GetName()); | |||
| op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); | |||
| auto root_graph = builder.GetGraph(); | |||
| auto while_op = root_graph->AddNode(op_desc); | |||
| builder.AddDataEdge(data, 0, while_op, 0); | |||
| builder.AddDataEdge(const_op, 0, while_op, 1); | |||
| builder.AddDataEdge(while_op, 0, c1, 0); | |||
| builder.AddDataEdge(while_op, 1, c2, 0); | |||
| sub_graph->SetParentGraph(root_graph); | |||
| sub_graph->SetParentNode(while_op); | |||
| root_graph->AddSubgraph(sub_graph); | |||
| return root_graph; | |||
| } | |||
| TEST_F(UTESTGraphPassesBasePass, while_infershape) { | |||
| NamesToPass names_to_pass; | |||
| auto test_pass = UtestTestPass(); | |||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||
| auto graph = BuildWhileGraph1(); | |||
| auto ge_pass = GEPass(graph); | |||
| auto while_node = graph->FindNode("while"); | |||
| EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); | |||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||
| } | |||
| } // namespace ge | |||