Browse Source

!1379 optimize hccl op dependency

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
7321eb0669
2 changed files with 36 additions and 26 deletions
  1. +23
    -18
      ge/hybrid/model/hybrid_model_builder.cc
  2. +13
    -8
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 23
- 18
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1089,14 +1089,14 @@ Status HybridModelBuilder::LoadTask(NodeItem &node_item) {


Status HybridModelBuilder::LoadTasks() { Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed.");
std::map<int64_t, NodeItem *> ordered_partitioned_calls;
std::map<int, std::map<std::string, NodeItem *>> ordered_partitioned_calls;
for (auto &it : hybrid_model_.node_items_) { for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second; auto &node_item = it.second;
if (node_item->node_type == NETOUTPUT) { if (node_item->node_type == NETOUTPUT) {
continue; continue;
} }
if (node_item->node_type == PARTITIONEDCALL) { if (node_item->node_type == PARTITIONEDCALL) {
ordered_partitioned_calls.emplace(node_item->node_id, node_item.get());
ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get();
continue; continue;
} }
GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item)); GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item));
@@ -1104,7 +1104,9 @@ Status HybridModelBuilder::LoadTasks() {


// HCCL operators need to be loaded in the same order across different processes // HCCL operators need to be loaded in the same order across different processes
for (auto &it : ordered_partitioned_calls) { for (auto &it : ordered_partitioned_calls) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it.second));
for (auto &it2 : it.second) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second));
}
} }


return SUCCESS; return SUCCESS;
@@ -1637,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem
auto temp_graph = MakeShared<ComputeGraph>("temp"); auto temp_graph = MakeShared<ComputeGraph>("temp");
GE_CHECK_NOTNULL(temp_graph); GE_CHECK_NOTNULL(temp_graph);
auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); auto wrapper_node = temp_graph->AddNode(wrapper_op_desc);
wrapper_op_desc->SetId(parent_node_item->node_id);
GeModelPtr ge_model = subgraph_models_[subgraph_name]; GeModelPtr ge_model = subgraph_models_[subgraph_name];
GE_CHECK_NOTNULL(ge_model); GE_CHECK_NOTNULL(ge_model);
hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model);
@@ -1916,7 +1919,6 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
NodeItem *node_item = nullptr; NodeItem *node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item));
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task


node_item->input_start = input_start; node_item->input_start = input_start;
@@ -2069,22 +2071,17 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
} }


Status HybridModelBuilder::ParseDependentByParallelGroup() { Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (auto &it : hybrid_model_.node_items_) {
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get()));
}
for (const auto &it : node_to_parallel_groups_) { for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first; auto node_item = it.first;
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) { for (const auto &parallel_group : it.second) {
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr; NodeItem *nearest_dep_node = nullptr;
int max_id = -1; int max_id = -1;
for (auto &dep_node : dependent_nodes) { for (auto &dep_node : dependent_nodes) {
if (node_item == dep_node) {
continue;
}
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node);
if (src_engine_type == dst_engine_type) {
continue;
}

if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) {
nearest_dep_node = dep_node; nearest_dep_node = dep_node;
max_id = dep_node->node_id; max_id = dep_node->node_id;
@@ -2092,17 +2089,25 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() {
} }


if (nearest_dep_node != nullptr) { if (nearest_dep_node != nullptr) {
GELOGD("Add dependency for nodes of same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str());
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node);
if (src_engine_type == dst_executor_type) {
GELOGD("No need to add dependency for nodes with same executor type");
continue;
}
auto &deps = node_item->dependents_for_execution; auto &deps = node_item->dependents_for_execution;
if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) {
GELOGD("Already has dependency, skip it");
GELOGD("%s->%s Already has dependency, skip it",
nearest_dep_node->node->GetName().c_str(),
node_item->NodeName().c_str());
continue; continue;
} }
nearest_dep_node->has_observer = true; nearest_dep_node->has_observer = true;
deps.emplace_back(nearest_dep_node->node); deps.emplace_back(nearest_dep_node->node);
GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
} }
} }
} }


+ 13
- 8
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -315,16 +315,21 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) {
ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2); ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2);
ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1); ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1);


ASSERT_FALSE(node_item->has_observer);
ASSERT_TRUE(node_item_1->dependents_for_execution.empty());
builder.parallel_group_to_nodes_.clear();
builder.node_ref_inputs_.clear();
model.node_items_[node] = std::move(node_item);
model.node_items_[node_1] = std::move(node_item_1);

ASSERT_FALSE(model.node_items_[node]->has_observer);
ASSERT_TRUE(model.node_items_[node_1]->dependents_for_execution.empty());
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
ASSERT_TRUE(model.node_items_[node]->has_observer);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node);


// repeat parse // repeat parse
ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS); ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
ASSERT_TRUE(node_item->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(node_item_1->dependents_for_execution[0], node);
ASSERT_TRUE(model.node_items_[node]->has_observer);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node);
} }

Loading…
Cancel
Save