From: @wan_xuelei Reviewed-by: @xchu42 Signed-off-by:tags/v1.3.0
| @@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { | |||||
| auto root_graph = ge_root_model.GetRootGraph(); | |||||
| if (root_graph == nullptr) { | |||||
| GELOGE(FAILED, "no model on root model"); | |||||
| return false; | |||||
| } | |||||
| bool is_shape_unknown = root_graph->GetGraphUnknownFlag(); | |||||
| bool is_dsp_partitioned_graph = false; | |||||
| (void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph); | |||||
| return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag(); | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| /// @brief load model online | /// @brief load model online | ||||
| @@ -299,9 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| } | } | ||||
| auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | ||||
| string model_name = ""; | string model_name = ""; | ||||
| bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); | |||||
| // if multi subgraph is known, do hybrid load process | |||||
| if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) { | |||||
| if (IsNeedHybridLoad(*ge_root_model)) { | |||||
| return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); | return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); | ||||
| } | } | ||||
| @@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| std::vector<InputOutputDims> &output_dims); | std::vector<InputOutputDims> &output_dims); | ||||
| bool IsDynamicShape(uint32_t model_id); | bool IsDynamicShape(uint32_t model_id); | ||||
| bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model); | |||||
| ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ||||
| ge::Status EnableExceptionDump(const std::map<string, string> &options); | ge::Status EnableExceptionDump(const std::map<string, string> &options); | ||||
| @@ -340,6 +341,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| void GenModelId(uint32_t *id); | void GenModelId(uint32_t *id); | ||||
| std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | ||||
| std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | ||||
| std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | ||||
| @@ -151,6 +151,15 @@ class DModelListener : public ModelListener { | |||||
| uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | ||||
| }; | }; | ||||
| TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) { | |||||
| ModelManager mm; | |||||
| uint32_t model_id = 0; | |||||
| ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("graph"); | |||||
| ge::GeRootModel model; | |||||
| EXPECT_EQ(mm.IsNeedHybridLoad(model), false); | |||||
| model.SetRootGraph(root_graph); | |||||
| EXPECT_EQ(mm.IsNeedHybridLoad(model), false); | |||||
| } | |||||
| TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | ||||
| ModelManager mm; | ModelManager mm; | ||||