From ca3b811ba1142c69ac31865e2d4df779d5dcf576 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 10 Jul 2021 15:39:27 +0800 Subject: [PATCH] Fix root node for MergeInputNodes --- ge/hybrid/model/hybrid_model_builder.cc | 5 +- ge/opskernel_manager/ops_kernel_manager.cc | 2 +- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 139 ++++++++++++++------- 3 files changed, 98 insertions(+), 48 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f8ec6db1..c722d269 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -588,7 +588,10 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { auto dst_node = peer_in_data_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); - root_nodes.emplace(dst_node); + const auto in_nodes = dst_node->GetInDataNodes(); + if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { + root_nodes.emplace(dst_node); + } GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); } diff --git a/ge/opskernel_manager/ops_kernel_manager.cc b/ge/opskernel_manager/ops_kernel_manager.cc index fc7bbdc2..d35ebda5 100644 --- a/ge/opskernel_manager/ops_kernel_manager.cc +++ b/ge/opskernel_manager/ops_kernel_manager.cc @@ -279,7 +279,7 @@ void OpsKernelManager::InitOpsKernelInfo() { if (it.second.empty()) { continue; } - auto comp_func = [this, &instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool { + auto comp_func = [&instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool { const string &a = op_a.engine; const string &b = op_b.engine; // check if a or b is registered diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 1d1c4fa9..b09211cb 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -43,14 +43,11 @@ #include "graph/testcase/ge_graph/graph_builder_utils.h" #include "single_op/task/build_task_utils.h" #include "graph/op_desc_impl.h" -#undef private -#undef protected using namespace std; -using namespace testing; -using namespace ge; -using namespace hybrid; +namespace ge { +using namespace hybrid; class UtestGeHybrid : public testing::Test { protected: @@ -61,16 +58,30 @@ class UtestGeHybrid : public testing::Test { } }; -static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { +static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", int in_num = 0, int out_num = 0) { auto op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); - op_desc->SetId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); op_desc->SetWorkspace({}); - ; op_desc->SetWorkspaceBytes({}); - op_desc->SetInputOffset({}); - op_desc->SetOutputOffset({}); ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); bool support_dynamic = true; @@ -414,49 +425,84 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { } TEST_F(UtestGeHybrid, unfold_subgraphs_success) { - ComputeGraphPtr merged_graph = nullptr; + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = CreateOpDesc("partitioned_call", PARTITIONEDCALL, 3, 1); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + partitioned_call_op_desc->AddSubgraphName("f"); + partitioned_call_op_desc->SetSubgraphInstanceName(0, "sub_graph"); ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); - OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); - NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + { + OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); + NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + sub_sub_graph1->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_sub_graph1); + } ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); - /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); - NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ - OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); - NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); - sub_sub_graph2->SetGraphUnknownFlag(true); - /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); - NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ + { + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + sub_sub_graph2->SetGraphUnknownFlag(true); + sub_sub_graph2->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_sub_graph2); + } + // Will unfold to merged_graph. ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); - NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(true); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + { + OpDescPtr sub_graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); + OpDescPtr sub_graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); + OpDescPtr sub_graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); + NodePtr sub_graph_data1_node = sub_graph->AddNode(sub_graph_data1_op_desc); + NodePtr sub_graph_data2_node = sub_graph->AddNode(sub_graph_data2_op_desc); + NodePtr sub_graph_data3_node = sub_graph->AddNode(sub_graph_data3_op_desc); + + AttrUtils::SetInt(sub_graph_data1_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 0); + AttrUtils::SetInt(sub_graph_data2_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 1); + AttrUtils::SetInt(sub_graph_data3_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 2); + + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE, 2, 2); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_sub_graph1->SetParentNode(sub_graph_while_node); + sub_sub_graph2->SetParentNode(sub_graph_while_node); + sub_graph_while_op_desc->AddSubgraphName("while_cond"); + sub_graph_while_op_desc->SetSubgraphInstanceName(0, "while_cond"); + sub_graph_while_op_desc->AddSubgraphName("while_body"); + sub_graph_while_op_desc->SetSubgraphInstanceName(1, "while_body"); + + OpDescPtr sub_graph_matmul_op_desc = CreateOpDesc("matmul", MATMUL, 2, 1); + NodePtr sub_graph_matmul_node = sub_graph->AddNode(sub_graph_matmul_op_desc); + + OpDescPtr sub_graph_output_op_desc = CreateOpDesc("output", NETOUTPUT, 1, 1); + NodePtr sub_graph_output_node = sub_graph->AddNode(sub_graph_output_op_desc); + + GraphUtils::AddEdge(sub_graph_data1_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub_graph_data2_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub_graph_data3_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub_graph_while_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub_graph_matmul_node->GetOutDataAnchor(0), sub_graph_output_node->GetInDataAnchor(0)); + + sub_graph->SetGraphUnknownFlag(true); + sub_graph->SetParentNode(partitioned_call_node); + sub_graph->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_graph); + } - ComputeGraphPtr root_graph = std::make_shared("root_graph"); - auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); - auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); - partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); - partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - - root_graph->AddSubGraph(sub_sub_graph1); - root_graph->AddSubGraph(sub_sub_graph2); - sub_sub_graph1->SetParentGraph(root_graph); - sub_sub_graph2->SetParentGraph(root_graph); - sub_sub_graph1->SetParentNode(sub_graph_while_node); - sub_sub_graph2->SetParentNode(sub_graph_while_node); - - root_graph->AddSubGraph(sub_graph); - sub_graph->SetParentNode(partitioned_call_node); - sub_graph->SetParentGraph(root_graph); + OpDescPtr graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); + OpDescPtr graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); + OpDescPtr graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); + NodePtr graph_data1_node = root_graph->AddNode(graph_data1_op_desc); + NodePtr graph_data2_node = root_graph->AddNode(graph_data2_op_desc); + NodePtr graph_data3_node = root_graph->AddNode(graph_data3_op_desc); + AttrUtils::SetInt(graph_data1_op_desc, ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(graph_data2_op_desc, ATTR_NAME_INDEX, 1); + AttrUtils::SetInt(graph_data3_op_desc, ATTR_NAME_INDEX, 2); + GraphUtils::AddEdge(graph_data1_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(graph_data2_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(graph_data3_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(2)); + ComputeGraphPtr merged_graph = nullptr; GeRootModelPtr root_model = MakeShared(root_graph); HybridModel hybrid_model(root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); @@ -787,4 +833,5 @@ TEST_F(UtestGeHybrid, TestTaskExecuteAsync) { std::vector> tasks; AiCoreNodeTask node_task(std::move(tasks)); ASSERT_EQ(node_task.ExecuteAsync(task_context, nullptr), SUCCESS); -} \ No newline at end of file +} +} // namespace ge \ No newline at end of file