Browse Source

!1201 support multi-thread in online infer

From: @HW_KK
Reviewed-by: @wqtshg
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
1deef05f49
24 changed files with 1146 additions and 188 deletions
  1. +77
    -3
      ge/graph/execute/graph_execute.cc
  2. +8
    -2
      ge/graph/execute/graph_execute.h
  3. +0
    -1
      ge/graph/load/graph_loader.cc
  4. +2
    -0
      ge/graph/load/model_manager/data_inputer.h
  5. +14
    -2
      ge/graph/load/model_manager/davinci_model.cc
  6. +11
    -0
      ge/graph/load/model_manager/davinci_model.h
  7. +19
    -7
      ge/graph/load/model_manager/model_manager.cc
  8. +2
    -2
      ge/graph/load/model_manager/model_manager.h
  9. +349
    -167
      ge/graph/manager/graph_manager.cc
  10. +42
    -0
      ge/graph/manager/graph_manager.h
  11. +9
    -0
      ge/graph/manager/graph_manager_utils.cc
  12. +16
    -0
      ge/graph/manager/graph_manager_utils.h
  13. +6
    -1
      ge/hybrid/executor/hybrid_model_async_executor.cc
  14. +8
    -0
      ge/hybrid/executor/hybrid_model_async_executor.h
  15. +23
    -0
      ge/hybrid/hybrid_davinci_model.cc
  16. +6
    -0
      ge/hybrid/hybrid_davinci_model.h
  17. +12
    -0
      ge/hybrid/hybrid_davinci_model_stub.cc
  18. +10
    -0
      ge/model/ge_model.cc
  19. +6
    -0
      ge/model/ge_model.h
  20. +18
    -2
      ge/model/ge_root_model.h
  21. +1
    -1
      metadef
  22. +3
    -0
      tests/ut/ge/CMakeLists.txt
  23. +129
    -0
      tests/ut/ge/graph/execute/graph_execute_unittest.cc
  24. +375
    -0
      tests/ut/ge/graph/manager/graph_manager_unittest.cc

+ 77
- 3
ge/graph/execute/graph_execute.cc View File

@@ -20,9 +20,12 @@
#include <string> #include <string>


#include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/model_manager.h"
#include "graph/load/model_manager/davinci_model.h"
#include "omm/csa_interact.h" #include "omm/csa_interact.h"


namespace ge { namespace ge {
using Uint32Pair = pair<uint32_t, uint32_t>;
const uint32_t kInvalidModelId = UINT32_MAX;
GraphExecutor::GraphExecutor() GraphExecutor::GraphExecutor()
: init_flag_(false), : init_flag_(false),
train_graph_flag_(false), train_graph_flag_(false),
@@ -380,7 +383,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro
} }


Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model,
const std::vector<InputTensorInfo> &input_tensor) {
const std::vector<InputTensorInfo> &input_tensor,
const RunAsyncCallback& callback) {
GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id);
if (graph_id != last_graph_id_) { if (graph_id != last_graph_id_) {
auto ret = FreeExecuteMemory(); auto ret = FreeExecuteMemory();
@@ -390,7 +394,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &
} }
last_graph_id_ = graph_id; last_graph_id_ = graph_id;
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor);
Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!");
return GE_GRAPH_SYNC_MODEL_FAILED; return GE_GRAPH_SYNC_MODEL_FAILED;
@@ -400,11 +404,81 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &
return SUCCESS; return SUCCESS;
} }


Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) {
bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) {
return lhs.second < rhs.second;
}

uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) {
std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId();
if (model_ids.empty()) {
return kInvalidModelId;
}
if (model_ids.size() == 1) {
return ge_root_model->GetModelId();
}
std::vector<Uint32Pair> model_id_to_loads;
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
for (auto model_id : model_ids) {
auto davinci_model = model_manager->GetModel(model_id);
auto hybrid_model = model_manager->GetHybridModel(model_id);
if (hybrid_model == nullptr) {
GE_CHECK_NOTNULL(davinci_model);
}
uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() :
davinci_model->GetDataInputerSize();
uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) :
static_cast<uint32_t>(davinci_model->GetRunningFlag());
uint32_t load = input_load + running_load;
if (load == 0) {
return model_id;
}
model_id_to_loads.emplace_back(model_id, load);
}
sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad);
if (model_id_to_loads.empty()) {
return kInvalidModelId;
}
return model_id_to_loads.begin()->first;
}

Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model,
const RunAsyncCallback &callback) {
auto model_manager = ge::ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
if (model_manager->IsNeedHybridLoad(*ge_root_model)) {
auto model = model_manager->GetHybridModel(model_id);
GE_CHECK_NOTNULL(model);
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
return FAILED;
}
} else {
auto model = model_manager->GetModel(model_id);
GE_CHECK_NOTNULL(model);
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
return FAILED;
}
}
return SUCCESS;
}

Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs,
const RunAsyncCallback &callback) {
uint32_t model_id = GetExecuteModelId(ge_root_model);
if (model_id == kInvalidModelId) {
GELOGE(INTERNAL_ERROR, "No valid model id.");
return INTERNAL_ERROR;
}
try { try {
auto model_manager = ge::ModelManager::GetInstance(); auto model_manager = ge::ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
GELOGI("RunAsync begin.model_id %u", model_id); GELOGI("RunAsync begin.model_id %u", model_id);
if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) {
GELOGE(FAILED, "RunAsync: SetCallBack for model fail");
return FAILED;
}


Status ret = model_manager->DataInputTensor(model_id, inputs); Status ret = model_manager->DataInputTensor(model_id, inputs);
if (ret != SUCCESS) { if (ret != SUCCESS) {


+ 8
- 2
ge/graph/execute/graph_execute.h View File

@@ -50,7 +50,7 @@ class GraphExecutor {
std::vector<GeTensor> &output_tensor); std::vector<GeTensor> &output_tensor);


ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model,
const std::vector<InputTensorInfo> &input_tensor);
const std::vector<InputTensorInfo> &input_tensor, const RunAsyncCallback &callback);


Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener);


@@ -116,6 +116,8 @@ class GraphExecutor {


static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info);


uint32_t GetExecuteModelId(const GeRootModelPtr &ge_root_model);

private: private:
Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data, Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data,
OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc); OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc);
@@ -123,7 +125,8 @@ class GraphExecutor {
Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
std::vector<GeTensor> &output_tensor); std::vector<GeTensor> &output_tensor);


Status AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &input_tensor);
Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &input_tensor,
const RunAsyncCallback &callback);


void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec,
uint32_t output_size); uint32_t output_size);
@@ -132,6 +135,9 @@ class GraphExecutor {


Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr); Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr);


static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model,
const RunAsyncCallback &callback);

bool init_flag_; bool init_flag_;


bool train_graph_flag_; bool train_graph_flag_;


+ 0
- 1
ge/graph/load/graph_loader.cc View File

@@ -63,7 +63,6 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr.");
return GE_GRAPH_PARAM_NULLPTR; return GE_GRAPH_PARAM_NULLPTR;
} }
model_id = ge_root_model_ptr->GetModelId();


auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);


+ 2
- 0
ge/graph/load/model_manager/data_inputer.h View File

@@ -134,6 +134,8 @@ class DataInputer {
/// ///
void Stop() { queue_.Stop(); } void Stop() { queue_.Stop(); }


uint32_t Size() { return queue_.Size(); }

private: private:
/// ///
/// @ingroup domi_ome /// @ingroup domi_ome


+ 14
- 2
ge/graph/load/model_manager/davinci_model.cc View File

@@ -2737,6 +2737,8 @@ void *DavinciModel::Run(DavinciModel *model) {


ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute);
while (model->RunFlag()) { while (model->RunFlag()) {
// Model hasn't truly started runing before received data
model->SetRunningFlag(false);
bool rslt_flg = true; bool rslt_flg = true;
if (model->GetDataInputer() == nullptr) { if (model->GetDataInputer() == nullptr) {
GELOGW("Data inputer is nullptr."); GELOGW("Data inputer is nullptr.");
@@ -2746,6 +2748,8 @@ void *DavinciModel::Run(DavinciModel *model) {


std::shared_ptr<InputDataWrapper> data_wrapper; std::shared_ptr<InputDataWrapper> data_wrapper;
Status ret = model->GetDataInputer()->Pop(data_wrapper); Status ret = model->GetDataInputer()->Pop(data_wrapper);
// Model run indeedly start after received data.
model->SetRunningFlag(true);
if (data_wrapper == nullptr || ret != SUCCESS) { if (data_wrapper == nullptr || ret != SUCCESS) {
GELOGI("data_wrapper is null!"); GELOGI("data_wrapper is null!");
continue; continue;
@@ -2832,7 +2836,9 @@ void *DavinciModel::Run(DavinciModel *model) {


model->iterator_count_++; model->iterator_count_++;
model->is_first_execute_ = false; model->is_first_execute_ = false;
GELOGI("run iterator count is %lu", model->iterator_count_);
// model run finished
model->SetRunningFlag(false);
GELOGI("run iterator count is %lu, model_id:%u", model->iterator_count_, model->model_id_);
} }


CsaInteract::GetInstance().WriteInternalErrorCode(); CsaInteract::GetInstance().WriteInternalErrorCode();
@@ -2890,7 +2896,7 @@ Status DavinciModel::ModelRunStart() {


error_context_ = ErrorManager::GetInstance().GetErrorContext(); error_context_ = ErrorManager::GetInstance().GetErrorContext();
CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this);
GELOGI("model tread create success, model id:%u.", model_id_);
GELOGI("model thread create success, model id:%u.", model_id_);
return SUCCESS; return SUCCESS;
} }


@@ -4340,4 +4346,10 @@ Status DavinciModel::InitL1DataDumperArgs() {
return SUCCESS; return SUCCESS;
} }


Status DavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) {
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get());
GE_CHECK_NOTNULL(listener);
listener->SetCallback(callback);
return SUCCESS;
}
} // namespace ge } // namespace ge

+ 11
- 0
ge/graph/load/model_manager/davinci_model.h View File

@@ -221,6 +221,11 @@ class DavinciModel {
/// ///
DataInputer *const GetDataInputer() const { return data_inputer_; } DataInputer *const GetDataInputer() const { return data_inputer_; }


uint32_t GetDataInputerSize() {
GE_CHECK_NOTNULL(data_inputer_);
return data_inputer_->Size();
}

// get Stream number // get Stream number
uint32_t StreamNum() const { return runtime_param_.stream_num; } uint32_t StreamNum() const { return runtime_param_.stream_num; }


@@ -560,6 +565,10 @@ class DavinciModel {
return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info);
} }


bool GetRunningFlag() const { return running_flg_; }
void SetRunningFlag(bool flag) { running_flg_ = flag; }
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback);

private: private:
// memory address of weights // memory address of weights
uint8_t *weights_mem_base_; uint8_t *weights_mem_base_;
@@ -924,6 +933,8 @@ class DavinciModel {
shared_ptr<ModelListener> listener_; shared_ptr<ModelListener> listener_;


bool run_flg_; bool run_flg_;
// check whether model is running with data
bool running_flg_ = false;


mutex mux_run_flg_; mutex mux_run_flg_;




+ 19
- 7
ge/graph/load/model_manager/model_manager.cc View File

@@ -330,6 +330,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null");
if (model_id == INVALID_MODEL_ID) { if (model_id == INVALID_MODEL_ID) {
GenModelId(&model_id); GenModelId(&model_id);
GELOGD("Generate new model_id:%u", model_id);
} }
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
string om_name; string om_name;
@@ -363,7 +364,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed.");
break;); break;);
GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign");

/// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail.
/// These session_ids come from the same model, so the values of session_id are the same.
/// Update session_id for infer in load model to avoid the same session_id.
if (!ge_root_model->GetTrainFlag()) {
uint64_t new_session_id;
ret = GenSessionId(new_session_id);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed.");
ret = davinci_model->UpdateSessionId(new_session_id);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed.");
ge_model->InsertSessionMap(model_id, new_session_id);
GELOGD("Update new session id: %lu.", new_session_id);
}
GE_TIMESTAMP_START(Init); GE_TIMESTAMP_START(Init);
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;);
GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit");
@@ -376,16 +388,16 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
return ret; return ret;
} }


void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) {
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id);
void ModelManager::InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model) {
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", model_id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_); std::lock_guard<std::recursive_mutex> lock(map_mutex_);
model_map_[id] = davinci_model;
model_map_[model_id] = davinci_model;
} }


void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) {
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id);
void ModelManager::InsertModel(uint32_t model_id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) {
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", model_id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_); std::lock_guard<std::recursive_mutex> lock(map_mutex_);
hybrid_model_map_[id] = hybrid_model;
hybrid_model_map_[model_id] = hybrid_model;
} }


Status ModelManager::DeleteModel(uint32_t id) { Status ModelManager::DeleteModel(uint32_t id) {


+ 2
- 2
ge/graph/load/model_manager/model_manager.h View File

@@ -330,8 +330,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
/// @ingroup domi_ome /// @ingroup domi_ome
/// @brief insert new model into model manager set /// @brief insert new model into model manager set
/// ///
void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model);
void InsertModel(uint32_t id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model);
void InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model);
void InsertModel(uint32_t model_id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model);


/// ///
/// @ingroup domi_ome /// @ingroup domi_ome


+ 349
- 167
ge/graph/manager/graph_manager.cc View File

@@ -121,6 +121,10 @@ const char *const kAIcoreEngine = "AIcoreEngine";
const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsGetNext = 0;
const int32_t kDynamicDimsTypeIsData = 1; const int32_t kDynamicDimsTypeIsData = 1;
const char *const kGetNextName = "IteratorV2"; const char *const kGetNextName = "IteratorV2";
const uint32_t kInitGraphCount = 1;
const uint32_t kNotAdded = 0;
const uint32_t kStartAdd = 1;
const uint32_t kDoneAdded = 2;


bool IsTailingOptimization() { bool IsTailingOptimization() {
string is_tailing_optimization_option; string is_tailing_optimization_option;
@@ -202,6 +206,8 @@ Status GraphManager::Initialize(const std::map<string, string> &options) {


graph_map_.clear(); graph_map_.clear();
cache_helper_map_.clear(); cache_helper_map_.clear();
graph_id_to_add_graph_cond_.clear();
graph_count_.clear();
init_flag_ = true; init_flag_ = true;


thread_run_flag_ = true; thread_run_flag_ = true;
@@ -211,6 +217,20 @@ Status GraphManager::Initialize(const std::map<string, string> &options) {
return SUCCESS; return SUCCESS;
} }


Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) {
Status ret = SUCCESS;
for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) {
uint32_t model_id = ge_root_model->GetAllModelId()[i];
GELOGI("Unload model %u.", model_id);
ret = GraphLoader::UnloadModel(model_id);
if (ret != SUCCESS) {
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id);
return ret;
}
}
return ret;
}

Status GraphManager::Finalize() { Status GraphManager::Finalize() {
if (!init_flag_) { if (!init_flag_) {
GELOGW("GraphManager has not been initialized."); GELOGW("GraphManager has not been initialized.");
@@ -241,7 +261,6 @@ Status GraphManager::Finalize() {
unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING;
continue; continue;
} }

// unload model // unload model
auto ge_root_model = graph_node->GetGeRootModel(); auto ge_root_model = graph_node->GetGeRootModel();
if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) {
@@ -251,15 +270,14 @@ Status GraphManager::Finalize() {
unload_model_ret = FAILED; unload_model_ret = FAILED;
continue; continue;
} }
ret = GraphLoader::UnloadModel(ge_root_model->GetModelId());
ret = UnloadModel(ge_root_model, iter->first);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first);
GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first);
unload_model_ret = ret; unload_model_ret = ret;
} }
rt_ret = rtDeviceReset(GetContext().DeviceId()); rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(),
iter->first);
GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first);
unload_model_ret = FAILED; unload_model_ret = FAILED;
continue; continue;
} }
@@ -274,6 +292,7 @@ Status GraphManager::Finalize() {
} }
graph_map_.clear(); graph_map_.clear();
cache_helper_map_.clear(); cache_helper_map_.clear();
graph_count_.clear();


// graph context // graph context
if (graph_context_ != nullptr) { if (graph_context_ != nullptr) {
@@ -326,35 +345,59 @@ Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) {
return SUCCESS; return SUCCESS;
} }


Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (HasGraphNode(graph_id)) {
REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id);
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
void GraphManager::SetAddGraphCondition(GraphId graph_id, uint32_t cond) {
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_);
graph_id_to_add_graph_cond_[graph_id] = cond;
GELOGD("Graph [id:%u] has been added.", graph_id);
}

uint32_t GraphManager::GetAddGraphCondition(GraphId graph_id) {
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_);
auto it = graph_id_to_add_graph_cond_.find(graph_id);
if (it != graph_id_to_add_graph_cond_.end()) {
return it->second;
} else {
GELOGD("Graph [id:%u] has not been added.", graph_id);
return kNotAdded;
} }
}


auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph != nullptr) {
compute_graph->SetGraphID(graph_id);
bool graph_has_been_added = false;
if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added)
&& graph_has_been_added) {
REPORT_INNER_ERROR("E19999", "Get Attr:%s from graph:%u fail",
ATTR_NAME_GRAPH_HAS_BEEN_ADDED.c_str(), graph_id);
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST,
"[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
}
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true);
compute_graph_ = compute_graph;
void GraphManager::RemoveAddGraphCondition(GraphId graph_id) {
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_);
auto it = graph_id_to_add_graph_cond_.find(graph_id);
if (it != graph_id_to_add_graph_cond_.end()) {
graph_id_to_add_graph_cond_.erase(it);
GELOGD("Successfully removed add_graph_cond of graph [id:%u].", graph_id);
} else { } else {
REPORT_INNER_ERROR("E19999", "compute_graph from graph:%u is nullptr, check invalid",
graph_id);
GELOGE(FAILED, "compute graph is null");
return FAILED;
GELOGD("Graph [id:%u] has not been added. no need to remove.", graph_id);
} }
}

Status GraphManager::CheckRepeatAdd(uint32_t graph_id, bool &is_added) {
uint32_t count = 0;
if (GetGraphCount(graph_id, count) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id);
return INTERNAL_ERROR;
}
// previous thread owns same graph_id has been in the middle of the AddGraph procession
if (count > 1 && GetAddGraphCondition(graph_id) == kStartAdd) {
std::unique_lock<std::mutex> lock(add_graph_mutex_);
GELOGD("Waitting for build end of previous thread.");
while (GetAddGraphCondition(graph_id) != kDoneAdded) {
add_graph_cv_.wait(lock);
}
GraphNodePtr graph_node;
Status ret = GetGraphNode(graph_id, graph_node);
if (ret != SUCCESS) {
GELOGE(ret, "[AddGraph] GetGraphNode failed, graph_id = %u.", graph_id);
return ret;
}
is_added = true;
}
return SUCCESS;
}

void GraphManager::SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id) {
std::string session_graph_id; std::string session_graph_id;
if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) {
session_graph_id = "-1_" + to_string(graph_id); session_graph_id = "-1_" + to_string(graph_id);
@@ -366,7 +409,24 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
} }
GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]");
} }
}

Status GraphManager::NotifyWaittingGraph(uint32_t graph_id) {
uint32_t count = 0;
if (GetGraphCount(graph_id, count) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id);
return INTERNAL_ERROR;
}
GELOGD("Add graph finished, graph_id:%u", graph_id);
if (count > 1) {
GELOGD("Finish addgraph, graph_id:%u, graph_count:%u, start to notify.", graph_id, count);
add_graph_cv_.notify_all();
}
return SUCCESS;
}


Status GraphManager::CreateGraphNode(uint32_t graph_id, const Graph &graph,
const std::map<std::string, std::string> &options) {
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
GE_IF_BOOL_EXEC(graph_node == nullptr, GE_IF_BOOL_EXEC(graph_node == nullptr,
REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u", REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u",
@@ -385,7 +445,62 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
ParseOption(options, TUNING_PATH, options_.tuning_path); ParseOption(options, TUNING_PATH, options_.tuning_path);
graph_node->SetGraph(graph_ptr); graph_node->SetGraph(graph_ptr);
graph_node->SetOptions(options); graph_node->SetOptions(options);
graph_node->IncreaseLoadCount();
AddGraphNode(graph_id, graph_node); AddGraphNode(graph_id, graph_node);
return SUCCESS;
}

Status GraphManager::SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options) {
CompilerStages &stages = GetCompilerStages(graph_id);
stages.preparer.SetOptions(options_);
Status status = stages.optimizer.SetOptions(options_);
if (status != SUCCESS) {
GELOGE(status, "Graph optimizer set options failed.");
return status;
}
stages.builder.SetOptions(options_);
return SUCCESS;
}

Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
IncreaseGraphCount(graph_id);
// validation for adding graphs of same graph_id in multi-thread secenario
// 1.previous thread owns same graph_id has finished the AddGraph procession
if (GetAddGraphCondition(graph_id) == kDoneAdded) {
GraphNodePtr graph_node;
if (GetGraphNode(graph_id, graph_node) != SUCCESS) {
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "Graph not exist while done adding previously, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_NOT_EXIST;
}
graph_node->IncreaseLoadCount();
return SUCCESS;
}
// In multi-thread scenario, former thread owns same graph_id has been
// in the middle of the AddGraph procession while following threads have to wait until
// done adding graph of the former graph, avoiding repeatively adding same graph.
bool is_added = false;
if (CheckRepeatAdd(graph_id, is_added) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "CheckRepeatAdd for graph[id:%u] failed.", graph_id);
return INTERNAL_ERROR;
}
// The former graph (from different thread) owns same graph id has been successfully added.
if (is_added) {
return SUCCESS;
}
// Do add graph
SetAddGraphCondition(graph_id, kStartAdd);
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
compute_graph->SetGraphID(graph_id);

SetSessionGraphId(compute_graph, graph_id);

if (CreateGraphNode(graph_id, graph, options) != SUCCESS) {
GELOGE(FAILED, "Failed to create graph_node.");
return FAILED;
}


AddLocalOmgContext(graph_id, omg_context); AddLocalOmgContext(graph_id, omg_context);
if (!options_.output_datatype.empty()) { if (!options_.output_datatype.empty()) {
@@ -396,16 +511,18 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }


CompilerStages &stages = GetCompilerStages(graph_id);
stages.preparer.SetOptions(options_);
Status status = stages.optimizer.SetOptions(options_);
if (status != SUCCESS) {
GELOGE(status, "Graph optimizer set options failed.");
return status;
if (SetStagesOptions(graph_id, options_) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Set stage options failed.");
return INTERNAL_ERROR;
} }
stages.builder.SetOptions(options_);


var_acc_ctrl_.AddGraph(graph_id, compute_graph); var_acc_ctrl_.AddGraph(graph_id, compute_graph);
SetAddGraphCondition(graph_id, kDoneAdded);
// There are threads waitting for adding same graph
if (NotifyWaittingGraph(graph_id) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "NotifyWaittingGraph failed.");
return INTERNAL_ERROR;
}
return SUCCESS; return SUCCESS;
} }


@@ -962,6 +1079,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) { if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model, graph_node); ret = LoadGraph(ge_root_model, graph_node);
} else { } else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model, graph_node); ret = LoadGraphAsync(ge_root_model, graph_node);
} }
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -976,6 +1094,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) { if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model_ptr, graph_node); ret = LoadGraph(ge_root_model_ptr, graph_node);
} else { } else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model_ptr, graph_node); ret = LoadGraphAsync(ge_root_model_ptr, graph_node);
} }
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -988,6 +1107,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId());
if (options_.run_graph_flag && ge_root_model != nullptr) { if (options_.run_graph_flag && ge_root_model != nullptr) {
ge_root_model->SetTrainFlag(GetTrainFlag());
// synchronization run graph with model // synchronization run graph with model
std::shared_ptr<GraphModelListener> model_listener = GetModelListener(); std::shared_ptr<GraphModelListener> model_listener = GetModelListener();
ModelIdInfo model_id_info; ModelIdInfo model_id_info;
@@ -1413,62 +1533,29 @@ bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load
} }


Status GraphManager::RemoveGraph(const GraphId &graph_id) { Status GraphManager::RemoveGraph(const GraphId &graph_id) {
auto it = to_be_deleted_graphs_.find(graph_id);
if (it != to_be_deleted_graphs_.end()) {
to_be_deleted_graphs_.erase(it);
}
GraphNodePtr graph_node = nullptr; GraphNodePtr graph_node = nullptr;
Status ret = GetGraphNode(graph_id, graph_node); Status ret = GetGraphNode(graph_id, graph_node);
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid",
graph_id);
if (ret != SUCCESS || graph_node == nullptr) {
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid when GraphManager %s",
graph_id, __FUNCTION__);
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id);
return GE_GRAPH_GRAPH_NOT_EXIST; return GE_GRAPH_GRAPH_NOT_EXIST;
} }

if ((graph_node == nullptr) || (graph_node->GetRunFlag())) {
REPORT_INNER_ERROR("E19999", "Graph:%u is running, can't be remove, check invalid",
graph_id);
GELOGE(GE_GRAPH_GRAPH_IS_RUNNING, "[GraphManager] Id %u is running, can't be deleted.", graph_id);
return GE_GRAPH_GRAPH_IS_RUNNING;
if (graph_node->GetRunFlag()) {
// only put graph into to-be-deleted list when exceptional scenario
to_be_deleted_graphs_.insert(graph_id);
GELOGI("[GraphManager] Trying to remove running graph[Id:%u], added into to_be_deleted_graphs_.", graph_id);
return SUCCESS;
} }


