From e20387891086ca2ae9e4c1ece29d1df17d564f31 Mon Sep 17 00:00:00 2001 From: wxl Date: Wed, 17 Mar 2021 22:42:39 +0800 Subject: [PATCH] fix bug of dynamic shape load error --- ge/graph/load/model_manager/model_manager.cc | 16 ++++++++++++---- ge/graph/load/model_manager/model_manager.h | 3 +++ tests/ut/ge/graph/load/model_manager_unittest.cc | 9 +++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index aa2de7e6..27cbd526 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string return SUCCESS; } +bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { + auto root_graph = ge_root_model.GetRootGraph(); + if (root_graph == nullptr) { + GELOGE(FAILED, "no model on root model"); + return false; + } + bool is_shape_unknown = root_graph->GetGraphUnknownFlag(); + bool is_dsp_partitioned_graph = false; + (void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph); + return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag(); +} /// /// @ingroup domi_ome /// @brief load model online @@ -299,10 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetSubgraphInstanceNameToModel(); string model_name = ""; - bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); - bool is_dsp_partitioned_graph = false; - (void)AttrUtils::GetBool(ge_root_model->GetRootGraph(), ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_shape_unknown); - if (is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag()) { + if (IsNeedHybridLoad(*ge_root_model)) { return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); } diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index f2d55db7..735e4a7a 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::vector &output_dims); bool IsDynamicShape(uint32_t model_id); + bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model); ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); ge::Status EnableExceptionDump(const std::map &options); @@ -339,6 +340,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status DeleteModel(uint32_t id); void GenModelId(uint32_t *id); + bool IsNeedHybridLoad(); + std::map> model_map_; std::map> hybrid_model_map_; diff --git a/tests/ut/ge/graph/load/model_manager_unittest.cc b/tests/ut/ge/graph/load/model_manager_unittest.cc index 0e65954d..342f6362 100644 --- a/tests/ut/ge/graph/load/model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/model_manager_unittest.cc @@ -151,6 +151,15 @@ class DModelListener : public ModelListener { uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } }; +TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) { + ModelManager mm; + uint32_t model_id = 0; + ComputeGraphPtr root_graph = std::make_shared("graph"); + ge::GeRootModel model; + EXPECT_EQ(mm.IsNeedHybridLoad(model), false); + model.SetRootGraph(root_graph); + EXPECT_EQ(mm.IsNeedHybridLoad(model), false); +} TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { ModelManager mm;