Browse Source

!1669 modify tensor_desc idx when insert identity

From: @chen_yemeng
Reviewed-by: @sheng-nan,@ji_chen
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
efdf5e2da4
3 changed files with 133 additions and 3 deletions
  1. +2
    -2
      ge/graph/passes/subgraph_pass.cc
  2. +2
    -1
      tests/ut/ge/CMakeLists.txt
  3. +129
    -0
      tests/ut/ge/graph/passes/subgraph_pass_unittest.cc

+ 2
- 2
ge/graph/passes/subgraph_pass.cc View File

@@ -464,8 +464,8 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat
GE_CHECK_NOTNULL(out_anchor);
NodePtr in_node = out_anchor->GetOwnerNode();
OpDescBuilder op_desc_builder(name, IDENTITY);
OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0))
.AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0))
OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()))
.AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()))
.Build();
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true);


+ 2
- 1
tests/ut/ge/CMakeLists.txt View File

@@ -669,7 +669,7 @@ set(PASS_TEST_FILES
"graph/passes/permute_pass_unittest.cc"
"graph/passes/print_op_pass_unittest.cc"
"graph/passes/shape_operate_op_remove_pass_unittest.cc"
"graph/passes/variable_op_pass_unittest.cc"
"graph/passes/variable_op_pass_unittest.cc"
"graph/passes/base_pass_unittest.cc"
"graph/passes/addn_pass_unittest.cc"
"graph/passes/save_pass_unittest.cc"
@@ -678,6 +678,7 @@ set(PASS_TEST_FILES
"graph/passes/cond_branch_v1_unittest.cc"
"graph/passes/loop_branch_v1_unittest.cc"
"graph/passes/switch_dead_branch_elimination_unittest.cc"
"graph/passes/subgraph_pass_unittest.cc"
"graph/passes/assert_pass_unittest.cc"
"graph/passes/dropout_pass_unittest.cc"
"graph/passes/unused_const_pass_unittest.cc"


+ 129
- 0
tests/ut/ge/graph/passes/subgraph_pass_unittest.cc View File

@@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <string>
#include <gtest/gtest.h>

#include "graph/passes/subgraph_pass.h"
#include "inc/pass_manager.h"

namespace ge {
namespace {
class UtestGraphPassesSubgraphPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

OpDescPtr CreateOpDesc(const std::string name, const std::string type, uint32_t input_num, uint32_t output_num) {
OpDescPtr op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
if (op_desc == nullptr) {
return nullptr;
}
for (uint32_t i = 0; i < input_num; i++) {
op_desc->AddInputDesc(GeTensorDesc());
}
for (uint32_t i = 0; i < output_num; i++) {
op_desc->AddOutputDesc(GeTensorDesc());
}
return op_desc;
}

bool CheckMemcpyExist(const ComputeGraphPtr &graph) {
for (const auto &node : graph->GetDirectNode()) {
if (node->GetType() == IDENTITY) {
return true;
}
}
return false;
}

uint32_t CheckMemcpyNum(const ComputeGraphPtr &graph) {
uint32_t num = 0;
for (const auto &node : graph->GetDirectNode()) {
if (node->GetType() == IDENTITY) {
num++;
}
}
return num;
}
} // namespace

///
/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 *****
/// * *
/// Case * Const * Data
/// / \ * | * |
/// data_0 data_1 * NetOutput * NetOutput
/// * *
/// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 *****
///
TEST(UtestGraphPassesSubgraphPass, add_memcpy_success) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("add_memcpy_success");
NodePtr func_node = graph->AddNode(CreateOpDesc("Case", CASE, 2, 1));
NodePtr data_node_0 = graph->AddNode(CreateOpDesc("data_0", DATA, 1, 1));
NodePtr data_node_1 = graph->AddNode(CreateOpDesc("data_1", DATA, 1, 1));
EXPECT_EQ(GraphUtils::AddEdge(data_node_0->GetOutDataAnchor(0), func_node->GetInDataAnchor(0)), GRAPH_SUCCESS);
EXPECT_EQ(GraphUtils::AddEdge(data_node_1->GetOutDataAnchor(0), func_node->GetInDataAnchor(1)), GRAPH_SUCCESS);

std::string subgraph_name_1 = "instance_branch_1";
ComputeGraphPtr subgraph_1 = std::make_shared<ComputeGraph>(subgraph_name_1);
subgraph_1->SetParentNode(func_node);
subgraph_1->SetParentGraph(graph);
size_t index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size();
EXPECT_EQ(index, 0);
func_node->GetOpDesc()->AddSubgraphName("branch1");
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1);
func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_1);
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1);

std::string subgraph_name_2 = "instance_branch_2";
ComputeGraphPtr subgraph_2 = std::make_shared<ComputeGraph>(subgraph_name_2);
subgraph_2->SetParentNode(func_node);
subgraph_2->SetParentGraph(graph);
index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size();
EXPECT_EQ(index, 1);
func_node->GetOpDesc()->AddSubgraphName("branch2");
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2);
func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_2);
EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2);

{
// Const->NetOutput in subgraph
NodePtr const_node = subgraph_1->AddNode(CreateOpDesc("const", CONSTANTOP, 0, 1));
NodePtr output_node = subgraph_1->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1));
EXPECT_EQ(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS);
}

{
// Data->NetOutput in subgraph but not while body
NodePtr data_node = subgraph_2->AddNode(CreateOpDesc("sata", DATA, 1, 1));
NodePtr output_node = subgraph_2->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1));
EXPECT_EQ(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS);
EXPECT_TRUE(AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1));
}

PassManager pass_manager;
pass_manager.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass);
EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
EXPECT_FALSE(CheckMemcpyExist(graph));
EXPECT_EQ(pass_manager.Run(subgraph_1), SUCCESS);
EXPECT_EQ(CheckMemcpyNum(subgraph_1), 1);
EXPECT_EQ(pass_manager.Run(subgraph_2), SUCCESS);
EXPECT_EQ(CheckMemcpyNum(subgraph_2), 1);
}
} // namespace ge

Loading…
Cancel
Save