std::lock_guard<std::mutex> lock(unload_model_mutex_); std::lock_guard<std::mutex> lock(unload_model_mutex_);


Status middle_ret; Status middle_ret;
rtError_t rt_ret; rtError_t rt_ret;
const std::vector<SubGraphInfoPtr> &all_sub_graph = graph_node->GetAllSubGraph();
for (size_t i = 0; i < all_sub_graph.size(); ++i) {
// must free buffer firstly
middle_ret = all_sub_graph[i]->FreeInOutBuffer();
if (middle_ret != SUCCESS) {
GELOGE(middle_ret, "[GraphManager] RemoveGraph free mem failed, graph_id=%u.", graph_id);
ret = middle_ret;
}
if (all_sub_graph[i]->GeModelIsValid() && all_sub_graph[i]->GetModelIdInfo().model_id != INVALID_MODEL_ID) {
// unload model
GELOGI("UnloadModel via new ome.");
rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u",
GetContext().DeviceId(), graph_id);
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.",
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id);
ret = FAILED;
continue;
}
middle_ret = GraphLoader::UnloadModel(all_sub_graph[i]->GetModelIdInfo().model_id);
if (middle_ret != SUCCESS) {
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.",
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id);
ret = middle_ret;
}
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset fail, device_id:%u, graph_id:%u",
GetContext().DeviceId(), graph_id);
GELOGE(RT_FAILED, "[GraphManager:] unload model failed, modelId=%u, graphId=%u.",
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id);
ret = FAILED;
}
}
}
var_acc_ctrl_.RemoveGraph(graph_id); var_acc_ctrl_.RemoveGraph(graph_id);
RemoveGraphNode(graph_id); RemoveGraphNode(graph_id);


@@ -1476,7 +1563,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) {


auto ge_root_model = graph_node->GetGeRootModel(); auto ge_root_model = graph_node->GetGeRootModel();
if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) {
GELOGI("Unload model %u.", ge_root_model->GetModelId());
rt_ret = rtSetDevice(GetContext().DeviceId()); rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u",
@@ -1485,23 +1571,27 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) {
graph_id); graph_id);
return FAILED; return FAILED;
} }
middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId());
// same graph may be added for several times, different models were created separately,
// unload them respectively.
middle_ret = UnloadModel(ge_root_model, graph_id);
if (middle_ret != SUCCESS) { if (middle_ret != SUCCESS) {
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(),
graph_id);
REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check unload detail in GraphLoader %s",
graph_id, __FUNCTION__);
GELOGE(middle_ret, "[GraphManager:] unload model failed, graph_id=%u.", graph_id);
ret = middle_ret; ret = middle_ret;
} }
rt_ret = rtDeviceReset(GetContext().DeviceId()); rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u",
GetContext().DeviceId(), graph_id);
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(),
graph_id);
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u, when GraphManager %s",
GetContext().DeviceId(), graph_id, __FUNCTION__);
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id);
ret = FAILED; ret = FAILED;
} }
} }


RemoveCompilerStages(graph_id); RemoveCompilerStages(graph_id);
RemoveGraphCount(graph_id);
RemoveAddGraphCondition(graph_id);


GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id);
GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id);
@@ -2588,6 +2678,7 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr
Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId());
if (options_.run_graph_flag && ge_root_model != nullptr) { if (options_.run_graph_flag && ge_root_model != nullptr) {
ge_root_model->SetTrainFlag(GetTrainFlag());
// synchronization run graph with model // synchronization run graph with model
ModelIdInfo model_id_info; ModelIdInfo model_id_info;
bool is_unknown_shape = false; bool is_unknown_shape = false;
@@ -2604,9 +2695,9 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G
} }
} }
GE_TIMESTAMP_START(LoadGraph); GE_TIMESTAMP_START(LoadGraph);
GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_);
Status ret =
GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_);
auto listener = MakeShared<RunAsyncListener>();
GE_CHECK_NOTNULL(listener);
Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener);
GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync");
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed");
@@ -2620,6 +2711,52 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G
return SUCCESS; return SUCCESS;
} }


void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node,
const std::vector<uint32_t> &model_ids, uint32_t graph_id, uint64_t session_id) {
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, when GraphManager %s",
GetContext().DeviceId(), __FUNCTION__);
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, graphId=%u.", graph_id);
return;
}
for (auto model_id : model_ids) {
uint64_t max_memory_size = 0;
Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size);
if (result != SUCCESS) {
continue;
}
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id,
max_memory_size);
if (model_ids.size() > 1) {
result = ge_model->GetSessionId(model_id, session_id);
if (result != SUCCESS) {
GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id,
graph_id);
continue;
}
}
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0);
if (result != SUCCESS) {
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id,
graph_id);
}
result = GraphLoader::UnloadModel(model_id);
if (result != SUCCESS) {
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id);
}
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id);
}
graph_node->SetLoadFlag(false);
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s",
GetContext().DeviceId(), __FUNCTION__);
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id);
return;
}
}

Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) {
GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId());
int64_t value = 0; int64_t value = 0;
@@ -2665,6 +2802,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra
continue; continue;
} }
auto model_id = model->GetModelId(); auto model_id = model->GetModelId();
auto model_ids = model->GetAllModelId();
// unload model not release // unload model not release
bool is_unknown_shape = false; bool is_unknown_shape = false;
GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape));
@@ -2677,38 +2815,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra
GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id);
continue; continue;
} }
uint64_t max_memory_size = 0;
result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size);
if (result != SUCCESS) {
continue;
}
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id,
max_memory_size);
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u",
GetContext().DeviceId());
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id);
continue;
}
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0);
if (result != SUCCESS) {
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id,
graph_id);
}
result = GraphLoader::UnloadModel(model_id);
if (result != SUCCESS) {
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id);
}
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u",
GetContext().DeviceId());
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", model_id, graph_id);
continue;
}
it.second->SetLoadFlag(false);
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success and set LoadFlag to false.", graph_id, model_id);
ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id);
} }


return SUCCESS; return SUCCESS;
@@ -2849,6 +2956,38 @@ void GraphManager::ConstructGeInput(const vector<InputTensorInfo> &inputs, vecto
} }
} }


Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args,
GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) {
if (!graph_manager->IsGraphNeedBuild(graph_node)) {
ge_root_model = graph_node->GetGeRootModel();
return SUCCESS;
}
if (graph_node->GetBuildFlag()) {
ReturnError(graph_manager, args.callback, PARAM_INVALID,
"The graph " + std::to_string(graph_node->GetGraphId()) +
" need to re-build, you should remove it"
" from GE first, then AddGraph again and rebuild it.");
graph_node->Unlock();
return PARAM_INVALID;
}
// check need incre build.
GeModelPtr ge_model = nullptr;
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) {
std::vector<GeTensor> ge_inputs;
ConstructGeInput(args.input_tensor, ge_inputs);
Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id);
// release rts generate context
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId());
if (ret != SUCCESS) {
ReturnError(graph_manager, args.callback, ret, "PreRun Failed.");
return ret;
}
}
graph_node->SetBuildFlag(true);
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId());
return SUCCESS;
}

void GraphManager::PreRunThread(GraphManager *graph_manager) { void GraphManager::PreRunThread(GraphManager *graph_manager) {
if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) {
GELOGW("Set thread name failed."); GELOGW("Set thread name failed.");
@@ -2861,7 +3000,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
continue; continue;
} }


GELOGI("A new loop start.");
GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id);


ErrorManager::GetInstance().SetErrorContext(args.error_context); ErrorManager::GetInstance().SetErrorContext(args.error_context);
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
@@ -2877,7 +3016,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
"[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id));
return; return;
} }

// more than one graph owns same graph_id
uint32_t count = 0;
if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed.", args.graph_id);
return;
}
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency
if (count > 1 && graph_node->GetBuildFlag()) {
graph_node->Lock();
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id);
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times
graph_node->SetSemSize(count);
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context,
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback }));
GELOGI("[PreRunThread] Loop end. Start to run with cached build model.");
continue;
}
// Cannot be put ahead of the repeatively prerun judgement
graph_node->Lock(); graph_node->Lock();


if (graph_node->GetRunFlag()) { if (graph_node->GetRunFlag()) {
@@ -2909,46 +3065,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
// it will not execute graph preprocess, optimize, parition, build if the graph has built successful. // it will not execute graph preprocess, optimize, parition, build if the graph has built successful.
GELOGI("Start for run graph async."); GELOGI("Start for run graph async.");
GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
if (graph_manager->IsGraphNeedBuild(graph_node)) {
if (graph_node->GetBuildFlag()) {
ReturnError(graph_manager, args.callback, PARAM_INVALID,
"The graph " + std::to_string(graph_node->GetGraphId()) +
" need to re-build, you should remove it"
" from GE first, then AddGraph again and rebuild it.");
ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model);
if (ret != SUCCESS) {
graph_node->SetRunFlag(false);
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) {
ReturnError(graph_manager, args.callback, ret, "CheckIncreBuildAndPreRun Failed, thread exit..");
graph_node->Unlock(); graph_node->Unlock();
return; return;
} else {
ReturnError(graph_manager, graph_node, args.callback, ret,
"CheckIncreBuildAndPreRun Failed, keep geop continue!");
graph_node->Unlock();
continue;
} }

// check need incre build.
GeModelPtr ge_model = nullptr;
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) {
std::vector<GeTensor> ge_inputs;
ConstructGeInput(args.input_tensor, ge_inputs);
ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id);
// release rts generate context
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId());
if (ret != SUCCESS) {
graph_node->SetRunFlag(false);
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) {
ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit..");
graph_node->Unlock();
return;
} else {
ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!");
graph_node->Unlock();
continue;
}
}
}
graph_node->SetBuildFlag(true);
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId());
} else {
ge_root_model = graph_node->GetGeRootModel();
} }

graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context,
args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback }));
GELOGI("Loop end.");
GELOGI("[PreRunThread] Loop end.");
} }
} }


@@ -3051,16 +3185,13 @@ void GraphManager::RunThread(GraphManager *graph_manager) {
continue; continue;
} }


GELOGI("A new loop start.");
GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id);


ErrorManager::GetInstance().SetErrorContext(args.error_context); ErrorManager::GetInstance().SetErrorContext(args.error_context);
GetContext().SetSessionId(args.session_id); GetContext().SetSessionId(args.session_id);
GetThreadLocalContext() = args.context; GetThreadLocalContext() = args.context;
graph_manager->UpdateLocalOmgContext(args.graph_id); graph_manager->UpdateLocalOmgContext(args.graph_id);


if (args.graph_node->graph_run_async_listener_ != nullptr) {
args.graph_node->graph_run_async_listener_->SetCallback(args.callback);
}
Status ret; Status ret;
// parse inputs.dims to vector<vector<uint64_t>> dynamic_dims // parse inputs.dims to vector<vector<uint64_t>> dynamic_dims
ret = graph_manager->ParseInputsDims(args.input_tensor); ret = graph_manager->ParseInputsDims(args.input_tensor);
@@ -3070,8 +3201,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) {
return; return;
} }


args.graph_node->UpdateLoadFlag();
if (!args.graph_node->GetLoadFlag()) { if (!args.graph_node->GetLoadFlag()) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad);
args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag());
ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node);
if (ret != SUCCESS || args.ge_root_model == nullptr) { if (ret != SUCCESS || args.ge_root_model == nullptr) {
StopQueue(graph_manager); StopQueue(graph_manager);
@@ -3079,6 +3212,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) {
args.graph_node->Unlock(); args.graph_node->Unlock();
return; return;
} }
// control the times of graph loading in multi-thread scenario
args.graph_node->DecreaseLoadCount();
args.graph_node->IncreaseLoadRecord();

args.graph_node->SetLoadFlag(true); args.graph_node->SetLoadFlag(true);
GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(),
args.ge_root_model->GetModelId()); args.ge_root_model->GetModelId());
@@ -3093,9 +3230,9 @@ void GraphManager::RunThread(GraphManager *graph_manager) {
graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag);
} }


args.graph_node->SetRunFlag(false);
ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(),
args.input_tensor);
args.input_tensor, args.callback);
args.graph_node->SetRunFlag(false);
if (ret != SUCCESS) { if (ret != SUCCESS) {
ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit.");
args.graph_node->Unlock(); args.graph_node->Unlock();
@@ -3546,4 +3683,49 @@ void GraphManager::RemoveCompilerStages(GraphId graph_id) {
std::lock_guard<std::mutex> lock(member_mutex_); std::lock_guard<std::mutex> lock(member_mutex_);
compiler_stages_.erase(graph_id); compiler_stages_.erase(graph_id);
} }

void GraphManager::IncreaseGraphCount(GraphId graph_id) {
std::lock_guard<std::mutex> lock(graph_count_mutex_);
auto it = graph_count_.find(graph_id);
if (it == graph_count_.end()) {
graph_count_.insert({graph_id, kInitGraphCount});
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]);
} else {
++graph_count_[graph_id];
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]);
}
}

void GraphManager::RemoveGraphCount(GraphId graph_id) {
std::lock_guard<std::mutex> lock(graph_count_mutex_);
auto it = graph_count_.find(graph_id);
if (it == graph_count_.end()) {
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id);
} else {
GELOGD("RemoveGraphCount success, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]);
graph_count_.erase(it);
}
}

void GraphManager::DecreaseGraphCount(GraphId graph_id) {
std::lock_guard<std::mutex> lock(graph_count_mutex_);
auto it = graph_count_.find(graph_id);
if (it == graph_count_.end()) {
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id);
} else {
--it->second;
GELOGD("After DecreaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]);
}
}

Status GraphManager::GetGraphCount(GraphId graph_id, uint32_t &count) {
std::lock_guard<std::mutex> lock(graph_count_mutex_);
auto it = graph_count_.find(graph_id);
if (it == graph_count_.end()) {
GELOGW("Graph [id:%u] has not been added.", graph_id);
return FAILED;
}
count = it->second;
return SUCCESS;
}
} // namespace ge } // namespace ge

+ 42
- 0
ge/graph/manager/graph_manager.h View File

@@ -184,6 +184,20 @@ class GraphManager {


Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results); Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results);


void RemoveGraphCount(GraphId graph_id);

void IncreaseGraphCount(GraphId graph_id);

void DecreaseGraphCount(GraphId graph_id);

Status GetGraphCount(GraphId graph_id, uint32_t &count);

void SetAddGraphCondition(GraphId graph_id, uint32_t cond);

uint32_t GetAddGraphCondition(GraphId graph_id);

void RemoveAddGraphCondition(GraphId graph_id);

private: private:
struct CompilerStages { struct CompilerStages {
GraphPrepare preparer; GraphPrepare preparer;
@@ -381,6 +395,24 @@ class GraphManager {
CompilerStages &GetCompilerStages(GraphId graph_id); CompilerStages &GetCompilerStages(GraphId graph_id);
void RemoveCompilerStages(GraphId graph_id); void RemoveCompilerStages(GraphId graph_id);


static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node,
GeRootModelPtr &ge_root_model);

void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector<uint32_t> &model_ids,
uint32_t graph_id, uint64_t session_id);

Status CheckRepeatAdd(uint32_t graph_id, bool &is_added);

Status NotifyWaittingGraph(uint32_t graph_id);

Status CreateGraphNode(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options);

Status SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options);

Status UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id);

void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id);

std::atomic_bool thread_run_flag_; std::atomic_bool thread_run_flag_;
BlockingQueue<PreRunArgs> prerun_args_q_{}; BlockingQueue<PreRunArgs> prerun_args_q_{};
BlockingQueue<RunArgs> run_args_q_{}; BlockingQueue<RunArgs> run_args_q_{};
@@ -416,6 +448,16 @@ class GraphManager {


std::mutex member_mutex_; std::mutex member_mutex_;
std::mutex unload_model_mutex_; std::mutex unload_model_mutex_;
// avoid repeatively add same graph (owns same graph id)
std::mutex add_graph_mutex_;
std::mutex add_graph_cond_mutex_;
std::condition_variable add_graph_cv_;

std::map<GraphId, uint32_t> graph_id_to_add_graph_cond_;
// use for multi-thread online-infer scenario
std::set<GraphId> to_be_deleted_graphs_;
std::map<GraphId, uint32_t> graph_count_;
std::mutex graph_count_mutex_;
}; };
} // namespace ge } // namespace ge




+ 9
- 0
ge/graph/manager/graph_manager_utils.cc View File

@@ -60,6 +60,15 @@ void GraphNode::Unlock() {
sem_.Pop(unused); sem_.Pop(unused);
} }


void GraphNode::IncreaseLoadCount() {
std::unique_lock<std::mutex> lock(load_count_mu_);
if (load_record_ == kMaxLoadNum) {
GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum);
return;
}
++load_count_;
}

SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {}


