|
|
@@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass { |
|
|
|
names_to_add_repass_.erase(iter); |
|
|
|
} |
|
|
|
} |
|
|
|
// simulate infershape pass |
|
|
|
if(node->GetType() == WHILE){ |
|
|
|
bool need_repass = false; |
|
|
|
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); |
|
|
|
if(!OptionExists(kOptimizeAfterSubGraph)){ |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
if(need_repass){ |
|
|
|
AddImmediateRePassNode(node); |
|
|
|
} |
|
|
|
else{ |
|
|
|
// clear attr on while |
|
|
|
node->GetOpDesc()->DelAttr("need_infer_again_"); |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
void clear() { iter_nodes_.clear(); } |
|
|
@@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) { |
|
|
|
EXPECT_EQ(test_pass.GetRunTimes(), 1007); |
|
|
|
} |
|
|
|
*/ |
|
|
|
|
|
|
|
TEST_F(UTESTGraphPassesBasePass, while_loop) { |
|
|
|
NamesToPass names_to_pass; |
|
|
|
auto test_pass = UtestTestPass(true); |
|
|
@@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { |
|
|
|
auto ge_pass = GEPass(graph); |
|
|
|
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); |
|
|
|
} |
|
|
|
|
|
|
|
/// data1 const |
|
|
|
/// \ / |
|
|
|
/// while |
|
|
|
/// / \ |
|
|
|
/// | | |
|
|
|
/// cast1 cast2 |
|
|
|
ComputeGraphPtr BuildWhileGraph1() { |
|
|
|
// build sub graph |
|
|
|
auto builder_sub = ut::GraphBuilder("sub"); |
|
|
|
auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); |
|
|
|
auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); |
|
|
|
auto add = builder_sub.AddNode("add", ADD, 2, 1); |
|
|
|
|
|
|
|
builder_sub.AddDataEdge(data_1, 0, add, 0); |
|
|
|
builder_sub.AddDataEdge(data_2, 0, add, 1); |
|
|
|
auto sub_graph = builder_sub.GetGraph(); |
|
|
|
sub_graph->SetName("while_sub"); |
|
|
|
// build root graph |
|
|
|
auto builder = ut::GraphBuilder("g1"); |
|
|
|
auto data = builder.AddNode("data1", DATA, 0, 1); |
|
|
|
auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); |
|
|
|
auto c1 = builder.AddNode("cast1", CAST, 1, 1); |
|
|
|
auto c2 = builder.AddNode("cast2", CAST, 1, 1); |
|
|
|
// add while op |
|
|
|
auto tensor_desc = std::make_shared<GeTensorDesc>(); |
|
|
|
tensor_desc->SetShape(GeShape({1,1,1,1})); |
|
|
|
tensor_desc->SetFormat(FORMAT_ND); |
|
|
|
tensor_desc->SetDataType(DT_INT32); |
|
|
|
|
|
|
|
auto op_desc = std::make_shared<OpDesc>("while", WHILE); |
|
|
|
for (int i = 0; i < 2; ++i) { |
|
|
|
op_desc->AddInputDesc(tensor_desc->Clone()); |
|
|
|
} |
|
|
|
for (int i = 0; i < 2; ++i) { |
|
|
|
op_desc->AddOutputDesc(tensor_desc->Clone()); |
|
|
|
} |
|
|
|
AttrUtils::SetBool(op_desc,"need_infer_again_", true); |
|
|
|
op_desc->AddSubgraphName(sub_graph->GetName()); |
|
|
|
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); |
|
|
|
auto root_graph = builder.GetGraph(); |
|
|
|
auto while_op = root_graph->AddNode(op_desc); |
|
|
|
|
|
|
|
builder.AddDataEdge(data, 0, while_op, 0); |
|
|
|
builder.AddDataEdge(const_op, 0, while_op, 1); |
|
|
|
builder.AddDataEdge(while_op, 0, c1, 0); |
|
|
|
builder.AddDataEdge(while_op, 1, c2, 0); |
|
|
|
sub_graph->SetParentGraph(root_graph); |
|
|
|
sub_graph->SetParentNode(while_op); |
|
|
|
root_graph->AddSubgraph(sub_graph); |
|
|
|
return root_graph; |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UTESTGraphPassesBasePass, while_infershape) { |
|
|
|
NamesToPass names_to_pass; |
|
|
|
auto test_pass = UtestTestPass(); |
|
|
|
names_to_pass.push_back(std::make_pair("test", &test_pass)); |
|
|
|
|
|
|
|
auto graph = BuildWhileGraph1(); |
|
|
|
auto ge_pass = GEPass(graph); |
|
|
|
auto while_node = graph->FindNode("while"); |
|
|
|
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); |
|
|
|
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace ge |