diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 55f3c4dd..316b94de 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -806,7 +806,7 @@ Status HybridModelBuilder::LoadGraph() { } } - GE_CHK_STATUS_RET(ParseDependentForHcclNodes(), "Failed to establish dependencies for hccl ops"); + GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); GELOGI("Done loading all subgraphs successfully."); return SUCCESS; } @@ -1907,7 +1907,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root NodeItem *node_item = nullptr; GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); - GE_CHK_STATUS_RET_NOLOG(ParseParallelGroups(node_item)); + GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task node_item->input_start = input_start; @@ -2015,16 +2015,16 @@ Status HybridModelBuilder::CheckAicpuOpList() { return SUCCESS; } -Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { +Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { const auto &node = node_item->node; auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { std::string parallel_group; if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str()); - group_to_nodes_[parallel_group].emplace(node_item); + parallel_group_to_nodes_[parallel_group].emplace(node_item); std::set group{parallel_group}; - node_to_groups_[node_item].emplace(parallel_group); + node_to_parallel_groups_[node_item].emplace(parallel_group); } } else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { std::set parallel_groups; @@ -2049,25 +2049,28 @@ Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { if (!parallel_groups.empty()) { for (const auto ¶llel_group : parallel_groups) { - group_to_nodes_[parallel_group].emplace(node_item); + parallel_group_to_nodes_[parallel_group].emplace(node_item); GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); } - node_to_groups_.emplace(node_item, std::move(parallel_groups)); + node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); } } return SUCCESS; } -Status HybridModelBuilder::ParseDependentForHcclNodes() { - for (const auto &it : node_to_groups_) { +Status HybridModelBuilder::ParseDependentByParallelGroup() { + for (const auto &it : node_to_parallel_groups_) { auto node_item = it.first; auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); for (const auto ¶llel_group : it.second) { - auto &dependent_nodes = group_to_nodes_[parallel_group]; + auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; NodeItem *nearest_dep_node = nullptr; int max_id = -1; for (auto &dep_node : dependent_nodes) { + if (node_item == dep_node) { + continue; + } auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); if (src_engine_type == dst_engine_type) { continue; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 0b91afbe..1481d61e 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -63,10 +63,10 @@ class HybridModelBuilder { Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); - Status ParseParallelGroups(NodeItem *node_item); + Status CollectParallelGroups(NodeItem *node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies); - Status ParseDependentForHcclNodes(); + Status ParseDependentByParallelGroup(); Status IndexTaskDefs(); Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); Status IndexSpecialNodes(); @@ -102,8 +102,8 @@ class HybridModelBuilder { ComputeGraphPtr root_graph_; std::map subgraph_models_; std::map constant_op_nodes_; - std::map> group_to_nodes_; - std::map> node_to_groups_; + std::map> parallel_group_to_nodes_; + std::map> node_to_parallel_groups_; HybridModel &hybrid_model_; std::map>> node_ref_inputs_; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index e5669d15..2166b274 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -19,10 +19,12 @@ #include #include "runtime/rt.h" +#include "graph/utils/node_utils.h" #define protected public #define private public #include "hybrid/model/hybrid_model_builder.h" #include "hybrid/model/hybrid_model.h" +#include "hybrid/node_executor/node_executor.h" #include "model/ge_model.h" #include "model/ge_root_model.h" #include "hybrid/node_executor/aicore/aicore_op_task.h" @@ -247,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { ASSERT_EQ(ret,PARAM_INVALID); } - TEST_F(UtestGeHybrid, hybrid_model_executor) { +TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); @@ -258,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) { HybridModelExecutor executor(model_ptr, device_id, stream); executor.Init(); } + +TEST_F(UtestGeHybrid, test_parse_parallel_group) { + NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", + NodeExecutorManager::ExecutorType::HCCL); + ComputeGraphPtr compute_graph = MakeShared("test"); + OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce"); + op_desc->SetId(0); + ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1"); + auto node = compute_graph->AddNode(op_desc); + std::unique_ptr node_item; + NodeItem::Create(node, node_item); + node_item->node_id = 0; + + op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); + GeRootModelPtr root_model = MakeShared(compute_graph); + HybridModel model(root_model); + + HybridModelBuilder builder(model); + builder.root_graph_ = compute_graph; + ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); + + ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); + ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); + + OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall"); + op_desc_1->AddSubgraphName("subgraph"); + auto node_1 = compute_graph->AddNode(op_desc_1); + + ComputeGraphPtr subgraph = MakeShared("subgraph"); + ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS); + + std::unique_ptr node_item_1; + NodeItem::Create(node_1, node_item_1); + node_item_1->node_id = 1; + + ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); + ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); + ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); + + OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce"); + ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1"); + auto node_2 = subgraph->AddNode(op_desc_2); + ASSERT_TRUE(node_2 != nullptr); + + OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2"); + ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2"); + auto node_3 = subgraph->AddNode(op_desc_3); + ASSERT_TRUE(node_3 != nullptr); + + ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); + ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2); + ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2); + ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2); + ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1); + + ASSERT_FALSE(node_item->has_observer); + ASSERT_TRUE(node_item_1->dependents_for_execution.empty()); + ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); + ASSERT_TRUE(node_item->has_observer); + ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); + ASSERT_EQ(node_item_1->dependents_for_execution[0], node); + + // repeat parse + ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); + ASSERT_TRUE(node_item->has_observer); + ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); + ASSERT_EQ(node_item_1->dependents_for_execution[0], node); +} \ No newline at end of file