|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include <gtest/gtest.h>
-
- #define protected public
- #define private public
- #include "graph/passes/infer_base_pass.h"
- #include "graph/utils/tensor_utils.h"
- #include "graph/utils/graph_utils.h"
- #include "graph_builder_utils.h"
-
- #include "inc/external/graph/operator_reg.h"
- #include "inc/external/graph/operator.h"
- #include "inc/external/graph/operator_factory.h"
- #include "inc/graph/operator_factory_impl.h"
-
-
- using namespace std;
- using namespace testing;
- namespace ge {
-
- class InferBasePassStub : public InferBasePass {
- public:
- graphStatus Infer(NodePtr &node) override{
- auto op_desc = node->GetOpDesc();
- auto input_desc = op_desc->MutableInputDesc(0);
- auto output_desc = op_desc->MutableOutputDesc(0);
- if (input_desc->GetShape().GetDims() != output_desc->GetShape().GetDims()) {
- input_desc->SetShape(output_desc->GetShape());
- return GRAPH_NODE_NEED_REPASS;
- }
- return GRAPH_SUCCESS;
- };
- private:
- std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; };
- graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override {
- if (src->GetShape().GetDims() != dst->GetShape().GetDims()) {
- changed = true;
- } else {
- changed = false;
- }
- dst->SetShape(src->GetShape());
- return GRAPH_SUCCESS;
- };
- graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override {
- dst->SetShape(src[0]->GetShape());
- return GRAPH_SUCCESS;
- };
- graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
- GeTensorDescPtr &dst) override {
- dst->SetShape(src[0]->GetShape());
- return GRAPH_SUCCESS;
- };
- };
-
- class UtestGraphInferBasePassStub : public testing::Test {
- protected:
- void SetUp() {}
- void TearDown() {}
- };
-
- /*
- * data1 data2
- * \ /
- * merge
- * |
- * netoutput
- */
- ut::GraphBuilder TestSubgraphBuilder() {
- ut::GraphBuilder builder = ut::GraphBuilder("branch_graph");
- std::vector<int64_t> shape1 = {1,1};
- auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1);
- auto data1_desc = data1->GetOpDesc();
- EXPECT_NE(data1_desc, nullptr);
- AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
- std::vector<int64_t> shape2 = {2,2};
- auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2);
- auto data2_desc = data2->GetOpDesc();
- EXPECT_NE(data2_desc, nullptr);
- AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
-
- auto merge = builder.AddNode("merge", "Merge", 2, 1);
- std::vector<int64_t> shape7 = {8,8};
- auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7);
- auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
- EXPECT_NE(input0_desc, nullptr);
- AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
-
- builder.AddDataEdge(data1, 0, merge, 0);
- builder.AddDataEdge(data2, 0, merge, 1);
- builder.AddDataEdge(merge, 0, netoutput, 0);
- return builder;
- }
-
- /*
- * data1 data2
- * \ /
- * case1
- * |
- * netoutput
- */
- ut::GraphBuilder RootGraphBuilder() {
- ut::GraphBuilder builder = ut::GraphBuilder("root_graph");
- auto data1 = builder.AddNode("data1", "Data", 0, 1);
- auto data2 = builder.AddNode("data2", "Data", 0, 1);
- auto case1 = builder.AddNode("case1", CASE, 2, 1);
- auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
- builder.AddDataEdge(data1, 0, case1, 0);
- builder.AddDataEdge(data2, 0, case1, 1);
- builder.AddDataEdge(case1, 0, netoutput, 0);
-
- auto parent_graph = builder.GetGraph();
- auto subgraph_builder = TestSubgraphBuilder();
- auto subgraph = subgraph_builder.GetGraph();
- case1->GetOpDesc()->AddSubgraphName(subgraph->GetName());
- case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName());
- subgraph->SetParentNode(case1);
- subgraph->SetParentGraph(parent_graph);
- EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS);
- return builder;
- }
-
- TEST_F(UtestGraphInferBasePassStub, infer_base_before_subgraph) {
- auto builder = RootGraphBuilder();
- auto parent_graph = builder.GetGraph();
- auto subgraphs = parent_graph->GetAllSubgraphs();
- EXPECT_EQ(subgraphs.size(), 1);
-
- // check base pass run
- auto case_node = parent_graph->FindNode("case1");
- EXPECT_NE(case_node, nullptr);
- InferBasePassStub base_pass;
- EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
-
- // check subgraph data update
- auto data_node = subgraphs[0]->FindNode("data1_1");
- auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
- auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
- EXPECT_EQ(data_out_0_dims.size(), 4);
- std::vector<int64_t> data_target_dims = {1, 1, 224, 224};
- EXPECT_EQ(data_out_0_dims, data_target_dims);
-
- // check peer input update
- auto netoutput_node = parent_graph->FindNode("netoutput");
- EXPECT_NE(netoutput_node, nullptr);
- auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
- auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
- EXPECT_EQ(netoutput_in_0_dims.size(), 4);
- std::vector<int64_t> target_dims = {1, 1, 224, 224};
- EXPECT_EQ(netoutput_in_0_dims, target_dims);
- }
-
- TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_need_repass) {
- auto builder = RootGraphBuilder();
- auto parent_graph = builder.GetGraph();
- auto subgraphs = parent_graph->GetAllSubgraphs();
- EXPECT_EQ(subgraphs.size(), 1);
-
- // check base pass run
- auto case_node = parent_graph->FindNode("case1");
- EXPECT_NE(case_node, nullptr);
- InferBasePassStub base_pass;
- base_pass.options_[kOptimizeAfterSubGraph] = "yes";
- EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
-
- // check subgraph data update
- auto data_node = subgraphs[0]->FindNode("data1_1");
- auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
- auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
- EXPECT_EQ(data_out_0_dims.size(), 2);
- std::vector<int64_t> data_target_dims = {1, 1};
- EXPECT_EQ(data_out_0_dims, data_target_dims);
-
- // check peer input update
- auto netoutput_node = parent_graph->FindNode("netoutput");
- EXPECT_NE(netoutput_node, nullptr);
- auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
- auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
- EXPECT_EQ(netoutput_in_0_dims.size(), 4);
- std::vector<int64_t> target_dims = {1, 1, 224, 224};
- EXPECT_EQ(netoutput_in_0_dims, target_dims);
- }
-
-
- TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_no_repass) {
- auto builder = RootGraphBuilder();
- auto parent_graph = builder.GetGraph();
- auto subgraphs = parent_graph->GetAllSubgraphs();
- EXPECT_EQ(subgraphs.size(), 1);
-
- // check base pass run
- auto case_node = parent_graph->FindNode("case1");
- EXPECT_NE(case_node, nullptr);
- // update case in shape, do not re_pass
- auto case_in_shape_no_repass = GeShape({8,8});
- case_node->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in_shape_no_repass);
- InferBasePassStub base_pass;
- base_pass.options_[kOptimizeAfterSubGraph] = "yes";
- EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
-
- // check subgraph data update
- auto data_node = subgraphs[0]->FindNode("data1_1");
- auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
- auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
- EXPECT_EQ(data_out_0_dims.size(), 2);
- std::vector<int64_t> data_target_dims = {1, 1};
- EXPECT_EQ(data_out_0_dims, data_target_dims);
-
- // check peer input update
- auto netoutput_node = parent_graph->FindNode("netoutput");
- EXPECT_NE(netoutput_node, nullptr);
- auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
- auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
- EXPECT_EQ(netoutput_in_0_dims.size(), 2);
- std::vector<int64_t> target_dims = {8,8};
- EXPECT_EQ(netoutput_in_0_dims, target_dims);
- }
- } // namespace ge
|