@@ -806,7 +806,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
} | } | ||||
} | } | ||||
GE_CHK_STATUS_RET(ParseDependentForHcclNodes(), "Failed to establish dependencies for hccl ops"); | |||||
GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); | |||||
GELOGI("Done loading all subgraphs successfully."); | GELOGI("Done loading all subgraphs successfully."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1907,7 +1907,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | ||||
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | ||||
GE_CHK_STATUS_RET_NOLOG(ParseParallelGroups(node_item)); | |||||
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | ||||
node_item->input_start = input_start; | node_item->input_start = input_start; | ||||
@@ -2015,16 +2015,16 @@ Status HybridModelBuilder::CheckAicpuOpList() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { | |||||
Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
const auto &node = node_item->node; | const auto &node = node_item->node; | ||||
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | ||||
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | ||||
std::string parallel_group; | std::string parallel_group; | ||||
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | ||||
GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str()); | GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str()); | ||||
group_to_nodes_[parallel_group].emplace(node_item); | |||||
parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||||
std::set<std::string> group{parallel_group}; | std::set<std::string> group{parallel_group}; | ||||
node_to_groups_[node_item].emplace(parallel_group); | |||||
node_to_parallel_groups_[node_item].emplace(parallel_group); | |||||
} | } | ||||
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { | } else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { | ||||
std::set<std::string> parallel_groups; | std::set<std::string> parallel_groups; | ||||
@@ -2049,25 +2049,28 @@ Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { | |||||
if (!parallel_groups.empty()) { | if (!parallel_groups.empty()) { | ||||
for (const auto ¶llel_group : parallel_groups) { | for (const auto ¶llel_group : parallel_groups) { | ||||
group_to_nodes_[parallel_group].emplace(node_item); | |||||
parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||||
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); | GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); | ||||
} | } | ||||
node_to_groups_.emplace(node_item, std::move(parallel_groups)); | |||||
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::ParseDependentForHcclNodes() { | |||||
for (const auto &it : node_to_groups_) { | |||||
Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||||
for (const auto &it : node_to_parallel_groups_) { | |||||
auto node_item = it.first; | auto node_item = it.first; | ||||
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | ||||
for (const auto ¶llel_group : it.second) { | for (const auto ¶llel_group : it.second) { | ||||
auto &dependent_nodes = group_to_nodes_[parallel_group]; | |||||
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | |||||
NodeItem *nearest_dep_node = nullptr; | NodeItem *nearest_dep_node = nullptr; | ||||
int max_id = -1; | int max_id = -1; | ||||
for (auto &dep_node : dependent_nodes) { | for (auto &dep_node : dependent_nodes) { | ||||
if (node_item == dep_node) { | |||||
continue; | |||||
} | |||||
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); | auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); | ||||
if (src_engine_type == dst_engine_type) { | if (src_engine_type == dst_engine_type) { | ||||
continue; | continue; | ||||
@@ -63,10 +63,10 @@ class HybridModelBuilder { | |||||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | ||||
Status ParseParallelGroups(NodeItem *node_item); | |||||
Status CollectParallelGroups(NodeItem *node_item); | |||||
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies); | Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies); | ||||
Status ParseDependentForHcclNodes(); | |||||
Status ParseDependentByParallelGroup(); | |||||
Status IndexTaskDefs(); | Status IndexTaskDefs(); | ||||
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); | Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); | ||||
Status IndexSpecialNodes(); | Status IndexSpecialNodes(); | ||||
@@ -102,8 +102,8 @@ class HybridModelBuilder { | |||||
ComputeGraphPtr root_graph_; | ComputeGraphPtr root_graph_; | ||||
std::map<std::string, GeModelPtr> subgraph_models_; | std::map<std::string, GeModelPtr> subgraph_models_; | ||||
std::map<std::string, NodePtr> constant_op_nodes_; | std::map<std::string, NodePtr> constant_op_nodes_; | ||||
std::map<std::string, std::set<NodeItem *>> group_to_nodes_; | |||||
std::map<NodeItem *, std::set<std::string>> node_to_groups_; | |||||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||||
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | |||||
HybridModel &hybrid_model_; | HybridModel &hybrid_model_; | ||||
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; | std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; | ||||
@@ -19,10 +19,12 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "runtime/rt.h" | #include "runtime/rt.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "hybrid/model/hybrid_model_builder.h" | #include "hybrid/model/hybrid_model_builder.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "hybrid/node_executor/node_executor.h" | |||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
#include "model/ge_root_model.h" | #include "model/ge_root_model.h" | ||||
#include "hybrid/node_executor/aicore/aicore_op_task.h" | #include "hybrid/node_executor/aicore/aicore_op_task.h" | ||||
@@ -247,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||||
ASSERT_EQ(ret,PARAM_INVALID); | ASSERT_EQ(ret,PARAM_INVALID); | ||||
} | } | ||||
TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
HybridModel model(root_model); | HybridModel model(root_model); | ||||
@@ -258,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||||
HybridModelExecutor executor(model_ptr, device_id, stream); | HybridModelExecutor executor(model_ptr, device_id, stream); | ||||
executor.Init(); | executor.Init(); | ||||
} | } | ||||
TEST_F(UtestGeHybrid, test_parse_parallel_group) { | |||||
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", | |||||
NodeExecutorManager::ExecutorType::HCCL); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test"); | |||||
OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce"); | |||||
op_desc->SetId(0); | |||||
ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1"); | |||||
auto node = compute_graph->AddNode(op_desc); | |||||
std::unique_ptr<NodeItem> node_item; | |||||
NodeItem::Create(node, node_item); | |||||
node_item->node_id = 0; | |||||
op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); | |||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||||
HybridModel model(root_model); | |||||
HybridModelBuilder builder(model); | |||||
builder.root_graph_ = compute_graph; | |||||
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); | |||||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | |||||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); | |||||
OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall"); | |||||
op_desc_1->AddSubgraphName("subgraph"); | |||||
auto node_1 = compute_graph->AddNode(op_desc_1); | |||||
ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph"); | |||||
ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS); | |||||
std::unique_ptr<NodeItem> node_item_1; | |||||
NodeItem::Create(node_1, node_item_1); | |||||
node_item_1->node_id = 1; | |||||
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); | |||||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | |||||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); | |||||
OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce"); | |||||
ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1"); | |||||
auto node_2 = subgraph->AddNode(op_desc_2); | |||||
ASSERT_TRUE(node_2 != nullptr); | |||||
OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2"); | |||||
ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2"); | |||||
auto node_3 = subgraph->AddNode(op_desc_3); | |||||
ASSERT_TRUE(node_3 != nullptr); | |||||
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); | |||||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2); | |||||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2); | |||||
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2); | |||||
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1); | |||||
ASSERT_FALSE(node_item->has_observer); | |||||
ASSERT_TRUE(node_item_1->dependents_for_execution.empty()); | |||||
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); | |||||
ASSERT_TRUE(node_item->has_observer); | |||||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||||
ASSERT_EQ(node_item_1->dependents_for_execution[0], node); | |||||
// repeat parse | |||||
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); | |||||
ASSERT_TRUE(node_item->has_observer); | |||||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||||
ASSERT_EQ(node_item_1->dependents_for_execution[0], node); | |||||
} |