From: @wan_xuelei Reviewed-by: @xchu42 Signed-off-by:tags/v1.3.0
@@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { | |||||
auto root_graph = ge_root_model.GetRootGraph(); | |||||
if (root_graph == nullptr) { | |||||
GELOGE(FAILED, "no model on root model"); | |||||
return false; | |||||
} | |||||
bool is_shape_unknown = root_graph->GetGraphUnknownFlag(); | |||||
bool is_dsp_partitioned_graph = false; | |||||
(void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph); | |||||
return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag(); | |||||
} | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief load model online | /// @brief load model online | ||||
@@ -299,9 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
} | } | ||||
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | ||||
string model_name = ""; | string model_name = ""; | ||||
bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag(); | |||||
// if multi subgraph is known, do hybrid load process | |||||
if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) { | |||||
if (IsNeedHybridLoad(*ge_root_model)) { | |||||
return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); | return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener); | ||||
} | } | ||||
@@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
std::vector<InputOutputDims> &output_dims); | std::vector<InputOutputDims> &output_dims); | ||||
bool IsDynamicShape(uint32_t model_id); | bool IsDynamicShape(uint32_t model_id); | ||||
bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model); | |||||
ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ||||
ge::Status EnableExceptionDump(const std::map<string, string> &options); | ge::Status EnableExceptionDump(const std::map<string, string> &options); | ||||
@@ -340,6 +341,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
void GenModelId(uint32_t *id); | void GenModelId(uint32_t *id); | ||||
std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | ||||
std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | ||||
std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | ||||
@@ -151,6 +151,15 @@ class DModelListener : public ModelListener { | |||||
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | ||||
}; | }; | ||||
TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) { | |||||
ModelManager mm; | |||||
uint32_t model_id = 0; | |||||
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("graph"); | |||||
ge::GeRootModel model; | |||||
EXPECT_EQ(mm.IsNeedHybridLoad(model), false); | |||||
model.SetRootGraph(root_graph); | |||||
EXPECT_EQ(mm.IsNeedHybridLoad(model), false); | |||||
} | |||||
TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | ||||
ModelManager mm; | ModelManager mm; | ||||