|
|
@@ -147,7 +147,6 @@ Status HybridModelBuilder::Build() { |
|
|
|
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); |
|
|
|
hybrid_model_.model_name_ = ge_root_model_->GetModelName(); |
|
|
|
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), |
|
|
|
"[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); |
|
|
@@ -175,8 +174,8 @@ Status HybridModelBuilder::BuildForSingleOp() { |
|
|
|
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); |
|
|
|
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); |
|
|
|
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); |
|
|
|
const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; |
|
|
|
GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), |
|
|
|
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; |
|
|
|
GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), |
|
|
|
"[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); |
|
|
|
GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); |
|
|
@@ -191,29 +190,6 @@ Status HybridModelBuilder::ValidateParams() { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::CopyGraph() { |
|
|
|
GELOGD("Copy compute graph begin."); |
|
|
|
auto root_graph = ge_root_model_->GetRootGraph(); |
|
|
|
|
|
|
|
ge_root_model_->IncreaseBuildTimes(); |
|
|
|
std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName() + "_" + |
|
|
|
std::to_string(ge_root_model_->GetBuildTimes()); |
|
|
|
ComputeGraphPtr new_root_graph = MakeShared<ComputeGraph>(new_graph_name); |
|
|
|
GE_CHECK_NOTNULL(new_root_graph); |
|
|
|
int32_t depth = 0; |
|
|
|
std::map<ConstNodePtr, NodePtr> node_old_2_new; |
|
|
|
std::map<ConstOpDescPtr, OpDescPtr> op_desc_old_2_new; |
|
|
|
graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Copy compute graph failed."); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
hybrid_model_.root_graph_ = new_root_graph; |
|
|
|
|
|
|
|
GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), |
|
|
@@ -838,7 +814,7 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::LoadGraph() { |
|
|
|
auto root_graph = hybrid_model_.root_graph_; |
|
|
|
auto root_graph = ge_root_model_->GetRootGraph(); |
|
|
|
if (!GetContext().GetHostExecFlag()) { |
|
|
|
std::shared_ptr<ComputeGraph> merged_graph; |
|
|
|
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
@@ -852,6 +828,7 @@ Status HybridModelBuilder::LoadGraph() { |
|
|
|
root_graph->GetAllNodesSize()); |
|
|
|
} |
|
|
|
|
|
|
|
hybrid_model_.root_graph_ = root_graph; |
|
|
|
GE_CHK_STATUS_RET(RelinkNextIteration(), "[%s] Relink NextIteration failed", GetGraphName()); |
|
|
|
// Reset node id by topological order across all subgraphs |
|
|
|
int64_t index = 0; |
|
|
@@ -900,7 +877,6 @@ Status HybridModelBuilder::LoadGraph() { |
|
|
|
} |
|
|
|
for (auto &it : hybrid_model_.known_shape_sub_models_) { |
|
|
|
auto node_item = MutableNodeItem(it.first); |
|
|
|
GE_CHECK_NOTNULL(node_item); |
|
|
|
AscendString graph_name; |
|
|
|
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); |
|
|
|
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); |
|
|
@@ -1149,9 +1125,7 @@ Status HybridModelBuilder::InitWeights() { |
|
|
|
sub_weight_buffer->GetSize()); |
|
|
|
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
|
if (subgraph != ge_root_model_->GetRootGraph()) { |
|
|
|
subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); |
|
|
|
} else { |
|
|
|
subgraph = hybrid_model_.root_graph_; |
|
|
|
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); |
|
|
@@ -1308,7 +1282,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::IndexTaskDefs() { |
|
|
|
const auto &root_graph = hybrid_model_.root_graph_; |
|
|
|
const auto root_graph = ge_root_model_->GetRootGraph(); |
|
|
|
const auto &root_graph_name = root_graph->GetName(); |
|
|
|
if (SetOutputNameAttr(*root_graph) != SUCCESS) { |
|
|
|
GELOGW("Set output name attr failed."); |
|
|
@@ -1342,7 +1316,7 @@ Status HybridModelBuilder::IndexTaskDefs() { |
|
|
|
|
|
|
|
Status HybridModelBuilder::IndexSpecialNodes() { |
|
|
|
GELOGD("Start to index special nodes"); |
|
|
|
const auto &root_graph = hybrid_model_.root_graph_; |
|
|
|
const auto &root_graph = ge_root_model_->GetRootGraph(); |
|
|
|
for (auto &node : root_graph->GetAllNodes()) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc()); |
|
|
@@ -1497,7 +1471,7 @@ Status HybridModelBuilder::InitRuntimeParams() { |
|
|
|
runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0; |
|
|
|
ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); |
|
|
|
runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0; |
|
|
|
runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); |
|
|
|
runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); |
|
|
|
value = 0; |
|
|
|
for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { |
|
|
|
(void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); |
|
|
@@ -1634,7 +1608,7 @@ Status HybridModelBuilder::TransAllVarData() { |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::CopyVarData() { |
|
|
|
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, |
|
|
|
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), |
|
|
|
runtime_param_.session_id, |
|
|
|
hybrid_model_.device_id_), |
|
|
|
"[Invoke][CopyVarData] failed."); |
|
|
@@ -1717,7 +1691,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelBuilder::RecoverGraphUnknownFlag() { |
|
|
|
const auto &root_graph = hybrid_model_.root_graph_; |
|
|
|
const auto &root_graph = ge_root_model_->GetRootGraph(); |
|
|
|
for (auto &sub_graph : root_graph->GetAllSubgraphs()) { |
|
|
|
GE_CHECK_NOTNULL(sub_graph); |
|
|
|
for (const auto &node : sub_graph->GetDirectNode()) { |
|
|
|