/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "graph/passes/infer_base_pass.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" #include "graph_builder_utils.h" using namespace std; using namespace testing; namespace ge { class ChildPassBuilder; static const char *kInferTimes = "infer_times"; class InferBasePassStub : public InferBasePass { public: friend class ChildPassBuilder; graphStatus Infer(NodePtr &node) override{ call_infer_times++; for (size_t i = 0; i < node->GetOutDataNodesSize(); ++i) { auto output_td = node->GetOpDesc()->MutableOutputDesc(i); int times = 0; AttrUtils::GetInt(output_td, kInferTimes, times); AttrUtils::SetInt(output_td, kInferTimes, times + 1); } return infer_result_; }; int32_t call_infer_times = 0; int32_t call_update_tensor_desc_times = 0; int32_t call_update_from_subgraph_times = 0; int32_t call_update_from_subgraph_multi_dims_times = 0; std::vector> update_td_pairs; private: bool NeedInfer(const NodePtr &node) const override { return need_infer_; }; std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; }; graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override { call_update_tensor_desc_times++; changed = td_changed_; int times = 0; if (AttrUtils::GetInt(src, kInferTimes, times)) { AttrUtils::SetInt(dst, kInferTimes, times); } update_td_pairs.emplace_back(src, dst); return GRAPH_SUCCESS; }; graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override { call_update_from_subgraph_times++; return GRAPH_SUCCESS; }; graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) override { call_update_from_subgraph_multi_dims_times++; return GRAPH_SUCCESS; }; bool td_changed_; bool need_infer_; graphStatus infer_result_; }; class ChildPassBuilder { public: ChildPassBuilder &SetNeedInferFlag(bool flag) { need_infer_ = flag; return *this; } ChildPassBuilder &SetInferResult(graphStatus ret) { infer_result_ = ret; return *this; } ChildPassBuilder &SetTdChangedFlag(bool changed_flag) { td_changed_ = changed_flag; return *this; } InferBasePassStub Build() { InferBasePassStub ib; ib.td_changed_ = td_changed_; ib.need_infer_ = need_infer_; ib.infer_result_ = infer_result_; return ib; } private: bool td_changed_ = false; bool need_infer_ = true; graphStatus infer_result_ = GRAPH_SUCCESS; }; class UtestGraphInferBasePassStub : public testing::Test { protected: void SetUp() {} void TearDown() {} }; /* * data1 data2 * \ / * sub1 * | * netoutput */ ut::GraphBuilder TestSubgraphBuilder() { ut::GraphBuilder builder = ut::GraphBuilder("branch_graph"); std::vector shape1 = {1,1}; auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); auto data1_desc = data1->GetOpDesc(); EXPECT_NE(data1_desc, nullptr); AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); std::vector shape2 = {2,2}; auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); auto data2_desc = data2->GetOpDesc(); EXPECT_NE(data2_desc, nullptr); AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); auto sub1 = builder.AddNode("Sub", "Sub", 2, 1); std::vector shape7 = {8,8}; auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); EXPECT_NE(input0_desc, nullptr); AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); builder.AddDataEdge(data1, 0, sub1, 0); builder.AddDataEdge(data2, 0, sub1, 1); builder.AddDataEdge(sub1, 0, netoutput, 0); return builder; } /* * data1 data2 * \ / * case1 * | * netoutput */ ut::GraphBuilder RootGraphBuilder() { ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); auto data1 = builder.AddNode("data1", "Data", 0, 1); auto data2 = builder.AddNode("data2", "Data", 0, 1); auto case1 = builder.AddNode("case1", CASE, 2, 1); auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); builder.AddDataEdge(data1, 0, case1, 0); builder.AddDataEdge(data2, 0, case1, 1); builder.AddDataEdge(case1, 0, netoutput, 0); auto parent_graph = builder.GetGraph(); auto subgraph_builder = TestSubgraphBuilder(); auto subgraph = subgraph_builder.GetGraph(); case1->GetOpDesc()->AddSubgraphName(subgraph->GetName()); case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); subgraph->SetParentNode(case1); subgraph->SetParentGraph(parent_graph); EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS); return builder; } /* * data1 data2 * \ / * add1 * | * netoutput */ ut::GraphBuilder NoSubgraphBuilder() { ut::GraphBuilder builder = ut::GraphBuilder("no_subgraph"); auto data1 = builder.AddNode("data1", "Data", 0, 1); auto data2 = builder.AddNode("data2", "Data", 0, 1); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); builder.AddDataEdge(data1, 0, add1, 0); builder.AddDataEdge(data2, 0, add1, 1); builder.AddDataEdge(add1, 0, netoutput, 0); return builder; } TEST_F(UtestGraphInferBasePassStub, CallInfer_WhenNeedInferReturnTrue) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); EXPECT_NE(add_node, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.Build(); // NeedInfer return true EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_infer_times, 1); int times = -1; EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); EXPECT_EQ(times, 1); } TEST_F(UtestGraphInferBasePassStub, NotCallInfer_WhenNeedInferReturnFalse) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); EXPECT_NE(add_node, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.SetNeedInferFlag(false).Build(); // NeedInfer return false EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_infer_times, 0); int times = -1; EXPECT_FALSE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); } TEST_F(UtestGraphInferBasePassStub, NotAddCurNodeRepass_CallUpdatePeerNode_WhenInferReturnSuccess) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); auto netoutput = test_graph->FindNode("netoutput"); EXPECT_NE(add_node, nullptr); EXPECT_NE(netoutput, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.Build(); EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_infer_times, 1); EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); std::vector> expected_updated_tensor_desc_pairs = { {add_node->GetOpDesc()->MutableOutputDesc(0), netoutput->GetOpDesc()->MutableInputDesc(0)}}; EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({})); } TEST_F(UtestGraphInferBasePassStub, AddCurNodeRepass_NotCallUpdatePeerNode_WhenInferReturnNeedRepass) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); EXPECT_NE(add_node, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); // do re_pass EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_infer_times, 1); EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 0); // EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({add_node})); } TEST_F(UtestGraphInferBasePassStub, NotAddPeerNodeRepass_AfterUpdatePeerNode_WhenUnchanged) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); auto netoutput = test_graph->FindNode("netoutput"); EXPECT_NE(add_node, nullptr); EXPECT_NE(netoutput, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.Build(); EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({})); int times = -1; EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); EXPECT_EQ(times, 1); times = -1; EXPECT_TRUE(AttrUtils::GetInt(netoutput->GetOpDesc()->GetInputDescPtr(0), kInferTimes, times)); EXPECT_EQ(times, 1); } TEST_F(UtestGraphInferBasePassStub, AddPeerNodeRepass_AfterUpdatePeerNode_WhenChanged) { auto builder = NoSubgraphBuilder(); auto test_graph = builder.GetGraph(); auto add_node = test_graph->FindNode("add1"); auto netoutput = test_graph->FindNode("netoutput"); EXPECT_NE(add_node, nullptr); EXPECT_NE(netoutput, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.SetTdChangedFlag(true).Build(); EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); // EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({netoutput})); } TEST_F(UtestGraphInferBasePassStub, TestUpdateSubgraphData_WhenBeforeSubgraph) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); auto case_node = parent_graph->FindNode("case1"); auto data1 = subgraphs[0]->FindNode("data1_1"); auto data2 = subgraphs[0]->FindNode("data2_1"); EXPECT_NE(case_node, nullptr); EXPECT_NE(data1, nullptr); EXPECT_NE(data2, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); // when GRAPH_NODE_NEED_REPASS, not update peer node, only update two data, update input and output, 2*2 EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 4); std::vector> expected_updated_tensor_desc_pairs = { {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableInputDesc(0)}, {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableOutputDesc(0)}, {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableInputDesc(0)}, {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableOutputDesc(0)}, }; EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); } TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutput_WhenAfterSubgraph) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); auto case_node = parent_graph->FindNode("case1"); EXPECT_NE(case_node, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.Build(); stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 1); EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 0); } TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutputForMultiDims_WhenAfterSubgraph) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); auto case_node = parent_graph->FindNode("case1"); auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); EXPECT_EQ(set_ret, true); EXPECT_NE(case_node, nullptr); ChildPassBuilder pass_builder; auto stub_base_pass = pass_builder.Build(); stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 0); EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 1); } } // namespace ge