|
|
@@ -0,0 +1,129 @@ |
|
|
|
/** |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <cstdint> |
|
|
|
#include <string> |
|
|
|
#include <gtest/gtest.h> |
|
|
|
|
|
|
|
#include "graph/passes/subgraph_pass.h" |
|
|
|
#include "inc/pass_manager.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
class UtestGraphPassesSubgraphPass : public testing::Test { |
|
|
|
protected: |
|
|
|
void SetUp() {} |
|
|
|
void TearDown() {} |
|
|
|
}; |
|
|
|
|
|
|
|
OpDescPtr CreateOpDesc(const std::string name, const std::string type, uint32_t input_num, uint32_t output_num) { |
|
|
|
OpDescPtr op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type)); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
for (uint32_t i = 0; i < input_num; i++) { |
|
|
|
op_desc->AddInputDesc(GeTensorDesc()); |
|
|
|
} |
|
|
|
for (uint32_t i = 0; i < output_num; i++) { |
|
|
|
op_desc->AddOutputDesc(GeTensorDesc()); |
|
|
|
} |
|
|
|
return op_desc; |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckMemcpyExist(const ComputeGraphPtr &graph) { |
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
if (node->GetType() == IDENTITY) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t CheckMemcpyNum(const ComputeGraphPtr &graph) { |
|
|
|
uint32_t num = 0; |
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
if (node->GetType() == IDENTITY) { |
|
|
|
num++; |
|
|
|
} |
|
|
|
} |
|
|
|
return num; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
/// |
|
|
|
/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** |
|
|
|
/// * * |
|
|
|
/// Case * Const * Data |
|
|
|
/// / \ * | * | |
|
|
|
/// data_0 data_1 * NetOutput * NetOutput |
|
|
|
/// * * |
|
|
|
/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** |
|
|
|
/// |
|
|
|
TEST(UtestGraphPassesSubgraphPass, add_memcpy_success) { |
|
|
|
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("add_memcpy_success"); |
|
|
|
NodePtr func_node = graph->AddNode(CreateOpDesc("Case", CASE, 2, 1)); |
|
|
|
NodePtr data_node_0 = graph->AddNode(CreateOpDesc("data_0", DATA, 1, 1)); |
|
|
|
NodePtr data_node_1 = graph->AddNode(CreateOpDesc("data_1", DATA, 1, 1)); |
|
|
|
EXPECT_EQ(GraphUtils::AddEdge(data_node_0->GetOutDataAnchor(0), func_node->GetInDataAnchor(0)), GRAPH_SUCCESS); |
|
|
|
EXPECT_EQ(GraphUtils::AddEdge(data_node_1->GetOutDataAnchor(0), func_node->GetInDataAnchor(1)), GRAPH_SUCCESS); |
|
|
|
|
|
|
|
std::string subgraph_name_1 = "instance_branch_1"; |
|
|
|
ComputeGraphPtr subgraph_1 = std::make_shared<ComputeGraph>(subgraph_name_1); |
|
|
|
subgraph_1->SetParentNode(func_node); |
|
|
|
subgraph_1->SetParentGraph(graph); |
|
|
|
size_t index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); |
|
|
|
EXPECT_EQ(index, 0); |
|
|
|
func_node->GetOpDesc()->AddSubgraphName("branch1"); |
|
|
|
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); |
|
|
|
func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_1); |
|
|
|
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); |
|
|
|
|
|
|
|
std::string subgraph_name_2 = "instance_branch_2"; |
|
|
|
ComputeGraphPtr subgraph_2 = std::make_shared<ComputeGraph>(subgraph_name_2); |
|
|
|
subgraph_2->SetParentNode(func_node); |
|
|
|
subgraph_2->SetParentGraph(graph); |
|
|
|
index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); |
|
|
|
EXPECT_EQ(index, 1); |
|
|
|
func_node->GetOpDesc()->AddSubgraphName("branch2"); |
|
|
|
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); |
|
|
|
func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_2); |
|
|
|
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); |
|
|
|
|
|
|
|
{ |
|
|
|
// Const->NetOutput in subgraph |
|
|
|
NodePtr const_node = subgraph_1->AddNode(CreateOpDesc("const", CONSTANTOP, 0, 1)); |
|
|
|
NodePtr output_node = subgraph_1->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); |
|
|
|
EXPECT_EQ(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
// Data->NetOutput in subgraph but not while body |
|
|
|
NodePtr data_node = subgraph_2->AddNode(CreateOpDesc("sata", DATA, 1, 1)); |
|
|
|
NodePtr output_node = subgraph_2->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); |
|
|
|
EXPECT_EQ(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); |
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1)); |
|
|
|
} |
|
|
|
|
|
|
|
PassManager pass_manager; |
|
|
|
pass_manager.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass); |
|
|
|
EXPECT_EQ(pass_manager.Run(graph), SUCCESS); |
|
|
|
EXPECT_FALSE(CheckMemcpyExist(graph)); |
|
|
|
EXPECT_EQ(pass_manager.Run(subgraph_1), SUCCESS); |
|
|
|
EXPECT_EQ(CheckMemcpyNum(subgraph_1), 1); |
|
|
|
EXPECT_EQ(pass_manager.Run(subgraph_2), SUCCESS); |
|
|
|
EXPECT_EQ(CheckMemcpyNum(subgraph_2), 1); |
|
|
|
} |
|
|
|
} // namespace ge |