/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "graph/passes/subgraph_pass.h" #include "inc/pass_manager.h" namespace ge { namespace { class UtestGraphPassesSubgraphPass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; OpDescPtr CreateOpDesc(const std::string name, const std::string type, uint32_t input_num, uint32_t output_num) { OpDescPtr op_desc = std::shared_ptr(new (std::nothrow) OpDesc(name, type)); if (op_desc == nullptr) { return nullptr; } for (uint32_t i = 0; i < input_num; i++) { op_desc->AddInputDesc(GeTensorDesc()); } for (uint32_t i = 0; i < output_num; i++) { op_desc->AddOutputDesc(GeTensorDesc()); } return op_desc; } bool CheckMemcpyExist(const ComputeGraphPtr &graph) { for (const auto &node : graph->GetDirectNode()) { if (node->GetType() == IDENTITY) { return true; } } return false; } uint32_t CheckMemcpyNum(const ComputeGraphPtr &graph) { uint32_t num = 0; for (const auto &node : graph->GetDirectNode()) { if (node->GetType() == IDENTITY) { num++; } } return num; } } // namespace /// /// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** /// * * /// Case * Const * Data /// / \ * | * | /// data_0 data_1 * NetOutput * NetOutput /// * * /// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** /// TEST(UtestGraphPassesSubgraphPass, add_memcpy_success) { ComputeGraphPtr graph = std::make_shared("add_memcpy_success"); NodePtr func_node = graph->AddNode(CreateOpDesc("Case", CASE, 2, 1)); NodePtr data_node_0 = graph->AddNode(CreateOpDesc("data_0", DATA, 1, 1)); NodePtr data_node_1 = graph->AddNode(CreateOpDesc("data_1", DATA, 1, 1)); EXPECT_EQ(GraphUtils::AddEdge(data_node_0->GetOutDataAnchor(0), func_node->GetInDataAnchor(0)), GRAPH_SUCCESS); EXPECT_EQ(GraphUtils::AddEdge(data_node_1->GetOutDataAnchor(0), func_node->GetInDataAnchor(1)), GRAPH_SUCCESS); std::string subgraph_name_1 = "instance_branch_1"; ComputeGraphPtr subgraph_1 = std::make_shared(subgraph_name_1); subgraph_1->SetParentNode(func_node); subgraph_1->SetParentGraph(graph); size_t index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); EXPECT_EQ(index, 0); func_node->GetOpDesc()->AddSubgraphName("branch1"); EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_1); EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); std::string subgraph_name_2 = "instance_branch_2"; ComputeGraphPtr subgraph_2 = std::make_shared(subgraph_name_2); subgraph_2->SetParentNode(func_node); subgraph_2->SetParentGraph(graph); index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); EXPECT_EQ(index, 1); func_node->GetOpDesc()->AddSubgraphName("branch2"); EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_2); EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); { // Const->NetOutput in subgraph NodePtr const_node = subgraph_1->AddNode(CreateOpDesc("const", CONSTANTOP, 0, 1)); NodePtr output_node = subgraph_1->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); EXPECT_EQ(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); } { // Data->NetOutput in subgraph but not while body NodePtr data_node = subgraph_2->AddNode(CreateOpDesc("sata", DATA, 1, 1)); NodePtr output_node = subgraph_2->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); EXPECT_EQ(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); EXPECT_TRUE(AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1)); } PassManager pass_manager; pass_manager.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass); EXPECT_EQ(pass_manager.Run(graph), SUCCESS); EXPECT_FALSE(CheckMemcpyExist(graph)); EXPECT_EQ(pass_manager.Run(subgraph_1), SUCCESS); EXPECT_EQ(CheckMemcpyNum(subgraph_1), 1); EXPECT_EQ(pass_manager.Run(subgraph_2), SUCCESS); EXPECT_EQ(CheckMemcpyNum(subgraph_2), 1); } } // namespace ge