Browse Source

!1477 fix data direct connect netoutput scene

From: @wan_xuelei
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
7a9791450a
2 changed files with 40 additions and 4 deletions
  1. +9
    -3
      ge/hybrid/model/hybrid_model_builder.cc
  2. +31
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 9
- 3
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1540,14 +1540,20 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
in_data_anchor->GetIdx(),
src_node->GetName().c_str(),
src_op_type.c_str());
uint32_t parent_index = 0;
GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index));
GELOGD("Got parent output index = %u", parent_index);
if (src_op_type == DATA) {
int ref_i = 0;
(void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i);
node_item.reuse_inputs.emplace(static_cast<int>(parent_index), ref_i);
GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i);
}

if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) {
continue;
}

uint32_t parent_index = 0;
GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index));
GELOGD("Got parent output index = %u", parent_index);
GE_CHECK_LE(parent_index, INT32_MAX);
node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node);
if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) {


+ 31
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -39,7 +39,7 @@
#include "hybrid/common/npu_memory_allocator.h"
#include "graph/types.h"
#include "graph/utils/tensor_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
#undef private
#undef protected

@@ -173,6 +173,36 @@ TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
}
static ComputeGraphPtr BuildDataDirectConnectGraph() {
const char *kRefIndex = "_parent_node_index";
ge::ut::GraphBuilder builder("subgraph");
auto data = builder.AddNode("Data", "Data", 1, 1);
auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 1);
(void)AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), kRefIndex, 0);

builder.AddDataEdge(data, 0, netoutput, 0);
return builder.GetGraph();
}
TEST_F(UtestGeHybrid, data_direct_connect) {
std::unique_ptr<NodeItem> node_item;
auto root_graph = make_shared<ComputeGraph>("root_graph");
OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall");
auto node = root_graph->AddNode(op_desc);
node->SetOwnerComputeGraph(root_graph);
auto sub_graph = BuildDataDirectConnectGraph();
sub_graph->SetParentGraph(root_graph);
sub_graph->SetParentNode(node);
node->GetOpDesc()->AddSubgraphName("subgraph");
node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph");
root_graph->AddSubgraph("subgraph", sub_graph);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get());
ASSERT_EQ(ret, SUCCESS);
}

TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task


Loading…
Cancel
Save