From cbfc856b3e3849db5059780713fc516652ea1c02 Mon Sep 17 00:00:00 2001 From: wxl Date: Thu, 8 Apr 2021 15:47:35 +0800 Subject: [PATCH 1/3] fix data directlly connect netoutput scene --- ge/hybrid/model/hybrid_model_builder.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index ad1dae7a..fc5c65d9 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1539,14 +1539,20 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { in_data_anchor->GetIdx(), src_node->GetName().c_str(), src_op_type.c_str()); + uint32_t parent_index = 0; + GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index)); + GELOGD("Got parent output index = %u", parent_index); + if (src_op_type == DATA) { + int ref_i = 0; + (void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); + node_item.reuse_inputs.emplace(static_cast(parent_index), ref_i); + GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); + } if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { continue; } - uint32_t parent_index = 0; - GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index)); - GELOGD("Got parent output index = %u", parent_index); GE_CHECK_LE(parent_index, INT32_MAX); node_item.ref_outputs.emplace(static_cast(parent_index), src_node); if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { From 93b6dff0d76f36d5fb89fb85f860d8b8840bb0b6 Mon Sep 17 00:00:00 2001 From: wxl Date: Fri, 9 Apr 2021 17:57:57 +0800 Subject: [PATCH 2/3] fix data directlly connect netoutput scene --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 29 +++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 274cc56f..bc706165 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -39,7 +39,7 @@ #include "hybrid/common/npu_memory_allocator.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" - +#include "graph/testcase/ge_graph/graph_builder_utils.h" #undef private #undef protected @@ -173,6 +173,33 @@ TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { HybridModelBuilder hybrid_model_builder(hybrid_model); ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); } +static ComputeGraphPtr BuildDataDirectConnectGraph() { + ge::ut::GraphBuilder builder("subgraph"); + auto data = builder.AddNode("Data", "Data", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "Netoutput", 1, 1); + + builder.AddDataEdge(data, 0, netoutput, 0); + return builder.GetGraph(); +} +TEST_F(UtestGeHybrid, data_direct_connect) { + std::unique_ptr node_item; + auto root_graph = make_shared("root_graph"); + OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall"); + auto node = root_graph->AddNode(op_desc); + auto sub_graph = BuildDataDirectConnectGraph(); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(node); + node->GetOpDesc()->AddSubgraphName("subgraph"); + node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); + root_graph->AddSubgraph("subgraph", sub_graph); + std::unique_ptr new_node; + NodeItem::Create(node, new_node); + GeRootModelPtr ge_root_model = make_shared(root_graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); + ASSERT_EQ(ret, SUCCESS); +} TEST_F(UtestGeHybrid, index_taskdefs_success) { // build aicore task From 91fe55a571c117af766cf6d638209be7e6be1a8e Mon Sep 17 00:00:00 2001 From: wxl Date: Fri, 9 Apr 2021 19:24:22 +0800 Subject: [PATCH 3/3] fix data directlly connect netoutput scene --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index bc706165..c424bdb4 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -174,9 +174,11 @@ TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); } static ComputeGraphPtr BuildDataDirectConnectGraph() { + const char *kRefIndex = "_parent_node_index"; ge::ut::GraphBuilder builder("subgraph"); auto data = builder.AddNode("Data", "Data", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "Netoutput", 1, 1); + auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 1); + (void)AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), kRefIndex, 0); builder.AddDataEdge(data, 0, netoutput, 0); return builder.GetGraph(); @@ -186,6 +188,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) { auto root_graph = make_shared("root_graph"); OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall"); auto node = root_graph->AddNode(op_desc); + node->SetOwnerComputeGraph(root_graph); auto sub_graph = BuildDataDirectConnectGraph(); sub_graph->SetParentGraph(root_graph); sub_graph->SetParentNode(node);