diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index d924302c..5142e347 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -20,9 +20,12 @@ #include #include "graph/load/model_manager/model_manager.h" +#include "graph/load/model_manager/davinci_model.h" #include "omm/csa_interact.h" namespace ge { +using Uint32Pair = pair; +const uint32_t kInvalidModelId = UINT32_MAX; GraphExecutor::GraphExecutor() : init_flag_(false), train_graph_flag_(false), @@ -380,7 +383,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro } Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, - const std::vector &input_tensor) { + const std::vector &input_tensor, + const RunAsyncCallback& callback) { GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); if (graph_id != last_graph_id_) { auto ret = FreeExecuteMemory(); @@ -390,7 +394,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & } last_graph_id_ = graph_id; GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); - Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); + Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); if (ret != SUCCESS) { GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); return GE_GRAPH_SYNC_MODEL_FAILED; @@ -400,11 +404,81 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & return SUCCESS; } -Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector &inputs) { +bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) { + return lhs.second < rhs.second; +} + +uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) { + std::vector model_ids = ge_root_model->GetAllModelId(); + if (model_ids.empty()) { + return kInvalidModelId; + } + if (model_ids.size() == 1) { + return ge_root_model->GetModelId(); + } + std::vector model_id_to_loads; + auto model_manager = ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + for (auto model_id : model_ids) { + auto davinci_model = model_manager->GetModel(model_id); + auto hybrid_model = model_manager->GetHybridModel(model_id); + if (hybrid_model == nullptr) { + GE_CHECK_NOTNULL(davinci_model); + } + uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() : + davinci_model->GetDataInputerSize(); + uint32_t running_load = hybrid_model != nullptr ? static_cast(hybrid_model->GetRunningFlag()) : + static_cast(davinci_model->GetRunningFlag()); + uint32_t load = input_load + running_load; + if (load == 0) { + return model_id; + } + model_id_to_loads.emplace_back(model_id, load); + } + sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad); + if (model_id_to_loads.empty()) { + return kInvalidModelId; + } + return model_id_to_loads.begin()->first; +} + +Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, + const RunAsyncCallback &callback) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + if (model_manager->IsNeedHybridLoad(*ge_root_model)) { + auto model = model_manager->GetHybridModel(model_id); + GE_CHECK_NOTNULL(model); + if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { + GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); + return FAILED; + } + } else { + auto model = model_manager->GetModel(model_id); + GE_CHECK_NOTNULL(model); + if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { + GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); + return FAILED; + } + } + return SUCCESS; +} + +Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &inputs, + const RunAsyncCallback &callback) { + uint32_t model_id = GetExecuteModelId(ge_root_model); + if (model_id == kInvalidModelId) { + GELOGE(INTERNAL_ERROR, "No valid model id."); + return INTERNAL_ERROR; + } try { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); GELOGI("RunAsync begin.model_id %u", model_id); + if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) { + GELOGE(FAILED, "RunAsync: SetCallBack for model fail"); + return FAILED; + } Status ret = model_manager->DataInputTensor(model_id, inputs); if (ret != SUCCESS) { diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h index d2a92e47..2add453f 100755 --- a/ge/graph/execute/graph_execute.h +++ b/ge/graph/execute/graph_execute.h @@ -50,7 +50,7 @@ class GraphExecutor { std::vector &output_tensor); ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, - const std::vector &input_tensor); + const std::vector &input_tensor, const RunAsyncCallback &callback); Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); @@ -116,6 +116,8 @@ class GraphExecutor { static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); + uint32_t GetExecuteModelId(const GeRootModelPtr &ge_root_model); + private: Status PrepareInputData(const std::vector &input_tensor, InputData &graph_input_data, OutputData &graph_output_data, std::vector &output_desc); @@ -123,7 +125,8 @@ class GraphExecutor { Status SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor); - Status AsyncExecuteModel(uint32_t model_id, const std::vector &input_tensor); + Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, + const RunAsyncCallback &callback); void InitModelIdInfo(std::vector &out_model_id_info, std::vector &sub_graph_vec, uint32_t output_size); @@ -132,6 +135,9 @@ class GraphExecutor { Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); + static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, + const RunAsyncCallback &callback); + bool init_flag_; bool train_graph_flag_; diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc index cf95b271..bdf415a3 100755 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -63,7 +63,6 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptrGetModelId(); auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); diff --git a/ge/graph/load/model_manager/data_inputer.h b/ge/graph/load/model_manager/data_inputer.h index 14ebcea5..b8d145d4 100755 --- a/ge/graph/load/model_manager/data_inputer.h +++ b/ge/graph/load/model_manager/data_inputer.h @@ -134,6 +134,8 @@ class DataInputer { /// void Stop() { queue_.Stop(); } + uint32_t Size() { return queue_.Size(); } + private: /// /// @ingroup domi_ome diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 78f4a64c..2811d0a1 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -2737,6 +2737,8 @@ void *DavinciModel::Run(DavinciModel *model) { ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); while (model->RunFlag()) { + // Model hasn't truly started runing before received data + model->SetRunningFlag(false); bool rslt_flg = true; if (model->GetDataInputer() == nullptr) { GELOGW("Data inputer is nullptr."); @@ -2746,6 +2748,8 @@ void *DavinciModel::Run(DavinciModel *model) { std::shared_ptr data_wrapper; Status ret = model->GetDataInputer()->Pop(data_wrapper); + // Model run indeedly start after received data. + model->SetRunningFlag(true); if (data_wrapper == nullptr || ret != SUCCESS) { GELOGI("data_wrapper is null!"); continue; @@ -2832,7 +2836,9 @@ void *DavinciModel::Run(DavinciModel *model) { model->iterator_count_++; model->is_first_execute_ = false; - GELOGI("run iterator count is %lu", model->iterator_count_); + // model run finished + model->SetRunningFlag(false); + GELOGI("run iterator count is %lu, model_id:%u", model->iterator_count_, model->model_id_); } CsaInteract::GetInstance().WriteInternalErrorCode(); @@ -2890,7 +2896,7 @@ Status DavinciModel::ModelRunStart() { error_context_ = ErrorManager::GetInstance().GetErrorContext(); CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); - GELOGI("model tread create success, model id:%u.", model_id_); + GELOGI("model thread create success, model id:%u.", model_id_); return SUCCESS; } @@ -4340,4 +4346,10 @@ Status DavinciModel::InitL1DataDumperArgs() { return SUCCESS; } +Status DavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { + auto listener = dynamic_cast(listener_.get()); + GE_CHECK_NOTNULL(listener); + listener->SetCallback(callback); + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 30240f25..c28ed4d0 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -221,6 +221,11 @@ class DavinciModel { /// DataInputer *const GetDataInputer() const { return data_inputer_; } + uint32_t GetDataInputerSize() { + GE_CHECK_NOTNULL(data_inputer_); + return data_inputer_->Size(); + } + // get Stream number uint32_t StreamNum() const { return runtime_param_.stream_num; } @@ -560,6 +565,10 @@ class DavinciModel { return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); } + bool GetRunningFlag() const { return running_flg_; } + void SetRunningFlag(bool flag) { running_flg_ = flag; } + Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -924,6 +933,8 @@ class DavinciModel { shared_ptr listener_; bool run_flg_; + // check whether model is running with data + bool running_flg_ = false; mutex mux_run_flg_; diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 84259731..df86291d 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -330,6 +330,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetSubgraphInstanceNameToModel(); string om_name; @@ -363,7 +364,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrAssign(ge_model)), GELOGW("assign model to modeldef failed."); break;); GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); - + /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. + /// These session_ids come from the same model, so the values of session_id are the same. + /// Update session_id for infer in load model to avoid the same session_id. + if (!ge_root_model->GetTrainFlag()) { + uint64_t new_session_id; + ret = GenSessionId(new_session_id); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); + ret = davinci_model->UpdateSessionId(new_session_id); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); + ge_model->InsertSessionMap(model_id, new_session_id); + GELOGD("Update new session id: %lu.", new_session_id); + } GE_TIMESTAMP_START(Init); GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); @@ -376,16 +388,16 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr &davinci_model) { - GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); +void ModelManager::InsertModel(uint32_t model_id, std::shared_ptr &davinci_model) { + GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", model_id); std::lock_guard lock(map_mutex_); - model_map_[id] = davinci_model; + model_map_[model_id] = davinci_model; } -void ModelManager::InsertModel(uint32_t id, shared_ptr &hybrid_model) { - GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); +void ModelManager::InsertModel(uint32_t model_id, shared_ptr &hybrid_model) { + GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", model_id); std::lock_guard lock(map_mutex_); - hybrid_model_map_[id] = hybrid_model; + hybrid_model_map_[model_id] = hybrid_model; } Status ModelManager::DeleteModel(uint32_t id) { diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index b537943b..1d52696a 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -330,8 +330,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @ingroup domi_ome /// @brief insert new model into model manager set /// - void InsertModel(uint32_t id, std::shared_ptr &davinci_model); - void InsertModel(uint32_t id, std::shared_ptr &hybrid_model); + void InsertModel(uint32_t model_id, std::shared_ptr &davinci_model); + void InsertModel(uint32_t model_id, std::shared_ptr &hybrid_model); /// /// @ingroup domi_ome diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 82da6257..f2b4211d 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -121,6 +121,10 @@ const char *const kAIcoreEngine = "AIcoreEngine"; const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsData = 1; const char *const kGetNextName = "IteratorV2"; +const uint32_t kInitGraphCount = 1; +const uint32_t kNotAdded = 0; +const uint32_t kStartAdd = 1; +const uint32_t kDoneAdded = 2; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -202,6 +206,8 @@ Status GraphManager::Initialize(const std::map &options) { graph_map_.clear(); cache_helper_map_.clear(); + graph_id_to_add_graph_cond_.clear(); + graph_count_.clear(); init_flag_ = true; thread_run_flag_ = true; @@ -211,6 +217,20 @@ Status GraphManager::Initialize(const std::map &options) { return SUCCESS; } +Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) { + Status ret = SUCCESS; + for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { + uint32_t model_id = ge_root_model->GetAllModelId()[i]; + GELOGI("Unload model %u.", model_id); + ret = GraphLoader::UnloadModel(model_id); + if (ret != SUCCESS) { + GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); + return ret; + } + } + return ret; +} + Status GraphManager::Finalize() { if (!init_flag_) { GELOGW("GraphManager has not been initialized."); @@ -241,7 +261,6 @@ Status GraphManager::Finalize() { unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; continue; } - // unload model auto ge_root_model = graph_node->GetGeRootModel(); if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { @@ -251,15 +270,14 @@ Status GraphManager::Finalize() { unload_model_ret = FAILED; continue; } - ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); + ret = UnloadModel(ge_root_model, iter->first); if (ret != SUCCESS) { - GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); + GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); unload_model_ret = ret; } rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), - iter->first); + GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first); unload_model_ret = FAILED; continue; } @@ -274,6 +292,7 @@ Status GraphManager::Finalize() { } graph_map_.clear(); cache_helper_map_.clear(); + graph_count_.clear(); // graph context if (graph_context_ != nullptr) { @@ -326,35 +345,59 @@ Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) { return SUCCESS; } -Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, - const std::map &options, - const OmgContext &omg_context) { - if (HasGraphNode(graph_id)) { - REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id); - GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); - return GE_GRAPH_GRAPH_ALREADY_EXIST; +void GraphManager::SetAddGraphCondition(GraphId graph_id, uint32_t cond) { + std::lock_guard lock(add_graph_cond_mutex_); + graph_id_to_add_graph_cond_[graph_id] = cond; + GELOGD("Graph [id:%u] has been added.", graph_id); +} + +uint32_t GraphManager::GetAddGraphCondition(GraphId graph_id) { + std::lock_guard lock(add_graph_cond_mutex_); + auto it = graph_id_to_add_graph_cond_.find(graph_id); + if (it != graph_id_to_add_graph_cond_.end()) { + return it->second; + } else { + GELOGD("Graph [id:%u] has not been added.", graph_id); + return kNotAdded; } +} - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph != nullptr) { - compute_graph->SetGraphID(graph_id); - bool graph_has_been_added = false; - if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) - && graph_has_been_added) { - REPORT_INNER_ERROR("E19999", "Get Attr:%s from graph:%u fail", - ATTR_NAME_GRAPH_HAS_BEEN_ADDED.c_str(), graph_id); - GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, - "[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id); - return GE_GRAPH_GRAPH_ALREADY_EXIST; - } - (void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); - compute_graph_ = compute_graph; +void GraphManager::RemoveAddGraphCondition(GraphId graph_id) { + std::lock_guard lock(add_graph_cond_mutex_); + auto it = graph_id_to_add_graph_cond_.find(graph_id); + if (it != graph_id_to_add_graph_cond_.end()) { + graph_id_to_add_graph_cond_.erase(it); + GELOGD("Successfully removed add_graph_cond of graph [id:%u].", graph_id); } else { - REPORT_INNER_ERROR("E19999", "compute_graph from graph:%u is nullptr, check invalid", - graph_id); - GELOGE(FAILED, "compute graph is null"); - return FAILED; + GELOGD("Graph [id:%u] has not been added. no need to remove.", graph_id); } +} + +Status GraphManager::CheckRepeatAdd(uint32_t graph_id, bool &is_added) { + uint32_t count = 0; + if (GetGraphCount(graph_id, count) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); + return INTERNAL_ERROR; + } + // previous thread owns same graph_id has been in the middle of the AddGraph procession + if (count > 1 && GetAddGraphCondition(graph_id) == kStartAdd) { + std::unique_lock lock(add_graph_mutex_); + GELOGD("Waitting for build end of previous thread."); + while (GetAddGraphCondition(graph_id) != kDoneAdded) { + add_graph_cv_.wait(lock); + } + GraphNodePtr graph_node; + Status ret = GetGraphNode(graph_id, graph_node); + if (ret != SUCCESS) { + GELOGE(ret, "[AddGraph] GetGraphNode failed, graph_id = %u.", graph_id); + return ret; + } + is_added = true; + } + return SUCCESS; +} + +void GraphManager::SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id) { std::string session_graph_id; if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { session_graph_id = "-1_" + to_string(graph_id); @@ -366,7 +409,24 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, } GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); } +} + +Status GraphManager::NotifyWaittingGraph(uint32_t graph_id) { + uint32_t count = 0; + if (GetGraphCount(graph_id, count) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); + return INTERNAL_ERROR; + } + GELOGD("Add graph finished, graph_id:%u", graph_id); + if (count > 1) { + GELOGD("Finish addgraph, graph_id:%u, graph_count:%u, start to notify.", graph_id, count); + add_graph_cv_.notify_all(); + } + return SUCCESS; +} +Status GraphManager::CreateGraphNode(uint32_t graph_id, const Graph &graph, + const std::map &options) { GraphNodePtr graph_node = MakeShared(graph_id); GE_IF_BOOL_EXEC(graph_node == nullptr, REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u", @@ -385,7 +445,62 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, ParseOption(options, TUNING_PATH, options_.tuning_path); graph_node->SetGraph(graph_ptr); graph_node->SetOptions(options); + graph_node->IncreaseLoadCount(); AddGraphNode(graph_id, graph_node); + return SUCCESS; +} + +Status GraphManager::SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options) { + CompilerStages &stages = GetCompilerStages(graph_id); + stages.preparer.SetOptions(options_); + Status status = stages.optimizer.SetOptions(options_); + if (status != SUCCESS) { + GELOGE(status, "Graph optimizer set options failed."); + return status; + } + stages.builder.SetOptions(options_); + return SUCCESS; +} + +Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, + const std::map &options, + const OmgContext &omg_context) { + IncreaseGraphCount(graph_id); + // validation for adding graphs of same graph_id in multi-thread secenario + // 1.previous thread owns same graph_id has finished the AddGraph procession + if (GetAddGraphCondition(graph_id) == kDoneAdded) { + GraphNodePtr graph_node; + if (GetGraphNode(graph_id, graph_node) != SUCCESS) { + GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "Graph not exist while done adding previously, graph_id = %u.", graph_id); + return GE_GRAPH_GRAPH_NOT_EXIST; + } + graph_node->IncreaseLoadCount(); + return SUCCESS; + } + // In multi-thread scenario, former thread owns same graph_id has been + // in the middle of the AddGraph procession while following threads have to wait until + // done adding graph of the former graph, avoiding repeatively adding same graph. + bool is_added = false; + if (CheckRepeatAdd(graph_id, is_added) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "CheckRepeatAdd for graph[id:%u] failed.", graph_id); + return INTERNAL_ERROR; + } + // The former graph (from different thread) owns same graph id has been successfully added. + if (is_added) { + return SUCCESS; + } + // Do add graph + SetAddGraphCondition(graph_id, kStartAdd); + auto compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + compute_graph->SetGraphID(graph_id); + + SetSessionGraphId(compute_graph, graph_id); + + if (CreateGraphNode(graph_id, graph, options) != SUCCESS) { + GELOGE(FAILED, "Failed to create graph_node."); + return FAILED; + } AddLocalOmgContext(graph_id, omg_context); if (!options_.output_datatype.empty()) { @@ -396,16 +511,18 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, return GRAPH_PARAM_INVALID; } - CompilerStages &stages = GetCompilerStages(graph_id); - stages.preparer.SetOptions(options_); - Status status = stages.optimizer.SetOptions(options_); - if (status != SUCCESS) { - GELOGE(status, "Graph optimizer set options failed."); - return status; + if (SetStagesOptions(graph_id, options_) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Set stage options failed."); + return INTERNAL_ERROR; } - stages.builder.SetOptions(options_); var_acc_ctrl_.AddGraph(graph_id, compute_graph); + SetAddGraphCondition(graph_id, kDoneAdded); + // There are threads waitting for adding same graph + if (NotifyWaittingGraph(graph_id) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "NotifyWaittingGraph failed."); + return INTERNAL_ERROR; + } return SUCCESS; } @@ -962,6 +1079,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: if (!graph_node->IsAsync()) { ret = LoadGraph(ge_root_model, graph_node); } else { + GE_CHECK_NOTNULL(ge_root_model); ret = LoadGraphAsync(ge_root_model, graph_node); } if (ret != SUCCESS) { @@ -976,6 +1094,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: if (!graph_node->IsAsync()) { ret = LoadGraph(ge_root_model_ptr, graph_node); } else { + GE_CHECK_NOTNULL(ge_root_model); ret = LoadGraphAsync(ge_root_model_ptr, graph_node); } if (ret != SUCCESS) { @@ -988,6 +1107,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); if (options_.run_graph_flag && ge_root_model != nullptr) { + ge_root_model->SetTrainFlag(GetTrainFlag()); // synchronization run graph with model std::shared_ptr model_listener = GetModelListener(); ModelIdInfo model_id_info; @@ -1413,62 +1533,29 @@ bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load } Status GraphManager::RemoveGraph(const GraphId &graph_id) { + auto it = to_be_deleted_graphs_.find(graph_id); + if (it != to_be_deleted_graphs_.end()) { + to_be_deleted_graphs_.erase(it); + } GraphNodePtr graph_node = nullptr; Status ret = GetGraphNode(graph_id, graph_node); - if (ret != SUCCESS) { - REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid", - graph_id); + if (ret != SUCCESS || graph_node == nullptr) { + REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid when GraphManager %s", + graph_id, __FUNCTION__); GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); return GE_GRAPH_GRAPH_NOT_EXIST; } - - if ((graph_node == nullptr) || (graph_node->GetRunFlag())) { - REPORT_INNER_ERROR("E19999", "Graph:%u is running, can't be remove, check invalid", - graph_id); - GELOGE(GE_GRAPH_GRAPH_IS_RUNNING, "[GraphManager] Id %u is running, can't be deleted.", graph_id); - return GE_GRAPH_GRAPH_IS_RUNNING; + if (graph_node->GetRunFlag()) { + // only put graph into to-be-deleted list when exceptional scenario + to_be_deleted_graphs_.insert(graph_id); + GELOGI("[GraphManager] Trying to remove running graph[Id:%u], added into to_be_deleted_graphs_.", graph_id); + return SUCCESS; } std::lock_guard lock(unload_model_mutex_); Status middle_ret; rtError_t rt_ret; - const std::vector &all_sub_graph = graph_node->GetAllSubGraph(); - for (size_t i = 0; i < all_sub_graph.size(); ++i) { - // must free buffer firstly - middle_ret = all_sub_graph[i]->FreeInOutBuffer(); - if (middle_ret != SUCCESS) { - GELOGE(middle_ret, "[GraphManager] RemoveGraph free mem failed, graph_id=%u.", graph_id); - ret = middle_ret; - } - if (all_sub_graph[i]->GeModelIsValid() && all_sub_graph[i]->GetModelIdInfo().model_id != INVALID_MODEL_ID) { - // unload model - GELOGI("UnloadModel via new ome."); - rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", - GetContext().DeviceId(), graph_id); - GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", - all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); - ret = FAILED; - continue; - } - middle_ret = GraphLoader::UnloadModel(all_sub_graph[i]->GetModelIdInfo().model_id); - if (middle_ret != SUCCESS) { - GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", - all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); - ret = middle_ret; - } - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset fail, device_id:%u, graph_id:%u", - GetContext().DeviceId(), graph_id); - GELOGE(RT_FAILED, "[GraphManager:] unload model failed, modelId=%u, graphId=%u.", - all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); - ret = FAILED; - } - } - } var_acc_ctrl_.RemoveGraph(graph_id); RemoveGraphNode(graph_id); @@ -1476,7 +1563,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { auto ge_root_model = graph_node->GetGeRootModel(); if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { - GELOGI("Unload model %u.", ge_root_model->GetModelId()); rt_ret = rtSetDevice(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", @@ -1485,23 +1571,27 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { graph_id); return FAILED; } - middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); + // same graph may be added for several times, different models were created separately, + // unload them respectively. + middle_ret = UnloadModel(ge_root_model, graph_id); if (middle_ret != SUCCESS) { - GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(), - graph_id); + REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check unload detail in GraphLoader %s", + graph_id, __FUNCTION__); + GELOGE(middle_ret, "[GraphManager:] unload model failed, graph_id=%u.", graph_id); ret = middle_ret; } rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u", - GetContext().DeviceId(), graph_id); - GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), - graph_id); + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u, when GraphManager %s", + GetContext().DeviceId(), graph_id, __FUNCTION__); + GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); ret = FAILED; } } RemoveCompilerStages(graph_id); + RemoveGraphCount(graph_id); + RemoveAddGraphCondition(graph_id); GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); @@ -2588,6 +2678,7 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); if (options_.run_graph_flag && ge_root_model != nullptr) { + ge_root_model->SetTrainFlag(GetTrainFlag()); // synchronization run graph with model ModelIdInfo model_id_info; bool is_unknown_shape = false; @@ -2604,9 +2695,9 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G } } GE_TIMESTAMP_START(LoadGraph); - GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); - Status ret = - GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); + auto listener = MakeShared(); + GE_CHECK_NOTNULL(listener); + Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener); GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); @@ -2620,6 +2711,52 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G return SUCCESS; } +void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, + const std::vector &model_ids, uint32_t graph_id, uint64_t session_id) { + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, when GraphManager %s", + GetContext().DeviceId(), __FUNCTION__); + GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, graphId=%u.", graph_id); + return; + } + for (auto model_id : model_ids) { + uint64_t max_memory_size = 0; + Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); + if (result != SUCCESS) { + continue; + } + GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, + max_memory_size); + if (model_ids.size() > 1) { + result = ge_model->GetSessionId(model_id, session_id); + if (result != SUCCESS) { + GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, + graph_id); + continue; + } + } + result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); + if (result != SUCCESS) { + GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, + graph_id); + } + result = GraphLoader::UnloadModel(model_id); + if (result != SUCCESS) { + GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); + } + GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); + } + graph_node->SetLoadFlag(false); + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", + GetContext().DeviceId(), __FUNCTION__); + GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); + return; + } +} + Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); int64_t value = 0; @@ -2665,6 +2802,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra continue; } auto model_id = model->GetModelId(); + auto model_ids = model->GetAllModelId(); // unload model not release bool is_unknown_shape = false; GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); @@ -2677,38 +2815,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); continue; } - uint64_t max_memory_size = 0; - result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); - if (result != SUCCESS) { - continue; - } - GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, - max_memory_size); - rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u", - GetContext().DeviceId()); - GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); - continue; - } - result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); - if (result != SUCCESS) { - GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, - graph_id); - } - result = GraphLoader::UnloadModel(model_id); - if (result != SUCCESS) { - GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); - } - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u", - GetContext().DeviceId()); - GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", model_id, graph_id); - continue; - } - it.second->SetLoadFlag(false); - GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success and set LoadFlag to false.", graph_id, model_id); + ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id); } return SUCCESS; @@ -2849,6 +2956,38 @@ void GraphManager::ConstructGeInput(const vector &inputs, vecto } } +Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, + GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { + if (!graph_manager->IsGraphNeedBuild(graph_node)) { + ge_root_model = graph_node->GetGeRootModel(); + return SUCCESS; + } + if (graph_node->GetBuildFlag()) { + ReturnError(graph_manager, args.callback, PARAM_INVALID, + "The graph " + std::to_string(graph_node->GetGraphId()) + + " need to re-build, you should remove it" + " from GE first, then AddGraph again and rebuild it."); + graph_node->Unlock(); + return PARAM_INVALID; + } + // check need incre build. + GeModelPtr ge_model = nullptr; + if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { + std::vector ge_inputs; + ConstructGeInput(args.input_tensor, ge_inputs); + Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); + if (ret != SUCCESS) { + ReturnError(graph_manager, args.callback, ret, "PreRun Failed."); + return ret; + } + } + graph_node->SetBuildFlag(true); + graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); + return SUCCESS; +} + void GraphManager::PreRunThread(GraphManager *graph_manager) { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { GELOGW("Set thread name failed."); @@ -2861,7 +3000,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { continue; } - GELOGI("A new loop start."); + GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id); ErrorManager::GetInstance().SetErrorContext(args.error_context); ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); @@ -2877,7 +3016,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); return; } - + // more than one graph owns same graph_id + uint32_t count = 0; + if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed.", args.graph_id); + return; + } + // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency + if (count > 1 && graph_node->GetBuildFlag()) { + graph_node->Lock(); + GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); + // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times + graph_node->SetSemSize(count); + graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, + args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); + GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); + continue; + } + // Cannot be put ahead of the repeatively prerun judgement graph_node->Lock(); if (graph_node->GetRunFlag()) { @@ -2909,46 +3065,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. GELOGI("Start for run graph async."); GeRootModelPtr ge_root_model = nullptr; - if (graph_manager->IsGraphNeedBuild(graph_node)) { - if (graph_node->GetBuildFlag()) { - ReturnError(graph_manager, args.callback, PARAM_INVALID, - "The graph " + std::to_string(graph_node->GetGraphId()) + - " need to re-build, you should remove it" - " from GE first, then AddGraph again and rebuild it."); + + ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model); + if (ret != SUCCESS) { + graph_node->SetRunFlag(false); + if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { + ReturnError(graph_manager, args.callback, ret, "CheckIncreBuildAndPreRun Failed, thread exit.."); graph_node->Unlock(); return; + } else { + ReturnError(graph_manager, graph_node, args.callback, ret, + "CheckIncreBuildAndPreRun Failed, keep geop continue!"); + graph_node->Unlock(); + continue; } - - // check need incre build. - GeModelPtr ge_model = nullptr; - if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { - std::vector ge_inputs; - ConstructGeInput(args.input_tensor, ge_inputs); - ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); - // release rts generate context - RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); - if (ret != SUCCESS) { - graph_node->SetRunFlag(false); - if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { - ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); - graph_node->Unlock(); - return; - } else { - ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!"); - graph_node->Unlock(); - continue; - } - } - } - graph_node->SetBuildFlag(true); - graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); - } else { - ge_root_model = graph_node->GetGeRootModel(); } - graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); - GELOGI("Loop end."); + GELOGI("[PreRunThread] Loop end."); } } @@ -3051,16 +3185,13 @@ void GraphManager::RunThread(GraphManager *graph_manager) { continue; } - GELOGI("A new loop start."); + GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); ErrorManager::GetInstance().SetErrorContext(args.error_context); GetContext().SetSessionId(args.session_id); GetThreadLocalContext() = args.context; graph_manager->UpdateLocalOmgContext(args.graph_id); - if (args.graph_node->graph_run_async_listener_ != nullptr) { - args.graph_node->graph_run_async_listener_->SetCallback(args.callback); - } Status ret; // parse inputs.dims to vector> dynamic_dims ret = graph_manager->ParseInputsDims(args.input_tensor); @@ -3070,8 +3201,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { return; } + args.graph_node->UpdateLoadFlag(); if (!args.graph_node->GetLoadFlag()) { ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); + args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag()); ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); if (ret != SUCCESS || args.ge_root_model == nullptr) { StopQueue(graph_manager); @@ -3079,6 +3212,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { args.graph_node->Unlock(); return; } + // control the times of graph loading in multi-thread scenario + args.graph_node->DecreaseLoadCount(); + args.graph_node->IncreaseLoadRecord(); + args.graph_node->SetLoadFlag(true); GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), args.ge_root_model->GetModelId()); @@ -3093,9 +3230,9 @@ void GraphManager::RunThread(GraphManager *graph_manager) { graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); } - args.graph_node->SetRunFlag(false); ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), - args.input_tensor); + args.input_tensor, args.callback); + args.graph_node->SetRunFlag(false); if (ret != SUCCESS) { ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); args.graph_node->Unlock(); @@ -3546,4 +3683,49 @@ void GraphManager::RemoveCompilerStages(GraphId graph_id) { std::lock_guard lock(member_mutex_); compiler_stages_.erase(graph_id); } + +void GraphManager::IncreaseGraphCount(GraphId graph_id) { + std::lock_guard lock(graph_count_mutex_); + auto it = graph_count_.find(graph_id); + if (it == graph_count_.end()) { + graph_count_.insert({graph_id, kInitGraphCount}); + GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); + } else { + ++graph_count_[graph_id]; + GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); + } +} + +void GraphManager::RemoveGraphCount(GraphId graph_id) { + std::lock_guard lock(graph_count_mutex_); + auto it = graph_count_.find(graph_id); + if (it == graph_count_.end()) { + GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); + } else { + GELOGD("RemoveGraphCount success, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); + graph_count_.erase(it); + } +} + +void GraphManager::DecreaseGraphCount(GraphId graph_id) { + std::lock_guard lock(graph_count_mutex_); + auto it = graph_count_.find(graph_id); + if (it == graph_count_.end()) { + GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); + } else { + --it->second; + GELOGD("After DecreaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); + } +} + +Status GraphManager::GetGraphCount(GraphId graph_id, uint32_t &count) { + std::lock_guard lock(graph_count_mutex_); + auto it = graph_count_.find(graph_id); + if (it == graph_count_.end()) { + GELOGW("Graph [id:%u] has not been added.", graph_id); + return FAILED; + } + count = it->second; + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index b63b138a..0533a0b6 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -184,6 +184,20 @@ class GraphManager { Status SaveCheckPointResult(const Graph &graph, const std::vector &outputs, map &var_results); + void RemoveGraphCount(GraphId graph_id); + + void IncreaseGraphCount(GraphId graph_id); + + void DecreaseGraphCount(GraphId graph_id); + + Status GetGraphCount(GraphId graph_id, uint32_t &count); + + void SetAddGraphCondition(GraphId graph_id, uint32_t cond); + + uint32_t GetAddGraphCondition(GraphId graph_id); + + void RemoveAddGraphCondition(GraphId graph_id); + private: struct CompilerStages { GraphPrepare preparer; @@ -381,6 +395,24 @@ class GraphManager { CompilerStages &GetCompilerStages(GraphId graph_id); void RemoveCompilerStages(GraphId graph_id); + static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node, + GeRootModelPtr &ge_root_model); + + void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector &model_ids, + uint32_t graph_id, uint64_t session_id); + + Status CheckRepeatAdd(uint32_t graph_id, bool &is_added); + + Status NotifyWaittingGraph(uint32_t graph_id); + + Status CreateGraphNode(uint32_t graph_id, const Graph &graph, const std::map &options); + + Status SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options); + + Status UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id); + + void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); + std::atomic_bool thread_run_flag_; BlockingQueue prerun_args_q_{}; BlockingQueue run_args_q_{}; @@ -416,6 +448,16 @@ class GraphManager { std::mutex member_mutex_; std::mutex unload_model_mutex_; + // avoid repeatively add same graph (owns same graph id) + std::mutex add_graph_mutex_; + std::mutex add_graph_cond_mutex_; + std::condition_variable add_graph_cv_; + + std::map graph_id_to_add_graph_cond_; + // use for multi-thread online-infer scenario + std::set to_be_deleted_graphs_; + std::map graph_count_; + std::mutex graph_count_mutex_; }; } // namespace ge diff --git a/ge/graph/manager/graph_manager_utils.cc b/ge/graph/manager/graph_manager_utils.cc index 3a8d577c..e9d72bd8 100644 --- a/ge/graph/manager/graph_manager_utils.cc +++ b/ge/graph/manager/graph_manager_utils.cc @@ -60,6 +60,15 @@ void GraphNode::Unlock() { sem_.Pop(unused); } +void GraphNode::IncreaseLoadCount() { + std::unique_lock lock(load_count_mu_); + if (load_record_ == kMaxLoadNum) { + GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum); + return; + } + ++load_count_; +} + SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} SubGraphInfo::~SubGraphInfo() { diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index cfe6588f..ffbc20cf 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -55,6 +55,7 @@ using ConstGraphPtr = std::shared_ptr; using GraphPtr = std::shared_ptr; const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; +const uint32_t kMaxLoadNum = 8; struct ModelIdInfo { uint32_t model_id{INVALID_MODEL_ID}; @@ -162,6 +163,8 @@ class GraphNode { bool GetBuildFlag() const { return build_flag_; } void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } bool GetLoadFlag() const { return load_flag_; } + // allow repeatively load graph owns same graph id + void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; } void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } GeModelPtr GetGeModel() const { return ge_model_; } @@ -172,6 +175,13 @@ class GraphNode { void Lock(); void Unlock(); + void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } + + uint32_t GetLoadCount() const { return load_count_; } + void IncreaseLoadCount(); + void DecreaseLoadCount() { --load_count_; } + void IncreaseLoadRecord() { ++load_record_; } + // run graph asynchronous listener std::shared_ptr graph_run_async_listener_; @@ -184,11 +194,17 @@ class GraphNode { GraphPtr graph_; ComputeGraphPtr compute_graph_; bool build_flag_; + // load_flag_ is true if more than 1 model were loaded bool load_flag_; bool async_; GeModelPtr ge_model_; GeRootModelPtr ge_root_model_; BlockingQueue sem_; + // consist with graph_count of same graph_id in graph_manager + uint32_t load_count_ = 0; + // total times of loading a graph with same graph_id. + uint32_t load_record_ = 0; + std::mutex load_count_mu_; }; using GraphNodePtr = std::shared_ptr; diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index ca505618..f3f1e1f5 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -144,8 +144,12 @@ Status HybridModelAsyncExecutor::RunInternal() { GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); while (run_flag_) { + // Model has not indeedly started running before received data + SetRunningFlag(false); std::shared_ptr data_wrapper; Status ret = data_inputer_->Pop(data_wrapper); + // Model indeedly start running + SetRunningFlag(true); if (data_wrapper == nullptr || ret != SUCCESS) { GELOGI("data_wrapper is null!, ret = %u", ret); continue; @@ -185,7 +189,8 @@ Status HybridModelAsyncExecutor::RunInternal() { RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); iterator_count_++; - GELOGI("run iterator count is %lu", iterator_count_); + SetRunningFlag(false); + GELOGI("run iterator count is %lu, model_id:%u", iterator_count_, model_id_); } CsaInteract::GetInstance().WriteInternalErrorCode(); diff --git a/ge/hybrid/executor/hybrid_model_async_executor.h b/ge/hybrid/executor/hybrid_model_async_executor.h index b6942b10..d3fd3d2a 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.h +++ b/ge/hybrid/executor/hybrid_model_async_executor.h @@ -55,6 +55,12 @@ class HybridModelAsyncExecutor { Status EnqueueData(const std::shared_ptr &data); + uint32_t GetDataInputerSize() { return data_inputer_->Size(); } + + bool GetRunningFlag() const { return running_flag_; } + + void SetRunningFlag(bool flag) { running_flag_ = flag; } + private: Status InitInputDesc(); @@ -84,6 +90,8 @@ class HybridModelAsyncExecutor { uint32_t device_id_ = 0U; uint32_t model_id_ = 0U; std::atomic_bool run_flag_; + // check whether model is running with data + bool running_flag_ = false; std::unique_ptr data_inputer_; std::unique_ptr executor_; std::unique_ptr pipe_executor_; diff --git a/ge/hybrid/hybrid_davinci_model.cc b/ge/hybrid/hybrid_davinci_model.cc index e06b9446..58432031 100755 --- a/ge/hybrid/hybrid_davinci_model.cc +++ b/ge/hybrid/hybrid_davinci_model.cc @@ -19,6 +19,7 @@ #include "hybrid/model/hybrid_model.h" #include "hybrid/executor/hybrid_model_async_executor.h" #include "hybrid/node_executor/node_executor.h" +#include "graph/manager/graph_manager_utils.h" namespace ge { namespace hybrid { @@ -108,6 +109,17 @@ class HybridDavinciModel::Impl { model_.SetModelDescVersion(is_new_model_desc); } + uint32_t GetDataInputerSize() { return executor_.GetDataInputerSize(); } + + bool GetRunningFlag() const { return executor_.GetRunningFlag(); } + + Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { + auto listener = dynamic_cast(listener_.get()); + GE_CHECK_NOTNULL(listener); + listener->SetCallback(callback); + return SUCCESS; + } + private: std::shared_ptr listener_; HybridModel model_; @@ -222,5 +234,16 @@ uint64_t HybridDavinciModel::GetSessionId() { GE_CHECK_NOTNULL(impl_); return impl_->GetSessionId(); } + +uint32_t HybridDavinciModel::GetDataInputerSize() { + GE_CHECK_NOTNULL(impl_); + return impl_->GetDataInputerSize(); +} + +bool HybridDavinciModel::GetRunningFlag() const { return impl_->GetRunningFlag(); } + +Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { + return impl_->SetRunAsyncListenerCallback(callback); +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/hybrid_davinci_model.h b/ge/hybrid/hybrid_davinci_model.h index 3b3473ff..449dd73e 100644 --- a/ge/hybrid/hybrid_davinci_model.h +++ b/ge/hybrid/hybrid_davinci_model.h @@ -74,6 +74,12 @@ class HybridDavinciModel { void SetModelDescVersion(bool is_new_model_desc); + uint32_t GetDataInputerSize(); + + bool GetRunningFlag() const; + + Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); + private: HybridDavinciModel() = default; class Impl; diff --git a/ge/hybrid/hybrid_davinci_model_stub.cc b/ge/hybrid/hybrid_davinci_model_stub.cc index 67a7a101..f30fe5cc 100644 --- a/ge/hybrid/hybrid_davinci_model_stub.cc +++ b/ge/hybrid/hybrid_davinci_model_stub.cc @@ -68,6 +68,10 @@ uint64_t HybridDavinciModel::GetSessionId() { return 0; } +uint32_t HybridDavinciModel::GetDataInputerSize() { + return 0; +} + Status HybridDavinciModel::GetDynamicBatchInfo(std::vector> &batch_info, int32_t &dynamic_type) { return UNSUPPORTED; } @@ -87,5 +91,13 @@ Status HybridDavinciModel::GetInputOutputDescInfo(vector &i void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { } + +bool HybridDavinciModel::GetRunningFlag() const { + return false; +} + +Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { + return UNSUPPORTED; +} } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/ge/model/ge_model.cc b/ge/model/ge_model.cc index acaeff0d..bcccc6f8 100755 --- a/ge/model/ge_model.cc +++ b/ge/model/ge_model.cc @@ -85,4 +85,14 @@ ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; } ConstProtoAttrMapHelper GeModel::GetAttrMap() const { return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); } + +Status GeModel::GetSessionId(uint32_t model_id, uint64_t &session_id) const { + auto it = model_id_to_session_id_map_.find(model_id); + if (it != model_id_to_session_id_map_.end()) { + session_id = it->second; + return SUCCESS; + } + GELOGW("No session id were found with model id [%u].", model_id); + return INTERNAL_ERROR; +} } // namespace ge diff --git a/ge/model/ge_model.h b/ge/model/ge_model.h index 5676c3b6..08db8cc3 100755 --- a/ge/model/ge_model.h +++ b/ge/model/ge_model.h @@ -71,6 +71,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder void SetModelId(uint32_t model_id) { model_id_ = model_id; } uint32_t GetModelId() const { return model_id_; } + Status GetSessionId(uint32_t model_id, uint64_t &session_id) const; + void InsertSessionMap(uint32_t model_id, uint64_t session_id) { + model_id_to_session_id_map_.insert({model_id, session_id}); + } + protected: ConstProtoAttrMapHelper GetAttrMap() const override; @@ -90,6 +95,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder std::string platform_version_; uint8_t platform_type_ = {0}; uint32_t model_id_ = INVALID_MODEL_ID; + std::map model_id_to_session_id_map_; }; } // namespace ge using GeModelPtr = std::shared_ptr; diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index aa5a4d47..0747d77c 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -32,15 +32,31 @@ class GeRootModel { return subgraph_instance_name_to_model_; }; - const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; - void SetModelId(uint32_t model_id) { model_id_ = model_id; } + const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } + void SetModelId(uint32_t model_id) { + model_id_ = model_id; + // cached for removement + model_ids_.emplace_back(model_id); + } uint32_t GetModelId() const { return model_id_; } + + std::vector GetAllModelId() const { return model_ids_; } + Status CheckIsUnknownShape(bool &is_dynamic_shape); + void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } + + void SetTrainFlag(bool flag) { train_flag_ = flag; } + + bool GetTrainFlag() const { return train_flag_; } + private: ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; uint32_t model_id_ = 0; + // In multithread online secenario, same graph can owns different davinci_model for for concurrency + std::vector model_ids_; + bool train_flag_ = false; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/metadef b/metadef index 1e88df1d..fcebf37d 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 1e88df1d6bfe60faae0aa9fa2d87f273b793aeb0 +Subproject commit fcebf37d7428caf4e0bd6e6c3a4f8143f6eac8b7 diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 75985e4c..07b10dac 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -593,6 +593,7 @@ set(SINGLE_OP_SRC_FILES "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" + "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_pipeline_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" @@ -780,10 +781,12 @@ set(MULTI_PARTS_TEST_FILES "graph/build/mem_assigner_unittest.cc" "graph/build/task_generator_unittest.cc" "graph/build/buffer_pool_mem_assigner_unittest.cc" + "graph/execute/graph_execute_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" "graph/manager/hcom_util_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc" + "graph/manager/graph_manager_unittest.cc" "session/omg_omg_unittest.cc" ) diff --git a/tests/ut/ge/graph/execute/graph_execute_unittest.cc b/tests/ut/ge/graph/execute/graph_execute_unittest.cc new file mode 100644 index 00000000..b24985be --- /dev/null +++ b/tests/ut/ge/graph/execute/graph_execute_unittest.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define protected public +#define private public +#include "graph/execute/graph_execute.h" +#include "graph/load/model_manager/model_manager.h" +#include "graph/load/model_manager/davinci_model.h" +#include "omm/csa_interact.h" +#undef private +#undef public + + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace testing; +using namespace ge; +using namespace domi; + +namespace ge { +namespace { +const uint32_t kInvalidModelId = UINT32_MAX; +} + +class UtestGraphExecuteTest : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +TEST_F(UtestGraphExecuteTest, get_execute_model_id_invalid) { + GraphExecutor executor; + ComputeGraphPtr graph = MakeShared("test"); + GeRootModelPtr ge_root_model = MakeShared(graph); + auto model_id = executor.GetExecuteModelId(ge_root_model); + EXPECT_EQ(model_id, kInvalidModelId); +} + +TEST_F(UtestGraphExecuteTest, get_execute_model_id_1) { + GraphExecutor executor; + ComputeGraphPtr graph = MakeShared("test"); + GeRootModelPtr ge_root_model = MakeShared(graph); + auto model_manager = ModelManager::GetInstance(); + shared_ptr davinci_model1 = MakeShared(1, nullptr); + davinci_model1->SetId(1); + model_manager->InsertModel(1, davinci_model1); + ge_root_model->SetModelId(1); + auto model_id = executor.GetExecuteModelId(ge_root_model); + EXPECT_EQ(model_id, 1); +} + +TEST_F(UtestGraphExecuteTest, get_execute_model_id_2) { + GraphExecutor executor; + ComputeGraphPtr graph = MakeShared("test"); + GeRootModelPtr ge_root_model = MakeShared(graph); + auto model_manager = ModelManager::GetInstance(); + // model1 with 2 load + shared_ptr davinci_model1 = MakeShared(1, nullptr); + davinci_model1->SetId(1); + davinci_model1->data_inputer_ = new DataInputer(); + auto data = MakeShared(); + davinci_model1->data_inputer_->Push(data); + davinci_model1->data_inputer_->Push(data); + model_manager->InsertModel(1, davinci_model1); + // model 2 with 3 load + shared_ptr davinci_model2 = MakeShared(1, nullptr); + davinci_model2->SetId(2); + davinci_model2->data_inputer_ = new DataInputer(); + davinci_model2->data_inputer_->Push(data); + davinci_model2->data_inputer_->Push(data); + davinci_model2->data_inputer_->Push(data); + model_manager->InsertModel(2, davinci_model2); + // model 3 witH 1 load + shared_ptr davinci_model3 = MakeShared(1, nullptr); + davinci_model3->SetId(3); + davinci_model3->data_inputer_ = new DataInputer(); + davinci_model3->data_inputer_->Push(data); + model_manager->InsertModel(3, davinci_model3); + + ge_root_model->SetModelId(1); + ge_root_model->SetModelId(2); + ge_root_model->SetModelId(3); + + auto model_id = executor.GetExecuteModelId(ge_root_model); + // model 3 is picked for having least loads + EXPECT_EQ(model_id, 3); +} + +TEST_F(UtestGraphExecuteTest, test_set_callback) { + GraphExecutor executor; + ComputeGraphPtr graph = MakeShared("test"); + // is_unknown_shape_graph_ = false + GeRootModelPtr ge_root_model = MakeShared(graph); + RunAsyncCallback callback = [](Status, std::vector &) {}; + + auto model_manager = ModelManager::GetInstance(); + auto listener = MakeShared(); + shared_ptr davinci_model1 = MakeShared(1, listener); + davinci_model1->SetId(1); + model_manager->InsertModel(1, davinci_model1); + auto status = executor.SetCallback(1, ge_root_model, callback); + EXPECT_EQ(status, SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc new file mode 100644 index 00000000..dad55f3d --- /dev/null +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -0,0 +1,375 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#define protected public +#define private public +#include "graph/manager/graph_manager.h" +#include "graph/load/model_manager/model_manager.h" +#include "graph/load/model_manager/davinci_model.h" +#define const +#include "common/helper/model_cache_helper.h" +#undef const +#include "init/gelib.h" +#undef private +#undef public + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/math/math_util.h" +#include "common/thread_pool.h" +#include "common/dump/dump_manager.h" +#include "analyzer/analyzer.h" +#include "graph/common/ge_call_wrapper.h" +#include "graph/common/local_context.h" +#include "graph/common/transop_util.h" +#include "graph/ge_context.h" +#include "graph/ge_global_options.h" +#include "graph/manager/util/rt_context_util.h" +#include "graph/partition/dynamic_shape_partition.h" +#include "graph/passes/enter_pass.h" +#include "graph/partition/stage_partition.h" +#include "graph/passes/addn_pass.h" +#include "graph/passes/bitcast_pass.h" +#include "graph/passes/assign_remove_pass.h" +#include "graph/passes/inplace_support_check_pass.h" +#include "graph/passes/atomic_addr_clean_pass.h" +#include "graph/passes/attach_stream_label_pass.h" +#include "graph/passes/cast_remove_pass.h" +#include "graph/passes/common_subexpression_elimination_pass.h" +#include "graph/passes/compile_nodes_pass.h" +#include "graph/passes/cond_remove_pass.h" +#include "graph/passes/constant_folding_pass.h" +#include "graph/passes/constant_fuse_same_pass.h" +#include "graph/passes/control_trigger_pass.h" +#include "graph/passes/ctrl_edge_transfer_pass.h" +#include "graph/passes/dimension_adjust_pass.h" +#include "graph/passes/dimension_compute_pass.h" +#include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" +#include "graph/passes/identity_pass.h" +#include "graph/passes/input_output_connection_identify_pass.h" +#include "graph/passes/iterator_op_pass.h" +#include "graph/passes/link_gen_mask_nodes_pass.h" +#include "graph/passes/mark_graph_unknown_status_pass.h" +#include "graph/passes/merge_pass.h" +#include "graph/passes/merge_input_memcpy_pass.h" +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "graph/passes/multi_batch_pass.h" +#include "graph/passes/next_iteration_pass.h" +#include "graph/passes/permute_pass.h" +#include "graph/passes/prune_pass.h" +#include "graph/passes/ref_identity_delete_op_pass.h" +#include "graph/passes/remove_same_const_pass.h" +#include "graph/passes/reshape_recovery_pass.h" +#include "graph/passes/reshape_remove_pass.h" +#include "graph/passes/same_transdata_breadth_fusion_pass.h" +#include "graph/passes/subgraph_pass.h" +#include "graph/passes/switch_data_edges_bypass.h" +#include "graph/passes/switch_dead_branch_elimination.h" +#include "graph/passes/switch_logic_remove_pass.h" +#include "graph/passes/switch_to_stream_switch_pass.h" +#include "graph/passes/transop_breadth_fusion_pass.h" +#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" +#include "graph/passes/transop_symmetry_elimination_pass.h" +#include "graph/passes/transop_without_reshape_fusion_pass.h" +#include "graph/passes/transpose_transdata_pass.h" +#include "graph/passes/useless_control_out_remove_pass.h" +#include "graph/passes/variable_op_pass.h" +#include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" +#include "graph/passes/end_of_sequence_add_control_pass.h" +#include "graph/passes/subexpression_migration_pass.h" +#include "graph/passes/subgraph_const_migration_pass.h" +#include "graph/passes/unused_args_clean_pass.h" +#include "graph/passes/global_step_insert_pass.h" +#include "graph/passes/memcpy_addr_async_pass.h" +#include "graph/passes/hccl_continuous_memcpy_pass.h" +#include "graph/build/label_allocator.h" +#include "graph/utils/tensor_adapter.h" +#include "inc/pass_manager.h" +#include "ir_build/atc_ir_common.h" +#include "graph/common/local_context.h" +#include "graph/common/omg_util.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "register/custom_pass_helper.h" +#include "graph/ops_stub.h" + +using namespace std; +using namespace testing; +using namespace ge; +using namespace domi; + +namespace { +const uint32_t kNotAdded = 0; +const uint32_t kStartAdd = 1; +const uint32_t kDoneAdded = 2; +} +class UtestGraphManagerTest : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +void CreateGraph(Graph &graph) { + TensorDesc desc(ge::Shape({1, 3, 224, 224})); + uint32_t size = desc.GetShape().GetShapeSize(); + desc.SetSize(size); + auto data = op::Data("Data").set_attr_index(0); + data.update_input_desc_data(desc); + data.update_output_desc_out(desc); + + auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); + + std::vector inputs{data}; + std::vector outputs{flatten}; + std::vector targets{flatten}; + // Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); +} + +TEST_F(UtestGraphManagerTest, set_and_get_add_graph_flag) { + GraphId graph_id = 1; + GraphManager graph_manager; + graph_manager.SetAddGraphCondition(graph_id, 1); + uint32_t res = graph_manager.GetAddGraphCondition(graph_id); + EXPECT_EQ(res, 1); +} + +TEST_F(UtestGraphManagerTest, test_add_graph_1) { + GraphId graph_id = 1; + GraphManager graph_manager; + // create graph + Graph graph("test_graph"); + CreateGraph(graph); + + std::map options; + OmgContext context; + Status status = graph_manager.AddGraph(graph_id, graph, options, context); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_add_graph_2) { + GraphId graph_id = 1; + GraphManager graph_manager; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node); + graph_manager.SetAddGraphCondition(graph_id, kDoneAdded); + Graph graph("test_graph"); + CreateGraph(graph); + std::map options; + OmgContext context; + Status status = graph_manager.AddGraph(graph_id, graph, options, context); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_add_graph_3) { + GraphId graph_id = 1; + GraphManager graph_manager; + Graph graph("test_graph"); + CreateGraph(graph); + + std::map options; + OmgContext context; + + std::future fut1 = std::async(std::launch::async, + &GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); + std::future fut2 = std::async(std::launch::async, + &GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); + fut1.wait(); + fut2.wait(); + Status status1 = fut1.get(); + Status status2 = fut2.get(); + EXPECT_EQ(status1, ge::SUCCESS); + EXPECT_EQ(status2, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_remove_graph_1) { + GraphId graph_id = 1; + GraphManager graph_manager; + GraphNodePtr graph_node = MakeShared(graph_id); + Status status = graph_manager.RemoveGraph(graph_id); + EXPECT_EQ(status, ge::GE_GRAPH_GRAPH_NOT_EXIST); + graph_manager.AddGraphNode(graph_id, graph_node); + graph_node->SetRunFlag(true); + status = graph_manager.RemoveGraph(graph_id); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_remove_graph_2) { + GraphId graph_id = 1; + GraphManager graph_manager; + GraphNodePtr graph_node = MakeShared(graph_id); + Graph graph("test_graph"); + CreateGraph(graph); + auto compute_graph = GraphUtils::GetComputeGraph(graph); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + auto model_manager = ModelManager::GetInstance(); + auto listener = MakeShared(); + shared_ptr davinci_model1 = MakeShared(1, listener); + davinci_model1->SetId(1); + shared_ptr davinci_model2 = MakeShared(2, listener); + davinci_model1->SetId(2); + model_manager->InsertModel(1, davinci_model1); + model_manager->InsertModel(2, davinci_model2); + ge_root_model->SetModelId(1); + ge_root_model->SetModelId(2); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + graph_manager.AddGraphNode(graph_id, graph_node); + Status status = graph_manager.RemoveGraph(graph_id); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_pre_run_thread) { + + GraphManager graph_manager; + graph_manager.thread_run_flag_ = true; + + GraphId graph_id = 1; + std::vector input_tensor; + uint64_t session_id = 0; + ErrorMessage::Context error_context; + GEThreadLocalContext context; + RunAsyncCallback callback; + // PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; + bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); + EXPECT_EQ(ret, true); + + GraphNodePtr graph_node = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node); + graph_manager.PreRunThread(&graph_manager); + // end with failed +} + +TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) { + + GraphManager graph_manager; + graph_manager.thread_run_flag_ = true; + + GraphId graph_id = 1; + GraphNodePtr graph_node_1 = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node_1); + graph_manager.IncreaseGraphCount(graph_id); + graph_manager.IncreaseGraphCount(graph_id); + graph_node_1->SetBuildFlag(true); + std::vector input_tensor; + uint64_t session_id = 0; + ErrorMessage::Context error_context; + GEThreadLocalContext context; + RunAsyncCallback callback; + // PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; + bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); + EXPECT_EQ(ret, true); + graph_id = 2; + GraphNodePtr graph_node_2 = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node_2); + ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); + EXPECT_EQ(ret, true); + graph_manager.PreRunThread(&graph_manager); + // end with failed +} + +TEST_F(UtestGraphManagerTest, test_check_and_release_memory) { + + GraphManager graph_manager; + GeModelPtr ge_model = make_shared(); + int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; + int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; + uint64_t session_id = 0; + ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size); + ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size); + ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id); + + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node); + graph_manager.IncreaseGraphCount(graph_id); + graph_manager.IncreaseGraphCount(graph_id); + + auto model_manager = ModelManager::GetInstance(); + auto listener = MakeShared(); + shared_ptr davinci_model1 = MakeShared(1, listener); + davinci_model1->SetId(1); + shared_ptr davinci_model2 = MakeShared(2, listener); + davinci_model1->SetId(2); + model_manager->InsertModel(1, davinci_model1); + model_manager->InsertModel(2, davinci_model2); + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + bool is_dynamic_shape = false; + (void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + ge_root_model->SetModelId(1); + ge_root_model->SetModelId(2); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { + // no need to build + GraphId graph_id = 1; + GraphManager graph_manager; + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + GraphManager::PreRunArgs arg; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetBuildFlag(true); + Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) { + // need build while buildflag is true, var format changed + GraphId graph_id = 1; + GraphManager graph_manager; + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + GraphManager::PreRunArgs arg; + arg.callback = [](Status, std::vector &) {}; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetBuildFlag(true); + graph_node->Lock(); + graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id); + Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + EXPECT_EQ(status, ge::PARAM_INVALID); +} + +TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { + // need build while buildflag is false, var format unchanged + GraphId graph_id = 1; + GraphManager graph_manager; + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + GraphManager::PreRunArgs arg; + arg.callback = [](Status, std::vector &) {}; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetBuildFlag(false); + graph_node->Lock(); + Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + EXPECT_NE(status, ge::SUCCESS); +}