| @@ -20,9 +20,12 @@ | |||
| #include <string> | |||
| #include "graph/load/model_manager/model_manager.h" | |||
| #include "graph/load/model_manager/davinci_model.h" | |||
| #include "omm/csa_interact.h" | |||
| namespace ge { | |||
| using Uint32Pair = pair<uint32_t, uint32_t>; | |||
| const uint32_t kInvalidModelId = UINT32_MAX; | |||
| GraphExecutor::GraphExecutor() | |||
| : init_flag_(false), | |||
| train_graph_flag_(false), | |||
| @@ -358,7 +361,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro | |||
| } | |||
| Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||
| const std::vector<InputTensorInfo> &input_tensor) { | |||
| const std::vector<InputTensorInfo> &input_tensor, | |||
| const RunAsyncCallback& callback) { | |||
| GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | |||
| if (graph_id != last_graph_id_) { | |||
| auto ret = FreeExecuteMemory(); | |||
| @@ -368,7 +372,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||
| } | |||
| last_graph_id_ = graph_id; | |||
| GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | |||
| Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); | |||
| Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | |||
| return GE_GRAPH_SYNC_MODEL_FAILED; | |||
| @@ -378,11 +382,81 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||
| return SUCCESS; | |||
| } | |||
| Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) { | |||
| bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) { | |||
| return lhs.second < rhs.second; | |||
| } | |||
| uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) { | |||
| std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId(); | |||
| if (model_ids.empty()) { | |||
| return kInvalidModelId; | |||
| } | |||
| if (model_ids.size() == 1) { | |||
| return ge_root_model->GetModelId(); | |||
| } | |||
| std::vector<Uint32Pair> model_id_to_loads; | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| GE_CHECK_NOTNULL(model_manager); | |||
| for (auto model_id : model_ids) { | |||
| auto davinci_model = model_manager->GetModel(model_id); | |||
| auto hybrid_model = model_manager->GetHybridModel(model_id); | |||
| if (hybrid_model == nullptr) { | |||
| GE_CHECK_NOTNULL(davinci_model); | |||
| } | |||
| uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() : | |||
| davinci_model->GetDataInputerSize(); | |||
| uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) : | |||
| static_cast<uint32_t>(davinci_model->GetRunningFlag()); | |||
| uint32_t load = input_load + running_load; | |||
| if (load == 0) { | |||
| return model_id; | |||
| } | |||
| model_id_to_loads.emplace_back(model_id, load); | |||
| } | |||
| sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad); | |||
| if (model_id_to_loads.empty()) { | |||
| return kInvalidModelId; | |||
| } | |||
| return model_id_to_loads.begin()->first; | |||
| } | |||
| Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||
| const RunAsyncCallback &callback) { | |||
| auto model_manager = ge::ModelManager::GetInstance(); | |||
| GE_CHECK_NOTNULL(model_manager); | |||
| if (model_manager->IsNeedHybridLoad(*ge_root_model)) { | |||
| auto model = model_manager->GetHybridModel(model_id); | |||
| GE_CHECK_NOTNULL(model); | |||
| if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||
| GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| auto model = model_manager->GetModel(model_id); | |||
| GE_CHECK_NOTNULL(model); | |||
| if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||
| GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs, | |||
| const RunAsyncCallback &callback) { | |||
| uint32_t model_id = GetExecuteModelId(ge_root_model); | |||
| if (model_id == kInvalidModelId) { | |||
| GELOGE(INTERNAL_ERROR, "No valid model id."); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| try { | |||
| auto model_manager = ge::ModelManager::GetInstance(); | |||
| GE_CHECK_NOTNULL(model_manager); | |||
| GELOGI("RunAsync begin.model_id %u", model_id); | |||
| if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) { | |||
| GELOGE(FAILED, "RunAsync: SetCallBack for model fail"); | |||
| return FAILED; | |||
| } | |||
| Status ret = model_manager->DataInputTensor(model_id, inputs); | |||
| if (ret != SUCCESS) { | |||
| @@ -50,7 +50,7 @@ class GraphExecutor { | |||
| std::vector<GeTensor> &output_tensor); | |||
| ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||
| const std::vector<InputTensorInfo> &input_tensor); | |||
| const std::vector<InputTensorInfo> &input_tensor, const RunAsyncCallback &callback); | |||
| Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | |||
| @@ -116,6 +116,8 @@ class GraphExecutor { | |||
| static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | |||
| uint32_t GetExecuteModelId(const GeRootModelPtr &ge_root_model); | |||
| private: | |||
| Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data, | |||
| OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc); | |||
| @@ -123,7 +125,8 @@ class GraphExecutor { | |||
| Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | |||
| std::vector<GeTensor> &output_tensor); | |||
| Status AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &input_tensor); | |||
| Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &input_tensor, | |||
| const RunAsyncCallback &callback); | |||
| void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | |||
| uint32_t output_size); | |||
| @@ -132,6 +135,9 @@ class GraphExecutor { | |||
| Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr); | |||
| static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||
| const RunAsyncCallback &callback); | |||
| bool init_flag_; | |||
| bool train_graph_flag_; | |||
| @@ -60,7 +60,6 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | |||
| return GE_GRAPH_PARAM_NULLPTR; | |||
| } | |||
| model_id = ge_root_model_ptr->GetModelId(); | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| GE_CHECK_NOTNULL(model_manager); | |||
| @@ -134,6 +134,8 @@ class DataInputer { | |||
| /// | |||
| void Stop() { queue_.Stop(); } | |||
| uint32_t Size() { return queue_.Size(); } | |||
| private: | |||
| /// | |||
| /// @ingroup domi_ome | |||
| @@ -2586,6 +2586,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
| ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); | |||
| while (model->RunFlag()) { | |||
| // Model hasn't truly started runing before received data | |||
| model->SetRunningFlag(false); | |||
| bool rslt_flg = true; | |||
| if (model->GetDataInputer() == nullptr) { | |||
| GELOGW("Data inputer is nullptr."); | |||
| @@ -2595,6 +2597,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
| std::shared_ptr<InputDataWrapper> data_wrapper; | |||
| Status ret = model->GetDataInputer()->Pop(data_wrapper); | |||
| // Model run indeedly start after received data. | |||
| model->SetRunningFlag(true); | |||
| if (data_wrapper == nullptr || ret != SUCCESS) { | |||
| GELOGI("data_wrapper is null!"); | |||
| continue; | |||
| @@ -2681,7 +2685,9 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
| model->iterator_count_++; | |||
| model->is_first_execute_ = false; | |||
| GELOGI("run iterator count is %lu", model->iterator_count_); | |||
| // model run finished | |||
| model->SetRunningFlag(false); | |||
| GELOGI("run iterator count is %lu, model_id:%u", model->iterator_count_, model->model_id_); | |||
| } | |||
| CsaInteract::GetInstance().WriteInternalErrorCode(); | |||
| @@ -2739,7 +2745,7 @@ Status DavinciModel::ModelRunStart() { | |||
| error_context_ = ErrorManager::GetInstance().GetErrorContext(); | |||
| CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); | |||
| GELOGI("model tread create success, model id:%u.", model_id_); | |||
| GELOGI("model thread create success, model id:%u.", model_id_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -4110,4 +4116,10 @@ Status DavinciModel::InitL1DataDumperArgs() { | |||
| return SUCCESS; | |||
| } | |||
| Status DavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
| auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||
| GE_CHECK_NOTNULL(listener); | |||
| listener->SetCallback(callback); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -221,6 +221,11 @@ class DavinciModel { | |||
| /// | |||
| DataInputer *const GetDataInputer() const { return data_inputer_; } | |||
| uint32_t GetDataInputerSize() { | |||
| GE_CHECK_NOTNULL(data_inputer_); | |||
| return data_inputer_->Size(); | |||
| } | |||
| // get Stream number | |||
| uint32_t StreamNum() const { return runtime_param_.stream_num; } | |||
| @@ -560,6 +565,10 @@ class DavinciModel { | |||
| return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); | |||
| } | |||
| bool GetRunningFlag() const { return running_flg_; } | |||
| void SetRunningFlag(bool flag) { running_flg_ = flag; } | |||
| Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||
| private: | |||
| // memory address of weights | |||
| uint8_t *weights_mem_base_; | |||
| @@ -924,6 +933,8 @@ class DavinciModel { | |||
| shared_ptr<ModelListener> listener_; | |||
| bool run_flg_; | |||
| // check whether model is running with data | |||
| bool running_flg_ = false; | |||
| mutex mux_run_flg_; | |||
| @@ -307,6 +307,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
| GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | |||
| if (model_id == INVALID_MODEL_ID) { | |||
| GenModelId(&model_id); | |||
| GELOGD("Generate new model_id:%u", model_id); | |||
| } | |||
| auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||
| string om_name; | |||
| @@ -339,7 +340,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
| GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); | |||
| break;); | |||
| GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); | |||
| /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. | |||
| /// These session_ids come from the same model, so the values of session_id are the same. | |||
| /// Update session_id for infer in load model to avoid the same session_id. | |||
| if (!ge_root_model->GetTrainFlag()) { | |||
| uint64_t new_session_id; | |||
| ret = GenSessionId(new_session_id); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); | |||
| ret = davinci_model->UpdateSessionId(new_session_id); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); | |||
| ge_model->InsertSessionMap(model_id, new_session_id); | |||
| GELOGD("Update new session id: %lu.", new_session_id); | |||
| } | |||
| GE_TIMESTAMP_START(Init); | |||
| GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); | |||
| GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); | |||
| @@ -352,16 +364,16 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
| return ret; | |||
| } | |||
| void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); | |||
| void ModelManager::InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", model_id); | |||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||
| model_map_[id] = davinci_model; | |||
| model_map_[model_id] = davinci_model; | |||
| } | |||
| void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||
| GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | |||
| void ModelManager::InsertModel(uint32_t model_id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||
| GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", model_id); | |||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||
| hybrid_model_map_[id] = hybrid_model; | |||
| hybrid_model_map_[model_id] = hybrid_model; | |||
| } | |||
| Status ModelManager::DeleteModel(uint32_t id) { | |||
| @@ -330,8 +330,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||
| /// @ingroup domi_ome | |||
| /// @brief insert new model into model manager set | |||
| /// | |||
| void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model); | |||
| void InsertModel(uint32_t id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||
| void InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model); | |||
| void InsertModel(uint32_t model_id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||
| /// | |||
| /// @ingroup domi_ome | |||
| @@ -117,6 +117,10 @@ const char *const kAIcoreEngine = "AIcoreEngine"; | |||
| const int32_t kDynamicDimsTypeIsGetNext = 0; | |||
| const int32_t kDynamicDimsTypeIsData = 1; | |||
| const char *const kGetNextName = "IteratorV2"; | |||
| const uint32_t kInitGraphCount = 1; | |||
| const uint32_t kNotAdded = 0; | |||
| const uint32_t kStartAdd = 1; | |||
| const uint32_t kDoneAdded = 2; | |||
| bool IsTailingOptimization() { | |||
| string is_tailing_optimization_option; | |||
| @@ -195,6 +199,8 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||
| graph_map_.clear(); | |||
| cache_helper_map_.clear(); | |||
| graph_id_to_add_graph_cond_.clear(); | |||
| graph_count_.clear(); | |||
| init_flag_ = true; | |||
| thread_run_flag_ = true; | |||
| @@ -204,6 +210,20 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) { | |||
| Status ret = SUCCESS; | |||
| for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { | |||
| uint32_t model_id = ge_root_model->GetAllModelId()[i]; | |||
| GELOGI("Unload model %u.", model_id); | |||
| ret = GraphLoader::UnloadModel(model_id); | |||
| if (ret != SUCCESS) { | |||
| GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
| return ret; | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| Status GraphManager::Finalize() { | |||
| if (!init_flag_) { | |||
| GELOGW("GraphManager has not been initialized."); | |||
| @@ -234,7 +254,6 @@ Status GraphManager::Finalize() { | |||
| unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; | |||
| continue; | |||
| } | |||
| // unload model | |||
| auto ge_root_model = graph_node->GetGeRootModel(); | |||
| if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { | |||
| @@ -244,15 +263,14 @@ Status GraphManager::Finalize() { | |||
| unload_model_ret = FAILED; | |||
| continue; | |||
| } | |||
| ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||
| ret = UnloadModel(ge_root_model, iter->first); | |||
| if (ret != SUCCESS) { | |||
| GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); | |||
| GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); | |||
| unload_model_ret = ret; | |||
| } | |||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||
| iter->first); | |||
| GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first); | |||
| unload_model_ret = FAILED; | |||
| continue; | |||
| } | |||
| @@ -267,6 +285,7 @@ Status GraphManager::Finalize() { | |||
| } | |||
| graph_map_.clear(); | |||
| cache_helper_map_.clear(); | |||
| graph_count_.clear(); | |||
| // graph context | |||
| if (graph_context_ != nullptr) { | |||
| @@ -317,30 +336,59 @@ Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) { | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
| const std::map<std::string, std::string> &options, | |||
| const OmgContext &omg_context) { | |||
| if (HasGraphNode(graph_id)) { | |||
| GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); | |||
| return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||
| void GraphManager::SetAddGraphCondition(GraphId graph_id, uint32_t cond) { | |||
| std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
| graph_id_to_add_graph_cond_[graph_id] = cond; | |||
| GELOGD("Graph [id:%u] has been added.", graph_id); | |||
| } | |||
| uint32_t GraphManager::GetAddGraphCondition(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
| auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||
| if (it != graph_id_to_add_graph_cond_.end()) { | |||
| return it->second; | |||
| } else { | |||
| GELOGD("Graph [id:%u] has not been added.", graph_id); | |||
| return kNotAdded; | |||
| } | |||
| } | |||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| if (compute_graph != nullptr) { | |||
| compute_graph->SetGraphID(graph_id); | |||
| bool graph_has_been_added = false; | |||
| if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) | |||
| && graph_has_been_added) { | |||
| GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, | |||
| "[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id); | |||
| return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||
| } | |||
| (void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); | |||
| compute_graph_ = compute_graph; | |||
| void GraphManager::RemoveAddGraphCondition(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
| auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||
| if (it != graph_id_to_add_graph_cond_.end()) { | |||
| graph_id_to_add_graph_cond_.erase(it); | |||
| GELOGD("Successfully removed add_graph_cond of graph [id:%u].", graph_id); | |||
| } else { | |||
| GELOGE(FAILED, "compute graph is null"); | |||
| return FAILED; | |||
| GELOGD("Graph [id:%u] has not been added. no need to remove.", graph_id); | |||
| } | |||
| } | |||
| Status GraphManager::CheckRepeatAdd(uint32_t graph_id, bool &is_added) { | |||
| uint32_t count = 0; | |||
| if (GetGraphCount(graph_id, count) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| // previous thread owns same graph_id has been in the middle of the AddGraph procession | |||
| if (count > 1 && GetAddGraphCondition(graph_id) == kStartAdd) { | |||
| std::unique_lock<std::mutex> lock(add_graph_mutex_); | |||
| GELOGD("Waitting for build end of previous thread."); | |||
| while (GetAddGraphCondition(graph_id) != kDoneAdded) { | |||
| add_graph_cv_.wait(lock); | |||
| } | |||
| GraphNodePtr graph_node; | |||
| Status ret = GetGraphNode(graph_id, graph_node); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[AddGraph] GetGraphNode failed, graph_id = %u.", graph_id); | |||
| return ret; | |||
| } | |||
| is_added = true; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void GraphManager::SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id) { | |||
| std::string session_graph_id; | |||
| if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { | |||
| session_graph_id = "-1_" + to_string(graph_id); | |||
| @@ -352,7 +400,24 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
| } | |||
| GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); | |||
| } | |||
| } | |||
| Status GraphManager::NotifyWaittingGraph(uint32_t graph_id) { | |||
| uint32_t count = 0; | |||
| if (GetGraphCount(graph_id, count) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| GELOGD("Add graph finished, graph_id:%u", graph_id); | |||
| if (count > 1) { | |||
| GELOGD("Finish addgraph, graph_id:%u, graph_count:%u, start to notify.", graph_id, count); | |||
| add_graph_cv_.notify_all(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::CreateGraphNode(uint32_t graph_id, const Graph &graph, | |||
| const std::map<std::string, std::string> &options) { | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| GE_IF_BOOL_EXEC(graph_node == nullptr, GELOGE(FAILED, "GraphNode make shared failed"); | |||
| return FAILED); | |||
| @@ -365,7 +430,62 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
| ParseOption(options, TUNING_PATH, options_.tuning_path); | |||
| graph_node->SetGraph(graph_ptr); | |||
| graph_node->SetOptions(options); | |||
| graph_node->IncreaseLoadCount(); | |||
| AddGraphNode(graph_id, graph_node); | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options) { | |||
| CompilerStages &stages = GetCompilerStages(graph_id); | |||
| stages.preparer.SetOptions(options_); | |||
| Status status = stages.optimizer.SetOptions(options_); | |||
| if (status != SUCCESS) { | |||
| GELOGE(status, "Graph optimizer set options failed."); | |||
| return status; | |||
| } | |||
| stages.builder.SetOptions(options_); | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
| const std::map<std::string, std::string> &options, | |||
| const OmgContext &omg_context) { | |||
| IncreaseGraphCount(graph_id); | |||
| // validation for adding graphs of same graph_id in multi-thread secenario | |||
| // 1.previous thread owns same graph_id has finished the AddGraph procession | |||
| if (GetAddGraphCondition(graph_id) == kDoneAdded) { | |||
| GraphNodePtr graph_node; | |||
| if (GetGraphNode(graph_id, graph_node) != SUCCESS) { | |||
| GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "Graph not exist while done adding previously, graph_id = %u.", graph_id); | |||
| return GE_GRAPH_GRAPH_NOT_EXIST; | |||
| } | |||
| graph_node->IncreaseLoadCount(); | |||
| return SUCCESS; | |||
| } | |||
| // In multi-thread scenario, former thread owns same graph_id has been | |||
| // in the middle of the AddGraph procession while following threads have to wait until | |||
| // done adding graph of the former graph, avoiding repeatively adding same graph. | |||
| bool is_added = false; | |||
| if (CheckRepeatAdd(graph_id, is_added) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "CheckRepeatAdd for graph[id:%u] failed.", graph_id); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| // The former graph (from different thread) owns same graph id has been successfully added. | |||
| if (is_added) { | |||
| return SUCCESS; | |||
| } | |||
| // Do add graph | |||
| SetAddGraphCondition(graph_id, kStartAdd); | |||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| compute_graph->SetGraphID(graph_id); | |||
| SetSessionGraphId(compute_graph, graph_id); | |||
| if (CreateGraphNode(graph_id, graph, options) != SUCCESS) { | |||
| GELOGE(FAILED, "Failed to create graph_node."); | |||
| return FAILED; | |||
| } | |||
| AddLocalOmgContext(graph_id, omg_context); | |||
| if (!options_.output_datatype.empty()) { | |||
| @@ -376,16 +496,18 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| CompilerStages &stages = GetCompilerStages(graph_id); | |||
| stages.preparer.SetOptions(options_); | |||
| Status status = stages.optimizer.SetOptions(options_); | |||
| if (status != SUCCESS) { | |||
| GELOGE(status, "Graph optimizer set options failed."); | |||
| return status; | |||
| if (SetStagesOptions(graph_id, options_) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Set stage options failed."); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| stages.builder.SetOptions(options_); | |||
| var_acc_ctrl_.AddGraph(graph_id, compute_graph); | |||
| SetAddGraphCondition(graph_id, kDoneAdded); | |||
| // There are threads waitting for adding same graph | |||
| if (NotifyWaittingGraph(graph_id) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "NotifyWaittingGraph failed."); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -895,6 +1017,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
| if (!graph_node->IsAsync()) { | |||
| ret = LoadGraph(ge_root_model, graph_node); | |||
| } else { | |||
| GE_CHECK_NOTNULL(ge_root_model); | |||
| ret = LoadGraphAsync(ge_root_model, graph_node); | |||
| } | |||
| if (ret != SUCCESS) { | |||
| @@ -909,6 +1032,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
| if (!graph_node->IsAsync()) { | |||
| ret = LoadGraph(ge_root_model_ptr, graph_node); | |||
| } else { | |||
| GE_CHECK_NOTNULL(ge_root_model); | |||
| ret = LoadGraphAsync(ge_root_model_ptr, graph_node); | |||
| } | |||
| if (ret != SUCCESS) { | |||
| @@ -921,6 +1045,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
| Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | |||
| GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | |||
| if (options_.run_graph_flag && ge_root_model != nullptr) { | |||
| ge_root_model->SetTrainFlag(GetTrainFlag()); | |||
| // synchronization run graph with model | |||
| std::shared_ptr<GraphModelListener> model_listener = GetModelListener(); | |||
| ModelIdInfo model_id_info; | |||
| @@ -1315,54 +1440,29 @@ bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load | |||
| } | |||
| Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||
| auto it = to_be_deleted_graphs_.find(graph_id); | |||
| if (it != to_be_deleted_graphs_.end()) { | |||
| to_be_deleted_graphs_.erase(it); | |||
| } | |||
| GraphNodePtr graph_node = nullptr; | |||
| Status ret = GetGraphNode(graph_id, graph_node); | |||
| if (ret != SUCCESS) { | |||
| if (ret != SUCCESS || graph_node == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid when GraphManager %s", | |||
| graph_id, __FUNCTION__); | |||
| GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); | |||
| return GE_GRAPH_GRAPH_NOT_EXIST; | |||
| } | |||
| if ((graph_node == nullptr) || (graph_node->GetRunFlag())) { | |||
| GELOGE(GE_GRAPH_GRAPH_IS_RUNNING, "[GraphManager] Id %u is running, can't be deleted.", graph_id); | |||
| return GE_GRAPH_GRAPH_IS_RUNNING; | |||
| if (graph_node->GetRunFlag()) { | |||
| // only put graph into to-be-deleted list when exceptional scenario | |||
| to_be_deleted_graphs_.insert(graph_id); | |||
| GELOGI("[GraphManager] Trying to remove running graph[Id:%u], added into to_be_deleted_graphs_.", graph_id); | |||
| return SUCCESS; | |||
| } | |||
| std::lock_guard<std::mutex> lock(unload_model_mutex_); | |||
| Status middle_ret; | |||
| rtError_t rt_ret; | |||
| const std::vector<SubGraphInfoPtr> &all_sub_graph = graph_node->GetAllSubGraph(); | |||
| for (size_t i = 0; i < all_sub_graph.size(); ++i) { | |||
| // must free buffer firstly | |||
| middle_ret = all_sub_graph[i]->FreeInOutBuffer(); | |||
| if (middle_ret != SUCCESS) { | |||
| GELOGE(middle_ret, "[GraphManager] RemoveGraph free mem failed, graph_id=%u.", graph_id); | |||
| ret = middle_ret; | |||
| } | |||
| if (all_sub_graph[i]->GeModelIsValid() && all_sub_graph[i]->GetModelIdInfo().model_id != INVALID_MODEL_ID) { | |||
| // unload model | |||
| GELOGI("UnloadModel via new ome."); | |||
| rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", | |||
| all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
| ret = FAILED; | |||
| continue; | |||
| } | |||
| middle_ret = GraphLoader::UnloadModel(all_sub_graph[i]->GetModelIdInfo().model_id); | |||
| if (middle_ret != SUCCESS) { | |||
| GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", | |||
| all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
| ret = middle_ret; | |||
| } | |||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] unload model failed, modelId=%u, graphId=%u.", | |||
| all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
| ret = FAILED; | |||
| } | |||
| } | |||
| } | |||
| var_acc_ctrl_.RemoveGraph(graph_id); | |||
| RemoveGraphNode(graph_id); | |||
| @@ -1370,28 +1470,33 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||
| auto ge_root_model = graph_node->GetGeRootModel(); | |||
| if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | |||
| GELOGI("Unload model %u.", ge_root_model->GetModelId()); | |||
| rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||
| graph_id); | |||
| return FAILED; | |||
| } | |||
| middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||
| // same graph may be added for several times, different models were created separately, | |||
| // unload them respectively. | |||
| middle_ret = UnloadModel(ge_root_model, graph_id); | |||
| if (middle_ret != SUCCESS) { | |||
| GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(), | |||
| graph_id); | |||
| REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check unload detail in GraphLoader %s", | |||
| graph_id, __FUNCTION__); | |||
| GELOGE(middle_ret, "[GraphManager:] unload model failed, graph_id=%u.", graph_id); | |||
| ret = middle_ret; | |||
| } | |||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||
| graph_id); | |||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u, when GraphManager %s", | |||
| GetContext().DeviceId(), graph_id, __FUNCTION__); | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||
| ret = FAILED; | |||
| } | |||
| } | |||
| RemoveCompilerStages(graph_id); | |||
| RemoveGraphCount(graph_id); | |||
| RemoveAddGraphCondition(graph_id); | |||
| GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); | |||
| GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); | |||
| @@ -2409,6 +2514,7 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr | |||
| Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | |||
| GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | |||
| if (options_.run_graph_flag && ge_root_model != nullptr) { | |||
| ge_root_model->SetTrainFlag(GetTrainFlag()); | |||
| // synchronization run graph with model | |||
| ModelIdInfo model_id_info; | |||
| bool is_unknown_shape = false; | |||
| @@ -2425,9 +2531,9 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||
| } | |||
| } | |||
| GE_TIMESTAMP_START(LoadGraph); | |||
| GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); | |||
| Status ret = | |||
| GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); | |||
| auto listener = MakeShared<RunAsyncListener>(); | |||
| GE_CHECK_NOTNULL(listener); | |||
| Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener); | |||
| GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); | |||
| @@ -2441,6 +2547,52 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||
| return SUCCESS; | |||
| } | |||
| void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, | |||
| const std::vector<uint32_t> &model_ids, uint32_t graph_id, uint64_t session_id) { | |||
| rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, when GraphManager %s", | |||
| GetContext().DeviceId(), __FUNCTION__); | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, graphId=%u.", graph_id); | |||
| return; | |||
| } | |||
| for (auto model_id : model_ids) { | |||
| uint64_t max_memory_size = 0; | |||
| Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||
| if (result != SUCCESS) { | |||
| continue; | |||
| } | |||
| GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||
| max_memory_size); | |||
| if (model_ids.size() > 1) { | |||
| result = ge_model->GetSessionId(model_id, session_id); | |||
| if (result != SUCCESS) { | |||
| GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
| graph_id); | |||
| continue; | |||
| } | |||
| } | |||
| result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||
| if (result != SUCCESS) { | |||
| GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
| graph_id); | |||
| } | |||
| result = GraphLoader::UnloadModel(model_id); | |||
| if (result != SUCCESS) { | |||
| GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
| } | |||
| GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); | |||
| } | |||
| graph_node->SetLoadFlag(false); | |||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", | |||
| GetContext().DeviceId(), __FUNCTION__); | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||
| return; | |||
| } | |||
| } | |||
| Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { | |||
| GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); | |||
| int64_t value = 0; | |||
| @@ -2484,6 +2636,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||
| continue; | |||
| } | |||
| auto model_id = model->GetModelId(); | |||
| auto model_ids = model->GetAllModelId(); | |||
| // unload model not release | |||
| bool is_unknown_shape = false; | |||
| GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); | |||
| @@ -2496,34 +2649,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||
| GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | |||
| continue; | |||
| } | |||
| uint64_t max_memory_size = 0; | |||
| result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||
| if (result != SUCCESS) { | |||
| continue; | |||
| } | |||
| GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||
| max_memory_size); | |||
| rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
| continue; | |||
| } | |||
| result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||
| if (result != SUCCESS) { | |||
| GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
| graph_id); | |||
| } | |||
| result = GraphLoader::UnloadModel(model_id); | |||
| if (result != SUCCESS) { | |||
| GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
| } | |||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
| continue; | |||
| } | |||
| it.second->SetLoadFlag(false); | |||
| GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success and set LoadFlag to false.", graph_id, model_id); | |||
| ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id); | |||
| } | |||
| return SUCCESS; | |||
| @@ -2659,6 +2785,38 @@ void GraphManager::ConstructGeInput(const vector<InputTensorInfo> &inputs, vecto | |||
| } | |||
| } | |||
| Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, | |||
| GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { | |||
| if (!graph_manager->IsGraphNeedBuild(graph_node)) { | |||
| ge_root_model = graph_node->GetGeRootModel(); | |||
| return SUCCESS; | |||
| } | |||
| if (graph_node->GetBuildFlag()) { | |||
| ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||
| "The graph " + std::to_string(graph_node->GetGraphId()) + | |||
| " need to re-build, you should remove it" | |||
| " from GE first, then AddGraph again and rebuild it."); | |||
| graph_node->Unlock(); | |||
| return PARAM_INVALID; | |||
| } | |||
| // check need incre build. | |||
| GeModelPtr ge_model = nullptr; | |||
| if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||
| std::vector<GeTensor> ge_inputs; | |||
| ConstructGeInput(args.input_tensor, ge_inputs); | |||
| Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||
| // release rts generate context | |||
| RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||
| if (ret != SUCCESS) { | |||
| ReturnError(graph_manager, args.callback, ret, "PreRun Failed."); | |||
| return ret; | |||
| } | |||
| } | |||
| graph_node->SetBuildFlag(true); | |||
| graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||
| return SUCCESS; | |||
| } | |||
| void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
| if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | |||
| GELOGW("Set thread name failed."); | |||
| @@ -2671,7 +2829,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
| continue; | |||
| } | |||
| GELOGI("A new loop start."); | |||
| GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id); | |||
| ErrorManager::GetInstance().SetErrorContext(args.error_context); | |||
| ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); | |||
| @@ -2687,7 +2845,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
| "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | |||
| return; | |||
| } | |||
| // more than one graph owns same graph_id | |||
| uint32_t count = 0; | |||
| if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed.", args.graph_id); | |||
| return; | |||
| } | |||
| // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | |||
| if (count > 1 && graph_node->GetBuildFlag()) { | |||
| graph_node->Lock(); | |||
| GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | |||
| // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | |||
| graph_node->SetSemSize(count); | |||
| graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||
| args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | |||
| GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | |||
| continue; | |||
| } | |||
| // Cannot be put ahead of the repeatively prerun judgement | |||
| graph_node->Lock(); | |||
| if (graph_node->GetRunFlag()) { | |||
| @@ -2719,46 +2894,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
| // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | |||
| GELOGI("Start for run graph async."); | |||
| GeRootModelPtr ge_root_model = nullptr; | |||
| if (graph_manager->IsGraphNeedBuild(graph_node)) { | |||
| if (graph_node->GetBuildFlag()) { | |||
| ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||
| "The graph " + std::to_string(graph_node->GetGraphId()) + | |||
| " need to re-build, you should remove it" | |||
| " from GE first, then AddGraph again and rebuild it."); | |||
| ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model); | |||
| if (ret != SUCCESS) { | |||
| graph_node->SetRunFlag(false); | |||
| if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||
| ReturnError(graph_manager, args.callback, ret, "CheckIncreBuildAndPreRun Failed, thread exit.."); | |||
| graph_node->Unlock(); | |||
| return; | |||
| } else { | |||
| ReturnError(graph_manager, graph_node, args.callback, ret, | |||
| "CheckIncreBuildAndPreRun Failed, keep geop continue!"); | |||
| graph_node->Unlock(); | |||
| continue; | |||
| } | |||
| // check need incre build. | |||
| GeModelPtr ge_model = nullptr; | |||
| if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||
| std::vector<GeTensor> ge_inputs; | |||
| ConstructGeInput(args.input_tensor, ge_inputs); | |||
| ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||
| // release rts generate context | |||
| RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||
| if (ret != SUCCESS) { | |||
| graph_node->SetRunFlag(false); | |||
| if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||
| ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | |||
| graph_node->Unlock(); | |||
| return; | |||
| } else { | |||
| ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!"); | |||
| graph_node->Unlock(); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| graph_node->SetBuildFlag(true); | |||
| graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||
| } else { | |||
| ge_root_model = graph_node->GetGeRootModel(); | |||
| } | |||
| graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||
| args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); | |||
| GELOGI("Loop end."); | |||
| GELOGI("[PreRunThread] Loop end."); | |||
| } | |||
| } | |||
| @@ -2855,16 +3008,13 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
| continue; | |||
| } | |||
| GELOGI("A new loop start."); | |||
| GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); | |||
| ErrorManager::GetInstance().SetErrorContext(args.error_context); | |||
| GetContext().SetSessionId(args.session_id); | |||
| GetThreadLocalContext() = args.context; | |||
| graph_manager->UpdateLocalOmgContext(args.graph_id); | |||
| if (args.graph_node->graph_run_async_listener_ != nullptr) { | |||
| args.graph_node->graph_run_async_listener_->SetCallback(args.callback); | |||
| } | |||
| Status ret; | |||
| // parse inputs.dims to vector<vector<uint64_t>> dynamic_dims | |||
| ret = graph_manager->ParseInputsDims(args.input_tensor); | |||
| @@ -2874,8 +3024,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
| return; | |||
| } | |||
| args.graph_node->UpdateLoadFlag(); | |||
| if (!args.graph_node->GetLoadFlag()) { | |||
| ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); | |||
| args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag()); | |||
| ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); | |||
| if (ret != SUCCESS || args.ge_root_model == nullptr) { | |||
| StopQueue(graph_manager); | |||
| @@ -2883,6 +3035,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
| args.graph_node->Unlock(); | |||
| return; | |||
| } | |||
| // control the times of graph loading in multi-thread scenario | |||
| args.graph_node->DecreaseLoadCount(); | |||
| args.graph_node->IncreaseLoadRecord(); | |||
| args.graph_node->SetLoadFlag(true); | |||
| GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), | |||
| args.ge_root_model->GetModelId()); | |||
| @@ -2898,7 +3054,7 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
| } | |||
| ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), | |||
| args.input_tensor); | |||
| args.input_tensor, args.callback); | |||
| args.graph_node->SetRunFlag(false); | |||
| if (ret != SUCCESS) { | |||
| ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); | |||
| @@ -3314,4 +3470,49 @@ void GraphManager::RemoveCompilerStages(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(member_mutex_); | |||
| compiler_stages_.erase(graph_id); | |||
| } | |||
| void GraphManager::IncreaseGraphCount(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
| auto it = graph_count_.find(graph_id); | |||
| if (it == graph_count_.end()) { | |||
| graph_count_.insert({graph_id, kInitGraphCount}); | |||
| GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
| } else { | |||
| ++graph_count_[graph_id]; | |||
| GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
| } | |||
| } | |||
| void GraphManager::RemoveGraphCount(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
| auto it = graph_count_.find(graph_id); | |||
| if (it == graph_count_.end()) { | |||
| GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||
| } else { | |||
| GELOGD("RemoveGraphCount success, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
| graph_count_.erase(it); | |||
| } | |||
| } | |||
| void GraphManager::DecreaseGraphCount(GraphId graph_id) { | |||
| std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
| auto it = graph_count_.find(graph_id); | |||
| if (it == graph_count_.end()) { | |||
| GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||
| } else { | |||
| --it->second; | |||
| GELOGD("After DecreaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
| } | |||
| } | |||
| Status GraphManager::GetGraphCount(GraphId graph_id, uint32_t &count) { | |||
| std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
| auto it = graph_count_.find(graph_id); | |||
| if (it == graph_count_.end()) { | |||
| GELOGW("Graph [id:%u] has not been added.", graph_id); | |||
| return FAILED; | |||
| } | |||
| count = it->second; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -184,6 +184,20 @@ class GraphManager { | |||
| Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results); | |||
| void RemoveGraphCount(GraphId graph_id); | |||
| void IncreaseGraphCount(GraphId graph_id); | |||
| void DecreaseGraphCount(GraphId graph_id); | |||
| Status GetGraphCount(GraphId graph_id, uint32_t &count); | |||
| void SetAddGraphCondition(GraphId graph_id, uint32_t cond); | |||
| uint32_t GetAddGraphCondition(GraphId graph_id); | |||
| void RemoveAddGraphCondition(GraphId graph_id); | |||
| private: | |||
| struct CompilerStages { | |||
| GraphPrepare preparer; | |||
| @@ -380,6 +394,24 @@ class GraphManager { | |||
| CompilerStages &GetCompilerStages(GraphId graph_id); | |||
| void RemoveCompilerStages(GraphId graph_id); | |||
| static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node, | |||
| GeRootModelPtr &ge_root_model); | |||
| void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector<uint32_t> &model_ids, | |||
| uint32_t graph_id, uint64_t session_id); | |||
| Status CheckRepeatAdd(uint32_t graph_id, bool &is_added); | |||
| Status NotifyWaittingGraph(uint32_t graph_id); | |||
| Status CreateGraphNode(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options); | |||
| Status SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options); | |||
| Status UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id); | |||
| void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); | |||
| std::atomic_bool thread_run_flag_; | |||
| BlockingQueue<PreRunArgs> prerun_args_q_{}; | |||
| BlockingQueue<RunArgs> run_args_q_{}; | |||
| @@ -415,6 +447,16 @@ class GraphManager { | |||
| std::mutex member_mutex_; | |||
| std::mutex unload_model_mutex_; | |||
| // avoid repeatively add same graph (owns same graph id) | |||
| std::mutex add_graph_mutex_; | |||
| std::mutex add_graph_cond_mutex_; | |||
| std::condition_variable add_graph_cv_; | |||
| std::map<GraphId, uint32_t> graph_id_to_add_graph_cond_; | |||
| // use for multi-thread online-infer scenario | |||
| std::set<GraphId> to_be_deleted_graphs_; | |||
| std::map<GraphId, uint32_t> graph_count_; | |||
| std::mutex graph_count_mutex_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -60,6 +60,15 @@ void GraphNode::Unlock() { | |||
| sem_.Pop(unused); | |||
| } | |||
| void GraphNode::IncreaseLoadCount() { | |||
| std::unique_lock<std::mutex> lock(load_count_mu_); | |||
| if (load_record_ == kMaxLoadNum) { | |||
| GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum); | |||
| return; | |||
| } | |||
| ++load_count_; | |||
| } | |||
| SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} | |||
| SubGraphInfo::~SubGraphInfo() { | |||
| @@ -55,6 +55,7 @@ using ConstGraphPtr = std::shared_ptr<const ge::Graph>; | |||
| using GraphPtr = std::shared_ptr<ge::Graph>; | |||
| const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; | |||
| const uint32_t kMaxLoadNum = 8; | |||
| struct ModelIdInfo { | |||
| uint32_t model_id{INVALID_MODEL_ID}; | |||
| @@ -162,6 +163,8 @@ class GraphNode { | |||
| bool GetBuildFlag() const { return build_flag_; } | |||
| void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } | |||
| bool GetLoadFlag() const { return load_flag_; } | |||
| // allow repeatively load graph owns same graph id | |||
| void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; } | |||
| void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | |||
| void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | |||
| GeModelPtr GetGeModel() const { return ge_model_; } | |||
| @@ -172,6 +175,13 @@ class GraphNode { | |||
| void Lock(); | |||
| void Unlock(); | |||
| void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } | |||
| uint32_t GetLoadCount() const { return load_count_; } | |||
| void IncreaseLoadCount(); | |||
| void DecreaseLoadCount() { --load_count_; } | |||
| void IncreaseLoadRecord() { ++load_record_; } | |||
| // run graph asynchronous listener | |||
| std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | |||
| @@ -184,11 +194,17 @@ class GraphNode { | |||
| GraphPtr graph_; | |||
| ComputeGraphPtr compute_graph_; | |||
| bool build_flag_; | |||
| // load_flag_ is true if more than 1 model were loaded | |||
| bool load_flag_; | |||
| bool async_; | |||
| GeModelPtr ge_model_; | |||
| GeRootModelPtr ge_root_model_; | |||
| BlockingQueue<uint8_t> sem_; | |||
| // consist with graph_count of same graph_id in graph_manager | |||
| uint32_t load_count_ = 0; | |||
| // total times of loading a graph with same graph_id. | |||
| uint32_t load_record_ = 0; | |||
| std::mutex load_count_mu_; | |||
| }; | |||
| using GraphNodePtr = std::shared_ptr<GraphNode>; | |||
| @@ -133,8 +133,12 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||
| GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); | |||
| while (run_flag_) { | |||
| // Model has not indeedly started running before received data | |||
| SetRunningFlag(false); | |||
| std::shared_ptr<InputDataWrapper> data_wrapper; | |||
| Status ret = data_inputer_->Pop(data_wrapper); | |||
| // Model indeedly start running | |||
| SetRunningFlag(true); | |||
| if (data_wrapper == nullptr || ret != SUCCESS) { | |||
| GELOGI("data_wrapper is null!, ret = %u", ret); | |||
| continue; | |||
| @@ -174,7 +178,8 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||
| RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); | |||
| iterator_count_++; | |||
| GELOGI("run iterator count is %lu", iterator_count_); | |||
| SetRunningFlag(false); | |||
| GELOGI("run iterator count is %lu, model_id:%u", iterator_count_, model_id_); | |||
| } | |||
| CsaInteract::GetInstance().WriteInternalErrorCode(); | |||
| @@ -55,6 +55,12 @@ class HybridModelAsyncExecutor { | |||
| Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data); | |||
| uint32_t GetDataInputerSize() { return data_inputer_->Size(); } | |||
| bool GetRunningFlag() const { return running_flag_; } | |||
| void SetRunningFlag(bool flag) { running_flag_ = flag; } | |||
| private: | |||
| Status InitInputDesc(); | |||
| @@ -84,6 +90,8 @@ class HybridModelAsyncExecutor { | |||
| uint32_t device_id_ = 0U; | |||
| uint32_t model_id_ = 0U; | |||
| std::atomic_bool run_flag_; | |||
| // check whether model is running with data | |||
| bool running_flag_ = false; | |||
| std::unique_ptr<DataInputer> data_inputer_; | |||
| std::unique_ptr<HybridModelExecutor> executor_; | |||
| std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "hybrid/model/hybrid_model.h" | |||
| #include "hybrid/executor/hybrid_model_async_executor.h" | |||
| #include "hybrid/node_executor/node_executor.h" | |||
| #include "graph/manager/graph_manager_utils.h" | |||
| namespace ge { | |||
| namespace hybrid { | |||
| @@ -107,6 +108,17 @@ class HybridDavinciModel::Impl { | |||
| model_.SetModelDescVersion(is_new_model_desc); | |||
| } | |||
| uint32_t GetDataInputerSize() { return executor_.GetDataInputerSize(); } | |||
| bool GetRunningFlag() const { return executor_.GetRunningFlag(); } | |||
| Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
| auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||
| GE_CHECK_NOTNULL(listener); | |||
| listener->SetCallback(callback); | |||
| return SUCCESS; | |||
| } | |||
| private: | |||
| std::shared_ptr<ModelListener> listener_; | |||
| HybridModel model_; | |||
| @@ -221,5 +233,16 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->GetSessionId(); | |||
| } | |||
| uint32_t HybridDavinciModel::GetDataInputerSize() { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->GetDataInputerSize(); | |||
| } | |||
| bool HybridDavinciModel::GetRunningFlag() const { return impl_->GetRunningFlag(); } | |||
| Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
| return impl_->SetRunAsyncListenerCallback(callback); | |||
| } | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -74,6 +74,12 @@ class HybridDavinciModel { | |||
| void SetModelDescVersion(bool is_new_model_desc); | |||
| uint32_t GetDataInputerSize(); | |||
| bool GetRunningFlag() const; | |||
| Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||
| private: | |||
| HybridDavinciModel() = default; | |||
| class Impl; | |||
| @@ -68,6 +68,10 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||
| return 0; | |||
| } | |||
| uint32_t HybridDavinciModel::GetDataInputerSize() { | |||
| return 0; | |||
| } | |||
| Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) { | |||
| return UNSUPPORTED; | |||
| } | |||
| @@ -87,5 +91,13 @@ Status HybridDavinciModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &i | |||
| void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { | |||
| } | |||
| bool HybridDavinciModel::GetRunningFlag() const { | |||
| return false; | |||
| } | |||
| Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
| return UNSUPPORTED; | |||
| } | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -85,4 +85,14 @@ ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; } | |||
| ConstProtoAttrMapHelper GeModel::GetAttrMap() const { | |||
| return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); | |||
| } | |||
| Status GeModel::GetSessionId(uint32_t model_id, uint64_t &session_id) const { | |||
| auto it = model_id_to_session_id_map_.find(model_id); | |||
| if (it != model_id_to_session_id_map_.end()) { | |||
| session_id = it->second; | |||
| return SUCCESS; | |||
| } | |||
| GELOGW("No session id were found with model id [%u].", model_id); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| } // namespace ge | |||
| @@ -71,6 +71,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||
| void SetModelId(uint32_t model_id) { model_id_ = model_id; } | |||
| uint32_t GetModelId() const { return model_id_; } | |||
| Status GetSessionId(uint32_t model_id, uint64_t &session_id) const; | |||
| void InsertSessionMap(uint32_t model_id, uint64_t session_id) { | |||
| model_id_to_session_id_map_.insert({model_id, session_id}); | |||
| } | |||
| protected: | |||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||
| @@ -90,6 +95,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||
| std::string platform_version_; | |||
| uint8_t platform_type_ = {0}; | |||
| uint32_t model_id_ = INVALID_MODEL_ID; | |||
| std::map<uint32_t, uint64_t> model_id_to_session_id_map_; | |||
| }; | |||
| } // namespace ge | |||
| using GeModelPtr = std::shared_ptr<ge::GeModel>; | |||
| @@ -32,18 +32,34 @@ class GeRootModel { | |||
| return subgraph_instance_name_to_model_; | |||
| }; | |||
| const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; | |||
| void SetModelId(uint32_t model_id) { model_id_ = model_id; } | |||
| const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } | |||
| void SetModelId(uint32_t model_id) { | |||
| model_id_ = model_id; | |||
| // cached for removement | |||
| model_ids_.emplace_back(model_id); | |||
| } | |||
| uint32_t GetModelId() const { return model_id_; } | |||
| void SetModelName(const std::string &model_name) { model_name_ = model_name; } | |||
| const std::string &GetModelName() const { return model_name_; } | |||
| std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | |||
| Status CheckIsUnknownShape(bool &is_dynamic_shape); | |||
| void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | |||
| void SetTrainFlag(bool flag) { train_flag_ = flag; } | |||
| bool GetTrainFlag() const { return train_flag_; } | |||
| private: | |||
| ComputeGraphPtr root_graph_ = nullptr; | |||
| std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | |||
| uint32_t model_id_ = 0; | |||
| std::string model_name_; | |||
| // In multithread online secenario, same graph can owns different davinci_model for for concurrency | |||
| std::vector<uint32_t> model_ids_; | |||
| bool train_flag_ = false; | |||
| }; | |||
| } // namespace ge | |||
| using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | |||
| @@ -1 +1 @@ | |||
| Subproject commit 54935e7d9d7d825eaef6f477ffb64e8e92b35153 | |||
| Subproject commit 05b2882f8d0364d295aae6ed4ab818a6dc83ad9a | |||
| @@ -588,6 +588,7 @@ set(SINGLE_OP_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_pipeline_executor.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" | |||
| "${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" | |||
| @@ -766,9 +767,11 @@ set(MULTI_PARTS_TEST_FILES | |||
| "graph/build/logical_stream_allocator_unittest.cc" | |||
| "graph/build/mem_assigner_unittest.cc" | |||
| "graph/build/task_generator_unittest.cc" | |||
| "graph/execute/graph_execute_unittest.cc" | |||
| "graph/preprocess/graph_preprocess_unittest.cc" | |||
| "graph/manager/hcom_util_unittest.cc" | |||
| "graph/manager/graph_caching_allocator_unittest.cc" | |||
| "graph/manager/graph_manager_unittest.cc" | |||
| "session/omg_omg_unittest.cc" | |||
| ) | |||
| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <memory> | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/execute/graph_execute.h" | |||
| #include "graph/load/model_manager/model_manager.h" | |||
| #include "graph/load/model_manager/davinci_model.h" | |||
| #include "omm/csa_interact.h" | |||
| #undef private | |||
| #undef public | |||
| #include <pthread.h> | |||
| #include <algorithm> | |||
| #include <future> | |||
| #include <set> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <future> | |||
| using namespace std; | |||
| using namespace testing; | |||
| using namespace ge; | |||
| using namespace domi; | |||
| namespace ge { | |||
| namespace { | |||
| const uint32_t kInvalidModelId = UINT32_MAX; | |||
| } | |||
| class UtestGraphExecuteTest : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(UtestGraphExecuteTest, get_execute_model_id_invalid) { | |||
| GraphExecutor executor; | |||
| ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
| auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
| EXPECT_EQ(model_id, kInvalidModelId); | |||
| } | |||
| TEST_F(UtestGraphExecuteTest, get_execute_model_id_1) { | |||
| GraphExecutor executor; | |||
| ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||
| davinci_model1->SetId(1); | |||
| model_manager->InsertModel(1, davinci_model1); | |||
| ge_root_model->SetModelId(1); | |||
| auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
| EXPECT_EQ(model_id, 1); | |||
| } | |||
| TEST_F(UtestGraphExecuteTest, get_execute_model_id_2) { | |||
| GraphExecutor executor; | |||
| ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| // model1 with 2 load | |||
| shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||
| davinci_model1->SetId(1); | |||
| davinci_model1->data_inputer_ = new DataInputer(); | |||
| auto data = MakeShared<InputDataWrapper>(); | |||
| davinci_model1->data_inputer_->Push(data); | |||
| davinci_model1->data_inputer_->Push(data); | |||
| model_manager->InsertModel(1, davinci_model1); | |||
| // model 2 with 3 load | |||
| shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(1, nullptr); | |||
| davinci_model2->SetId(2); | |||
| davinci_model2->data_inputer_ = new DataInputer(); | |||
| davinci_model2->data_inputer_->Push(data); | |||
| davinci_model2->data_inputer_->Push(data); | |||
| davinci_model2->data_inputer_->Push(data); | |||
| model_manager->InsertModel(2, davinci_model2); | |||
| // model 3 witH 1 load | |||
| shared_ptr<DavinciModel> davinci_model3 = MakeShared<DavinciModel>(1, nullptr); | |||
| davinci_model3->SetId(3); | |||
| davinci_model3->data_inputer_ = new DataInputer(); | |||
| davinci_model3->data_inputer_->Push(data); | |||
| model_manager->InsertModel(3, davinci_model3); | |||
| ge_root_model->SetModelId(1); | |||
| ge_root_model->SetModelId(2); | |||
| ge_root_model->SetModelId(3); | |||
| auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
| // model 3 is picked for having least loads | |||
| EXPECT_EQ(model_id, 3); | |||
| } | |||
| TEST_F(UtestGraphExecuteTest, test_set_callback) { | |||
| GraphExecutor executor; | |||
| ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
| // is_unknown_shape_graph_ = false | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
| RunAsyncCallback callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| auto listener = MakeShared<RunAsyncListener>(); | |||
| shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
| davinci_model1->SetId(1); | |||
| model_manager->InsertModel(1, davinci_model1); | |||
| auto status = executor.SetCallback(1, ge_root_model, callback); | |||
| EXPECT_EQ(status, SUCCESS); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,375 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <memory> | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/manager/graph_manager.h" | |||
| #include "graph/load/model_manager/model_manager.h" | |||
| #include "graph/load/model_manager/davinci_model.h" | |||
| #define const | |||
| #include "common/helper/model_cache_helper.h" | |||
| #undef const | |||
| #include "init/gelib.h" | |||
| #undef private | |||
| #undef public | |||
| #include <pthread.h> | |||
| #include <algorithm> | |||
| #include <future> | |||
| #include <set> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <future> | |||
| #include "common/math/math_util.h" | |||
| #include "common/thread_pool.h" | |||
| #include "common/dump/dump_manager.h" | |||
| #include "analyzer/analyzer.h" | |||
| #include "graph/common/ge_call_wrapper.h" | |||
| #include "graph/common/local_context.h" | |||
| #include "graph/common/transop_util.h" | |||
| #include "graph/ge_context.h" | |||
| #include "graph/ge_global_options.h" | |||
| #include "graph/manager/util/rt_context_util.h" | |||
| #include "graph/partition/dynamic_shape_partition.h" | |||
| #include "graph/passes/enter_pass.h" | |||
| #include "graph/partition/stage_partition.h" | |||
| #include "graph/passes/addn_pass.h" | |||
| #include "graph/passes/bitcast_pass.h" | |||
| #include "graph/passes/assign_remove_pass.h" | |||
| #include "graph/passes/inplace_support_check_pass.h" | |||
| #include "graph/passes/atomic_addr_clean_pass.h" | |||
| #include "graph/passes/attach_stream_label_pass.h" | |||
| #include "graph/passes/cast_remove_pass.h" | |||
| #include "graph/passes/common_subexpression_elimination_pass.h" | |||
| #include "graph/passes/compile_nodes_pass.h" | |||
| #include "graph/passes/cond_remove_pass.h" | |||
| #include "graph/passes/constant_folding_pass.h" | |||
| #include "graph/passes/constant_fuse_same_pass.h" | |||
| #include "graph/passes/control_trigger_pass.h" | |||
| #include "graph/passes/ctrl_edge_transfer_pass.h" | |||
| #include "graph/passes/dimension_adjust_pass.h" | |||
| #include "graph/passes/dimension_compute_pass.h" | |||
| #include "graph/passes/flow_ctrl_pass.h" | |||
| #include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
| #include "graph/passes/identity_pass.h" | |||
| #include "graph/passes/input_output_connection_identify_pass.h" | |||
| #include "graph/passes/iterator_op_pass.h" | |||
| #include "graph/passes/link_gen_mask_nodes_pass.h" | |||
| #include "graph/passes/mark_graph_unknown_status_pass.h" | |||
| #include "graph/passes/merge_pass.h" | |||
| #include "graph/passes/merge_input_memcpy_pass.h" | |||
| #include "graph/passes/merge_to_stream_merge_pass.h" | |||
| #include "graph/passes/multi_batch_pass.h" | |||
| #include "graph/passes/next_iteration_pass.h" | |||
| #include "graph/passes/permute_pass.h" | |||
| #include "graph/passes/prune_pass.h" | |||
| #include "graph/passes/ref_identity_delete_op_pass.h" | |||
| #include "graph/passes/remove_same_const_pass.h" | |||
| #include "graph/passes/reshape_recovery_pass.h" | |||
| #include "graph/passes/reshape_remove_pass.h" | |||
| #include "graph/passes/same_transdata_breadth_fusion_pass.h" | |||
| #include "graph/passes/subgraph_pass.h" | |||
| #include "graph/passes/switch_data_edges_bypass.h" | |||
| #include "graph/passes/switch_dead_branch_elimination.h" | |||
| #include "graph/passes/switch_logic_remove_pass.h" | |||
| #include "graph/passes/switch_to_stream_switch_pass.h" | |||
| #include "graph/passes/transop_breadth_fusion_pass.h" | |||
| #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" | |||
| #include "graph/passes/transop_symmetry_elimination_pass.h" | |||
| #include "graph/passes/transop_without_reshape_fusion_pass.h" | |||
| #include "graph/passes/transpose_transdata_pass.h" | |||
| #include "graph/passes/useless_control_out_remove_pass.h" | |||
| #include "graph/passes/variable_op_pass.h" | |||
| #include "graph/passes/variable_ref_delete_op_pass.h" | |||
| #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" | |||
| #include "graph/passes/end_of_sequence_add_control_pass.h" | |||
| #include "graph/passes/subexpression_migration_pass.h" | |||
| #include "graph/passes/subgraph_const_migration_pass.h" | |||
| #include "graph/passes/unused_args_clean_pass.h" | |||
| #include "graph/passes/global_step_insert_pass.h" | |||
| #include "graph/passes/memcpy_addr_async_pass.h" | |||
| #include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
| #include "graph/build/label_allocator.h" | |||
| #include "graph/utils/tensor_adapter.h" | |||
| #include "inc/pass_manager.h" | |||
| #include "ir_build/atc_ir_common.h" | |||
| #include "graph/common/local_context.h" | |||
| #include "graph/common/omg_util.h" | |||
| #include "common/formats/utils/formats_trans_utils.h" | |||
| #include "register/custom_pass_helper.h" | |||
| #include "graph/ops_stub.h" | |||
| using namespace std; | |||
| using namespace testing; | |||
| using namespace ge; | |||
| using namespace domi; | |||
| namespace { | |||
| const uint32_t kNotAdded = 0; | |||
| const uint32_t kStartAdd = 1; | |||
| const uint32_t kDoneAdded = 2; | |||
| } | |||
| class UtestGraphManagerTest : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| void CreateGraph(Graph &graph) { | |||
| TensorDesc desc(ge::Shape({1, 3, 224, 224})); | |||
| uint32_t size = desc.GetShape().GetShapeSize(); | |||
| desc.SetSize(size); | |||
| auto data = op::Data("Data").set_attr_index(0); | |||
| data.update_input_desc_data(desc); | |||
| data.update_output_desc_out(desc); | |||
| auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); | |||
| std::vector<Operator> inputs{data}; | |||
| std::vector<Operator> outputs{flatten}; | |||
| std::vector<Operator> targets{flatten}; | |||
| // Graph graph("test_graph"); | |||
| graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, set_and_get_add_graph_flag) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| graph_manager.SetAddGraphCondition(graph_id, 1); | |||
| uint32_t res = graph_manager.GetAddGraphCondition(graph_id); | |||
| EXPECT_EQ(res, 1); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_add_graph_1) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| // create graph | |||
| Graph graph("test_graph"); | |||
| CreateGraph(graph); | |||
| std::map<std::string, std::string> options; | |||
| OmgContext context; | |||
| Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_add_graph_2) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_manager.AddGraphNode(graph_id, graph_node); | |||
| graph_manager.SetAddGraphCondition(graph_id, kDoneAdded); | |||
| Graph graph("test_graph"); | |||
| CreateGraph(graph); | |||
| std::map<std::string, std::string> options; | |||
| OmgContext context; | |||
| Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_add_graph_3) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| Graph graph("test_graph"); | |||
| CreateGraph(graph); | |||
| std::map<std::string, std::string> options; | |||
| OmgContext context; | |||
| std::future<Status> fut1 = std::async(std::launch::async, | |||
| &GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||
| std::future<Status> fut2 = std::async(std::launch::async, | |||
| &GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||
| fut1.wait(); | |||
| fut2.wait(); | |||
| Status status1 = fut1.get(); | |||
| Status status2 = fut2.get(); | |||
| EXPECT_EQ(status1, ge::SUCCESS); | |||
| EXPECT_EQ(status2, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_remove_graph_1) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| Status status = graph_manager.RemoveGraph(graph_id); | |||
| EXPECT_EQ(status, ge::GE_GRAPH_GRAPH_NOT_EXIST); | |||
| graph_manager.AddGraphNode(graph_id, graph_node); | |||
| graph_node->SetRunFlag(true); | |||
| status = graph_manager.RemoveGraph(graph_id); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_remove_graph_2) { | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| Graph graph("test_graph"); | |||
| CreateGraph(graph); | |||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| auto listener = MakeShared<RunAsyncListener>(); | |||
| shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
| davinci_model1->SetId(1); | |||
| shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||
| davinci_model1->SetId(2); | |||
| model_manager->InsertModel(1, davinci_model1); | |||
| model_manager->InsertModel(2, davinci_model2); | |||
| ge_root_model->SetModelId(1); | |||
| ge_root_model->SetModelId(2); | |||
| graph_node->SetGeRootModel(ge_root_model); | |||
| graph_node->SetLoadFlag(true); | |||
| graph_manager.AddGraphNode(graph_id, graph_node); | |||
| Status status = graph_manager.RemoveGraph(graph_id); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_pre_run_thread) { | |||
| GraphManager graph_manager; | |||
| graph_manager.thread_run_flag_ = true; | |||
| GraphId graph_id = 1; | |||
| std::vector<ge::InputTensorInfo> input_tensor; | |||
| uint64_t session_id = 0; | |||
| ErrorMessage::Context error_context; | |||
| GEThreadLocalContext context; | |||
| RunAsyncCallback callback; | |||
| // PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||
| bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
| EXPECT_EQ(ret, true); | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_manager.AddGraphNode(graph_id, graph_node); | |||
| graph_manager.PreRunThread(&graph_manager); | |||
| // end with failed | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) { | |||
| GraphManager graph_manager; | |||
| graph_manager.thread_run_flag_ = true; | |||
| GraphId graph_id = 1; | |||
| GraphNodePtr graph_node_1 = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_manager.AddGraphNode(graph_id, graph_node_1); | |||
| graph_manager.IncreaseGraphCount(graph_id); | |||
| graph_manager.IncreaseGraphCount(graph_id); | |||
| graph_node_1->SetBuildFlag(true); | |||
| std::vector<ge::InputTensorInfo> input_tensor; | |||
| uint64_t session_id = 0; | |||
| ErrorMessage::Context error_context; | |||
| GEThreadLocalContext context; | |||
| RunAsyncCallback callback; | |||
| // PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||
| bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
| EXPECT_EQ(ret, true); | |||
| graph_id = 2; | |||
| GraphNodePtr graph_node_2 = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_manager.AddGraphNode(graph_id, graph_node_2); | |||
| ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
| EXPECT_EQ(ret, true); | |||
| graph_manager.PreRunThread(&graph_manager); | |||
| // end with failed | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_check_and_release_memory) { | |||
| GraphManager graph_manager; | |||
| GeModelPtr ge_model = make_shared<GeModel>(); | |||
| int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; | |||
| int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; | |||
| uint64_t session_id = 0; | |||
| ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size); | |||
| ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size); | |||
| ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id); | |||
| GraphId graph_id = 1; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_manager.AddGraphNode(graph_id, graph_node); | |||
| graph_manager.IncreaseGraphCount(graph_id); | |||
| graph_manager.IncreaseGraphCount(graph_id); | |||
| auto model_manager = ModelManager::GetInstance(); | |||
| auto listener = MakeShared<RunAsyncListener>(); | |||
| shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
| davinci_model1->SetId(1); | |||
| shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||
| davinci_model1->SetId(2); | |||
| model_manager->InsertModel(1, davinci_model1); | |||
| model_manager->InsertModel(2, davinci_model2); | |||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
| bool is_dynamic_shape = false; | |||
| (void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
| ge_root_model->SetModelId(1); | |||
| ge_root_model->SetModelId(2); | |||
| graph_node->SetGeRootModel(ge_root_model); | |||
| graph_node->SetLoadFlag(true); | |||
| Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { | |||
| // no need to build | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
| GraphManager::PreRunArgs arg; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_node->SetBuildFlag(true); | |||
| Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
| EXPECT_EQ(status, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) { | |||
| // need build while buildflag is true, var format changed | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
| GraphManager::PreRunArgs arg; | |||
| arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_node->SetBuildFlag(true); | |||
| graph_node->Lock(); | |||
| graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id); | |||
| Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
| EXPECT_EQ(status, ge::PARAM_INVALID); | |||
| } | |||
| TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { | |||
| // need build while buildflag is false, var format unchanged | |||
| GraphId graph_id = 1; | |||
| GraphManager graph_manager; | |||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
| GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
| GraphManager::PreRunArgs arg; | |||
| arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
| GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
| graph_node->SetBuildFlag(false); | |||
| graph_node->Lock(); | |||
| Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
| EXPECT_NE(status, ge::SUCCESS); | |||
| } | |||