SubGraphInfo::~SubGraphInfo() { SubGraphInfo::~SubGraphInfo() {


+ 16
- 0
ge/graph/manager/graph_manager_utils.h View File

@@ -55,6 +55,7 @@ using ConstGraphPtr = std::shared_ptr<const ge::Graph>;
using GraphPtr = std::shared_ptr<ge::Graph>; using GraphPtr = std::shared_ptr<ge::Graph>;


const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL;
const uint32_t kMaxLoadNum = 8;


struct ModelIdInfo { struct ModelIdInfo {
uint32_t model_id{INVALID_MODEL_ID}; uint32_t model_id{INVALID_MODEL_ID};
@@ -162,6 +163,8 @@ class GraphNode {
bool GetBuildFlag() const { return build_flag_; } bool GetBuildFlag() const { return build_flag_; }
void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; }
bool GetLoadFlag() const { return load_flag_; } bool GetLoadFlag() const { return load_flag_; }
// allow repeatively load graph owns same graph id
void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; }
void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; }
void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; }
GeModelPtr GetGeModel() const { return ge_model_; } GeModelPtr GetGeModel() const { return ge_model_; }
@@ -172,6 +175,13 @@ class GraphNode {
void Lock(); void Lock();
void Unlock(); void Unlock();


void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); }

uint32_t GetLoadCount() const { return load_count_; }
void IncreaseLoadCount();
void DecreaseLoadCount() { --load_count_; }
void IncreaseLoadRecord() { ++load_record_; }

// run graph asynchronous listener // run graph asynchronous listener
std::shared_ptr<RunAsyncListener> graph_run_async_listener_; std::shared_ptr<RunAsyncListener> graph_run_async_listener_;


@@ -184,11 +194,17 @@ class GraphNode {
GraphPtr graph_; GraphPtr graph_;
ComputeGraphPtr compute_graph_; ComputeGraphPtr compute_graph_;
bool build_flag_; bool build_flag_;
// load_flag_ is true if more than 1 model were loaded
bool load_flag_; bool load_flag_;
bool async_; bool async_;
GeModelPtr ge_model_; GeModelPtr ge_model_;
GeRootModelPtr ge_root_model_; GeRootModelPtr ge_root_model_;
BlockingQueue<uint8_t> sem_; BlockingQueue<uint8_t> sem_;
// consist with graph_count of same graph_id in graph_manager
uint32_t load_count_ = 0;
// total times of loading a graph with same graph_id.
uint32_t load_record_ = 0;
std::mutex load_count_mu_;
}; };


using GraphNodePtr = std::shared_ptr<GraphNode>; using GraphNodePtr = std::shared_ptr<GraphNode>;


+ 6
- 1
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -144,8 +144,12 @@ Status HybridModelAsyncExecutor::RunInternal() {
GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); });


while (run_flag_) { while (run_flag_) {
// Model has not indeedly started running before received data
SetRunningFlag(false);
std::shared_ptr<InputDataWrapper> data_wrapper; std::shared_ptr<InputDataWrapper> data_wrapper;
Status ret = data_inputer_->Pop(data_wrapper); Status ret = data_inputer_->Pop(data_wrapper);
// Model indeedly start running
SetRunningFlag(true);
if (data_wrapper == nullptr || ret != SUCCESS) { if (data_wrapper == nullptr || ret != SUCCESS) {
GELOGI("data_wrapper is null!, ret = %u", ret); GELOGI("data_wrapper is null!, ret = %u", ret);
continue; continue;
@@ -185,7 +189,8 @@ Status HybridModelAsyncExecutor::RunInternal() {


RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_);
iterator_count_++; iterator_count_++;
GELOGI("run iterator count is %lu", iterator_count_);
SetRunningFlag(false);
GELOGI("run iterator count is %lu, model_id:%u", iterator_count_, model_id_);
} }


CsaInteract::GetInstance().WriteInternalErrorCode(); CsaInteract::GetInstance().WriteInternalErrorCode();


+ 8
- 0
ge/hybrid/executor/hybrid_model_async_executor.h View File

@@ -55,6 +55,12 @@ class HybridModelAsyncExecutor {


Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data); Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data);


uint32_t GetDataInputerSize() { return data_inputer_->Size(); }

bool GetRunningFlag() const { return running_flag_; }

void SetRunningFlag(bool flag) { running_flag_ = flag; }

private: private:
Status InitInputDesc(); Status InitInputDesc();


@@ -84,6 +90,8 @@ class HybridModelAsyncExecutor {
uint32_t device_id_ = 0U; uint32_t device_id_ = 0U;
uint32_t model_id_ = 0U; uint32_t model_id_ = 0U;
std::atomic_bool run_flag_; std::atomic_bool run_flag_;
// check whether model is running with data
bool running_flag_ = false;
std::unique_ptr<DataInputer> data_inputer_; std::unique_ptr<DataInputer> data_inputer_;
std::unique_ptr<HybridModelExecutor> executor_; std::unique_ptr<HybridModelExecutor> executor_;
std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_;


+ 23
- 0
ge/hybrid/hybrid_davinci_model.cc View File

@@ -19,6 +19,7 @@
#include "hybrid/model/hybrid_model.h" #include "hybrid/model/hybrid_model.h"
#include "hybrid/executor/hybrid_model_async_executor.h" #include "hybrid/executor/hybrid_model_async_executor.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"
#include "graph/manager/graph_manager_utils.h"


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
@@ -108,6 +109,17 @@ class HybridDavinciModel::Impl {
model_.SetModelDescVersion(is_new_model_desc); model_.SetModelDescVersion(is_new_model_desc);
} }


uint32_t GetDataInputerSize() { return executor_.GetDataInputerSize(); }

bool GetRunningFlag() const { return executor_.GetRunningFlag(); }

Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback) {
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get());
GE_CHECK_NOTNULL(listener);
listener->SetCallback(callback);
return SUCCESS;
}

private: private:
std::shared_ptr<ModelListener> listener_; std::shared_ptr<ModelListener> listener_;
HybridModel model_; HybridModel model_;
@@ -222,5 +234,16 @@ uint64_t HybridDavinciModel::GetSessionId() {
GE_CHECK_NOTNULL(impl_); GE_CHECK_NOTNULL(impl_);
return impl_->GetSessionId(); return impl_->GetSessionId();
} }

uint32_t HybridDavinciModel::GetDataInputerSize() {
GE_CHECK_NOTNULL(impl_);
return impl_->GetDataInputerSize();
}

bool HybridDavinciModel::GetRunningFlag() const { return impl_->GetRunningFlag(); }

Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) {
return impl_->SetRunAsyncListenerCallback(callback);
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 6
- 0
ge/hybrid/hybrid_davinci_model.h View File

@@ -74,6 +74,12 @@ class HybridDavinciModel {


void SetModelDescVersion(bool is_new_model_desc); void SetModelDescVersion(bool is_new_model_desc);


uint32_t GetDataInputerSize();

bool GetRunningFlag() const;

Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback);

private: private:
HybridDavinciModel() = default; HybridDavinciModel() = default;
class Impl; class Impl;


+ 12
- 0
ge/hybrid/hybrid_davinci_model_stub.cc View File

@@ -68,6 +68,10 @@ uint64_t HybridDavinciModel::GetSessionId() {
return 0; return 0;
} }


uint32_t HybridDavinciModel::GetDataInputerSize() {
return 0;
}

Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) { Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) {
return UNSUPPORTED; return UNSUPPORTED;
} }
@@ -87,5 +91,13 @@ Status HybridDavinciModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &i


void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) {
} }

bool HybridDavinciModel::GetRunningFlag() const {
return false;
}

Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) {
return UNSUPPORTED;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 10
- 0
ge/model/ge_model.cc View File

@@ -85,4 +85,14 @@ ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; }
ConstProtoAttrMapHelper GeModel::GetAttrMap() const { ConstProtoAttrMapHelper GeModel::GetAttrMap() const {
return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
} }

Status GeModel::GetSessionId(uint32_t model_id, uint64_t &session_id) const {
auto it = model_id_to_session_id_map_.find(model_id);
if (it != model_id_to_session_id_map_.end()) {
session_id = it->second;
return SUCCESS;
}
GELOGW("No session id were found with model id [%u].", model_id);
return INTERNAL_ERROR;
}
} // namespace ge } // namespace ge

+ 6
- 0
ge/model/ge_model.h View File

@@ -71,6 +71,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder
void SetModelId(uint32_t model_id) { model_id_ = model_id; } void SetModelId(uint32_t model_id) { model_id_ = model_id; }
uint32_t GetModelId() const { return model_id_; } uint32_t GetModelId() const { return model_id_; }


Status GetSessionId(uint32_t model_id, uint64_t &session_id) const;
void InsertSessionMap(uint32_t model_id, uint64_t session_id) {
model_id_to_session_id_map_.insert({model_id, session_id});
}

