From: @HW_KK Reviewed-by: @wqtshg Signed-off-by:tags/v1.3.0
@@ -20,9 +20,12 @@ | |||||
#include <string> | #include <string> | ||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#include "omm/csa_interact.h" | #include "omm/csa_interact.h" | ||||
namespace ge { | namespace ge { | ||||
using Uint32Pair = pair<uint32_t, uint32_t>; | |||||
const uint32_t kInvalidModelId = UINT32_MAX; | |||||
GraphExecutor::GraphExecutor() | GraphExecutor::GraphExecutor() | ||||
: init_flag_(false), | : init_flag_(false), | ||||
train_graph_flag_(false), | train_graph_flag_(false), | ||||
@@ -380,7 +383,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro | |||||
} | } | ||||
Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | ||||
const std::vector<InputTensorInfo> &input_tensor) { | |||||
const std::vector<InputTensorInfo> &input_tensor, | |||||
const RunAsyncCallback& callback) { | |||||
GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | ||||
if (graph_id != last_graph_id_) { | if (graph_id != last_graph_id_) { | ||||
auto ret = FreeExecuteMemory(); | auto ret = FreeExecuteMemory(); | ||||
@@ -390,7 +394,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||||
} | } | ||||
last_graph_id_ = graph_id; | last_graph_id_ = graph_id; | ||||
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | ||||
Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); | |||||
Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | ||||
return GE_GRAPH_SYNC_MODEL_FAILED; | return GE_GRAPH_SYNC_MODEL_FAILED; | ||||
@@ -400,11 +404,81 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) { | |||||
bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) { | |||||
return lhs.second < rhs.second; | |||||
} | |||||
uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) { | |||||
std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId(); | |||||
if (model_ids.empty()) { | |||||
return kInvalidModelId; | |||||
} | |||||
if (model_ids.size() == 1) { | |||||
return ge_root_model->GetModelId(); | |||||
} | |||||
std::vector<Uint32Pair> model_id_to_loads; | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
GE_CHECK_NOTNULL(model_manager); | |||||
for (auto model_id : model_ids) { | |||||
auto davinci_model = model_manager->GetModel(model_id); | |||||
auto hybrid_model = model_manager->GetHybridModel(model_id); | |||||
if (hybrid_model == nullptr) { | |||||
GE_CHECK_NOTNULL(davinci_model); | |||||
} | |||||
uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() : | |||||
davinci_model->GetDataInputerSize(); | |||||
uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) : | |||||
static_cast<uint32_t>(davinci_model->GetRunningFlag()); | |||||
uint32_t load = input_load + running_load; | |||||
if (load == 0) { | |||||
return model_id; | |||||
} | |||||
model_id_to_loads.emplace_back(model_id, load); | |||||
} | |||||
sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad); | |||||
if (model_id_to_loads.empty()) { | |||||
return kInvalidModelId; | |||||
} | |||||
return model_id_to_loads.begin()->first; | |||||
} | |||||
Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||||
const RunAsyncCallback &callback) { | |||||
auto model_manager = ge::ModelManager::GetInstance(); | |||||
GE_CHECK_NOTNULL(model_manager); | |||||
if (model_manager->IsNeedHybridLoad(*ge_root_model)) { | |||||
auto model = model_manager->GetHybridModel(model_id); | |||||
GE_CHECK_NOTNULL(model); | |||||
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||||
GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||||
return FAILED; | |||||
} | |||||
} else { | |||||
auto model = model_manager->GetModel(model_id); | |||||
GE_CHECK_NOTNULL(model); | |||||
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||||
GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs, | |||||
const RunAsyncCallback &callback) { | |||||
uint32_t model_id = GetExecuteModelId(ge_root_model); | |||||
if (model_id == kInvalidModelId) { | |||||
GELOGE(INTERNAL_ERROR, "No valid model id."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
try { | try { | ||||
auto model_manager = ge::ModelManager::GetInstance(); | auto model_manager = ge::ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
GELOGI("RunAsync begin.model_id %u", model_id); | GELOGI("RunAsync begin.model_id %u", model_id); | ||||
if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) { | |||||
GELOGE(FAILED, "RunAsync: SetCallBack for model fail"); | |||||
return FAILED; | |||||
} | |||||
Status ret = model_manager->DataInputTensor(model_id, inputs); | Status ret = model_manager->DataInputTensor(model_id, inputs); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -50,7 +50,7 @@ class GraphExecutor { | |||||
std::vector<GeTensor> &output_tensor); | std::vector<GeTensor> &output_tensor); | ||||
ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | ||||
const std::vector<InputTensorInfo> &input_tensor); | |||||
const std::vector<InputTensorInfo> &input_tensor, const RunAsyncCallback &callback); | |||||
Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | ||||
@@ -116,6 +116,8 @@ class GraphExecutor { | |||||
static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ||||
uint32_t GetExecuteModelId(const GeRootModelPtr &ge_root_model); | |||||
private: | private: | ||||
Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data, | Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data, | ||||
OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc); | OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc); | ||||
@@ -123,7 +125,8 @@ class GraphExecutor { | |||||
Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | ||||
std::vector<GeTensor> &output_tensor); | std::vector<GeTensor> &output_tensor); | ||||
Status AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &input_tensor); | |||||
Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &input_tensor, | |||||
const RunAsyncCallback &callback); | |||||
void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | ||||
uint32_t output_size); | uint32_t output_size); | ||||
@@ -132,6 +135,9 @@ class GraphExecutor { | |||||
Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr); | Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr); | ||||
static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||||
const RunAsyncCallback &callback); | |||||
bool init_flag_; | bool init_flag_; | ||||
bool train_graph_flag_; | bool train_graph_flag_; | ||||
@@ -63,7 +63,6 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | ||||
return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
} | } | ||||
model_id = ge_root_model_ptr->GetModelId(); | |||||
auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
@@ -134,6 +134,8 @@ class DataInputer { | |||||
/// | /// | ||||
void Stop() { queue_.Stop(); } | void Stop() { queue_.Stop(); } | ||||
uint32_t Size() { return queue_.Size(); } | |||||
private: | private: | ||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
@@ -2737,6 +2737,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); | ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); | ||||
while (model->RunFlag()) { | while (model->RunFlag()) { | ||||
// Model hasn't truly started runing before received data | |||||
model->SetRunningFlag(false); | |||||
bool rslt_flg = true; | bool rslt_flg = true; | ||||
if (model->GetDataInputer() == nullptr) { | if (model->GetDataInputer() == nullptr) { | ||||
GELOGW("Data inputer is nullptr."); | GELOGW("Data inputer is nullptr."); | ||||
@@ -2746,6 +2748,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
std::shared_ptr<InputDataWrapper> data_wrapper; | std::shared_ptr<InputDataWrapper> data_wrapper; | ||||
Status ret = model->GetDataInputer()->Pop(data_wrapper); | Status ret = model->GetDataInputer()->Pop(data_wrapper); | ||||
// Model run indeedly start after received data. | |||||
model->SetRunningFlag(true); | |||||
if (data_wrapper == nullptr || ret != SUCCESS) { | if (data_wrapper == nullptr || ret != SUCCESS) { | ||||
GELOGI("data_wrapper is null!"); | GELOGI("data_wrapper is null!"); | ||||
continue; | continue; | ||||
@@ -2832,7 +2836,9 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
model->iterator_count_++; | model->iterator_count_++; | ||||
model->is_first_execute_ = false; | model->is_first_execute_ = false; | ||||
GELOGI("run iterator count is %lu", model->iterator_count_); | |||||
// model run finished | |||||
model->SetRunningFlag(false); | |||||
GELOGI("run iterator count is %lu, model_id:%u", model->iterator_count_, model->model_id_); | |||||
} | } | ||||
CsaInteract::GetInstance().WriteInternalErrorCode(); | CsaInteract::GetInstance().WriteInternalErrorCode(); | ||||
@@ -2890,7 +2896,7 @@ Status DavinciModel::ModelRunStart() { | |||||
error_context_ = ErrorManager::GetInstance().GetErrorContext(); | error_context_ = ErrorManager::GetInstance().GetErrorContext(); | ||||
CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); | CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); | ||||
GELOGI("model tread create success, model id:%u.", model_id_); | |||||
GELOGI("model thread create success, model id:%u.", model_id_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -4340,4 +4346,10 @@ Status DavinciModel::InitL1DataDumperArgs() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||||
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||||
GE_CHECK_NOTNULL(listener); | |||||
listener->SetCallback(callback); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -221,6 +221,11 @@ class DavinciModel { | |||||
/// | /// | ||||
DataInputer *const GetDataInputer() const { return data_inputer_; } | DataInputer *const GetDataInputer() const { return data_inputer_; } | ||||
uint32_t GetDataInputerSize() { | |||||
GE_CHECK_NOTNULL(data_inputer_); | |||||
return data_inputer_->Size(); | |||||
} | |||||
// get Stream number | // get Stream number | ||||
uint32_t StreamNum() const { return runtime_param_.stream_num; } | uint32_t StreamNum() const { return runtime_param_.stream_num; } | ||||
@@ -560,6 +565,10 @@ class DavinciModel { | |||||
return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); | return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); | ||||
} | } | ||||
bool GetRunningFlag() const { return running_flg_; } | |||||
void SetRunningFlag(bool flag) { running_flg_ = flag; } | |||||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||||
private: | private: | ||||
// memory address of weights | // memory address of weights | ||||
uint8_t *weights_mem_base_; | uint8_t *weights_mem_base_; | ||||
@@ -924,6 +933,8 @@ class DavinciModel { | |||||
shared_ptr<ModelListener> listener_; | shared_ptr<ModelListener> listener_; | ||||
bool run_flg_; | bool run_flg_; | ||||
// check whether model is running with data | |||||
bool running_flg_ = false; | |||||
mutex mux_run_flg_; | mutex mux_run_flg_; | ||||
@@ -330,6 +330,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | ||||
if (model_id == INVALID_MODEL_ID) { | if (model_id == INVALID_MODEL_ID) { | ||||
GenModelId(&model_id); | GenModelId(&model_id); | ||||
GELOGD("Generate new model_id:%u", model_id); | |||||
} | } | ||||
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | ||||
string om_name; | string om_name; | ||||
@@ -363,7 +364,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); | GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); | ||||
break;); | break;); | ||||
GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); | GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); | ||||
/// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. | |||||
/// These session_ids come from the same model, so the values of session_id are the same. | |||||
/// Update session_id for infer in load model to avoid the same session_id. | |||||
if (!ge_root_model->GetTrainFlag()) { | |||||
uint64_t new_session_id; | |||||
ret = GenSessionId(new_session_id); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); | |||||
ret = davinci_model->UpdateSessionId(new_session_id); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); | |||||
ge_model->InsertSessionMap(model_id, new_session_id); | |||||
GELOGD("Update new session id: %lu.", new_session_id); | |||||
} | |||||
GE_TIMESTAMP_START(Init); | GE_TIMESTAMP_START(Init); | ||||
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); | GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); | ||||
GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); | GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); | ||||
@@ -376,16 +388,16 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
return ret; | return ret; | ||||
} | } | ||||
void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); | |||||
void ModelManager::InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", model_id); | |||||
std::lock_guard<std::recursive_mutex> lock(map_mutex_); | std::lock_guard<std::recursive_mutex> lock(map_mutex_); | ||||
model_map_[id] = davinci_model; | |||||
model_map_[model_id] = davinci_model; | |||||
} | } | ||||
void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||||
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | |||||
void ModelManager::InsertModel(uint32_t model_id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||||
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", model_id); | |||||
std::lock_guard<std::recursive_mutex> lock(map_mutex_); | std::lock_guard<std::recursive_mutex> lock(map_mutex_); | ||||
hybrid_model_map_[id] = hybrid_model; | |||||
hybrid_model_map_[model_id] = hybrid_model; | |||||
} | } | ||||
Status ModelManager::DeleteModel(uint32_t id) { | Status ModelManager::DeleteModel(uint32_t id) { | ||||
@@ -330,8 +330,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief insert new model into model manager set | /// @brief insert new model into model manager set | ||||
/// | /// | ||||
void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model); | |||||
void InsertModel(uint32_t id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||||
void InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model); | |||||
void InsertModel(uint32_t model_id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
@@ -121,6 +121,10 @@ const char *const kAIcoreEngine = "AIcoreEngine"; | |||||
const int32_t kDynamicDimsTypeIsGetNext = 0; | const int32_t kDynamicDimsTypeIsGetNext = 0; | ||||
const int32_t kDynamicDimsTypeIsData = 1; | const int32_t kDynamicDimsTypeIsData = 1; | ||||
const char *const kGetNextName = "IteratorV2"; | const char *const kGetNextName = "IteratorV2"; | ||||
const uint32_t kInitGraphCount = 1; | |||||
const uint32_t kNotAdded = 0; | |||||
const uint32_t kStartAdd = 1; | |||||
const uint32_t kDoneAdded = 2; | |||||
bool IsTailingOptimization() { | bool IsTailingOptimization() { | ||||
string is_tailing_optimization_option; | string is_tailing_optimization_option; | ||||
@@ -202,6 +206,8 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||||
graph_map_.clear(); | graph_map_.clear(); | ||||
cache_helper_map_.clear(); | cache_helper_map_.clear(); | ||||
graph_id_to_add_graph_cond_.clear(); | |||||
graph_count_.clear(); | |||||
init_flag_ = true; | init_flag_ = true; | ||||
thread_run_flag_ = true; | thread_run_flag_ = true; | ||||
@@ -211,6 +217,20 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) { | |||||
Status ret = SUCCESS; | |||||
for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { | |||||
uint32_t model_id = ge_root_model->GetAllModelId()[i]; | |||||
GELOGI("Unload model %u.", model_id); | |||||
ret = GraphLoader::UnloadModel(model_id); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||||
return ret; | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
Status GraphManager::Finalize() { | Status GraphManager::Finalize() { | ||||
if (!init_flag_) { | if (!init_flag_) { | ||||
GELOGW("GraphManager has not been initialized."); | GELOGW("GraphManager has not been initialized."); | ||||
@@ -241,7 +261,6 @@ Status GraphManager::Finalize() { | |||||
unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; | unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; | ||||
continue; | continue; | ||||
} | } | ||||
// unload model | // unload model | ||||
auto ge_root_model = graph_node->GetGeRootModel(); | auto ge_root_model = graph_node->GetGeRootModel(); | ||||
if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { | if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { | ||||
@@ -251,15 +270,14 @@ Status GraphManager::Finalize() { | |||||
unload_model_ret = FAILED; | unload_model_ret = FAILED; | ||||
continue; | continue; | ||||
} | } | ||||
ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||||
ret = UnloadModel(ge_root_model, iter->first); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); | |||||
GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); | |||||
unload_model_ret = ret; | unload_model_ret = ret; | ||||
} | } | ||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||||
iter->first); | |||||
GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first); | |||||
unload_model_ret = FAILED; | unload_model_ret = FAILED; | ||||
continue; | continue; | ||||
} | } | ||||
@@ -274,6 +292,7 @@ Status GraphManager::Finalize() { | |||||
} | } | ||||
graph_map_.clear(); | graph_map_.clear(); | ||||
cache_helper_map_.clear(); | cache_helper_map_.clear(); | ||||
graph_count_.clear(); | |||||
// graph context | // graph context | ||||
if (graph_context_ != nullptr) { | if (graph_context_ != nullptr) { | ||||
@@ -326,35 +345,59 @@ Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
const std::map<std::string, std::string> &options, | |||||
const OmgContext &omg_context) { | |||||
if (HasGraphNode(graph_id)) { | |||||
REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id); | |||||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); | |||||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||||
void GraphManager::SetAddGraphCondition(GraphId graph_id, uint32_t cond) { | |||||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||||
graph_id_to_add_graph_cond_[graph_id] = cond; | |||||
GELOGD("Graph [id:%u] has been added.", graph_id); | |||||
} | |||||
uint32_t GraphManager::GetAddGraphCondition(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||||
auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||||
if (it != graph_id_to_add_graph_cond_.end()) { | |||||
return it->second; | |||||
} else { | |||||
GELOGD("Graph [id:%u] has not been added.", graph_id); | |||||
return kNotAdded; | |||||
} | } | ||||
} | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
if (compute_graph != nullptr) { | |||||
compute_graph->SetGraphID(graph_id); | |||||
bool graph_has_been_added = false; | |||||
if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) | |||||
&& graph_has_been_added) { | |||||
REPORT_INNER_ERROR("E19999", "Get Attr:%s from graph:%u fail", | |||||
ATTR_NAME_GRAPH_HAS_BEEN_ADDED.c_str(), graph_id); | |||||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, | |||||
"[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id); | |||||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||||
} | |||||
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); | |||||
compute_graph_ = compute_graph; | |||||
void GraphManager::RemoveAddGraphCondition(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||||
auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||||
if (it != graph_id_to_add_graph_cond_.end()) { | |||||
graph_id_to_add_graph_cond_.erase(it); | |||||
GELOGD("Successfully removed add_graph_cond of graph [id:%u].", graph_id); | |||||
} else { | } else { | ||||
REPORT_INNER_ERROR("E19999", "compute_graph from graph:%u is nullptr, check invalid", | |||||
graph_id); | |||||
GELOGE(FAILED, "compute graph is null"); | |||||
return FAILED; | |||||
GELOGD("Graph [id:%u] has not been added. no need to remove.", graph_id); | |||||
} | } | ||||
} | |||||
Status GraphManager::CheckRepeatAdd(uint32_t graph_id, bool &is_added) { | |||||
uint32_t count = 0; | |||||
if (GetGraphCount(graph_id, count) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// previous thread owns same graph_id has been in the middle of the AddGraph procession | |||||
if (count > 1 && GetAddGraphCondition(graph_id) == kStartAdd) { | |||||
std::unique_lock<std::mutex> lock(add_graph_mutex_); | |||||
GELOGD("Waitting for build end of previous thread."); | |||||
while (GetAddGraphCondition(graph_id) != kDoneAdded) { | |||||
add_graph_cv_.wait(lock); | |||||
} | |||||
GraphNodePtr graph_node; | |||||
Status ret = GetGraphNode(graph_id, graph_node); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[AddGraph] GetGraphNode failed, graph_id = %u.", graph_id); | |||||
return ret; | |||||
} | |||||
is_added = true; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void GraphManager::SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id) { | |||||
std::string session_graph_id; | std::string session_graph_id; | ||||
if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { | if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { | ||||
session_graph_id = "-1_" + to_string(graph_id); | session_graph_id = "-1_" + to_string(graph_id); | ||||
@@ -366,7 +409,24 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
} | } | ||||
GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); | GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); | ||||
} | } | ||||
} | |||||
Status GraphManager::NotifyWaittingGraph(uint32_t graph_id) { | |||||
uint32_t count = 0; | |||||
if (GetGraphCount(graph_id, count) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GELOGD("Add graph finished, graph_id:%u", graph_id); | |||||
if (count > 1) { | |||||
GELOGD("Finish addgraph, graph_id:%u, graph_count:%u, start to notify.", graph_id, count); | |||||
add_graph_cv_.notify_all(); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::CreateGraphNode(uint32_t graph_id, const Graph &graph, | |||||
const std::map<std::string, std::string> &options) { | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | ||||
GE_IF_BOOL_EXEC(graph_node == nullptr, | GE_IF_BOOL_EXEC(graph_node == nullptr, | ||||
REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u", | REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u", | ||||
@@ -385,7 +445,62 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
ParseOption(options, TUNING_PATH, options_.tuning_path); | ParseOption(options, TUNING_PATH, options_.tuning_path); | ||||
graph_node->SetGraph(graph_ptr); | graph_node->SetGraph(graph_ptr); | ||||
graph_node->SetOptions(options); | graph_node->SetOptions(options); | ||||
graph_node->IncreaseLoadCount(); | |||||
AddGraphNode(graph_id, graph_node); | AddGraphNode(graph_id, graph_node); | ||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options) { | |||||
CompilerStages &stages = GetCompilerStages(graph_id); | |||||
stages.preparer.SetOptions(options_); | |||||
Status status = stages.optimizer.SetOptions(options_); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Graph optimizer set options failed."); | |||||
return status; | |||||
} | |||||
stages.builder.SetOptions(options_); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
const std::map<std::string, std::string> &options, | |||||
const OmgContext &omg_context) { | |||||
IncreaseGraphCount(graph_id); | |||||
// validation for adding graphs of same graph_id in multi-thread secenario | |||||
// 1.previous thread owns same graph_id has finished the AddGraph procession | |||||
if (GetAddGraphCondition(graph_id) == kDoneAdded) { | |||||
GraphNodePtr graph_node; | |||||
if (GetGraphNode(graph_id, graph_node) != SUCCESS) { | |||||
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "Graph not exist while done adding previously, graph_id = %u.", graph_id); | |||||
return GE_GRAPH_GRAPH_NOT_EXIST; | |||||
} | |||||
graph_node->IncreaseLoadCount(); | |||||
return SUCCESS; | |||||
} | |||||
// In multi-thread scenario, former thread owns same graph_id has been | |||||
// in the middle of the AddGraph procession while following threads have to wait until | |||||
// done adding graph of the former graph, avoiding repeatively adding same graph. | |||||
bool is_added = false; | |||||
if (CheckRepeatAdd(graph_id, is_added) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "CheckRepeatAdd for graph[id:%u] failed.", graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// The former graph (from different thread) owns same graph id has been successfully added. | |||||
if (is_added) { | |||||
return SUCCESS; | |||||
} | |||||
// Do add graph | |||||
SetAddGraphCondition(graph_id, kStartAdd); | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
compute_graph->SetGraphID(graph_id); | |||||
SetSessionGraphId(compute_graph, graph_id); | |||||
if (CreateGraphNode(graph_id, graph, options) != SUCCESS) { | |||||
GELOGE(FAILED, "Failed to create graph_node."); | |||||
return FAILED; | |||||
} | |||||
AddLocalOmgContext(graph_id, omg_context); | AddLocalOmgContext(graph_id, omg_context); | ||||
if (!options_.output_datatype.empty()) { | if (!options_.output_datatype.empty()) { | ||||
@@ -396,16 +511,18 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
} | } | ||||
CompilerStages &stages = GetCompilerStages(graph_id); | |||||
stages.preparer.SetOptions(options_); | |||||
Status status = stages.optimizer.SetOptions(options_); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Graph optimizer set options failed."); | |||||
return status; | |||||
if (SetStagesOptions(graph_id, options_) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Set stage options failed."); | |||||
return INTERNAL_ERROR; | |||||
} | } | ||||
stages.builder.SetOptions(options_); | |||||
var_acc_ctrl_.AddGraph(graph_id, compute_graph); | var_acc_ctrl_.AddGraph(graph_id, compute_graph); | ||||
SetAddGraphCondition(graph_id, kDoneAdded); | |||||
// There are threads waitting for adding same graph | |||||
if (NotifyWaittingGraph(graph_id) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "NotifyWaittingGraph failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -962,6 +1079,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
if (!graph_node->IsAsync()) { | if (!graph_node->IsAsync()) { | ||||
ret = LoadGraph(ge_root_model, graph_node); | ret = LoadGraph(ge_root_model, graph_node); | ||||
} else { | } else { | ||||
GE_CHECK_NOTNULL(ge_root_model); | |||||
ret = LoadGraphAsync(ge_root_model, graph_node); | ret = LoadGraphAsync(ge_root_model, graph_node); | ||||
} | } | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -976,6 +1094,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
if (!graph_node->IsAsync()) { | if (!graph_node->IsAsync()) { | ||||
ret = LoadGraph(ge_root_model_ptr, graph_node); | ret = LoadGraph(ge_root_model_ptr, graph_node); | ||||
} else { | } else { | ||||
GE_CHECK_NOTNULL(ge_root_model); | |||||
ret = LoadGraphAsync(ge_root_model_ptr, graph_node); | ret = LoadGraphAsync(ge_root_model_ptr, graph_node); | ||||
} | } | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -988,6 +1107,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | ||||
GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | ||||
if (options_.run_graph_flag && ge_root_model != nullptr) { | if (options_.run_graph_flag && ge_root_model != nullptr) { | ||||
ge_root_model->SetTrainFlag(GetTrainFlag()); | |||||
// synchronization run graph with model | // synchronization run graph with model | ||||
std::shared_ptr<GraphModelListener> model_listener = GetModelListener(); | std::shared_ptr<GraphModelListener> model_listener = GetModelListener(); | ||||
ModelIdInfo model_id_info; | ModelIdInfo model_id_info; | ||||
@@ -1413,62 +1533,29 @@ bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load | |||||
} | } | ||||
Status GraphManager::RemoveGraph(const GraphId &graph_id) { | Status GraphManager::RemoveGraph(const GraphId &graph_id) { | ||||
auto it = to_be_deleted_graphs_.find(graph_id); | |||||
if (it != to_be_deleted_graphs_.end()) { | |||||
to_be_deleted_graphs_.erase(it); | |||||
} | |||||
GraphNodePtr graph_node = nullptr; | GraphNodePtr graph_node = nullptr; | ||||
Status ret = GetGraphNode(graph_id, graph_node); | Status ret = GetGraphNode(graph_id, graph_node); | ||||
if (ret != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid", | |||||
graph_id); | |||||
if (ret != SUCCESS || graph_node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid when GraphManager %s", | |||||
graph_id, __FUNCTION__); | |||||
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); | GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); | ||||
return GE_GRAPH_GRAPH_NOT_EXIST; | return GE_GRAPH_GRAPH_NOT_EXIST; | ||||
} | } | ||||
if ((graph_node == nullptr) || (graph_node->GetRunFlag())) { | |||||
REPORT_INNER_ERROR("E19999", "Graph:%u is running, can't be remove, check invalid", | |||||
graph_id); | |||||
GELOGE(GE_GRAPH_GRAPH_IS_RUNNING, "[GraphManager] Id %u is running, can't be deleted.", graph_id); | |||||
return GE_GRAPH_GRAPH_IS_RUNNING; | |||||
if (graph_node->GetRunFlag()) { | |||||
// only put graph into to-be-deleted list when exceptional scenario | |||||
to_be_deleted_graphs_.insert(graph_id); | |||||
GELOGI("[GraphManager] Trying to remove running graph[Id:%u], added into to_be_deleted_graphs_.", graph_id); | |||||
return SUCCESS; | |||||
} | } | ||||
std::lock_guard<std::mutex> lock(unload_model_mutex_); | std::lock_guard<std::mutex> lock(unload_model_mutex_); | ||||
Status middle_ret; | Status middle_ret; | ||||
rtError_t rt_ret; | rtError_t rt_ret; | ||||
const std::vector<SubGraphInfoPtr> &all_sub_graph = graph_node->GetAllSubGraph(); | |||||
for (size_t i = 0; i < all_sub_graph.size(); ++i) { | |||||
// must free buffer firstly | |||||
middle_ret = all_sub_graph[i]->FreeInOutBuffer(); | |||||
if (middle_ret != SUCCESS) { | |||||
GELOGE(middle_ret, "[GraphManager] RemoveGraph free mem failed, graph_id=%u.", graph_id); | |||||
ret = middle_ret; | |||||
} | |||||
if (all_sub_graph[i]->GeModelIsValid() && all_sub_graph[i]->GetModelIdInfo().model_id != INVALID_MODEL_ID) { | |||||
// unload model | |||||
GELOGI("UnloadModel via new ome."); | |||||
rt_ret = rtSetDevice(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", | |||||
GetContext().DeviceId(), graph_id); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", | |||||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||||
ret = FAILED; | |||||
continue; | |||||
} | |||||
middle_ret = GraphLoader::UnloadModel(all_sub_graph[i]->GetModelIdInfo().model_id); | |||||
if (middle_ret != SUCCESS) { | |||||
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", | |||||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||||
ret = middle_ret; | |||||
} | |||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset fail, device_id:%u, graph_id:%u", | |||||
GetContext().DeviceId(), graph_id); | |||||
GELOGE(RT_FAILED, "[GraphManager:] unload model failed, modelId=%u, graphId=%u.", | |||||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||||
ret = FAILED; | |||||
} | |||||
} | |||||
} | |||||
var_acc_ctrl_.RemoveGraph(graph_id); | var_acc_ctrl_.RemoveGraph(graph_id); | ||||
RemoveGraphNode(graph_id); | RemoveGraphNode(graph_id); | ||||
@@ -1476,7 +1563,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||||
auto ge_root_model = graph_node->GetGeRootModel(); | auto ge_root_model = graph_node->GetGeRootModel(); | ||||
if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | ||||
GELOGI("Unload model %u.", ge_root_model->GetModelId()); | |||||
rt_ret = rtSetDevice(GetContext().DeviceId()); | rt_ret = rtSetDevice(GetContext().DeviceId()); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", | REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", | ||||
@@ -1485,23 +1571,27 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||||
graph_id); | graph_id); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||||
// same graph may be added for several times, different models were created separately, | |||||
// unload them respectively. | |||||
middle_ret = UnloadModel(ge_root_model, graph_id); | |||||
if (middle_ret != SUCCESS) { | if (middle_ret != SUCCESS) { | ||||
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(), | |||||
graph_id); | |||||
REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check unload detail in GraphLoader %s", | |||||
graph_id, __FUNCTION__); | |||||
GELOGE(middle_ret, "[GraphManager:] unload model failed, graph_id=%u.", graph_id); | |||||
ret = middle_ret; | ret = middle_ret; | ||||
} | } | ||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u", | |||||
GetContext().DeviceId(), graph_id); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||||
graph_id); | |||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u, when GraphManager %s", | |||||
GetContext().DeviceId(), graph_id, __FUNCTION__); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||||
ret = FAILED; | ret = FAILED; | ||||
} | } | ||||
} | } | ||||
RemoveCompilerStages(graph_id); | RemoveCompilerStages(graph_id); | ||||
RemoveGraphCount(graph_id); | |||||
RemoveAddGraphCondition(graph_id); | |||||
GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); | GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); | ||||
GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); | GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); | ||||
@@ -2588,6 +2678,7 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr | |||||
Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | ||||
GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | ||||
if (options_.run_graph_flag && ge_root_model != nullptr) { | if (options_.run_graph_flag && ge_root_model != nullptr) { | ||||
ge_root_model->SetTrainFlag(GetTrainFlag()); | |||||
// synchronization run graph with model | // synchronization run graph with model | ||||
ModelIdInfo model_id_info; | ModelIdInfo model_id_info; | ||||
bool is_unknown_shape = false; | bool is_unknown_shape = false; | ||||
@@ -2604,9 +2695,9 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||||
} | } | ||||
} | } | ||||
GE_TIMESTAMP_START(LoadGraph); | GE_TIMESTAMP_START(LoadGraph); | ||||
GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); | |||||
Status ret = | |||||
GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); | |||||
auto listener = MakeShared<RunAsyncListener>(); | |||||
GE_CHECK_NOTNULL(listener); | |||||
Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener); | |||||
GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); | GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); | GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); | ||||
@@ -2620,6 +2711,52 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, | |||||
const std::vector<uint32_t> &model_ids, uint32_t graph_id, uint64_t session_id) { | |||||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, when GraphManager %s", | |||||
GetContext().DeviceId(), __FUNCTION__); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, graphId=%u.", graph_id); | |||||
return; | |||||
} | |||||
for (auto model_id : model_ids) { | |||||
uint64_t max_memory_size = 0; | |||||
Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||||
if (result != SUCCESS) { | |||||
continue; | |||||
} | |||||
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||||
max_memory_size); | |||||
if (model_ids.size() > 1) { | |||||
result = ge_model->GetSessionId(model_id, session_id); | |||||
if (result != SUCCESS) { | |||||
GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||||
graph_id); | |||||
continue; | |||||
} | |||||
} | |||||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||||
if (result != SUCCESS) { | |||||
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||||
graph_id); | |||||
} | |||||
result = GraphLoader::UnloadModel(model_id); | |||||
if (result != SUCCESS) { | |||||
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||||
} | |||||
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); | |||||
} | |||||
graph_node->SetLoadFlag(false); | |||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", | |||||
GetContext().DeviceId(), __FUNCTION__); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||||
return; | |||||
} | |||||
} | |||||
Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { | Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { | ||||
GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); | GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); | ||||
int64_t value = 0; | int64_t value = 0; | ||||
@@ -2665,6 +2802,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
continue; | continue; | ||||
} | } | ||||
auto model_id = model->GetModelId(); | auto model_id = model->GetModelId(); | ||||
auto model_ids = model->GetAllModelId(); | |||||
// unload model not release | // unload model not release | ||||
bool is_unknown_shape = false; | bool is_unknown_shape = false; | ||||
GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); | GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); | ||||
@@ -2677,38 +2815,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | ||||
continue; | continue; | ||||
} | } | ||||
uint64_t max_memory_size = 0; | |||||
result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||||
if (result != SUCCESS) { | |||||
continue; | |||||
} | |||||
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||||
max_memory_size); | |||||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u", | |||||
GetContext().DeviceId()); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||||
continue; | |||||
} | |||||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||||
if (result != SUCCESS) { | |||||
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||||
graph_id); | |||||
} | |||||
result = GraphLoader::UnloadModel(model_id); | |||||
if (result != SUCCESS) { | |||||
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||||
} | |||||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u", | |||||
GetContext().DeviceId()); | |||||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||||
continue; | |||||
} | |||||
it.second->SetLoadFlag(false); | |||||
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success and set LoadFlag to false.", graph_id, model_id); | |||||
ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -2849,6 +2956,38 @@ void GraphManager::ConstructGeInput(const vector<InputTensorInfo> &inputs, vecto | |||||
} | } | ||||
} | } | ||||
Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, | |||||
GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { | |||||
if (!graph_manager->IsGraphNeedBuild(graph_node)) { | |||||
ge_root_model = graph_node->GetGeRootModel(); | |||||
return SUCCESS; | |||||
} | |||||
if (graph_node->GetBuildFlag()) { | |||||
ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||||
"The graph " + std::to_string(graph_node->GetGraphId()) + | |||||
" need to re-build, you should remove it" | |||||
" from GE first, then AddGraph again and rebuild it."); | |||||
graph_node->Unlock(); | |||||
return PARAM_INVALID; | |||||
} | |||||
// check need incre build. | |||||
GeModelPtr ge_model = nullptr; | |||||
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||||
std::vector<GeTensor> ge_inputs; | |||||
ConstructGeInput(args.input_tensor, ge_inputs); | |||||
Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | |||||
ReturnError(graph_manager, args.callback, ret, "PreRun Failed."); | |||||
return ret; | |||||
} | |||||
} | |||||
graph_node->SetBuildFlag(true); | |||||
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||||
return SUCCESS; | |||||
} | |||||
void GraphManager::PreRunThread(GraphManager *graph_manager) { | void GraphManager::PreRunThread(GraphManager *graph_manager) { | ||||
if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | ||||
GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
@@ -2861,7 +3000,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
continue; | continue; | ||||
} | } | ||||
GELOGI("A new loop start."); | |||||
GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id); | |||||
ErrorManager::GetInstance().SetErrorContext(args.error_context); | ErrorManager::GetInstance().SetErrorContext(args.error_context); | ||||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); | ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); | ||||
@@ -2877,7 +3016,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
"[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | ||||
return; | return; | ||||
} | } | ||||
// more than one graph owns same graph_id | |||||
uint32_t count = 0; | |||||
if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed.", args.graph_id); | |||||
return; | |||||
} | |||||
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | |||||
if (count > 1 && graph_node->GetBuildFlag()) { | |||||
graph_node->Lock(); | |||||
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | |||||
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | |||||
graph_node->SetSemSize(count); | |||||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||||
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | |||||
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | |||||
continue; | |||||
} | |||||
// Cannot be put ahead of the repeatively prerun judgement | |||||
graph_node->Lock(); | graph_node->Lock(); | ||||
if (graph_node->GetRunFlag()) { | if (graph_node->GetRunFlag()) { | ||||
@@ -2909,46 +3065,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
// it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | ||||
GELOGI("Start for run graph async."); | GELOGI("Start for run graph async."); | ||||
GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
if (graph_manager->IsGraphNeedBuild(graph_node)) { | |||||
if (graph_node->GetBuildFlag()) { | |||||
ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||||
"The graph " + std::to_string(graph_node->GetGraphId()) + | |||||
" need to re-build, you should remove it" | |||||
" from GE first, then AddGraph again and rebuild it."); | |||||
ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model); | |||||
if (ret != SUCCESS) { | |||||
graph_node->SetRunFlag(false); | |||||
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||||
ReturnError(graph_manager, args.callback, ret, "CheckIncreBuildAndPreRun Failed, thread exit.."); | |||||
graph_node->Unlock(); | graph_node->Unlock(); | ||||
return; | return; | ||||
} else { | |||||
ReturnError(graph_manager, graph_node, args.callback, ret, | |||||
"CheckIncreBuildAndPreRun Failed, keep geop continue!"); | |||||
graph_node->Unlock(); | |||||
continue; | |||||
} | } | ||||
// check need incre build. | |||||
GeModelPtr ge_model = nullptr; | |||||
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||||
std::vector<GeTensor> ge_inputs; | |||||
ConstructGeInput(args.input_tensor, ge_inputs); | |||||
ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | |||||
graph_node->SetRunFlag(false); | |||||
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||||
ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | |||||
graph_node->Unlock(); | |||||
return; | |||||
} else { | |||||
ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!"); | |||||
graph_node->Unlock(); | |||||
continue; | |||||
} | |||||
} | |||||
} | |||||
graph_node->SetBuildFlag(true); | |||||
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||||
} else { | |||||
ge_root_model = graph_node->GetGeRootModel(); | |||||
} | } | ||||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | ||||
args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); | args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); | ||||
GELOGI("Loop end."); | |||||
GELOGI("[PreRunThread] Loop end."); | |||||
} | } | ||||
} | } | ||||
@@ -3051,16 +3185,13 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
continue; | continue; | ||||
} | } | ||||
GELOGI("A new loop start."); | |||||
GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); | |||||
ErrorManager::GetInstance().SetErrorContext(args.error_context); | ErrorManager::GetInstance().SetErrorContext(args.error_context); | ||||
GetContext().SetSessionId(args.session_id); | GetContext().SetSessionId(args.session_id); | ||||
GetThreadLocalContext() = args.context; | GetThreadLocalContext() = args.context; | ||||
graph_manager->UpdateLocalOmgContext(args.graph_id); | graph_manager->UpdateLocalOmgContext(args.graph_id); | ||||
if (args.graph_node->graph_run_async_listener_ != nullptr) { | |||||
args.graph_node->graph_run_async_listener_->SetCallback(args.callback); | |||||
} | |||||
Status ret; | Status ret; | ||||
// parse inputs.dims to vector<vector<uint64_t>> dynamic_dims | // parse inputs.dims to vector<vector<uint64_t>> dynamic_dims | ||||
ret = graph_manager->ParseInputsDims(args.input_tensor); | ret = graph_manager->ParseInputsDims(args.input_tensor); | ||||
@@ -3070,8 +3201,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
return; | return; | ||||
} | } | ||||
args.graph_node->UpdateLoadFlag(); | |||||
if (!args.graph_node->GetLoadFlag()) { | if (!args.graph_node->GetLoadFlag()) { | ||||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); | ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); | ||||
args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag()); | |||||
ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); | ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); | ||||
if (ret != SUCCESS || args.ge_root_model == nullptr) { | if (ret != SUCCESS || args.ge_root_model == nullptr) { | ||||
StopQueue(graph_manager); | StopQueue(graph_manager); | ||||
@@ -3079,6 +3212,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
args.graph_node->Unlock(); | args.graph_node->Unlock(); | ||||
return; | return; | ||||
} | } | ||||
// control the times of graph loading in multi-thread scenario | |||||
args.graph_node->DecreaseLoadCount(); | |||||
args.graph_node->IncreaseLoadRecord(); | |||||
args.graph_node->SetLoadFlag(true); | args.graph_node->SetLoadFlag(true); | ||||
GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), | GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), | ||||
args.ge_root_model->GetModelId()); | args.ge_root_model->GetModelId()); | ||||
@@ -3093,9 +3230,9 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); | graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); | ||||
} | } | ||||
args.graph_node->SetRunFlag(false); | |||||
ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), | ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), | ||||
args.input_tensor); | |||||
args.input_tensor, args.callback); | |||||
args.graph_node->SetRunFlag(false); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); | ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); | ||||
args.graph_node->Unlock(); | args.graph_node->Unlock(); | ||||
@@ -3546,4 +3683,49 @@ void GraphManager::RemoveCompilerStages(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | std::lock_guard<std::mutex> lock(member_mutex_); | ||||
compiler_stages_.erase(graph_id); | compiler_stages_.erase(graph_id); | ||||
} | } | ||||
void GraphManager::IncreaseGraphCount(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||||
auto it = graph_count_.find(graph_id); | |||||
if (it == graph_count_.end()) { | |||||
graph_count_.insert({graph_id, kInitGraphCount}); | |||||
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||||
} else { | |||||
++graph_count_[graph_id]; | |||||
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||||
} | |||||
} | |||||
void GraphManager::RemoveGraphCount(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||||
auto it = graph_count_.find(graph_id); | |||||
if (it == graph_count_.end()) { | |||||
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||||
} else { | |||||
GELOGD("RemoveGraphCount success, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||||
graph_count_.erase(it); | |||||
} | |||||
} | |||||
void GraphManager::DecreaseGraphCount(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||||
auto it = graph_count_.find(graph_id); | |||||
if (it == graph_count_.end()) { | |||||
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||||
} else { | |||||
--it->second; | |||||
GELOGD("After DecreaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||||
} | |||||
} | |||||
Status GraphManager::GetGraphCount(GraphId graph_id, uint32_t &count) { | |||||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||||
auto it = graph_count_.find(graph_id); | |||||
if (it == graph_count_.end()) { | |||||
GELOGW("Graph [id:%u] has not been added.", graph_id); | |||||
return FAILED; | |||||
} | |||||
count = it->second; | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -184,6 +184,20 @@ class GraphManager { | |||||
Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results); | Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results); | ||||
void RemoveGraphCount(GraphId graph_id); | |||||
void IncreaseGraphCount(GraphId graph_id); | |||||
void DecreaseGraphCount(GraphId graph_id); | |||||
Status GetGraphCount(GraphId graph_id, uint32_t &count); | |||||
void SetAddGraphCondition(GraphId graph_id, uint32_t cond); | |||||
uint32_t GetAddGraphCondition(GraphId graph_id); | |||||
void RemoveAddGraphCondition(GraphId graph_id); | |||||
private: | private: | ||||
struct CompilerStages { | struct CompilerStages { | ||||
GraphPrepare preparer; | GraphPrepare preparer; | ||||
@@ -381,6 +395,24 @@ class GraphManager { | |||||
CompilerStages &GetCompilerStages(GraphId graph_id); | CompilerStages &GetCompilerStages(GraphId graph_id); | ||||
void RemoveCompilerStages(GraphId graph_id); | void RemoveCompilerStages(GraphId graph_id); | ||||
static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node, | |||||
GeRootModelPtr &ge_root_model); | |||||
void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector<uint32_t> &model_ids, | |||||
uint32_t graph_id, uint64_t session_id); | |||||
Status CheckRepeatAdd(uint32_t graph_id, bool &is_added); | |||||
Status NotifyWaittingGraph(uint32_t graph_id); | |||||
Status CreateGraphNode(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options); | |||||
Status SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options); | |||||
Status UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id); | |||||
void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); | |||||
std::atomic_bool thread_run_flag_; | std::atomic_bool thread_run_flag_; | ||||
BlockingQueue<PreRunArgs> prerun_args_q_{}; | BlockingQueue<PreRunArgs> prerun_args_q_{}; | ||||
BlockingQueue<RunArgs> run_args_q_{}; | BlockingQueue<RunArgs> run_args_q_{}; | ||||
@@ -416,6 +448,16 @@ class GraphManager { | |||||
std::mutex member_mutex_; | std::mutex member_mutex_; | ||||
std::mutex unload_model_mutex_; | std::mutex unload_model_mutex_; | ||||
// avoid repeatively add same graph (owns same graph id) | |||||
std::mutex add_graph_mutex_; | |||||
std::mutex add_graph_cond_mutex_; | |||||
std::condition_variable add_graph_cv_; | |||||
std::map<GraphId, uint32_t> graph_id_to_add_graph_cond_; | |||||
// use for multi-thread online-infer scenario | |||||
std::set<GraphId> to_be_deleted_graphs_; | |||||
std::map<GraphId, uint32_t> graph_count_; | |||||
std::mutex graph_count_mutex_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -60,6 +60,15 @@ void GraphNode::Unlock() { | |||||
sem_.Pop(unused); | sem_.Pop(unused); | ||||
} | } | ||||
void GraphNode::IncreaseLoadCount() { | |||||
std::unique_lock<std::mutex> lock(load_count_mu_); | |||||
if (load_record_ == kMaxLoadNum) { | |||||
GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum); | |||||
return; | |||||
} | |||||
++load_count_; | |||||
} | |||||
SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} | SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} | ||||
SubGraphInfo::~SubGraphInfo() { | SubGraphInfo::~SubGraphInfo() { | ||||
@@ -55,6 +55,7 @@ using ConstGraphPtr = std::shared_ptr<const ge::Graph>; | |||||
using GraphPtr = std::shared_ptr<ge::Graph>; | using GraphPtr = std::shared_ptr<ge::Graph>; | ||||
const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; | const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; | ||||
const uint32_t kMaxLoadNum = 8; | |||||
struct ModelIdInfo { | struct ModelIdInfo { | ||||
uint32_t model_id{INVALID_MODEL_ID}; | uint32_t model_id{INVALID_MODEL_ID}; | ||||
@@ -162,6 +163,8 @@ class GraphNode { | |||||
bool GetBuildFlag() const { return build_flag_; } | bool GetBuildFlag() const { return build_flag_; } | ||||
void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } | void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } | ||||
bool GetLoadFlag() const { return load_flag_; } | bool GetLoadFlag() const { return load_flag_; } | ||||
// allow repeatively load graph owns same graph id | |||||
void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; } | |||||
void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | ||||
void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | ||||
GeModelPtr GetGeModel() const { return ge_model_; } | GeModelPtr GetGeModel() const { return ge_model_; } | ||||
@@ -172,6 +175,13 @@ class GraphNode { | |||||
void Lock(); | void Lock(); | ||||
void Unlock(); | void Unlock(); | ||||
void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } | |||||
uint32_t GetLoadCount() const { return load_count_; } | |||||
void IncreaseLoadCount(); | |||||
void DecreaseLoadCount() { --load_count_; } | |||||
void IncreaseLoadRecord() { ++load_record_; } | |||||
// run graph asynchronous listener | // run graph asynchronous listener | ||||
std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | ||||
@@ -184,11 +194,17 @@ class GraphNode { | |||||
GraphPtr graph_; | GraphPtr graph_; | ||||
ComputeGraphPtr compute_graph_; | ComputeGraphPtr compute_graph_; | ||||
bool build_flag_; | bool build_flag_; | ||||
// load_flag_ is true if more than 1 model were loaded | |||||
bool load_flag_; | bool load_flag_; | ||||
bool async_; | bool async_; | ||||
GeModelPtr ge_model_; | GeModelPtr ge_model_; | ||||
GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
BlockingQueue<uint8_t> sem_; | BlockingQueue<uint8_t> sem_; | ||||
// consist with graph_count of same graph_id in graph_manager | |||||
uint32_t load_count_ = 0; | |||||
// total times of loading a graph with same graph_id. | |||||
uint32_t load_record_ = 0; | |||||
std::mutex load_count_mu_; | |||||
}; | }; | ||||
using GraphNodePtr = std::shared_ptr<GraphNode>; | using GraphNodePtr = std::shared_ptr<GraphNode>; | ||||
@@ -144,8 +144,12 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||||
GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); | GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); | ||||
while (run_flag_) { | while (run_flag_) { | ||||
// Model has not indeedly started running before received data | |||||
SetRunningFlag(false); | |||||
std::shared_ptr<InputDataWrapper> data_wrapper; | std::shared_ptr<InputDataWrapper> data_wrapper; | ||||
Status ret = data_inputer_->Pop(data_wrapper); | Status ret = data_inputer_->Pop(data_wrapper); | ||||
// Model indeedly start running | |||||
SetRunningFlag(true); | |||||
if (data_wrapper == nullptr || ret != SUCCESS) { | if (data_wrapper == nullptr || ret != SUCCESS) { | ||||
GELOGI("data_wrapper is null!, ret = %u", ret); | GELOGI("data_wrapper is null!, ret = %u", ret); | ||||
continue; | continue; | ||||
@@ -185,7 +189,8 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||||
RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); | RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); | ||||
iterator_count_++; | iterator_count_++; | ||||
GELOGI("run iterator count is %lu", iterator_count_); | |||||
SetRunningFlag(false); | |||||
GELOGI("run iterator count is %lu, model_id:%u", iterator_count_, model_id_); | |||||
} | } | ||||
CsaInteract::GetInstance().WriteInternalErrorCode(); | CsaInteract::GetInstance().WriteInternalErrorCode(); | ||||
@@ -55,6 +55,12 @@ class HybridModelAsyncExecutor { | |||||
Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data); | Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data); | ||||
uint32_t GetDataInputerSize() { return data_inputer_->Size(); } | |||||
bool GetRunningFlag() const { return running_flag_; } | |||||
void SetRunningFlag(bool flag) { running_flag_ = flag; } | |||||
private: | private: | ||||
Status InitInputDesc(); | Status InitInputDesc(); | ||||
@@ -84,6 +90,8 @@ class HybridModelAsyncExecutor { | |||||
uint32_t device_id_ = 0U; | uint32_t device_id_ = 0U; | ||||
uint32_t model_id_ = 0U; | uint32_t model_id_ = 0U; | ||||
std::atomic_bool run_flag_; | std::atomic_bool run_flag_; | ||||
// check whether model is running with data | |||||
bool running_flag_ = false; | |||||
std::unique_ptr<DataInputer> data_inputer_; | std::unique_ptr<DataInputer> data_inputer_; | ||||
std::unique_ptr<HybridModelExecutor> executor_; | std::unique_ptr<HybridModelExecutor> executor_; | ||||
std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | ||||
@@ -19,6 +19,7 @@ | |||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "hybrid/executor/hybrid_model_async_executor.h" | #include "hybrid/executor/hybrid_model_async_executor.h" | ||||
#include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
#include "graph/manager/graph_manager_utils.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -108,6 +109,17 @@ class HybridDavinciModel::Impl { | |||||
model_.SetModelDescVersion(is_new_model_desc); | model_.SetModelDescVersion(is_new_model_desc); | ||||
} | } | ||||
uint32_t GetDataInputerSize() { return executor_.GetDataInputerSize(); } | |||||
bool GetRunningFlag() const { return executor_.GetRunningFlag(); } | |||||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||||
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||||
GE_CHECK_NOTNULL(listener); | |||||
listener->SetCallback(callback); | |||||
return SUCCESS; | |||||
} | |||||
private: | private: | ||||
std::shared_ptr<ModelListener> listener_; | std::shared_ptr<ModelListener> listener_; | ||||
HybridModel model_; | HybridModel model_; | ||||
@@ -222,5 +234,16 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||||
GE_CHECK_NOTNULL(impl_); | GE_CHECK_NOTNULL(impl_); | ||||
return impl_->GetSessionId(); | return impl_->GetSessionId(); | ||||
} | } | ||||
uint32_t HybridDavinciModel::GetDataInputerSize() { | |||||
GE_CHECK_NOTNULL(impl_); | |||||
return impl_->GetDataInputerSize(); | |||||
} | |||||
bool HybridDavinciModel::GetRunningFlag() const { return impl_->GetRunningFlag(); } | |||||
Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||||
return impl_->SetRunAsyncListenerCallback(callback); | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -74,6 +74,12 @@ class HybridDavinciModel { | |||||
void SetModelDescVersion(bool is_new_model_desc); | void SetModelDescVersion(bool is_new_model_desc); | ||||
uint32_t GetDataInputerSize(); | |||||
bool GetRunningFlag() const; | |||||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||||
private: | private: | ||||
HybridDavinciModel() = default; | HybridDavinciModel() = default; | ||||
class Impl; | class Impl; | ||||
@@ -68,6 +68,10 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||||
return 0; | return 0; | ||||
} | } | ||||
uint32_t HybridDavinciModel::GetDataInputerSize() { | |||||
return 0; | |||||
} | |||||
Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) { | Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) { | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -87,5 +91,13 @@ Status HybridDavinciModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &i | |||||
void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { | void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { | ||||
} | } | ||||
bool HybridDavinciModel::GetRunningFlag() const { | |||||
return false; | |||||
} | |||||
Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||||
return UNSUPPORTED; | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -85,4 +85,14 @@ ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; } | |||||
ConstProtoAttrMapHelper GeModel::GetAttrMap() const { | ConstProtoAttrMapHelper GeModel::GetAttrMap() const { | ||||
return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); | return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); | ||||
} | } | ||||
Status GeModel::GetSessionId(uint32_t model_id, uint64_t &session_id) const { | |||||
auto it = model_id_to_session_id_map_.find(model_id); | |||||
if (it != model_id_to_session_id_map_.end()) { | |||||
session_id = it->second; | |||||
return SUCCESS; | |||||
} | |||||
GELOGW("No session id were found with model id [%u].", model_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -71,6 +71,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||||
void SetModelId(uint32_t model_id) { model_id_ = model_id; } | void SetModelId(uint32_t model_id) { model_id_ = model_id; } | ||||
uint32_t GetModelId() const { return model_id_; } | uint32_t GetModelId() const { return model_id_; } | ||||
Status GetSessionId(uint32_t model_id, uint64_t &session_id) const; | |||||
void InsertSessionMap(uint32_t model_id, uint64_t session_id) { | |||||
model_id_to_session_id_map_.insert({model_id, session_id}); | |||||
} | |||||
protected: | protected: | ||||
ConstProtoAttrMapHelper GetAttrMap() const override; | ConstProtoAttrMapHelper GetAttrMap() const override; | ||||
@@ -90,6 +95,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||||
std::string platform_version_; | std::string platform_version_; | ||||
uint8_t platform_type_ = {0}; | uint8_t platform_type_ = {0}; | ||||
uint32_t model_id_ = INVALID_MODEL_ID; | uint32_t model_id_ = INVALID_MODEL_ID; | ||||
std::map<uint32_t, uint64_t> model_id_to_session_id_map_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
using GeModelPtr = std::shared_ptr<ge::GeModel>; | using GeModelPtr = std::shared_ptr<ge::GeModel>; | ||||
@@ -32,15 +32,31 @@ class GeRootModel { | |||||
return subgraph_instance_name_to_model_; | return subgraph_instance_name_to_model_; | ||||
}; | }; | ||||
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; | |||||
void SetModelId(uint32_t model_id) { model_id_ = model_id; } | |||||
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } | |||||
void SetModelId(uint32_t model_id) { | |||||
model_id_ = model_id; | |||||
// cached for removement | |||||
model_ids_.emplace_back(model_id); | |||||
} | |||||
uint32_t GetModelId() const { return model_id_; } | uint32_t GetModelId() const { return model_id_; } | ||||
std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | |||||
Status CheckIsUnknownShape(bool &is_dynamic_shape); | Status CheckIsUnknownShape(bool &is_dynamic_shape); | ||||
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | ||||
void SetTrainFlag(bool flag) { train_flag_ = flag; } | |||||
bool GetTrainFlag() const { return train_flag_; } | |||||
private: | private: | ||||
ComputeGraphPtr root_graph_ = nullptr; | ComputeGraphPtr root_graph_ = nullptr; | ||||
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | ||||
uint32_t model_id_ = 0; | uint32_t model_id_ = 0; | ||||
// In multithread online secenario, same graph can owns different davinci_model for for concurrency | |||||
std::vector<uint32_t> model_ids_; | |||||
bool train_flag_ = false; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
@@ -1 +1 @@ | |||||
Subproject commit 1e88df1d6bfe60faae0aa9fa2d87f273b793aeb0 | |||||
Subproject commit fcebf37d7428caf4e0bd6e6c3a4f8143f6eac8b7 |
@@ -593,6 +593,7 @@ set(SINGLE_OP_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_pipeline_executor.cc" | |||||
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" | "${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" | ||||
@@ -780,10 +781,12 @@ set(MULTI_PARTS_TEST_FILES | |||||
"graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
"graph/build/task_generator_unittest.cc" | "graph/build/task_generator_unittest.cc" | ||||
"graph/build/buffer_pool_mem_assigner_unittest.cc" | "graph/build/buffer_pool_mem_assigner_unittest.cc" | ||||
"graph/execute/graph_execute_unittest.cc" | |||||
"graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
"graph/manager/hcom_util_unittest.cc" | "graph/manager/hcom_util_unittest.cc" | ||||
"graph/manager/graph_caching_allocator_unittest.cc" | "graph/manager/graph_caching_allocator_unittest.cc" | ||||
"graph/partition/dynamic_shape_partition_unittest.cc" | "graph/partition/dynamic_shape_partition_unittest.cc" | ||||
"graph/manager/graph_manager_unittest.cc" | |||||
"session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
) | ) | ||||
@@ -0,0 +1,129 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/execute/graph_execute.h" | |||||
#include "graph/load/model_manager/model_manager.h" | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#include "omm/csa_interact.h" | |||||
#undef private | |||||
#undef public | |||||
#include <pthread.h> | |||||
#include <algorithm> | |||||
#include <future> | |||||
#include <set> | |||||
#include <sstream> | |||||
#include <string> | |||||
#include <thread> | |||||
#include <future> | |||||
using namespace std; | |||||
using namespace testing; | |||||
using namespace ge; | |||||
using namespace domi; | |||||
namespace ge { | |||||
namespace { | |||||
const uint32_t kInvalidModelId = UINT32_MAX; | |||||
} | |||||
class UtestGraphExecuteTest : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_invalid) { | |||||
GraphExecutor executor; | |||||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||||
EXPECT_EQ(model_id, kInvalidModelId); | |||||
} | |||||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_1) { | |||||
GraphExecutor executor; | |||||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||||
davinci_model1->SetId(1); | |||||
model_manager->InsertModel(1, davinci_model1); | |||||
ge_root_model->SetModelId(1); | |||||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||||
EXPECT_EQ(model_id, 1); | |||||
} | |||||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_2) { | |||||
GraphExecutor executor; | |||||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
// model1 with 2 load | |||||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||||
davinci_model1->SetId(1); | |||||
davinci_model1->data_inputer_ = new DataInputer(); | |||||
auto data = MakeShared<InputDataWrapper>(); | |||||
davinci_model1->data_inputer_->Push(data); | |||||
davinci_model1->data_inputer_->Push(data); | |||||
model_manager->InsertModel(1, davinci_model1); | |||||
// model 2 with 3 load | |||||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(1, nullptr); | |||||
davinci_model2->SetId(2); | |||||
davinci_model2->data_inputer_ = new DataInputer(); | |||||
davinci_model2->data_inputer_->Push(data); | |||||
davinci_model2->data_inputer_->Push(data); | |||||
davinci_model2->data_inputer_->Push(data); | |||||
model_manager->InsertModel(2, davinci_model2); | |||||
// model 3 witH 1 load | |||||
shared_ptr<DavinciModel> davinci_model3 = MakeShared<DavinciModel>(1, nullptr); | |||||
davinci_model3->SetId(3); | |||||
davinci_model3->data_inputer_ = new DataInputer(); | |||||
davinci_model3->data_inputer_->Push(data); | |||||
model_manager->InsertModel(3, davinci_model3); | |||||
ge_root_model->SetModelId(1); | |||||
ge_root_model->SetModelId(2); | |||||
ge_root_model->SetModelId(3); | |||||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||||
// model 3 is picked for having least loads | |||||
EXPECT_EQ(model_id, 3); | |||||
} | |||||
TEST_F(UtestGraphExecuteTest, test_set_callback) { | |||||
GraphExecutor executor; | |||||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||||
// is_unknown_shape_graph_ = false | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||||
RunAsyncCallback callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
auto listener = MakeShared<RunAsyncListener>(); | |||||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||||
davinci_model1->SetId(1); | |||||
model_manager->InsertModel(1, davinci_model1); | |||||
auto status = executor.SetCallback(1, ge_root_model, callback); | |||||
EXPECT_EQ(status, SUCCESS); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,375 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/manager/graph_manager.h" | |||||
#include "graph/load/model_manager/model_manager.h" | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#define const | |||||
#include "common/helper/model_cache_helper.h" | |||||
#undef const | |||||
#include "init/gelib.h" | |||||
#undef private | |||||
#undef public | |||||
#include <pthread.h> | |||||
#include <algorithm> | |||||
#include <future> | |||||
#include <set> | |||||
#include <sstream> | |||||
#include <string> | |||||
#include <thread> | |||||
#include <future> | |||||
#include "common/math/math_util.h" | |||||
#include "common/thread_pool.h" | |||||
#include "common/dump/dump_manager.h" | |||||
#include "analyzer/analyzer.h" | |||||
#include "graph/common/ge_call_wrapper.h" | |||||
#include "graph/common/local_context.h" | |||||
#include "graph/common/transop_util.h" | |||||
#include "graph/ge_context.h" | |||||
#include "graph/ge_global_options.h" | |||||
#include "graph/manager/util/rt_context_util.h" | |||||
#include "graph/partition/dynamic_shape_partition.h" | |||||
#include "graph/passes/enter_pass.h" | |||||
#include "graph/partition/stage_partition.h" | |||||
#include "graph/passes/addn_pass.h" | |||||
#include "graph/passes/bitcast_pass.h" | |||||
#include "graph/passes/assign_remove_pass.h" | |||||
#include "graph/passes/inplace_support_check_pass.h" | |||||
#include "graph/passes/atomic_addr_clean_pass.h" | |||||
#include "graph/passes/attach_stream_label_pass.h" | |||||
#include "graph/passes/cast_remove_pass.h" | |||||
#include "graph/passes/common_subexpression_elimination_pass.h" | |||||
#include "graph/passes/compile_nodes_pass.h" | |||||
#include "graph/passes/cond_remove_pass.h" | |||||
#include "graph/passes/constant_folding_pass.h" | |||||
#include "graph/passes/constant_fuse_same_pass.h" | |||||
#include "graph/passes/control_trigger_pass.h" | |||||
#include "graph/passes/ctrl_edge_transfer_pass.h" | |||||
#include "graph/passes/dimension_adjust_pass.h" | |||||
#include "graph/passes/dimension_compute_pass.h" | |||||
#include "graph/passes/flow_ctrl_pass.h" | |||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||||
#include "graph/passes/identity_pass.h" | |||||
#include "graph/passes/input_output_connection_identify_pass.h" | |||||
#include "graph/passes/iterator_op_pass.h" | |||||
#include "graph/passes/link_gen_mask_nodes_pass.h" | |||||
#include "graph/passes/mark_graph_unknown_status_pass.h" | |||||
#include "graph/passes/merge_pass.h" | |||||
#include "graph/passes/merge_input_memcpy_pass.h" | |||||
#include "graph/passes/merge_to_stream_merge_pass.h" | |||||
#include "graph/passes/multi_batch_pass.h" | |||||
#include "graph/passes/next_iteration_pass.h" | |||||
#include "graph/passes/permute_pass.h" | |||||
#include "graph/passes/prune_pass.h" | |||||
#include "graph/passes/ref_identity_delete_op_pass.h" | |||||
#include "graph/passes/remove_same_const_pass.h" | |||||
#include "graph/passes/reshape_recovery_pass.h" | |||||
#include "graph/passes/reshape_remove_pass.h" | |||||
#include "graph/passes/same_transdata_breadth_fusion_pass.h" | |||||
#include "graph/passes/subgraph_pass.h" | |||||
#include "graph/passes/switch_data_edges_bypass.h" | |||||
#include "graph/passes/switch_dead_branch_elimination.h" | |||||
#include "graph/passes/switch_logic_remove_pass.h" | |||||
#include "graph/passes/switch_to_stream_switch_pass.h" | |||||
#include "graph/passes/transop_breadth_fusion_pass.h" | |||||
#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" | |||||
#include "graph/passes/transop_symmetry_elimination_pass.h" | |||||
#include "graph/passes/transop_without_reshape_fusion_pass.h" | |||||
#include "graph/passes/transpose_transdata_pass.h" | |||||
#include "graph/passes/useless_control_out_remove_pass.h" | |||||
#include "graph/passes/variable_op_pass.h" | |||||
#include "graph/passes/variable_ref_delete_op_pass.h" | |||||
#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" | |||||
#include "graph/passes/end_of_sequence_add_control_pass.h" | |||||
#include "graph/passes/subexpression_migration_pass.h" | |||||
#include "graph/passes/subgraph_const_migration_pass.h" | |||||
#include "graph/passes/unused_args_clean_pass.h" | |||||
#include "graph/passes/global_step_insert_pass.h" | |||||
#include "graph/passes/memcpy_addr_async_pass.h" | |||||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||||
#include "graph/build/label_allocator.h" | |||||
#include "graph/utils/tensor_adapter.h" | |||||
#include "inc/pass_manager.h" | |||||
#include "ir_build/atc_ir_common.h" | |||||
#include "graph/common/local_context.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
#include "register/custom_pass_helper.h" | |||||
#include "graph/ops_stub.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
using namespace ge; | |||||
using namespace domi; | |||||
namespace { | |||||
const uint32_t kNotAdded = 0; | |||||
const uint32_t kStartAdd = 1; | |||||
const uint32_t kDoneAdded = 2; | |||||
} | |||||
class UtestGraphManagerTest : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
void CreateGraph(Graph &graph) { | |||||
TensorDesc desc(ge::Shape({1, 3, 224, 224})); | |||||
uint32_t size = desc.GetShape().GetShapeSize(); | |||||
desc.SetSize(size); | |||||
auto data = op::Data("Data").set_attr_index(0); | |||||
data.update_input_desc_data(desc); | |||||
data.update_output_desc_out(desc); | |||||
auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); | |||||
std::vector<Operator> inputs{data}; | |||||
std::vector<Operator> outputs{flatten}; | |||||
std::vector<Operator> targets{flatten}; | |||||
// Graph graph("test_graph"); | |||||
graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, set_and_get_add_graph_flag) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
graph_manager.SetAddGraphCondition(graph_id, 1); | |||||
uint32_t res = graph_manager.GetAddGraphCondition(graph_id); | |||||
EXPECT_EQ(res, 1); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_add_graph_1) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
// create graph | |||||
Graph graph("test_graph"); | |||||
CreateGraph(graph); | |||||
std::map<std::string, std::string> options; | |||||
OmgContext context; | |||||
Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_add_graph_2) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_manager.AddGraphNode(graph_id, graph_node); | |||||
graph_manager.SetAddGraphCondition(graph_id, kDoneAdded); | |||||
Graph graph("test_graph"); | |||||
CreateGraph(graph); | |||||
std::map<std::string, std::string> options; | |||||
OmgContext context; | |||||
Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_add_graph_3) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
Graph graph("test_graph"); | |||||
CreateGraph(graph); | |||||
std::map<std::string, std::string> options; | |||||
OmgContext context; | |||||
std::future<Status> fut1 = std::async(std::launch::async, | |||||
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||||
std::future<Status> fut2 = std::async(std::launch::async, | |||||
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||||
fut1.wait(); | |||||
fut2.wait(); | |||||
Status status1 = fut1.get(); | |||||
Status status2 = fut2.get(); | |||||
EXPECT_EQ(status1, ge::SUCCESS); | |||||
EXPECT_EQ(status2, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_remove_graph_1) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
Status status = graph_manager.RemoveGraph(graph_id); | |||||
EXPECT_EQ(status, ge::GE_GRAPH_GRAPH_NOT_EXIST); | |||||
graph_manager.AddGraphNode(graph_id, graph_node); | |||||
graph_node->SetRunFlag(true); | |||||
status = graph_manager.RemoveGraph(graph_id); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_remove_graph_2) { | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
Graph graph("test_graph"); | |||||
CreateGraph(graph); | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
auto listener = MakeShared<RunAsyncListener>(); | |||||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||||
davinci_model1->SetId(1); | |||||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||||
davinci_model1->SetId(2); | |||||
model_manager->InsertModel(1, davinci_model1); | |||||
model_manager->InsertModel(2, davinci_model2); | |||||
ge_root_model->SetModelId(1); | |||||
ge_root_model->SetModelId(2); | |||||
graph_node->SetGeRootModel(ge_root_model); | |||||
graph_node->SetLoadFlag(true); | |||||
graph_manager.AddGraphNode(graph_id, graph_node); | |||||
Status status = graph_manager.RemoveGraph(graph_id); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_pre_run_thread) { | |||||
GraphManager graph_manager; | |||||
graph_manager.thread_run_flag_ = true; | |||||
GraphId graph_id = 1; | |||||
std::vector<ge::InputTensorInfo> input_tensor; | |||||
uint64_t session_id = 0; | |||||
ErrorMessage::Context error_context; | |||||
GEThreadLocalContext context; | |||||
RunAsyncCallback callback; | |||||
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||||
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||||
EXPECT_EQ(ret, true); | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_manager.AddGraphNode(graph_id, graph_node); | |||||
graph_manager.PreRunThread(&graph_manager); | |||||
// end with failed | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) { | |||||
GraphManager graph_manager; | |||||
graph_manager.thread_run_flag_ = true; | |||||
GraphId graph_id = 1; | |||||
GraphNodePtr graph_node_1 = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_manager.AddGraphNode(graph_id, graph_node_1); | |||||
graph_manager.IncreaseGraphCount(graph_id); | |||||
graph_manager.IncreaseGraphCount(graph_id); | |||||
graph_node_1->SetBuildFlag(true); | |||||
std::vector<ge::InputTensorInfo> input_tensor; | |||||
uint64_t session_id = 0; | |||||
ErrorMessage::Context error_context; | |||||
GEThreadLocalContext context; | |||||
RunAsyncCallback callback; | |||||
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||||
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||||
EXPECT_EQ(ret, true); | |||||
graph_id = 2; | |||||
GraphNodePtr graph_node_2 = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_manager.AddGraphNode(graph_id, graph_node_2); | |||||
ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||||
EXPECT_EQ(ret, true); | |||||
graph_manager.PreRunThread(&graph_manager); | |||||
// end with failed | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_check_and_release_memory) { | |||||
GraphManager graph_manager; | |||||
GeModelPtr ge_model = make_shared<GeModel>(); | |||||
int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; | |||||
int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; | |||||
uint64_t session_id = 0; | |||||
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size); | |||||
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size); | |||||
ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id); | |||||
GraphId graph_id = 1; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_manager.AddGraphNode(graph_id, graph_node); | |||||
graph_manager.IncreaseGraphCount(graph_id); | |||||
graph_manager.IncreaseGraphCount(graph_id); | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
auto listener = MakeShared<RunAsyncListener>(); | |||||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||||
davinci_model1->SetId(1); | |||||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||||
davinci_model1->SetId(2); | |||||
model_manager->InsertModel(1, davinci_model1); | |||||
model_manager->InsertModel(2, davinci_model2); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
bool is_dynamic_shape = false; | |||||
(void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||||
ge_root_model->SetModelId(1); | |||||
ge_root_model->SetModelId(2); | |||||
graph_node->SetGeRootModel(ge_root_model); | |||||
graph_node->SetLoadFlag(true); | |||||
Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { | |||||
// no need to build | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||||
GraphManager::PreRunArgs arg; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_node->SetBuildFlag(true); | |||||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||||
EXPECT_EQ(status, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) { | |||||
// need build while buildflag is true, var format changed | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||||
GraphManager::PreRunArgs arg; | |||||
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_node->SetBuildFlag(true); | |||||
graph_node->Lock(); | |||||
graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id); | |||||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||||
EXPECT_EQ(status, ge::PARAM_INVALID); | |||||
} | |||||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { | |||||
// need build while buildflag is false, var format unchanged | |||||
GraphId graph_id = 1; | |||||
GraphManager graph_manager; | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||||
GraphManager::PreRunArgs arg; | |||||
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
graph_node->SetBuildFlag(false); | |||||
graph_node->Lock(); | |||||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||||
EXPECT_NE(status, ge::SUCCESS); | |||||
} |