diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 316b94de..dfd6ac6b 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1072,30 +1072,39 @@ Status HybridModelBuilder::InitWeights() { return SUCCESS; } +Status HybridModelBuilder::LoadTask(NodeItem &node_item) { + auto &node_ptr = node_item.node; + GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); + auto load_ret = node_item.node_executor->LoadTask(hybrid_model_, + node_ptr, + node_item.kernel_task); + if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { + GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); + return load_ret; + } + + GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + return SUCCESS; +} + Status HybridModelBuilder::LoadTasks() { GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); - std::map ordered_node_items; + std::map ordered_partitioned_calls; for (auto &it : hybrid_model_.node_items_) { - auto &node_item = it.second; - ordered_node_items.emplace(node_item->node_id, node_item.get()); - } - for (auto &it : ordered_node_items) { auto &node_item = it.second; auto &node_ptr = node_item->node; if (node_item->node_type == NETOUTPUT) { continue; } - - GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); - auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, - node_ptr, - node_item->kernel_task); - if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { - GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); - return load_ret; + if (node_item->node_type == PARTITIONEDCALL) { + ordered_partitioned_calls.emplace(node_item->node_id, node_item.get()); } + GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item)); + } - GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + // HCCL operators need to be loaded in the same order across different processes + for (auto &it : ordered_partitioned_calls) { + GE_CHK_STATUS_RET_NOLOG(LoadTask(*it.second)); } return SUCCESS; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 1481d61e..a59a282a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -57,6 +57,7 @@ class HybridModelBuilder { Status ValidateParams(); Status LoadGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); + Status LoadTask(NodeItem &node_item); Status LoadTasks(); Status IdentifyVariableOutputs(NodeItem &node_item); Status IdentifySameInputs(NodeItem &node_item);