protected: protected:
ConstProtoAttrMapHelper GetAttrMap() const override; ConstProtoAttrMapHelper GetAttrMap() const override;


@@ -90,6 +95,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder
std::string platform_version_; std::string platform_version_;
uint8_t platform_type_ = {0}; uint8_t platform_type_ = {0};
uint32_t model_id_ = INVALID_MODEL_ID; uint32_t model_id_ = INVALID_MODEL_ID;
std::map<uint32_t, uint64_t> model_id_to_session_id_map_;
}; };
} // namespace ge } // namespace ge
using GeModelPtr = std::shared_ptr<ge::GeModel>; using GeModelPtr = std::shared_ptr<ge::GeModel>;


+ 18
- 2
ge/model/ge_root_model.h View File

@@ -32,15 +32,31 @@ class GeRootModel {
return subgraph_instance_name_to_model_; return subgraph_instance_name_to_model_;
}; };


const ComputeGraphPtr &GetRootGraph() const { return root_graph_; };
void SetModelId(uint32_t model_id) { model_id_ = model_id; }
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }
void SetModelId(uint32_t model_id) {
model_id_ = model_id;
// cached for removement
model_ids_.emplace_back(model_id);
}
uint32_t GetModelId() const { return model_id_; } uint32_t GetModelId() const { return model_id_; }

std::vector<uint32_t> GetAllModelId() const { return model_ids_; }

Status CheckIsUnknownShape(bool &is_dynamic_shape); Status CheckIsUnknownShape(bool &is_dynamic_shape);

void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }

void SetTrainFlag(bool flag) { train_flag_ = flag; }

bool GetTrainFlag() const { return train_flag_; }

private: private:
ComputeGraphPtr root_graph_ = nullptr; ComputeGraphPtr root_graph_ = nullptr;
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_;
uint32_t model_id_ = 0; uint32_t model_id_ = 0;
// In multithread online secenario, same graph can owns different davinci_model for for concurrency
std::vector<uint32_t> model_ids_;
bool train_flag_ = false;
}; };
} // namespace ge } // namespace ge
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 1e88df1d6bfe60faae0aa9fa2d87f273b793aeb0
Subproject commit fcebf37d7428caf4e0bd6e6c3a4f8143f6eac8b7

+ 3
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -593,6 +593,7 @@ set(SINGLE_OP_SRC_FILES
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_pipeline_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" "${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc"
@@ -780,10 +781,12 @@ set(MULTI_PARTS_TEST_FILES
"graph/build/mem_assigner_unittest.cc" "graph/build/mem_assigner_unittest.cc"
"graph/build/task_generator_unittest.cc" "graph/build/task_generator_unittest.cc"
"graph/build/buffer_pool_mem_assigner_unittest.cc" "graph/build/buffer_pool_mem_assigner_unittest.cc"
"graph/execute/graph_execute_unittest.cc"
"graph/preprocess/graph_preprocess_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc"
"graph/manager/hcom_util_unittest.cc" "graph/manager/hcom_util_unittest.cc"
"graph/manager/graph_caching_allocator_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc"
"graph/partition/dynamic_shape_partition_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc"
"graph/manager/graph_manager_unittest.cc"
"session/omg_omg_unittest.cc" "session/omg_omg_unittest.cc"
) )




+ 129
- 0
tests/ut/ge/graph/execute/graph_execute_unittest.cc View File

@@ -0,0 +1,129 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <memory>

#define protected public
#define private public
#include "graph/execute/graph_execute.h"
#include "graph/load/model_manager/model_manager.h"
#include "graph/load/model_manager/davinci_model.h"
#include "omm/csa_interact.h"
#undef private
#undef public


#include <pthread.h>
#include <algorithm>
#include <future>
#include <set>
#include <sstream>
#include <string>
#include <thread>
#include <future>

using namespace std;
using namespace testing;
using namespace ge;
using namespace domi;

namespace ge {
namespace {
const uint32_t kInvalidModelId = UINT32_MAX;
}

class UtestGraphExecuteTest : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

TEST_F(UtestGraphExecuteTest, get_execute_model_id_invalid) {
GraphExecutor executor;
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph);
auto model_id = executor.GetExecuteModelId(ge_root_model);
EXPECT_EQ(model_id, kInvalidModelId);
}

TEST_F(UtestGraphExecuteTest, get_execute_model_id_1) {
GraphExecutor executor;
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph);
auto model_manager = ModelManager::GetInstance();
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr);
davinci_model1->SetId(1);
model_manager->InsertModel(1, davinci_model1);
ge_root_model->SetModelId(1);
auto model_id = executor.GetExecuteModelId(ge_root_model);
EXPECT_EQ(model_id, 1);
}

TEST_F(UtestGraphExecuteTest, get_execute_model_id_2) {
GraphExecutor executor;
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph);
auto model_manager = ModelManager::GetInstance();
// model1 with 2 load
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr);
davinci_model1->SetId(1);
davinci_model1->data_inputer_ = new DataInputer();
auto data = MakeShared<InputDataWrapper>();
davinci_model1->data_inputer_->Push(data);
davinci_model1->data_inputer_->Push(data);
model_manager->InsertModel(1, davinci_model1);
// model 2 with 3 load
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(1, nullptr);
davinci_model2->SetId(2);
davinci_model2->data_inputer_ = new DataInputer();
davinci_model2->data_inputer_->Push(data);
davinci_model2->data_inputer_->Push(data);
davinci_model2->data_inputer_->Push(data);
model_manager->InsertModel(2, davinci_model2);
// model 3 witH 1 load
shared_ptr<DavinciModel> davinci_model3 = MakeShared<DavinciModel>(1, nullptr);
davinci_model3->SetId(3);
davinci_model3->data_inputer_ = new DataInputer();
davinci_model3->data_inputer_->Push(data);
model_manager->InsertModel(3, davinci_model3);

ge_root_model->SetModelId(1);
ge_root_model->SetModelId(2);
ge_root_model->SetModelId(3);

auto model_id = executor.GetExecuteModelId(ge_root_model);
// model 3 is picked for having least loads
EXPECT_EQ(model_id, 3);
}

TEST_F(UtestGraphExecuteTest, test_set_callback) {
GraphExecutor executor;
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test");
// is_unknown_shape_graph_ = false
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph);
RunAsyncCallback callback = [](Status, std::vector<ge::OutputTensorInfo> &) {};

auto model_manager = ModelManager::GetInstance();
auto listener = MakeShared<RunAsyncListener>();
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener);
davinci_model1->SetId(1);
model_manager->InsertModel(1, davinci_model1);
auto status = executor.SetCallback(1, ge_root_model, callback);
EXPECT_EQ(status, SUCCESS);
}
} // namespace ge

+ 375
- 0
tests/ut/ge/graph/manager/graph_manager_unittest.cc View File

@@ -0,0 +1,375 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <memory>
#define protected public
#define private public
#include "graph/manager/graph_manager.h"
#include "graph/load/model_manager/model_manager.h"
#include "graph/load/model_manager/davinci_model.h"
#define const
#include "common/helper/model_cache_helper.h"
#undef const
#include "init/gelib.h"
#undef private
#undef public

#include <pthread.h>
#include <algorithm>
#include <future>
#include <set>
#include <sstream>
#include <string>
#include <thread>
#include <future>

