@@ -806,7 +806,7 @@ Status HybridModelBuilder::LoadGraph() { | |||
} | |||
} | |||
GE_CHK_STATUS_RET(ParseDependentForHcclNodes(), "Failed to establish dependencies for hccl ops"); | |||
GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); | |||
GELOGI("Done loading all subgraphs successfully."); | |||
return SUCCESS; | |||
} | |||
@@ -1907,7 +1907,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||
NodeItem *node_item = nullptr; | |||
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(ParseParallelGroups(node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | |||
node_item->input_start = input_start; | |||
@@ -2015,16 +2015,16 @@ Status HybridModelBuilder::CheckAicpuOpList() { | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { | |||
Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||
const auto &node = node_item->node; | |||
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | |||
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | |||
std::string parallel_group; | |||
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | |||
GELOGD("[%s] Got parallel group = %s", node_item->NodeName().c_str(), parallel_group.c_str()); | |||
group_to_nodes_[parallel_group].emplace(node_item); | |||
parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||
std::set<std::string> group{parallel_group}; | |||
node_to_groups_[node_item].emplace(parallel_group); | |||
node_to_parallel_groups_[node_item].emplace(parallel_group); | |||
} | |||
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { | |||
std::set<std::string> parallel_groups; | |||
@@ -2049,25 +2049,28 @@ Status HybridModelBuilder::ParseParallelGroups(NodeItem *node_item) { | |||
if (!parallel_groups.empty()) { | |||
for (const auto ¶llel_group : parallel_groups) { | |||
group_to_nodes_[parallel_group].emplace(node_item); | |||
parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); | |||
} | |||
node_to_groups_.emplace(node_item, std::move(parallel_groups)); | |||
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::ParseDependentForHcclNodes() { | |||
for (const auto &it : node_to_groups_) { | |||
Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||
for (const auto &it : node_to_parallel_groups_) { | |||
auto node_item = it.first; | |||
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||
for (const auto ¶llel_group : it.second) { | |||
auto &dependent_nodes = group_to_nodes_[parallel_group]; | |||
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | |||
NodeItem *nearest_dep_node = nullptr; | |||
int max_id = -1; | |||
for (auto &dep_node : dependent_nodes) { | |||
if (node_item == dep_node) { | |||
continue; | |||
} | |||
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); | |||
if (src_engine_type == dst_engine_type) { | |||
continue; | |||
@@ -63,10 +63,10 @@ class HybridModelBuilder { | |||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | |||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | |||
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | |||
Status ParseParallelGroups(NodeItem *node_item); | |||
Status CollectParallelGroups(NodeItem *node_item); | |||
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | |||
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies); | |||
Status ParseDependentForHcclNodes(); | |||
Status ParseDependentByParallelGroup(); | |||
Status IndexTaskDefs(); | |||
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); | |||
Status IndexSpecialNodes(); | |||
@@ -102,8 +102,8 @@ class HybridModelBuilder { | |||
ComputeGraphPtr root_graph_; | |||
std::map<std::string, GeModelPtr> subgraph_models_; | |||
std::map<std::string, NodePtr> constant_op_nodes_; | |||
std::map<std::string, std::set<NodeItem *>> group_to_nodes_; | |||
std::map<NodeItem *, std::set<std::string>> node_to_groups_; | |||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | |||
HybridModel &hybrid_model_; | |||
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; | |||
@@ -19,10 +19,12 @@ | |||
#include <vector> | |||
#include "runtime/rt.h" | |||
#include "graph/utils/node_utils.h" | |||
#define protected public | |||
#define private public | |||
#include "hybrid/model/hybrid_model_builder.h" | |||
#include "hybrid/model/hybrid_model.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
#include "model/ge_model.h" | |||
#include "model/ge_root_model.h" | |||
#include "hybrid/node_executor/aicore/aicore_op_task.h" | |||
@@ -247,7 +249,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||
ASSERT_EQ(ret,PARAM_INVALID); | |||
} | |||
TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||
TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||
HybridModel model(root_model); | |||
@@ -258,3 +260,71 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||
HybridModelExecutor executor(model_ptr, device_id, stream); | |||
executor.Init(); | |||
} | |||
TEST_F(UtestGeHybrid, test_parse_parallel_group) { | |||
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", | |||
NodeExecutorManager::ExecutorType::HCCL); | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test"); | |||
OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce"); | |||
op_desc->SetId(0); | |||
ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1"); | |||
auto node = compute_graph->AddNode(op_desc); | |||
std::unique_ptr<NodeItem> node_item; | |||
NodeItem::Create(node, node_item); | |||
node_item->node_id = 0; | |||
op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); | |||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||
HybridModel model(root_model); | |||
HybridModelBuilder builder(model); | |||
builder.root_graph_ = compute_graph; | |||
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); | |||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | |||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); | |||
OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall"); | |||
op_desc_1->AddSubgraphName("subgraph"); | |||
auto node_1 = compute_graph->AddNode(op_desc_1); | |||
ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph"); | |||
ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS); | |||
std::unique_ptr<NodeItem> node_item_1; | |||
NodeItem::Create(node_1, node_item_1); | |||
node_item_1->node_id = 1; | |||
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); | |||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | |||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1); | |||
OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce"); | |||
ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1"); | |||
auto node_2 = subgraph->AddNode(op_desc_2); | |||
ASSERT_TRUE(node_2 != nullptr); | |||
OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2"); | |||
ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2"); | |||
auto node_3 = subgraph->AddNode(op_desc_3); | |||
ASSERT_TRUE(node_3 != nullptr); | |||
ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS); | |||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2); | |||
ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2); | |||
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2); | |||
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1); | |||
ASSERT_FALSE(node_item->has_observer); | |||
ASSERT_TRUE(node_item_1->dependents_for_execution.empty()); | |||
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); | |||
ASSERT_TRUE(node_item->has_observer); | |||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||
ASSERT_EQ(node_item_1->dependents_for_execution[0], node); | |||
// repeat parse | |||
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); | |||
ASSERT_TRUE(node_item->has_observer); | |||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||
ASSERT_EQ(node_item_1->dependents_for_execution[0], node); | |||
} |