/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #define protected public #define private public #include "graph/passes/infer_base_pass.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" #include "graph_builder_utils.h" #include "inc/external/graph/operator_reg.h" #include "inc/external/graph/operator.h" #include "inc/external/graph/operator_factory.h" #include "inc/graph/operator_factory_impl.h" using namespace std; using namespace testing; namespace ge { class InferBasePassStub : public InferBasePass { public: graphStatus Infer(NodePtr &node) override{ auto op_desc = node->GetOpDesc(); auto input_desc = op_desc->MutableInputDesc(0); auto output_desc = op_desc->MutableOutputDesc(0); if (input_desc->GetShape().GetDims() != output_desc->GetShape().GetDims()) { input_desc->SetShape(output_desc->GetShape()); return GRAPH_NODE_NEED_REPASS; } return GRAPH_SUCCESS; }; private: std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; }; graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override { if (src->GetShape().GetDims() != dst->GetShape().GetDims()) { changed = true; } else { changed = false; } dst->SetShape(src->GetShape()); return GRAPH_SUCCESS; }; graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override { dst->SetShape(src[0]->GetShape()); return GRAPH_SUCCESS; }; graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) override { dst->SetShape(src[0]->GetShape()); return GRAPH_SUCCESS; }; }; class UtestGraphInferBasePassStub : public testing::Test { protected: void SetUp() {} void TearDown() {} }; /* * data1 data2 * \ / * merge * | * netoutput */ ut::GraphBuilder TestSubgraphBuilder() { ut::GraphBuilder builder = ut::GraphBuilder("branch_graph"); std::vector shape1 = {1,1}; auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); auto data1_desc = data1->GetOpDesc(); EXPECT_NE(data1_desc, nullptr); AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); std::vector shape2 = {2,2}; auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); auto data2_desc = data2->GetOpDesc(); EXPECT_NE(data2_desc, nullptr); AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); auto merge = builder.AddNode("merge", "Merge", 2, 1); std::vector shape7 = {8,8}; auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); EXPECT_NE(input0_desc, nullptr); AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); builder.AddDataEdge(data1, 0, merge, 0); builder.AddDataEdge(data2, 0, merge, 1); builder.AddDataEdge(merge, 0, netoutput, 0); return builder; } /* * data1 data2 * \ / * case1 * | * netoutput */ ut::GraphBuilder RootGraphBuilder() { ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); auto data1 = builder.AddNode("data1", "Data", 0, 1); auto data2 = builder.AddNode("data2", "Data", 0, 1); auto case1 = builder.AddNode("case1", CASE, 2, 1); auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); builder.AddDataEdge(data1, 0, case1, 0); builder.AddDataEdge(data2, 0, case1, 1); builder.AddDataEdge(case1, 0, netoutput, 0); auto parent_graph = builder.GetGraph(); auto subgraph_builder = TestSubgraphBuilder(); auto subgraph = subgraph_builder.GetGraph(); case1->GetOpDesc()->AddSubgraphName(subgraph->GetName()); case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); subgraph->SetParentNode(case1); subgraph->SetParentGraph(parent_graph); EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS); return builder; } TEST_F(UtestGraphInferBasePassStub, infer_base_before_subgraph) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); // check base pass run auto case_node = parent_graph->FindNode("case1"); EXPECT_NE(case_node, nullptr); InferBasePassStub base_pass; EXPECT_EQ(base_pass.Run(case_node), SUCCESS); // check subgraph data update auto data_node = subgraphs[0]->FindNode("data1_1"); auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0); auto data_out_0_dims = data_out_0_desc->GetShape().GetDims(); EXPECT_EQ(data_out_0_dims.size(), 4); std::vector data_target_dims = {1, 1, 224, 224}; EXPECT_EQ(data_out_0_dims, data_target_dims); // check peer input update auto netoutput_node = parent_graph->FindNode("netoutput"); EXPECT_NE(netoutput_node, nullptr); auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0); auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims(); EXPECT_EQ(netoutput_in_0_dims.size(), 4); std::vector target_dims = {1, 1, 224, 224}; EXPECT_EQ(netoutput_in_0_dims, target_dims); } TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_need_repass) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); // check base pass run auto case_node = parent_graph->FindNode("case1"); EXPECT_NE(case_node, nullptr); InferBasePassStub base_pass; base_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(base_pass.Run(case_node), SUCCESS); // check subgraph data update auto data_node = subgraphs[0]->FindNode("data1_1"); auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0); auto data_out_0_dims = data_out_0_desc->GetShape().GetDims(); EXPECT_EQ(data_out_0_dims.size(), 2); std::vector data_target_dims = {1, 1}; EXPECT_EQ(data_out_0_dims, data_target_dims); // check peer input update auto netoutput_node = parent_graph->FindNode("netoutput"); EXPECT_NE(netoutput_node, nullptr); auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0); auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims(); EXPECT_EQ(netoutput_in_0_dims.size(), 4); std::vector target_dims = {1, 1, 224, 224}; EXPECT_EQ(netoutput_in_0_dims, target_dims); } TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_no_repass) { auto builder = RootGraphBuilder(); auto parent_graph = builder.GetGraph(); auto subgraphs = parent_graph->GetAllSubgraphs(); EXPECT_EQ(subgraphs.size(), 1); // check base pass run auto case_node = parent_graph->FindNode("case1"); EXPECT_NE(case_node, nullptr); // update case in shape, do not re_pass auto case_in_shape_no_repass = GeShape({8,8}); case_node->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in_shape_no_repass); InferBasePassStub base_pass; base_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(base_pass.Run(case_node), SUCCESS); // check subgraph data update auto data_node = subgraphs[0]->FindNode("data1_1"); auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0); auto data_out_0_dims = data_out_0_desc->GetShape().GetDims(); EXPECT_EQ(data_out_0_dims.size(), 2); std::vector data_target_dims = {1, 1}; EXPECT_EQ(data_out_0_dims, data_target_dims); // check peer input update auto netoutput_node = parent_graph->FindNode("netoutput"); EXPECT_NE(netoutput_node, nullptr); auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0); auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims(); EXPECT_EQ(netoutput_in_0_dims.size(), 2); std::vector target_dims = {8,8}; EXPECT_EQ(netoutput_in_0_dims, target_dims); } } // namespace ge