#include "common/math/math_util.h"
#include "common/thread_pool.h"
#include "common/dump/dump_manager.h"
#include "analyzer/analyzer.h"
#include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h"
#include "graph/common/transop_util.h"
#include "graph/ge_context.h"
#include "graph/ge_global_options.h"
#include "graph/manager/util/rt_context_util.h"
#include "graph/partition/dynamic_shape_partition.h"
#include "graph/passes/enter_pass.h"
#include "graph/partition/stage_partition.h"
#include "graph/passes/addn_pass.h"
#include "graph/passes/bitcast_pass.h"
#include "graph/passes/assign_remove_pass.h"
#include "graph/passes/inplace_support_check_pass.h"
#include "graph/passes/atomic_addr_clean_pass.h"
#include "graph/passes/attach_stream_label_pass.h"
#include "graph/passes/cast_remove_pass.h"
#include "graph/passes/common_subexpression_elimination_pass.h"
#include "graph/passes/compile_nodes_pass.h"
#include "graph/passes/cond_remove_pass.h"
#include "graph/passes/constant_folding_pass.h"
#include "graph/passes/constant_fuse_same_pass.h"
#include "graph/passes/control_trigger_pass.h"
#include "graph/passes/ctrl_edge_transfer_pass.h"
#include "graph/passes/dimension_adjust_pass.h"
#include "graph/passes/dimension_compute_pass.h"
#include "graph/passes/flow_ctrl_pass.h"
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"
#include "graph/passes/identity_pass.h"
#include "graph/passes/input_output_connection_identify_pass.h"
#include "graph/passes/iterator_op_pass.h"
#include "graph/passes/link_gen_mask_nodes_pass.h"
#include "graph/passes/mark_graph_unknown_status_pass.h"
#include "graph/passes/merge_pass.h"
#include "graph/passes/merge_input_memcpy_pass.h"
#include "graph/passes/merge_to_stream_merge_pass.h"
#include "graph/passes/multi_batch_pass.h"
#include "graph/passes/next_iteration_pass.h"
#include "graph/passes/permute_pass.h"
#include "graph/passes/prune_pass.h"
#include "graph/passes/ref_identity_delete_op_pass.h"
#include "graph/passes/remove_same_const_pass.h"
#include "graph/passes/reshape_recovery_pass.h"
#include "graph/passes/reshape_remove_pass.h"
#include "graph/passes/same_transdata_breadth_fusion_pass.h"
#include "graph/passes/subgraph_pass.h"
#include "graph/passes/switch_data_edges_bypass.h"
#include "graph/passes/switch_dead_branch_elimination.h"
#include "graph/passes/switch_logic_remove_pass.h"
#include "graph/passes/switch_to_stream_switch_pass.h"
#include "graph/passes/transop_breadth_fusion_pass.h"
#include "graph/passes/transop_nearby_allreduce_fusion_pass.h"
#include "graph/passes/transop_symmetry_elimination_pass.h"
#include "graph/passes/transop_without_reshape_fusion_pass.h"
#include "graph/passes/transpose_transdata_pass.h"
#include "graph/passes/useless_control_out_remove_pass.h"
#include "graph/passes/variable_op_pass.h"
#include "graph/passes/variable_ref_delete_op_pass.h"
#include "graph/passes/variable_ref_useless_control_out_delete_pass.h"
#include "graph/passes/end_of_sequence_add_control_pass.h"
#include "graph/passes/subexpression_migration_pass.h"
#include "graph/passes/subgraph_const_migration_pass.h"
#include "graph/passes/unused_args_clean_pass.h"
#include "graph/passes/global_step_insert_pass.h"
#include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h"
#include "ir_build/atc_ir_common.h"
#include "graph/common/local_context.h"
#include "graph/common/omg_util.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "register/custom_pass_helper.h"
#include "graph/ops_stub.h"

using namespace std;
using namespace testing;
using namespace ge;
using namespace domi;

namespace {
const uint32_t kNotAdded = 0;
const uint32_t kStartAdd = 1;
const uint32_t kDoneAdded = 2;
}
class UtestGraphManagerTest : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

void CreateGraph(Graph &graph) {
TensorDesc desc(ge::Shape({1, 3, 224, 224}));
uint32_t size = desc.GetShape().GetShapeSize();
desc.SetSize(size);
auto data = op::Data("Data").set_attr_index(0);
data.update_input_desc_data(desc);
data.update_output_desc_out(desc);

auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out());

std::vector<Operator> inputs{data};
std::vector<Operator> outputs{flatten};
std::vector<Operator> targets{flatten};
// Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets);
}

TEST_F(UtestGraphManagerTest, set_and_get_add_graph_flag) {
GraphId graph_id = 1;
GraphManager graph_manager;
graph_manager.SetAddGraphCondition(graph_id, 1);
uint32_t res = graph_manager.GetAddGraphCondition(graph_id);
EXPECT_EQ(res, 1);
}

TEST_F(UtestGraphManagerTest, test_add_graph_1) {
GraphId graph_id = 1;
GraphManager graph_manager;
// create graph
Graph graph("test_graph");
CreateGraph(graph);

std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraph(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_2) {
GraphId graph_id = 1;
GraphManager graph_manager;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node);
graph_manager.SetAddGraphCondition(graph_id, kDoneAdded);
Graph graph("test_graph");
CreateGraph(graph);
std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraph(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_3) {
GraphId graph_id = 1;
GraphManager graph_manager;
Graph graph("test_graph");
CreateGraph(graph);

std::map<std::string, std::string> options;
OmgContext context;

std::future<Status> fut1 = std::async(std::launch::async,
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context);
std::future<Status> fut2 = std::async(std::launch::async,
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context);
fut1.wait();
fut2.wait();
Status status1 = fut1.get();
Status status2 = fut2.get();
EXPECT_EQ(status1, ge::SUCCESS);
EXPECT_EQ(status2, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_remove_graph_1) {
GraphId graph_id = 1;
GraphManager graph_manager;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
Status status = graph_manager.RemoveGraph(graph_id);
EXPECT_EQ(status, ge::GE_GRAPH_GRAPH_NOT_EXIST);
graph_manager.AddGraphNode(graph_id, graph_node);
graph_node->SetRunFlag(true);
status = graph_manager.RemoveGraph(graph_id);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_remove_graph_2) {
GraphId graph_id = 1;
GraphManager graph_manager;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
Graph graph("test_graph");
CreateGraph(graph);
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
auto model_manager = ModelManager::GetInstance();
auto listener = MakeShared<RunAsyncListener>();
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener);
davinci_model1->SetId(1);
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener);
davinci_model1->SetId(2);
model_manager->InsertModel(1, davinci_model1);
model_manager->InsertModel(2, davinci_model2);
ge_root_model->SetModelId(1);
ge_root_model->SetModelId(2);
graph_node->SetGeRootModel(ge_root_model);
graph_node->SetLoadFlag(true);
graph_manager.AddGraphNode(graph_id, graph_node);
Status status = graph_manager.RemoveGraph(graph_id);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_pre_run_thread) {
GraphManager graph_manager;
graph_manager.thread_run_flag_ = true;

GraphId graph_id = 1;
std::vector<ge::InputTensorInfo> input_tensor;
uint64_t session_id = 0;
ErrorMessage::Context error_context;
GEThreadLocalContext context;
RunAsyncCallback callback;
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback};
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback});
EXPECT_EQ(ret, true);

GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node);
graph_manager.PreRunThread(&graph_manager);
// end with failed
}

TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) {
GraphManager graph_manager;
graph_manager.thread_run_flag_ = true;

GraphId graph_id = 1;
GraphNodePtr graph_node_1 = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node_1);
graph_manager.IncreaseGraphCount(graph_id);
graph_manager.IncreaseGraphCount(graph_id);
graph_node_1->SetBuildFlag(true);
std::vector<ge::InputTensorInfo> input_tensor;
uint64_t session_id = 0;
ErrorMessage::Context error_context;
GEThreadLocalContext context;
RunAsyncCallback callback;
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback};
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback});
EXPECT_EQ(ret, true);
graph_id = 2;
GraphNodePtr graph_node_2 = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node_2);
ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback});
EXPECT_EQ(ret, true);
graph_manager.PreRunThread(&graph_manager);
// end with failed
}

TEST_F(UtestGraphManagerTest, test_check_and_release_memory) {
GraphManager graph_manager;
GeModelPtr ge_model = make_shared<GeModel>();
int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL;
int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL;
uint64_t session_id = 0;
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size);
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size);
ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id);

GraphId graph_id = 1;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node);
graph_manager.IncreaseGraphCount(graph_id);
graph_manager.IncreaseGraphCount(graph_id);

auto model_manager = ModelManager::GetInstance();
auto listener = MakeShared<RunAsyncListener>();
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener);
davinci_model1->SetId(1);
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener);
davinci_model1->SetId(2);
model_manager->InsertModel(1, davinci_model1);
model_manager->InsertModel(2, davinci_model2);
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
bool is_dynamic_shape = false;
(void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
ge_root_model->SetModelId(1);
ge_root_model->SetModelId(2);
graph_node->SetGeRootModel(ge_root_model);
graph_node->SetLoadFlag(true);
Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) {
// no need to build
GraphId graph_id = 1;
GraphManager graph_manager;
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
GraphManager::PreRunArgs arg;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_node->SetBuildFlag(true);
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) {
// need build while buildflag is true, var format changed
GraphId graph_id = 1;
GraphManager graph_manager;
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
GraphManager::PreRunArgs arg;
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {};
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_node->SetBuildFlag(true);
graph_node->Lock();
graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id);
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model);
EXPECT_EQ(status, ge::PARAM_INVALID);
}

TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) {
// need build while buildflag is false, var format unchanged
GraphId graph_id = 1;
GraphManager graph_manager;
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
GraphManager::PreRunArgs arg;
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {};
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_node->SetBuildFlag(false);
graph_node->Lock();
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model);
EXPECT_NE(status, ge::SUCCESS);
}

Loading…
Cancel
Save