From b90d8105ecfac0e0c8dfb8220aad68b31bbc2c2f Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Wed, 19 May 2021 15:00:48 +0800 Subject: [PATCH] modify tensor_desc idx when insert identity --- ge/graph/passes/subgraph_pass.cc | 4 +- tests/ut/ge/CMakeLists.txt | 3 +- .../ge/graph/passes/subgraph_pass_unittest.cc | 129 ++++++++++++++++++ 3 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 tests/ut/ge/graph/passes/subgraph_pass_unittest.cc diff --git a/ge/graph/passes/subgraph_pass.cc b/ge/graph/passes/subgraph_pass.cc index b931eea8..401dee54 100755 --- a/ge/graph/passes/subgraph_pass.cc +++ b/ge/graph/passes/subgraph_pass.cc @@ -464,8 +464,8 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat GE_CHECK_NOTNULL(out_anchor); NodePtr in_node = out_anchor->GetOwnerNode(); OpDescBuilder op_desc_builder(name, IDENTITY); - OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) - .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) + OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx())) + .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx())) .Build(); (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); (void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 895c33df..d24b440d 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -665,7 +665,7 @@ set(PASS_TEST_FILES "graph/passes/permute_pass_unittest.cc" "graph/passes/print_op_pass_unittest.cc" "graph/passes/shape_operate_op_remove_pass_unittest.cc" - "graph/passes/variable_op_pass_unittest.cc" + "graph/passes/variable_op_pass_unittest.cc" "graph/passes/base_pass_unittest.cc" "graph/passes/addn_pass_unittest.cc" "graph/passes/save_pass_unittest.cc" @@ -674,6 +674,7 @@ set(PASS_TEST_FILES "graph/passes/cond_branch_v1_unittest.cc" "graph/passes/loop_branch_v1_unittest.cc" "graph/passes/switch_dead_branch_elimination_unittest.cc" + "graph/passes/subgraph_pass_unittest.cc" "graph/passes/assert_pass_unittest.cc" "graph/passes/dropout_pass_unittest.cc" "graph/passes/unused_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/subgraph_pass_unittest.cc b/tests/ut/ge/graph/passes/subgraph_pass_unittest.cc new file mode 100644 index 00000000..f11882e1 --- /dev/null +++ b/tests/ut/ge/graph/passes/subgraph_pass_unittest.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "graph/passes/subgraph_pass.h" +#include "inc/pass_manager.h" + +namespace ge { +namespace { +class UtestGraphPassesSubgraphPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +OpDescPtr CreateOpDesc(const std::string name, const std::string type, uint32_t input_num, uint32_t output_num) { + OpDescPtr op_desc = std::shared_ptr(new (std::nothrow) OpDesc(name, type)); + if (op_desc == nullptr) { + return nullptr; + } + for (uint32_t i = 0; i < input_num; i++) { + op_desc->AddInputDesc(GeTensorDesc()); + } + for (uint32_t i = 0; i < output_num; i++) { + op_desc->AddOutputDesc(GeTensorDesc()); + } + return op_desc; +} + +bool CheckMemcpyExist(const ComputeGraphPtr &graph) { + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() == IDENTITY) { + return true; + } + } + return false; +} + +uint32_t CheckMemcpyNum(const ComputeGraphPtr &graph) { + uint32_t num = 0; + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() == IDENTITY) { + num++; + } + } + return num; +} +} // namespace + +/// +/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** +/// * * +/// Case * Const * Data +/// / \ * | * | +/// data_0 data_1 * NetOutput * NetOutput +/// * * +/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 ***** +/// +TEST(UtestGraphPassesSubgraphPass, add_memcpy_success) { + ComputeGraphPtr graph = std::make_shared("add_memcpy_success"); + NodePtr func_node = graph->AddNode(CreateOpDesc("Case", CASE, 2, 1)); + NodePtr data_node_0 = graph->AddNode(CreateOpDesc("data_0", DATA, 1, 1)); + NodePtr data_node_1 = graph->AddNode(CreateOpDesc("data_1", DATA, 1, 1)); + EXPECT_EQ(GraphUtils::AddEdge(data_node_0->GetOutDataAnchor(0), func_node->GetInDataAnchor(0)), GRAPH_SUCCESS); + EXPECT_EQ(GraphUtils::AddEdge(data_node_1->GetOutDataAnchor(0), func_node->GetInDataAnchor(1)), GRAPH_SUCCESS); + + std::string subgraph_name_1 = "instance_branch_1"; + ComputeGraphPtr subgraph_1 = std::make_shared(subgraph_name_1); + subgraph_1->SetParentNode(func_node); + subgraph_1->SetParentGraph(graph); + size_t index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); + EXPECT_EQ(index, 0); + func_node->GetOpDesc()->AddSubgraphName("branch1"); + EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); + func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_1); + EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1); + + std::string subgraph_name_2 = "instance_branch_2"; + ComputeGraphPtr subgraph_2 = std::make_shared(subgraph_name_2); + subgraph_2->SetParentNode(func_node); + subgraph_2->SetParentGraph(graph); + index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size(); + EXPECT_EQ(index, 1); + func_node->GetOpDesc()->AddSubgraphName("branch2"); + EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); + func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_2); + EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2); + + { + // Const->NetOutput in subgraph + NodePtr const_node = subgraph_1->AddNode(CreateOpDesc("const", CONSTANTOP, 0, 1)); + NodePtr output_node = subgraph_1->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); + EXPECT_EQ(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); + } + + { + // Data->NetOutput in subgraph but not while body + NodePtr data_node = subgraph_2->AddNode(CreateOpDesc("sata", DATA, 1, 1)); + NodePtr output_node = subgraph_2->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1)); + EXPECT_EQ(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS); + EXPECT_TRUE(AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1)); + } + + PassManager pass_manager; + pass_manager.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); + EXPECT_FALSE(CheckMemcpyExist(graph)); + EXPECT_EQ(pass_manager.Run(subgraph_1), SUCCESS); + EXPECT_EQ(CheckMemcpyNum(subgraph_1), 1); + EXPECT_EQ(pass_manager.Run(subgraph_2), SUCCESS); + EXPECT_EQ(CheckMemcpyNum(subgraph_2), 1); +} +} // namespace ge