Browse Source

!1973 Fix root node for MergeInputNodes

Merge pull request !1973 from 张晓昆/master
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
3d7177155c
3 changed files with 98 additions and 48 deletions
  1. +4
    -1
      ge/hybrid/model/hybrid_model_builder.cc
  2. +1
    -1
      ge/opskernel_manager/ops_kernel_manager.cc
  3. +93
    -46
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 4
- 1
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -588,7 +588,10 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) {
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto dst_node = peer_in_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
root_nodes.emplace(dst_node);
const auto in_nodes = dst_node->GetInDataNodes();
if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) {
root_nodes.emplace(dst_node);
}
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor));
GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor));
}


+ 1
- 1
ge/opskernel_manager/ops_kernel_manager.cc View File

@@ -279,7 +279,7 @@ void OpsKernelManager::InitOpsKernelInfo() {
if (it.second.empty()) {
continue;
}
auto comp_func = [this, &instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool {
auto comp_func = [&instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool {
const string &a = op_a.engine;
const string &b = op_b.engine;
// check if a or b is registered


+ 93
- 46
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -43,14 +43,11 @@
#include "graph/testcase/ge_graph/graph_builder_utils.h"
#include "single_op/task/build_task_utils.h"
#include "graph/op_desc_impl.h"
#undef private
#undef protected

using namespace std;
using namespace testing;
using namespace ge;
using namespace hybrid;

namespace ge {
using namespace hybrid;

class UtestGeHybrid : public testing::Test {
protected:
@@ -61,16 +58,30 @@ class UtestGeHybrid : public testing::Test {
}
};

static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", int in_num = 0, int out_num = 0) {
auto op_desc = std::make_shared<ge::OpDesc>(name, type);
op_desc->SetStreamId(0);
op_desc->SetId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(index * 64 + i * 64);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(index * 64 + in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
;
op_desc->SetWorkspaceBytes({});
op_desc->SetInputOffset({});
op_desc->SetOutputOffset({});

ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC");
bool support_dynamic = true;
@@ -414,49 +425,84 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) {
}

TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
ComputeGraphPtr merged_graph = nullptr;
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
auto partitioned_call_op_desc = CreateOpDesc("partitioned_call", PARTITIONEDCALL, 3, 1);
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
partitioned_call_op_desc->AddSubgraphName("f");
partitioned_call_op_desc->SetSubgraphInstanceName(0, "sub_graph");

ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
{
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
sub_sub_graph1->SetParentGraph(root_graph);
root_graph->AddSubGraph(sub_sub_graph1);
}

ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
sub_sub_graph2->SetGraphUnknownFlag(true);
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/
{
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
sub_sub_graph2->SetGraphUnknownFlag(true);
sub_sub_graph2->SetParentGraph(root_graph);
root_graph->AddSubGraph(sub_sub_graph2);
}

// Will unfold to merged_graph.
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
sub_graph->SetGraphUnknownFlag(true);
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");
{
OpDescPtr sub_graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1);
OpDescPtr sub_graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1);
OpDescPtr sub_graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1);
NodePtr sub_graph_data1_node = sub_graph->AddNode(sub_graph_data1_op_desc);
NodePtr sub_graph_data2_node = sub_graph->AddNode(sub_graph_data2_op_desc);
NodePtr sub_graph_data3_node = sub_graph->AddNode(sub_graph_data3_op_desc);

AttrUtils::SetInt(sub_graph_data1_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 0);
AttrUtils::SetInt(sub_graph_data2_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 1);
AttrUtils::SetInt(sub_graph_data3_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 2);

OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE, 2, 2);
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
sub_sub_graph1->SetParentNode(sub_graph_while_node);
sub_sub_graph2->SetParentNode(sub_graph_while_node);
sub_graph_while_op_desc->AddSubgraphName("while_cond");
sub_graph_while_op_desc->SetSubgraphInstanceName(0, "while_cond");
sub_graph_while_op_desc->AddSubgraphName("while_body");
sub_graph_while_op_desc->SetSubgraphInstanceName(1, "while_body");

OpDescPtr sub_graph_matmul_op_desc = CreateOpDesc("matmul", MATMUL, 2, 1);
NodePtr sub_graph_matmul_node = sub_graph->AddNode(sub_graph_matmul_op_desc);

OpDescPtr sub_graph_output_op_desc = CreateOpDesc("output", NETOUTPUT, 1, 1);
NodePtr sub_graph_output_node = sub_graph->AddNode(sub_graph_output_op_desc);

GraphUtils::AddEdge(sub_graph_data1_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(0));
GraphUtils::AddEdge(sub_graph_data2_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(1));
GraphUtils::AddEdge(sub_graph_data3_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(0));
GraphUtils::AddEdge(sub_graph_while_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(1));
GraphUtils::AddEdge(sub_graph_matmul_node->GetOutDataAnchor(0), sub_graph_output_node->GetInDataAnchor(0));

sub_graph->SetGraphUnknownFlag(true);
sub_graph->SetParentNode(partitioned_call_node);
sub_graph->SetParentGraph(root_graph);
root_graph->AddSubGraph(sub_graph);
}

ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");

root_graph->AddSubGraph(sub_sub_graph1);
root_graph->AddSubGraph(sub_sub_graph2);
sub_sub_graph1->SetParentGraph(root_graph);
sub_sub_graph2->SetParentGraph(root_graph);
sub_sub_graph1->SetParentNode(sub_graph_while_node);
sub_sub_graph2->SetParentNode(sub_graph_while_node);

root_graph->AddSubGraph(sub_graph);
sub_graph->SetParentNode(partitioned_call_node);
sub_graph->SetParentGraph(root_graph);
OpDescPtr graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1);
OpDescPtr graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1);
OpDescPtr graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1);
NodePtr graph_data1_node = root_graph->AddNode(graph_data1_op_desc);
NodePtr graph_data2_node = root_graph->AddNode(graph_data2_op_desc);
NodePtr graph_data3_node = root_graph->AddNode(graph_data3_op_desc);
AttrUtils::SetInt(graph_data1_op_desc, ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(graph_data2_op_desc, ATTR_NAME_INDEX, 1);
AttrUtils::SetInt(graph_data3_op_desc, ATTR_NAME_INDEX, 2);
GraphUtils::AddEdge(graph_data1_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(0));
GraphUtils::AddEdge(graph_data2_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(1));
GraphUtils::AddEdge(graph_data3_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(2));

ComputeGraphPtr merged_graph = nullptr;
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
HybridModel hybrid_model(root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
@@ -787,4 +833,5 @@ TEST_F(UtestGeHybrid, TestTaskExecuteAsync) {
std::vector<std::unique_ptr<AiCoreOpTask>> tasks;
AiCoreNodeTask node_task(std::move(tasks));
ASSERT_EQ(node_task.ExecuteAsync(task_context, nullptr), SUCCESS);
}
}
} // namespace ge

Loading…
Cancel
Save