|
|
@@ -270,7 +270,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
// not care result, if no this attr, stand for the op does not need force infershape |
|
|
|
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); |
|
|
|
GELOGD("node [%s] is need do infershape , flag is %d", |
|
|
|
GELOGD("node [%s] is need do infershape, flag is %d", |
|
|
|
op_desc->GetName().c_str(), |
|
|
|
node_item.is_need_force_infershape); |
|
|
|
return SUCCESS; |
|
|
@@ -537,7 +537,7 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { |
|
|
|
const auto &parent_node = graph.GetParentNode(); |
|
|
|
const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); |
|
|
|
if (net_output_node == nullptr) { |
|
|
|
GELOGD("Graph has no netoutput no need to merge."); |
|
|
|
GELOGD("Graph has no netoutput no need to merge"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
const auto &net_output_desc = net_output_node->GetOpDesc(); |
|
|
@@ -601,6 +601,7 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { |
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { |
|
|
|
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); |
|
|
|
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); |
|
|
|
for (const auto &node : root_graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
@@ -670,7 +671,7 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, |
|
|
|
GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), |
|
|
|
"[%s] Failed to merge net output nodes for subgraph", |
|
|
|
sub_graph.GetName().c_str()); |
|
|
|
GELOGD("[%s] Done merging subgraph inputs and outputs successfully.", sub_graph.GetName().c_str()); |
|
|
|
GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str()); |
|
|
|
|
|
|
|
for (auto &sub_node : sub_graph.GetDirectNode()) { |
|
|
|
auto sub_op_type = sub_node->GetType(); |
|
|
@@ -703,7 +704,7 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, |
|
|
|
sub_node->SetOwnerComputeGraph(parent_graph); |
|
|
|
} |
|
|
|
|
|
|
|
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); |
|
|
|
GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str()); |
|
|
|
root_graph->RemoveSubgraph(sub_graph.GetName()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -1039,9 +1040,13 @@ Status HybridModelBuilder::InitWeights() { |
|
|
|
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu", |
|
|
|
weight_base, |
|
|
|
sub_weight_buffer->GetSize()); |
|
|
|
auto root_graph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); |
|
|
|
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(), std::move(sub_weight_buffer)); |
|
|
|
for (auto &node : root_graph->GetDirectNode()) { |
|
|
|
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
|
if (subgraph != ge_root_model_->GetRootGraph()) { |
|
|
|
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); |
|
|
|
for (auto &node : subgraph->GetDirectNode()) { |
|
|
|
if (node->GetType() != CONSTANT) { |
|
|
|
continue; |
|
|
|
} |
|
|
@@ -1170,11 +1175,11 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const |
|
|
|
GELOGD("Skip task type: %d", static_cast<int>(task_type)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
GELOGD("op_index = %u, task_type = %d.", op_index, task_type); |
|
|
|
GELOGD("op_index = %u, task_type = %d", op_index, task_type); |
|
|
|
|
|
|
|
auto iter = node_map.find(op_index); |
|
|
|
if (iter == node_map.end()) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to get node by op_index = %u.", op_index); |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to get node by op_index = %u", op_index); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
@@ -1183,7 +1188,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const |
|
|
|
ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(node->GetOpDesc()); |
|
|
|
} |
|
|
|
|
|
|
|
GELOGD("Task loaded for node: %s, task type = %d, op_index = %u.", node->GetName().c_str(), task_type, op_index); |
|
|
|
GELOGD("Task loaded for node: %s, task type = %d, op_index = %u", node->GetName().c_str(), task_type, op_index); |
|
|
|
hybrid_model_.task_defs_[node].emplace_back(task_def); |
|
|
|
} |
|
|
|
|
|
|
|