diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index 8e312b09..e495dfdf 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -35,6 +35,9 @@ class ModelRunner { bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, std::shared_ptr davinci_model, std::shared_ptr listener); + + bool DistributeTask(uint32_t model_id); + bool LoadModelComplete(uint32_t model_id); const std::vector &GetTaskIdList(uint32_t model_id) const; @@ -43,6 +46,8 @@ class ModelRunner { const std::map> &GetRuntimeInfoMap(uint32_t model_id) const; + void *GetModelHandle(uint32_t model_id) const; + bool UnloadModel(uint32_t model_id); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); diff --git a/src/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc index b6e43dd5..9961ab4e 100644 --- a/src/ge/ge_runtime/model_runner.cc +++ b/src/ge/ge_runtime/model_runner.cc @@ -49,6 +49,15 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint return true; } +bool ModelRunner::DistributeTask(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + return false; + } + return model_iter->second->DistributeTask(); +} + bool ModelRunner::LoadModelComplete(uint32_t model_id) { auto model_iter = runtime_models_.find(model_id); if (model_iter == runtime_models_.end()) { @@ -91,6 +100,16 @@ const std::map> &ModelRunner::GetRunti return model_iter->second->GetRuntimeInfoMap(); } +void *ModelRunner::GetModelHandle(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGW("Model id %u not found.", model_id); + return nullptr; + } + + return model_iter->second->GetModelHandle(); +} + bool ModelRunner::UnloadModel(uint32_t model_id) { auto iter = runtime_models_.find(model_id); if (iter != runtime_models_.end()) { diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index bdf8f2a6..f0405056 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -283,14 +283,16 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr } GenerateTask(device_id, session_id, davinci_model); + return status; +} - status = LoadTask(); +bool RuntimeModel::DistributeTask() { + bool status = LoadTask(); if (!status) { GELOGE(FAILED, "DistributeTask failed"); - return status; + return false; } - - return status; + return true; } bool RuntimeModel::Run() { diff --git a/src/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h index 67535296..d0c466d4 100644 --- a/src/ge/ge_runtime/runtime_model.h +++ b/src/ge/ge_runtime/runtime_model.h @@ -35,10 +35,12 @@ class RuntimeModel { ~RuntimeModel(); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); + bool DistributeTask(); bool LoadComplete(); const std::vector &GetTaskIdList() const; const std::vector &GetStreamIdList() const; const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } + rtModel_t GetModelHandle() const { return rt_model_handle_; } bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc,