/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include "gtest/gtest.h" #define protected public #include "graph/passes/base_pass.h" #undef protected #include "framework/common/types.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" #include "graph_builder_utils.h" template class std::unordered_set; namespace ge { class UtestTestPass : public BaseNodePass { public: UtestTestPass() = default; UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {} Status Run(NodePtr &node) override { ++run_times_; iter_nodes_.push_back(node); auto iter = names_to_add_del_.find(node->GetName()); if (iter != names_to_add_del_.end()) { for (const auto &node_name : iter->second) { auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); GraphUtils::IsolateNode(del_node, {0}); AddNodeDeleted(del_node); } } iter = names_to_add_repass_.find(node->GetName()); if (iter != names_to_add_repass_.end()) { auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); for (const auto &node_name : iter->second) { for (auto &node_re_pass : all_nodes) { if (node_re_pass->GetName() == node_name) { AddRePassNode(node_re_pass); break; } } } if (!dead_loop_) { names_to_add_repass_.erase(iter); } } iter = names_to_add_repass_immediate_.find(node->GetName()); if (iter != names_to_add_repass_immediate_.end()) { auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); for (const auto &node_name : iter->second) { for (auto &node_re_pass : all_nodes) { if (node_re_pass->GetName() == node_name) { AddImmediateRePassNode(node_re_pass); break; } } } if (!dead_loop_) { names_to_add_repass_immediate_.erase(iter); } } iter = names_to_add_suspend_.find(node->GetName()); if (iter != names_to_add_suspend_.end()) { auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); for (const auto &node_name : iter->second) { for (auto &node_re_pass : all_nodes) { if (node_re_pass->GetName() == node_name) { AddNodeSuspend(node_re_pass); break; } } } if (!dead_loop_) { names_to_add_suspend_.erase(iter); } } iter = names_to_add_resume_.find(node->GetName()); if (iter != names_to_add_resume_.end()) { auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); for (const auto &node_name : iter->second) { for (auto &node_re_pass : all_nodes) { if (node_re_pass->GetName() == node_name) { AddNodeResume(node_re_pass); break; } } } if (!dead_loop_) { names_to_add_resume_.erase(iter); } } // simulate infershape pass if(node->GetType() == WHILE){ bool need_repass = false; AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); if(!OptionExists(kOptimizeAfterSubGraph)){ return SUCCESS; } if(need_repass){ AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); AddImmediateRePassNode(node); } else{ // clear attr on while node->GetOpDesc()->DelAttr("_need_infer_again"); } } return SUCCESS; } Status OnSuspendNodesLeaked() override { // resume all node remain in suspend_nodes when leaked auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; if (compute_graph == nullptr) { return SUCCESS; } for (const auto &node_name : names_to_add_resume_onleaked_) { auto node_to_resume = compute_graph->FindNode(node_name); AddNodeResume(node_to_resume); } return SUCCESS; } void clear() { iter_nodes_.clear(); } std::vector GetIterNodes() { return iter_nodes_; } void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) { names_to_add_repass_[iter_node].insert(re_pass_node); } void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { names_to_add_del_[iter_node].insert(del_node); } void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { names_to_add_repass_immediate_[iter_node].insert(re_pass_node); } void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { names_to_add_suspend_[iter_node].insert(suspend_node); } void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { names_to_add_resume_[iter_node].insert(resume_node); } void AddResumeNodeNameOnLeaked(const std::string &resume_node) { names_to_add_resume_onleaked_.insert(resume_node); } unsigned int GetRunTimes() { return run_times_; } private: std::vector iter_nodes_; std::map> names_to_add_del_; std::map> names_to_add_repass_; std::map> names_to_add_repass_immediate_; std::map> names_to_add_suspend_; std::map> names_to_add_resume_; std::unordered_set names_to_add_resume_onleaked_; bool dead_loop_; unsigned int run_times_; }; class TestDelPass : public BaseNodePass { public: Status Run(NodePtr &node) override { return SUCCESS; } }; class UTESTGraphPassesBasePass : public testing::Test { protected: UTESTGraphPassesBasePass() { auto p1 = new UtestTestPass; names_to_pass_.push_back(std::make_pair("test1", p1)); } void SetUp() override { for (auto &name_to_pass : names_to_pass_) { dynamic_cast(name_to_pass.second)->clear(); } } ~UTESTGraphPassesBasePass() override { for (auto &name_to_pass : names_to_pass_) { delete name_to_pass.second; } } NamesToPass names_to_pass_; }; /// reshape1 /// | /// add1 /// / \. /// | | /// data1 const1 ComputeGraphPtr BuildGraph1() { auto builder = ut::GraphBuilder("g1"); auto data = builder.AddNode("data1", DATA, 0, 1); auto a1 = builder.AddNode("add1", ADD, 2, 1); auto c1 = builder.AddNode("const1", CONSTANT, 0, 1); auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1); builder.AddDataEdge(data, 0, a1, 0); builder.AddDataEdge(c1, 0, a1, 1); builder.AddDataEdge(a1, 0, r1, 0); return builder.GetGraph(); } /// sum1 /// / \. /// / \. /// / \. /// reshape1 addn1 /// | c | /// add1 <--- shape1 /// / \ | /// | | | /// data1 const1 const2 ComputeGraphPtr BuildGraph2() { auto builder = ut::GraphBuilder("g1"); auto data1 = builder.AddNode("data1", DATA, 0, 1); auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1); auto addn1 = builder.AddNode("addn1", ADDN, 1, 1); auto sum1 = builder.AddNode("sum1", SUM, 2, 1); builder.AddDataEdge(data1, 0, add1, 0); builder.AddDataEdge(const1, 0, add1, 1); builder.AddDataEdge(const2, 0, shape1, 0); builder.AddControlEdge(shape1, add1); builder.AddDataEdge(add1, 0, reshape1, 0); builder.AddDataEdge(shape1, 0, addn1, 0); builder.AddDataEdge(reshape1, 0, sum1, 0); builder.AddDataEdge(addn1, 0, sum1, 1); return builder.GetGraph(); } /// rnextiteration /// | | /// merge /// | /// data1 ComputeGraphPtr BuildGraph3() { auto builder = ut::GraphBuilder("g1"); auto data1 = builder.AddNode("data1", DATA, 0, 1); auto merge1 = builder.AddNode("merge1", MERGE, 2, 1); auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1); builder.AddDataEdge(data1, 0, merge1, 0); builder.AddDataEdge(merge1, 0, next1, 0); builder.AddDataEdge(next1, 0, merge1, 1); builder.AddControlEdge(merge1, next1); builder.AddControlEdge(next1, merge1); return builder.GetGraph(); } /// cast1--shape1 /// / /// data1 /// \ /// transdata1--shape2 ComputeGraphPtr BuildGraph4() { auto builder = ut::GraphBuilder("g1"); auto data1 = builder.AddNode("data1", DATA, 0, 1); auto cast1 = builder.AddNode("cast1", CAST, 1, 1); auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); builder.AddDataEdge(data1, 0, cast1, 0); builder.AddDataEdge(data1, 0, transdata1, 0); builder.AddDataEdge(cast1, 0, shape1, 0); builder.AddDataEdge(transdata1, 0, shape2, 0); return builder.GetGraph(); } void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { std::unordered_set layer_nodes; size_t layer_index = 0; for (const auto &node : pass->GetIterNodes()) { layer_nodes.insert(node->GetName()); EXPECT_LT(layer_index, nodes_layers.size()); if (layer_nodes == nodes_layers[layer_index]) { layer_index++; layer_nodes.clear(); } } EXPECT_EQ(layer_index, nodes_layers.size()); } /// Op1 /// | /// Merge /// / \. /// Op2 Op3 TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { auto builder = ut::GraphBuilder("g1"); auto merge_node = builder.AddNode("Merge", MERGE, 1, 1); auto node1 = builder.AddNode("Op1", RELU, 1, 1); auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); EXPECT_EQ(node1->GetOutDataNodes().size(), 1); TestDelPass del_pass; auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); EXPECT_EQ(ret, FAILED); OpDescPtr op_desc = std::make_shared("merge", MERGE); NodePtr node = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); ret = del_pass.IsolateAndDeleteNode(node, {0, -1}); EXPECT_EQ(ret, FAILED); } /// Op1 /// | /// Merge /// / \. /// Op2 Op3 TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { auto builder = ut::GraphBuilder("g1"); auto merge_node = builder.AddNode("Merge", MERGE, 1, 2); auto node1 = builder.AddNode("Op1", RELU, 1, 1); auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); EXPECT_EQ(node1->GetOutDataNodes().size(), 1); TestDelPass del_pass; auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); EXPECT_EQ(ret, SUCCESS); } TEST_F(UTESTGraphPassesBasePass, data_graph) { auto graph = BuildGraph1(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); auto *pass = dynamic_cast(names_to_pass_[0].second); EXPECT_EQ(pass->GetIterNodes().size(), 4); std::vector> layers; layers.push_back({"data1", "const1"}); layers.push_back({"add1"}); layers.push_back({"reshape1"}); CheckIterOrder(pass, layers); } TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); auto *pass = dynamic_cast(names_to_pass_[0].second); EXPECT_EQ(pass->GetIterNodes().size(), 8); EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1"); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1", "reshape1"}); layers.push_back({"sum1"}); CheckIterOrder(pass, layers); } TEST_F(UTESTGraphPassesBasePass, re_pass_after) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddRePassNodeName("add1", "sum1"); test_pass.AddRePassNodeName("shape1", "sum1"); test_pass.AddRePassNodeName("shape1", "add1"); test_pass.AddRePassNodeName("data1", "add1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 8); } TEST_F(UTESTGraphPassesBasePass, re_pass_before) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddRePassNodeName("add1", "data1"); auto graph = BuildGraph1(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 5); EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1"); } TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddRePassNodeName("add1", "data1"); test_pass.AddRePassNodeName("add1", "const1"); test_pass.AddRePassNodeName("reshape1", "data1"); auto graph = BuildGraph1(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 6); EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); } TEST_F(UTESTGraphPassesBasePass, del_after) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddDelNodeName("add1", "sum1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 7); } TEST_F(UTESTGraphPassesBasePass, del_after_multiple) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddDelNodeName("add1", "sum1"); test_pass.AddDelNodeName("add1", "reshape1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 6); } TEST_F(UTESTGraphPassesBasePass, del_after_break_link) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddDelNodeName("shape1", "add1"); test_pass.AddDelNodeName("shape1", "addn1"); test_pass.AddRePassNodeName("shape1", "shape1"); test_pass.AddRePassNodeName("shape1", "reshape1"); test_pass.AddRePassNodeName("shape1", "sum1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 7); } TEST_F(UTESTGraphPassesBasePass, del_self_and_after) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddDelNodeName("shape1", "add1"); test_pass.AddDelNodeName("shape1", "addn1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 6); } TEST_F(UTESTGraphPassesBasePass, del_before) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddDelNodeName("reshape1", "add1"); test_pass.AddDelNodeName("sum1", "addn1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 8); } TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddRePassNodeName("add1", "sum1"); test_pass.AddDelNodeName("reshape1", "sum1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 7); } /* TEST_F(UTESTGraphPassesBasePass, dead_loop) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(true); names_to_pass.push_back(std::make_pair("test", &test_pass)); test_pass.AddRePassNodeName("add1", "sum1"); test_pass.AddRePassNodeName("sum1", "add1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetRunTimes(), 1007); } */ TEST_F(UTESTGraphPassesBasePass, while_loop) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(true); names_to_pass.push_back(std::make_pair("test", &test_pass)); auto graph = BuildGraph3(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); } /// data1 const /// \ / /// while /// / \. /// | | /// cast1 cast2 ComputeGraphPtr BuildWhileGraph1() { // build sub graph auto builder_sub = ut::GraphBuilder("sub"); auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); auto add = builder_sub.AddNode("add", ADD, 2, 1); builder_sub.AddDataEdge(data_1, 0, add, 0); builder_sub.AddDataEdge(data_2, 0, add, 1); auto sub_graph = builder_sub.GetGraph(); sub_graph->SetName("while_sub"); // build root graph auto builder = ut::GraphBuilder("g1"); auto data = builder.AddNode("data1", DATA, 0, 1); auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); auto c1 = builder.AddNode("cast1", CAST, 1, 1); auto c2 = builder.AddNode("cast2", CAST, 1, 1); // add while op auto tensor_desc = std::make_shared(); tensor_desc->SetShape(GeShape({1,1,1,1})); tensor_desc->SetFormat(FORMAT_ND); tensor_desc->SetDataType(DT_INT32); auto op_desc = std::make_shared("while", WHILE); for (int i = 0; i < 2; ++i) { op_desc->AddInputDesc(tensor_desc->Clone()); } for (int i = 0; i < 2; ++i) { op_desc->AddOutputDesc(tensor_desc->Clone()); } AttrUtils::SetBool(op_desc,"_need_infer_again", true); op_desc->AddSubgraphName(sub_graph->GetName()); op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); auto root_graph = builder.GetGraph(); auto while_op = root_graph->AddNode(op_desc); builder.AddDataEdge(data, 0, while_op, 0); builder.AddDataEdge(const_op, 0, while_op, 1); builder.AddDataEdge(while_op, 0, c1, 0); builder.AddDataEdge(while_op, 1, c2, 0); sub_graph->SetParentGraph(root_graph); sub_graph->SetParentNode(while_op); root_graph->AddSubgraph(sub_graph); return root_graph; } TEST_F(UTESTGraphPassesBasePass, while_infershape) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); auto graph = BuildWhileGraph1(); auto ge_pass = GEPass(graph); auto while_node = graph->FindNode("while"); EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); } TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // repass pre_node immediately test_pass->AddRePassImmediateNodeName("reshape1", "add1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "add1", "sum1"}); CheckIterOrder(test_pass, layers); } TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // repass cur_node immediately test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 9); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1"}); layers.push_back({"reshape1", "sum1"}); CheckIterOrder(test_pass, layers); } TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // repass next_node immediately test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); // repass node after next_node immediately test_pass->AddRePassImmediateNodeName("add1", "sum1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 8); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1"}); CheckIterOrder(test_pass, layers); } /** * A->B->C * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order * when C resuem A, A will pass again. */ TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // add1->reshape1->sum1 test_pass->AddSuspendNodeName("reshape1", "add1"); test_pass->AddResumeNodeName("sum1", "add1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 9); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1"}); layers.push_back({"add1"}); CheckIterOrder(test_pass, layers); } /** * A->B->C * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order * when B resuem A, A will pass again. */ TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // add1->reshape1->sum1 test_pass->AddSuspendNodeName("reshape1", "add1"); test_pass->AddResumeNodeName("reshape1", "add1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 9); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1", "add1"}); CheckIterOrder(test_pass, layers); } /** * A->B->C * if node B resume C(which is not suspended), it is a useless operation, C will not pass. */ TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // add1->reshape1->sum1 test_pass->AddResumeNodeName("reshape1", "sum1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 8); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1"}); CheckIterOrder(test_pass, layers); } /** * A->B->C * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order * because nobody resume it ,which means A is a leaked node, so return fail */ TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); // suspend pre_node immediately test_pass.AddSuspendNodeName("reshape1", "add1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); } /** * A->B->C * if node B suspend its pre_node A, it is a useless operation, * so iter_order should follow normal order * resume A on leaked, which means A will pass again */ TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("reshape1", "add1"); test_pass->AddResumeNodeNameOnLeaked("add1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1"}); layers.push_back({"add1"}); CheckIterOrder(test_pass, layers); } /// cast1--shape1 /// / /// data1 /// \ /// transdata1--shape2 /** * suspend cur node * cast1 suspend itself, shape2 resume cast1 * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 */ TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { auto graph = BuildGraph4(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("cast1", "cast1"); test_pass->AddResumeNodeName("shape2", "cast1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 6); std::vector> layers; layers.push_back({"data1"}); layers.push_back({"cast1","transdata1"}); layers.push_back({"shape2"}); layers.push_back({"cast1", "shape1"}); CheckIterOrder(test_pass, layers); } /** * suspend cur node * cast1 suspend itself, then resume cast1 * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. */ TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { auto graph = BuildGraph4(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("cast1", "cast1"); test_pass->AddResumeNodeName("cast1", "cast1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 6); std::vector> layers; layers.push_back({"data1"}); layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); CheckIterOrder(test_pass, layers); } /** * suspend cur node * cast1 suspend itself, then resume cast1 on leaked * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. */ TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { auto graph = BuildGraph4(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("cast1", "cast1"); test_pass->AddResumeNodeNameOnLeaked("cast1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 6); std::vector> layers; layers.push_back({"data1"}); layers.push_back({"cast1","transdata1", "shape2"}); layers.push_back({"cast1","shape1"}); CheckIterOrder(test_pass, layers); } /** * suspend next node * data1 suspend cast1, then resume cast1 on leaked * iter order follows : data1; transdata1, shape2; cast1, shape1. */ TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { auto graph = BuildGraph4(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("data1", "cast1"); test_pass->AddResumeNodeNameOnLeaked("cast1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); EXPECT_EQ(test_pass->GetIterNodes().size(), 5); std::vector> layers; layers.push_back({"data1"}); layers.push_back({"transdata1", "shape2"}); layers.push_back({"cast1","shape1"}); CheckIterOrder(test_pass, layers); } /** * suspend next node * data1 suspend cast1, nobody resume it * iter order follows : data1; transdata1, shape2; * run ret is failed ,because node leaked */ TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { auto graph = BuildGraph4(); auto ge_pass = GEPass(graph); auto *test_pass = dynamic_cast(names_to_pass_[0].second); // suspend pre_node immediately test_pass->AddSuspendNodeName("data1", "cast1"); EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); EXPECT_EQ(test_pass->GetIterNodes().size(), 3); } /* TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(); names_to_pass.push_back(std::make_pair("test", &test_pass)); // repass next_node immediately test_pass.AddRePassNodeName("reshape1", "sum1"); // repass node after next_node immediately test_pass.AddRePassNodeName("add1", "sum1"); auto graph = BuildGraph2(); auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo std::vector> layers; layers.push_back({"data1", "const1", "const2"}); layers.push_back({"shape1"}); layers.push_back({"add1", "addn1"}); layers.push_back({"reshape1", "sum1"}); CheckIterOrder(&test_pass, layers); }*/ } // namespace ge