From 723f39867052ee3ef1a24691501e493526b791c0 Mon Sep 17 00:00:00 2001 From: wxl Date: Sat, 13 Mar 2021 14:04:07 +0800 Subject: [PATCH] fix bug that all subgraph is unknown and netoutput format is not nd bug --- ge/graph/load/model_manager/model_manager.cc | 6 +++--- ge/graph/passes/net_output_pass.cc | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 0273b77e..97ad0054 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -297,10 +297,11 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetSubgraphInstanceNameToModel(); string model_name = ""; bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); - if (is_shape_unknown || GetContext().GetHostExecFlag()) { + // if multi subgraph is known, do hybrid load process + if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) { return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); } @@ -322,7 +323,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetRootGraph(); GE_CHECK_NOTNULL(root_graph); string root_model_name = root_graph->GetName(); - auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); GeModelPtr ge_model = name_to_model[root_model_name]; Status ret = SUCCESS; do { diff --git a/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc index c553607f..37de2af9 100644 --- a/ge/graph/passes/net_output_pass.cc +++ b/ge/graph/passes/net_output_pass.cc @@ -202,6 +202,8 @@ Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) { GE_CHECK_NOTNULL(src_op_desc); uint32_t peer_index = static_cast(in_anchor->GetPeerOutAnchor()->GetIdx()); ge::GeTensorDesc output_in_desc = src_op_desc->GetOutputDesc(peer_index); + output_in_desc.SetFormat(FORMAT_ND); + output_in_desc.SetOriginFormat(FORMAT_ND); if (net_output_desc->UpdateInputDesc(index, output_in_desc) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Update input desc failed, index:%u.", index); return INTERNAL_ERROR;