Browse Source

fix bug that all subgraph is unknown and netoutput format is not nd bug

tags/v1.2.0
wxl 3 years ago
parent
commit
723f398670
2 changed files with 5 additions and 3 deletions
  1. +3
    -3
      ge/graph/load/model_manager/model_manager.cc
  2. +2
    -0
      ge/graph/passes/net_output_pass.cc

+ 3
- 3
ge/graph/load/model_manager/model_manager.cc View File

@@ -297,10 +297,11 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
if (model_id == INVALID_MODEL_ID) { if (model_id == INVALID_MODEL_ID) {
GenModelId(&model_id); GenModelId(&model_id);
} }
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
string model_name = ""; string model_name = "";
bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag();
if (is_shape_unknown || GetContext().GetHostExecFlag()) {
// if multi subgraph is known, do hybrid load process
if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) {
return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener);
} }


@@ -322,7 +323,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
auto root_graph = ge_root_model->GetRootGraph(); auto root_graph = ge_root_model->GetRootGraph();
GE_CHECK_NOTNULL(root_graph); GE_CHECK_NOTNULL(root_graph);
string root_model_name = root_graph->GetName(); string root_model_name = root_graph->GetName();
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr ge_model = name_to_model[root_model_name]; GeModelPtr ge_model = name_to_model[root_model_name];
Status ret = SUCCESS; Status ret = SUCCESS;
do { do {


+ 2
- 0
ge/graph/passes/net_output_pass.cc View File

@@ -202,6 +202,8 @@ Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) {
GE_CHECK_NOTNULL(src_op_desc); GE_CHECK_NOTNULL(src_op_desc);
uint32_t peer_index = static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx()); uint32_t peer_index = static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx());
ge::GeTensorDesc output_in_desc = src_op_desc->GetOutputDesc(peer_index); ge::GeTensorDesc output_in_desc = src_op_desc->GetOutputDesc(peer_index);
output_in_desc.SetFormat(FORMAT_ND);
output_in_desc.SetOriginFormat(FORMAT_ND);
if (net_output_desc->UpdateInputDesc(index, output_in_desc) != GRAPH_SUCCESS) { if (net_output_desc->UpdateInputDesc(index, output_in_desc) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Update input desc failed, index:%u.", index); GELOGE(INTERNAL_ERROR, "Update input desc failed, index:%u.", index);
return INTERNAL_ERROR; return INTERNAL_ERROR;


Loading…
Cancel
Save