From: @HW_KK Reviewed-by: @wqtshg Signed-off-by:tags/v1.3.0
@@ -20,9 +20,12 @@ | |||
#include <string> | |||
#include "graph/load/model_manager/model_manager.h" | |||
#include "graph/load/model_manager/davinci_model.h" | |||
#include "omm/csa_interact.h" | |||
namespace ge { | |||
using Uint32Pair = pair<uint32_t, uint32_t>; | |||
const uint32_t kInvalidModelId = UINT32_MAX; | |||
GraphExecutor::GraphExecutor() | |||
: init_flag_(false), | |||
train_graph_flag_(false), | |||
@@ -380,7 +383,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro | |||
} | |||
Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||
const std::vector<InputTensorInfo> &input_tensor) { | |||
const std::vector<InputTensorInfo> &input_tensor, | |||
const RunAsyncCallback& callback) { | |||
GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | |||
if (graph_id != last_graph_id_) { | |||
auto ret = FreeExecuteMemory(); | |||
@@ -390,7 +394,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||
} | |||
last_graph_id_ = graph_id; | |||
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | |||
Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); | |||
Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); | |||
if (ret != SUCCESS) { | |||
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | |||
return GE_GRAPH_SYNC_MODEL_FAILED; | |||
@@ -400,11 +404,81 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||
return SUCCESS; | |||
} | |||
Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) { | |||
bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) { | |||
return lhs.second < rhs.second; | |||
} | |||
uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) { | |||
std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId(); | |||
if (model_ids.empty()) { | |||
return kInvalidModelId; | |||
} | |||
if (model_ids.size() == 1) { | |||
return ge_root_model->GetModelId(); | |||
} | |||
std::vector<Uint32Pair> model_id_to_loads; | |||
auto model_manager = ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
for (auto model_id : model_ids) { | |||
auto davinci_model = model_manager->GetModel(model_id); | |||
auto hybrid_model = model_manager->GetHybridModel(model_id); | |||
if (hybrid_model == nullptr) { | |||
GE_CHECK_NOTNULL(davinci_model); | |||
} | |||
uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() : | |||
davinci_model->GetDataInputerSize(); | |||
uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) : | |||
static_cast<uint32_t>(davinci_model->GetRunningFlag()); | |||
uint32_t load = input_load + running_load; | |||
if (load == 0) { | |||
return model_id; | |||
} | |||
model_id_to_loads.emplace_back(model_id, load); | |||
} | |||
sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad); | |||
if (model_id_to_loads.empty()) { | |||
return kInvalidModelId; | |||
} | |||
return model_id_to_loads.begin()->first; | |||
} | |||
Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||
const RunAsyncCallback &callback) { | |||
auto model_manager = ge::ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
if (model_manager->IsNeedHybridLoad(*ge_root_model)) { | |||
auto model = model_manager->GetHybridModel(model_id); | |||
GE_CHECK_NOTNULL(model); | |||
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||
GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||
return FAILED; | |||
} | |||
} else { | |||
auto model = model_manager->GetModel(model_id); | |||
GE_CHECK_NOTNULL(model); | |||
if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) { | |||
GELOGE(FAILED, "SetRunAsyncListenerCallback failed."); | |||
return FAILED; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs, | |||
const RunAsyncCallback &callback) { | |||
uint32_t model_id = GetExecuteModelId(ge_root_model); | |||
if (model_id == kInvalidModelId) { | |||
GELOGE(INTERNAL_ERROR, "No valid model id."); | |||
return INTERNAL_ERROR; | |||
} | |||
try { | |||
auto model_manager = ge::ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
GELOGI("RunAsync begin.model_id %u", model_id); | |||
if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) { | |||
GELOGE(FAILED, "RunAsync: SetCallBack for model fail"); | |||
return FAILED; | |||
} | |||
Status ret = model_manager->DataInputTensor(model_id, inputs); | |||
if (ret != SUCCESS) { | |||
@@ -50,7 +50,7 @@ class GraphExecutor { | |||
std::vector<GeTensor> &output_tensor); | |||
ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||
const std::vector<InputTensorInfo> &input_tensor); | |||
const std::vector<InputTensorInfo> &input_tensor, const RunAsyncCallback &callback); | |||
Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | |||
@@ -116,6 +116,8 @@ class GraphExecutor { | |||
static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | |||
uint32_t GetExecuteModelId(const GeRootModelPtr &ge_root_model); | |||
private: | |||
Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data, | |||
OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc); | |||
@@ -123,7 +125,8 @@ class GraphExecutor { | |||
Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | |||
std::vector<GeTensor> &output_tensor); | |||
Status AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &input_tensor); | |||
Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &input_tensor, | |||
const RunAsyncCallback &callback); | |||
void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | |||
uint32_t output_size); | |||
@@ -132,6 +135,9 @@ class GraphExecutor { | |||
Status MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr); | |||
static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, | |||
const RunAsyncCallback &callback); | |||
bool init_flag_; | |||
bool train_graph_flag_; | |||
@@ -63,7 +63,6 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | |||
return GE_GRAPH_PARAM_NULLPTR; | |||
} | |||
model_id = ge_root_model_ptr->GetModelId(); | |||
auto model_manager = ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
@@ -134,6 +134,8 @@ class DataInputer { | |||
/// | |||
void Stop() { queue_.Stop(); } | |||
uint32_t Size() { return queue_.Size(); } | |||
private: | |||
/// | |||
/// @ingroup domi_ome | |||
@@ -2737,6 +2737,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute); | |||
while (model->RunFlag()) { | |||
// Model hasn't truly started runing before received data | |||
model->SetRunningFlag(false); | |||
bool rslt_flg = true; | |||
if (model->GetDataInputer() == nullptr) { | |||
GELOGW("Data inputer is nullptr."); | |||
@@ -2746,6 +2748,8 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
std::shared_ptr<InputDataWrapper> data_wrapper; | |||
Status ret = model->GetDataInputer()->Pop(data_wrapper); | |||
// Model run indeedly start after received data. | |||
model->SetRunningFlag(true); | |||
if (data_wrapper == nullptr || ret != SUCCESS) { | |||
GELOGI("data_wrapper is null!"); | |||
continue; | |||
@@ -2832,7 +2836,9 @@ void *DavinciModel::Run(DavinciModel *model) { | |||
model->iterator_count_++; | |||
model->is_first_execute_ = false; | |||
GELOGI("run iterator count is %lu", model->iterator_count_); | |||
// model run finished | |||
model->SetRunningFlag(false); | |||
GELOGI("run iterator count is %lu, model_id:%u", model->iterator_count_, model->model_id_); | |||
} | |||
CsaInteract::GetInstance().WriteInternalErrorCode(); | |||
@@ -2890,7 +2896,7 @@ Status DavinciModel::ModelRunStart() { | |||
error_context_ = ErrorManager::GetInstance().GetErrorContext(); | |||
CREATE_STD_THREAD(thread_id_, DavinciModel::Run, this); | |||
GELOGI("model tread create success, model id:%u.", model_id_); | |||
GELOGI("model thread create success, model id:%u.", model_id_); | |||
return SUCCESS; | |||
} | |||
@@ -4340,4 +4346,10 @@ Status DavinciModel::InitL1DataDumperArgs() { | |||
return SUCCESS; | |||
} | |||
Status DavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||
GE_CHECK_NOTNULL(listener); | |||
listener->SetCallback(callback); | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -221,6 +221,11 @@ class DavinciModel { | |||
/// | |||
DataInputer *const GetDataInputer() const { return data_inputer_; } | |||
uint32_t GetDataInputerSize() { | |||
GE_CHECK_NOTNULL(data_inputer_); | |||
return data_inputer_->Size(); | |||
} | |||
// get Stream number | |||
uint32_t StreamNum() const { return runtime_param_.stream_num; } | |||
@@ -560,6 +565,10 @@ class DavinciModel { | |||
return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); | |||
} | |||
bool GetRunningFlag() const { return running_flg_; } | |||
void SetRunningFlag(bool flag) { running_flg_ = flag; } | |||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||
private: | |||
// memory address of weights | |||
uint8_t *weights_mem_base_; | |||
@@ -924,6 +933,8 @@ class DavinciModel { | |||
shared_ptr<ModelListener> listener_; | |||
bool run_flg_; | |||
// check whether model is running with data | |||
bool running_flg_ = false; | |||
mutex mux_run_flg_; | |||
@@ -330,6 +330,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | |||
if (model_id == INVALID_MODEL_ID) { | |||
GenModelId(&model_id); | |||
GELOGD("Generate new model_id:%u", model_id); | |||
} | |||
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||
string om_name; | |||
@@ -363,7 +364,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); | |||
break;); | |||
GE_TIMESTAMP_END(Assign, "GraphLoader::ModelAssign"); | |||
/// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. | |||
/// These session_ids come from the same model, so the values of session_id are the same. | |||
/// Update session_id for infer in load model to avoid the same session_id. | |||
if (!ge_root_model->GetTrainFlag()) { | |||
uint64_t new_session_id; | |||
ret = GenSessionId(new_session_id); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); | |||
ret = davinci_model->UpdateSessionId(new_session_id); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); | |||
ge_model->InsertSessionMap(model_id, new_session_id); | |||
GELOGD("Update new session id: %lu.", new_session_id); | |||
} | |||
GE_TIMESTAMP_START(Init); | |||
GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Init()), GELOGW("DavinciInit failed."); break;); | |||
GE_TIMESTAMP_END(Init, "GraphLoader::ModelInit"); | |||
@@ -376,16 +388,16 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||
return ret; | |||
} | |||
void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); | |||
void ModelManager::InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model) { | |||
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", model_id); | |||
std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||
model_map_[id] = davinci_model; | |||
model_map_[model_id] = davinci_model; | |||
} | |||
void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | |||
void ModelManager::InsertModel(uint32_t model_id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", model_id); | |||
std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||
hybrid_model_map_[id] = hybrid_model; | |||
hybrid_model_map_[model_id] = hybrid_model; | |||
} | |||
Status ModelManager::DeleteModel(uint32_t id) { | |||
@@ -330,8 +330,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||
/// @ingroup domi_ome | |||
/// @brief insert new model into model manager set | |||
/// | |||
void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model); | |||
void InsertModel(uint32_t id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||
void InsertModel(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model); | |||
void InsertModel(uint32_t model_id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||
/// | |||
/// @ingroup domi_ome | |||
@@ -121,6 +121,10 @@ const char *const kAIcoreEngine = "AIcoreEngine"; | |||
const int32_t kDynamicDimsTypeIsGetNext = 0; | |||
const int32_t kDynamicDimsTypeIsData = 1; | |||
const char *const kGetNextName = "IteratorV2"; | |||
const uint32_t kInitGraphCount = 1; | |||
const uint32_t kNotAdded = 0; | |||
const uint32_t kStartAdd = 1; | |||
const uint32_t kDoneAdded = 2; | |||
bool IsTailingOptimization() { | |||
string is_tailing_optimization_option; | |||
@@ -202,6 +206,8 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||
graph_map_.clear(); | |||
cache_helper_map_.clear(); | |||
graph_id_to_add_graph_cond_.clear(); | |||
graph_count_.clear(); | |||
init_flag_ = true; | |||
thread_run_flag_ = true; | |||
@@ -211,6 +217,20 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||
return SUCCESS; | |||
} | |||
Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) { | |||
Status ret = SUCCESS; | |||
for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { | |||
uint32_t model_id = ge_root_model->GetAllModelId()[i]; | |||
GELOGI("Unload model %u.", model_id); | |||
ret = GraphLoader::UnloadModel(model_id); | |||
if (ret != SUCCESS) { | |||
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
return ret; | |||
} | |||
} | |||
return ret; | |||
} | |||
Status GraphManager::Finalize() { | |||
if (!init_flag_) { | |||
GELOGW("GraphManager has not been initialized."); | |||
@@ -241,7 +261,6 @@ Status GraphManager::Finalize() { | |||
unload_model_ret = GE_GRAPH_GRAPH_IS_RUNNING; | |||
continue; | |||
} | |||
// unload model | |||
auto ge_root_model = graph_node->GetGeRootModel(); | |||
if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { | |||
@@ -251,15 +270,14 @@ Status GraphManager::Finalize() { | |||
unload_model_ret = FAILED; | |||
continue; | |||
} | |||
ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||
ret = UnloadModel(ge_root_model, iter->first); | |||
if (ret != SUCCESS) { | |||
GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); | |||
GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); | |||
unload_model_ret = ret; | |||
} | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGW("[GraphManager] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||
iter->first); | |||
GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first); | |||
unload_model_ret = FAILED; | |||
continue; | |||
} | |||
@@ -274,6 +292,7 @@ Status GraphManager::Finalize() { | |||
} | |||
graph_map_.clear(); | |||
cache_helper_map_.clear(); | |||
graph_count_.clear(); | |||
// graph context | |||
if (graph_context_ != nullptr) { | |||
@@ -326,35 +345,59 @@ Status GraphManager::InitDynamicParams(ComputeGraphPtr &compute_graph) { | |||
return SUCCESS; | |||
} | |||
Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
const std::map<std::string, std::string> &options, | |||
const OmgContext &omg_context) { | |||
if (HasGraphNode(graph_id)) { | |||
REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id); | |||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); | |||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||
void GraphManager::SetAddGraphCondition(GraphId graph_id, uint32_t cond) { | |||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
graph_id_to_add_graph_cond_[graph_id] = cond; | |||
GELOGD("Graph [id:%u] has been added.", graph_id); | |||
} | |||
uint32_t GraphManager::GetAddGraphCondition(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||
if (it != graph_id_to_add_graph_cond_.end()) { | |||
return it->second; | |||
} else { | |||
GELOGD("Graph [id:%u] has not been added.", graph_id); | |||
return kNotAdded; | |||
} | |||
} | |||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
if (compute_graph != nullptr) { | |||
compute_graph->SetGraphID(graph_id); | |||
bool graph_has_been_added = false; | |||
if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) | |||
&& graph_has_been_added) { | |||
REPORT_INNER_ERROR("E19999", "Get Attr:%s from graph:%u fail", | |||
ATTR_NAME_GRAPH_HAS_BEEN_ADDED.c_str(), graph_id); | |||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, | |||
"[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id); | |||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||
} | |||
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); | |||
compute_graph_ = compute_graph; | |||
void GraphManager::RemoveAddGraphCondition(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(add_graph_cond_mutex_); | |||
auto it = graph_id_to_add_graph_cond_.find(graph_id); | |||
if (it != graph_id_to_add_graph_cond_.end()) { | |||
graph_id_to_add_graph_cond_.erase(it); | |||
GELOGD("Successfully removed add_graph_cond of graph [id:%u].", graph_id); | |||
} else { | |||
REPORT_INNER_ERROR("E19999", "compute_graph from graph:%u is nullptr, check invalid", | |||
graph_id); | |||
GELOGE(FAILED, "compute graph is null"); | |||
return FAILED; | |||
GELOGD("Graph [id:%u] has not been added. no need to remove.", graph_id); | |||
} | |||
} | |||
Status GraphManager::CheckRepeatAdd(uint32_t graph_id, bool &is_added) { | |||
uint32_t count = 0; | |||
if (GetGraphCount(graph_id, count) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||
return INTERNAL_ERROR; | |||
} | |||
// previous thread owns same graph_id has been in the middle of the AddGraph procession | |||
if (count > 1 && GetAddGraphCondition(graph_id) == kStartAdd) { | |||
std::unique_lock<std::mutex> lock(add_graph_mutex_); | |||
GELOGD("Waitting for build end of previous thread."); | |||
while (GetAddGraphCondition(graph_id) != kDoneAdded) { | |||
add_graph_cv_.wait(lock); | |||
} | |||
GraphNodePtr graph_node; | |||
Status ret = GetGraphNode(graph_id, graph_node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[AddGraph] GetGraphNode failed, graph_id = %u.", graph_id); | |||
return ret; | |||
} | |||
is_added = true; | |||
} | |||
return SUCCESS; | |||
} | |||
void GraphManager::SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id) { | |||
std::string session_graph_id; | |||
if (!AttrUtils::GetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || session_graph_id.empty()) { | |||
session_graph_id = "-1_" + to_string(graph_id); | |||
@@ -366,7 +409,24 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
} | |||
GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); | |||
} | |||
} | |||
Status GraphManager::NotifyWaittingGraph(uint32_t graph_id) { | |||
uint32_t count = 0; | |||
if (GetGraphCount(graph_id, count) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed, graph might have not been added.", graph_id); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGD("Add graph finished, graph_id:%u", graph_id); | |||
if (count > 1) { | |||
GELOGD("Finish addgraph, graph_id:%u, graph_count:%u, start to notify.", graph_id, count); | |||
add_graph_cv_.notify_all(); | |||
} | |||
return SUCCESS; | |||
} | |||
Status GraphManager::CreateGraphNode(uint32_t graph_id, const Graph &graph, | |||
const std::map<std::string, std::string> &options) { | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
GE_IF_BOOL_EXEC(graph_node == nullptr, | |||
REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u", | |||
@@ -385,7 +445,62 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
ParseOption(options, TUNING_PATH, options_.tuning_path); | |||
graph_node->SetGraph(graph_ptr); | |||
graph_node->SetOptions(options); | |||
graph_node->IncreaseLoadCount(); | |||
AddGraphNode(graph_id, graph_node); | |||
return SUCCESS; | |||
} | |||
Status GraphManager::SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options) { | |||
CompilerStages &stages = GetCompilerStages(graph_id); | |||
stages.preparer.SetOptions(options_); | |||
Status status = stages.optimizer.SetOptions(options_); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "Graph optimizer set options failed."); | |||
return status; | |||
} | |||
stages.builder.SetOptions(options_); | |||
return SUCCESS; | |||
} | |||
Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
const std::map<std::string, std::string> &options, | |||
const OmgContext &omg_context) { | |||
IncreaseGraphCount(graph_id); | |||
// validation for adding graphs of same graph_id in multi-thread secenario | |||
// 1.previous thread owns same graph_id has finished the AddGraph procession | |||
if (GetAddGraphCondition(graph_id) == kDoneAdded) { | |||
GraphNodePtr graph_node; | |||
if (GetGraphNode(graph_id, graph_node) != SUCCESS) { | |||
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "Graph not exist while done adding previously, graph_id = %u.", graph_id); | |||
return GE_GRAPH_GRAPH_NOT_EXIST; | |||
} | |||
graph_node->IncreaseLoadCount(); | |||
return SUCCESS; | |||
} | |||
// In multi-thread scenario, former thread owns same graph_id has been | |||
// in the middle of the AddGraph procession while following threads have to wait until | |||
// done adding graph of the former graph, avoiding repeatively adding same graph. | |||
bool is_added = false; | |||
if (CheckRepeatAdd(graph_id, is_added) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "CheckRepeatAdd for graph[id:%u] failed.", graph_id); | |||
return INTERNAL_ERROR; | |||
} | |||
// The former graph (from different thread) owns same graph id has been successfully added. | |||
if (is_added) { | |||
return SUCCESS; | |||
} | |||
// Do add graph | |||
SetAddGraphCondition(graph_id, kStartAdd); | |||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
GE_CHECK_NOTNULL(compute_graph); | |||
compute_graph->SetGraphID(graph_id); | |||
SetSessionGraphId(compute_graph, graph_id); | |||
if (CreateGraphNode(graph_id, graph, options) != SUCCESS) { | |||
GELOGE(FAILED, "Failed to create graph_node."); | |||
return FAILED; | |||
} | |||
AddLocalOmgContext(graph_id, omg_context); | |||
if (!options_.output_datatype.empty()) { | |||
@@ -396,16 +511,18 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
CompilerStages &stages = GetCompilerStages(graph_id); | |||
stages.preparer.SetOptions(options_); | |||
Status status = stages.optimizer.SetOptions(options_); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "Graph optimizer set options failed."); | |||
return status; | |||
if (SetStagesOptions(graph_id, options_) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Set stage options failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
stages.builder.SetOptions(options_); | |||
var_acc_ctrl_.AddGraph(graph_id, compute_graph); | |||
SetAddGraphCondition(graph_id, kDoneAdded); | |||
// There are threads waitting for adding same graph | |||
if (NotifyWaittingGraph(graph_id) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "NotifyWaittingGraph failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -962,6 +1079,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
if (!graph_node->IsAsync()) { | |||
ret = LoadGraph(ge_root_model, graph_node); | |||
} else { | |||
GE_CHECK_NOTNULL(ge_root_model); | |||
ret = LoadGraphAsync(ge_root_model, graph_node); | |||
} | |||
if (ret != SUCCESS) { | |||
@@ -976,6 +1094,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
if (!graph_node->IsAsync()) { | |||
ret = LoadGraph(ge_root_model_ptr, graph_node); | |||
} else { | |||
GE_CHECK_NOTNULL(ge_root_model); | |||
ret = LoadGraphAsync(ge_root_model_ptr, graph_node); | |||
} | |||
if (ret != SUCCESS) { | |||
@@ -988,6 +1107,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||
Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | |||
GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | |||
if (options_.run_graph_flag && ge_root_model != nullptr) { | |||
ge_root_model->SetTrainFlag(GetTrainFlag()); | |||
// synchronization run graph with model | |||
std::shared_ptr<GraphModelListener> model_listener = GetModelListener(); | |||
ModelIdInfo model_id_info; | |||
@@ -1413,62 +1533,29 @@ bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load | |||
} | |||
Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||
auto it = to_be_deleted_graphs_.find(graph_id); | |||
if (it != to_be_deleted_graphs_.end()) { | |||
to_be_deleted_graphs_.erase(it); | |||
} | |||
GraphNodePtr graph_node = nullptr; | |||
Status ret = GetGraphNode(graph_id, graph_node); | |||
if (ret != SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid", | |||
graph_id); | |||
if (ret != SUCCESS || graph_node == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Graph:%u not exist in graph_map, check invalid when GraphManager %s", | |||
graph_id, __FUNCTION__); | |||
GELOGE(GE_GRAPH_GRAPH_NOT_EXIST, "[GraphManager] Id %u does not exists.", graph_id); | |||
return GE_GRAPH_GRAPH_NOT_EXIST; | |||
} | |||
if ((graph_node == nullptr) || (graph_node->GetRunFlag())) { | |||
REPORT_INNER_ERROR("E19999", "Graph:%u is running, can't be remove, check invalid", | |||
graph_id); | |||
GELOGE(GE_GRAPH_GRAPH_IS_RUNNING, "[GraphManager] Id %u is running, can't be deleted.", graph_id); | |||
return GE_GRAPH_GRAPH_IS_RUNNING; | |||
if (graph_node->GetRunFlag()) { | |||
// only put graph into to-be-deleted list when exceptional scenario | |||
to_be_deleted_graphs_.insert(graph_id); | |||
GELOGI("[GraphManager] Trying to remove running graph[Id:%u], added into to_be_deleted_graphs_.", graph_id); | |||
return SUCCESS; | |||
} | |||
std::lock_guard<std::mutex> lock(unload_model_mutex_); | |||
Status middle_ret; | |||
rtError_t rt_ret; | |||
const std::vector<SubGraphInfoPtr> &all_sub_graph = graph_node->GetAllSubGraph(); | |||
for (size_t i = 0; i < all_sub_graph.size(); ++i) { | |||
// must free buffer firstly | |||
middle_ret = all_sub_graph[i]->FreeInOutBuffer(); | |||
if (middle_ret != SUCCESS) { | |||
GELOGE(middle_ret, "[GraphManager] RemoveGraph free mem failed, graph_id=%u.", graph_id); | |||
ret = middle_ret; | |||
} | |||
if (all_sub_graph[i]->GeModelIsValid() && all_sub_graph[i]->GetModelIdInfo().model_id != INVALID_MODEL_ID) { | |||
// unload model | |||
GELOGI("UnloadModel via new ome."); | |||
rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", | |||
GetContext().DeviceId(), graph_id); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", | |||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
ret = FAILED; | |||
continue; | |||
} | |||
middle_ret = GraphLoader::UnloadModel(all_sub_graph[i]->GetModelIdInfo().model_id); | |||
if (middle_ret != SUCCESS) { | |||
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", | |||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
ret = middle_ret; | |||
} | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset fail, device_id:%u, graph_id:%u", | |||
GetContext().DeviceId(), graph_id); | |||
GELOGE(RT_FAILED, "[GraphManager:] unload model failed, modelId=%u, graphId=%u.", | |||
all_sub_graph[i]->GetModelIdInfo().model_id, graph_id); | |||
ret = FAILED; | |||
} | |||
} | |||
} | |||
var_acc_ctrl_.RemoveGraph(graph_id); | |||
RemoveGraphNode(graph_id); | |||
@@ -1476,7 +1563,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||
auto ge_root_model = graph_node->GetGeRootModel(); | |||
if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | |||
GELOGI("Unload model %u.", ge_root_model->GetModelId()); | |||
rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", | |||
@@ -1485,23 +1571,27 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||
graph_id); | |||
return FAILED; | |||
} | |||
middle_ret = GraphLoader::UnloadModel(ge_root_model->GetModelId()); | |||
// same graph may be added for several times, different models were created separately, | |||
// unload them respectively. | |||
middle_ret = UnloadModel(ge_root_model, graph_id); | |||
if (middle_ret != SUCCESS) { | |||
GELOGE(middle_ret, "[GraphManager:] unload model failed, modelId=%u, graph_id=%u.", ge_root_model->GetModelId(), | |||
graph_id); | |||
REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check unload detail in GraphLoader %s", | |||
graph_id, __FUNCTION__); | |||
GELOGE(middle_ret, "[GraphManager:] unload model failed, graph_id=%u.", graph_id); | |||
ret = middle_ret; | |||
} | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u", | |||
GetContext().DeviceId(), graph_id); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), | |||
graph_id); | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u, when GraphManager %s", | |||
GetContext().DeviceId(), graph_id, __FUNCTION__); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||
ret = FAILED; | |||
} | |||
} | |||
RemoveCompilerStages(graph_id); | |||
RemoveGraphCount(graph_id); | |||
RemoveAddGraphCondition(graph_id); | |||
GE_CHK_STATUS_RET(ret, "[GraphManager:] Remove graph failed, graph_id=%u.", graph_id); | |||
GELOGI("[GraphManager] remove graph success, graph_id=%u.", graph_id); | |||
@@ -2588,6 +2678,7 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr | |||
Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | |||
GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); | |||
if (options_.run_graph_flag && ge_root_model != nullptr) { | |||
ge_root_model->SetTrainFlag(GetTrainFlag()); | |||
// synchronization run graph with model | |||
ModelIdInfo model_id_info; | |||
bool is_unknown_shape = false; | |||
@@ -2604,9 +2695,9 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||
} | |||
} | |||
GE_TIMESTAMP_START(LoadGraph); | |||
GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); | |||
Status ret = | |||
GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); | |||
auto listener = MakeShared<RunAsyncListener>(); | |||
GE_CHECK_NOTNULL(listener); | |||
Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener); | |||
GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); | |||
@@ -2620,6 +2711,52 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G | |||
return SUCCESS; | |||
} | |||
void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, | |||
const std::vector<uint32_t> &model_ids, uint32_t graph_id, uint64_t session_id) { | |||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, when GraphManager %s", | |||
GetContext().DeviceId(), __FUNCTION__); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, graphId=%u.", graph_id); | |||
return; | |||
} | |||
for (auto model_id : model_ids) { | |||
uint64_t max_memory_size = 0; | |||
Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||
if (result != SUCCESS) { | |||
continue; | |||
} | |||
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||
max_memory_size); | |||
if (model_ids.size() > 1) { | |||
result = ge_model->GetSessionId(model_id, session_id); | |||
if (result != SUCCESS) { | |||
GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
graph_id); | |||
continue; | |||
} | |||
} | |||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||
if (result != SUCCESS) { | |||
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
graph_id); | |||
} | |||
result = GraphLoader::UnloadModel(model_id); | |||
if (result != SUCCESS) { | |||
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
} | |||
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); | |||
} | |||
graph_node->SetLoadFlag(false); | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", | |||
GetContext().DeviceId(), __FUNCTION__); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, graphId=%u.", graph_id); | |||
return; | |||
} | |||
} | |||
Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { | |||
GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); | |||
int64_t value = 0; | |||
@@ -2665,6 +2802,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||
continue; | |||
} | |||
auto model_id = model->GetModelId(); | |||
auto model_ids = model->GetAllModelId(); | |||
// unload model not release | |||
bool is_unknown_shape = false; | |||
GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); | |||
@@ -2677,38 +2815,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||
GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); | |||
continue; | |||
} | |||
uint64_t max_memory_size = 0; | |||
result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); | |||
if (result != SUCCESS) { | |||
continue; | |||
} | |||
GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, | |||
max_memory_size); | |||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u", | |||
GetContext().DeviceId()); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtSetDevice failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
continue; | |||
} | |||
result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); | |||
if (result != SUCCESS) { | |||
GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, | |||
graph_id); | |||
} | |||
result = GraphLoader::UnloadModel(model_id); | |||
if (result != SUCCESS) { | |||
GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
} | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u", | |||
GetContext().DeviceId()); | |||
GELOGE(RT_FAILED, "[GraphManager:] rtDeviceReset failed, modelId=%u, graphId=%u.", model_id, graph_id); | |||
continue; | |||
} | |||
it.second->SetLoadFlag(false); | |||
GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success and set LoadFlag to false.", graph_id, model_id); | |||
ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id); | |||
} | |||
return SUCCESS; | |||
@@ -2849,6 +2956,38 @@ void GraphManager::ConstructGeInput(const vector<InputTensorInfo> &inputs, vecto | |||
} | |||
} | |||
Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, | |||
GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { | |||
if (!graph_manager->IsGraphNeedBuild(graph_node)) { | |||
ge_root_model = graph_node->GetGeRootModel(); | |||
return SUCCESS; | |||
} | |||
if (graph_node->GetBuildFlag()) { | |||
ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||
"The graph " + std::to_string(graph_node->GetGraphId()) + | |||
" need to re-build, you should remove it" | |||
" from GE first, then AddGraph again and rebuild it."); | |||
graph_node->Unlock(); | |||
return PARAM_INVALID; | |||
} | |||
// check need incre build. | |||
GeModelPtr ge_model = nullptr; | |||
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||
std::vector<GeTensor> ge_inputs; | |||
ConstructGeInput(args.input_tensor, ge_inputs); | |||
Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||
// release rts generate context | |||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||
if (ret != SUCCESS) { | |||
ReturnError(graph_manager, args.callback, ret, "PreRun Failed."); | |||
return ret; | |||
} | |||
} | |||
graph_node->SetBuildFlag(true); | |||
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||
return SUCCESS; | |||
} | |||
void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | |||
GELOGW("Set thread name failed."); | |||
@@ -2861,7 +3000,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
continue; | |||
} | |||
GELOGI("A new loop start."); | |||
GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id); | |||
ErrorManager::GetInstance().SetErrorContext(args.error_context); | |||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); | |||
@@ -2877,7 +3016,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
"[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); | |||
return; | |||
} | |||
// more than one graph owns same graph_id | |||
uint32_t count = 0; | |||
if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get graph [id:%u] count failed.", args.graph_id); | |||
return; | |||
} | |||
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | |||
if (count > 1 && graph_node->GetBuildFlag()) { | |||
graph_node->Lock(); | |||
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | |||
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | |||
graph_node->SetSemSize(count); | |||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | |||
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | |||
continue; | |||
} | |||
// Cannot be put ahead of the repeatively prerun judgement | |||
graph_node->Lock(); | |||
if (graph_node->GetRunFlag()) { | |||
@@ -2909,46 +3065,24 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||
// it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | |||
GELOGI("Start for run graph async."); | |||
GeRootModelPtr ge_root_model = nullptr; | |||
if (graph_manager->IsGraphNeedBuild(graph_node)) { | |||
if (graph_node->GetBuildFlag()) { | |||
ReturnError(graph_manager, args.callback, PARAM_INVALID, | |||
"The graph " + std::to_string(graph_node->GetGraphId()) + | |||
" need to re-build, you should remove it" | |||
" from GE first, then AddGraph again and rebuild it."); | |||
ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model); | |||
if (ret != SUCCESS) { | |||
graph_node->SetRunFlag(false); | |||
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||
ReturnError(graph_manager, args.callback, ret, "CheckIncreBuildAndPreRun Failed, thread exit.."); | |||
graph_node->Unlock(); | |||
return; | |||
} else { | |||
ReturnError(graph_manager, graph_node, args.callback, ret, | |||
"CheckIncreBuildAndPreRun Failed, keep geop continue!"); | |||
graph_node->Unlock(); | |||
continue; | |||
} | |||
// check need incre build. | |||
GeModelPtr ge_model = nullptr; | |||
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||
std::vector<GeTensor> ge_inputs; | |||
ConstructGeInput(args.input_tensor, ge_inputs); | |||
ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||
// release rts generate context | |||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||
if (ret != SUCCESS) { | |||
graph_node->SetRunFlag(false); | |||
if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||
ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | |||
graph_node->Unlock(); | |||
return; | |||
} else { | |||
ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!"); | |||
graph_node->Unlock(); | |||
continue; | |||
} | |||
} | |||
} | |||
graph_node->SetBuildFlag(true); | |||
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | |||
} else { | |||
ge_root_model = graph_node->GetGeRootModel(); | |||
} | |||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | |||
args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); | |||
GELOGI("Loop end."); | |||
GELOGI("[PreRunThread] Loop end."); | |||
} | |||
} | |||
@@ -3051,16 +3185,13 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
continue; | |||
} | |||
GELOGI("A new loop start."); | |||
GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); | |||
ErrorManager::GetInstance().SetErrorContext(args.error_context); | |||
GetContext().SetSessionId(args.session_id); | |||
GetThreadLocalContext() = args.context; | |||
graph_manager->UpdateLocalOmgContext(args.graph_id); | |||
if (args.graph_node->graph_run_async_listener_ != nullptr) { | |||
args.graph_node->graph_run_async_listener_->SetCallback(args.callback); | |||
} | |||
Status ret; | |||
// parse inputs.dims to vector<vector<uint64_t>> dynamic_dims | |||
ret = graph_manager->ParseInputsDims(args.input_tensor); | |||
@@ -3070,8 +3201,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
return; | |||
} | |||
args.graph_node->UpdateLoadFlag(); | |||
if (!args.graph_node->GetLoadFlag()) { | |||
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelLoad, ErrorMessage::kModelLoad); | |||
args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag()); | |||
ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); | |||
if (ret != SUCCESS || args.ge_root_model == nullptr) { | |||
StopQueue(graph_manager); | |||
@@ -3079,6 +3212,10 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
args.graph_node->Unlock(); | |||
return; | |||
} | |||
// control the times of graph loading in multi-thread scenario | |||
args.graph_node->DecreaseLoadCount(); | |||
args.graph_node->IncreaseLoadRecord(); | |||
args.graph_node->SetLoadFlag(true); | |||
GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), | |||
args.ge_root_model->GetModelId()); | |||
@@ -3093,9 +3230,9 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||
graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); | |||
} | |||
args.graph_node->SetRunFlag(false); | |||
ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), | |||
args.input_tensor); | |||
args.input_tensor, args.callback); | |||
args.graph_node->SetRunFlag(false); | |||
if (ret != SUCCESS) { | |||
ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); | |||
args.graph_node->Unlock(); | |||
@@ -3546,4 +3683,49 @@ void GraphManager::RemoveCompilerStages(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(member_mutex_); | |||
compiler_stages_.erase(graph_id); | |||
} | |||
void GraphManager::IncreaseGraphCount(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
auto it = graph_count_.find(graph_id); | |||
if (it == graph_count_.end()) { | |||
graph_count_.insert({graph_id, kInitGraphCount}); | |||
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
} else { | |||
++graph_count_[graph_id]; | |||
GELOGD("After increaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
} | |||
} | |||
void GraphManager::RemoveGraphCount(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
auto it = graph_count_.find(graph_id); | |||
if (it == graph_count_.end()) { | |||
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||
} else { | |||
GELOGD("RemoveGraphCount success, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
graph_count_.erase(it); | |||
} | |||
} | |||
void GraphManager::DecreaseGraphCount(GraphId graph_id) { | |||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
auto it = graph_count_.find(graph_id); | |||
if (it == graph_count_.end()) { | |||
GELOGW("Graph of id: %u has not been added, count cannot be decreased.", graph_id); | |||
} else { | |||
--it->second; | |||
GELOGD("After DecreaseGraphCount, graph count of id[%u] is %u.", graph_id, graph_count_[graph_id]); | |||
} | |||
} | |||
Status GraphManager::GetGraphCount(GraphId graph_id, uint32_t &count) { | |||
std::lock_guard<std::mutex> lock(graph_count_mutex_); | |||
auto it = graph_count_.find(graph_id); | |||
if (it == graph_count_.end()) { | |||
GELOGW("Graph [id:%u] has not been added.", graph_id); | |||
return FAILED; | |||
} | |||
count = it->second; | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -184,6 +184,20 @@ class GraphManager { | |||
Status SaveCheckPointResult(const Graph &graph, const std::vector<Tensor> &outputs, map<string, Tensor> &var_results); | |||
void RemoveGraphCount(GraphId graph_id); | |||
void IncreaseGraphCount(GraphId graph_id); | |||
void DecreaseGraphCount(GraphId graph_id); | |||
Status GetGraphCount(GraphId graph_id, uint32_t &count); | |||
void SetAddGraphCondition(GraphId graph_id, uint32_t cond); | |||
uint32_t GetAddGraphCondition(GraphId graph_id); | |||
void RemoveAddGraphCondition(GraphId graph_id); | |||
private: | |||
struct CompilerStages { | |||
GraphPrepare preparer; | |||
@@ -381,6 +395,24 @@ class GraphManager { | |||
CompilerStages &GetCompilerStages(GraphId graph_id); | |||
void RemoveCompilerStages(GraphId graph_id); | |||
static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node, | |||
GeRootModelPtr &ge_root_model); | |||
void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector<uint32_t> &model_ids, | |||
uint32_t graph_id, uint64_t session_id); | |||
Status CheckRepeatAdd(uint32_t graph_id, bool &is_added); | |||
Status NotifyWaittingGraph(uint32_t graph_id); | |||
Status CreateGraphNode(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options); | |||
Status SetStagesOptions(uint32_t graph_id, const GraphManagerOptions &options); | |||
Status UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id); | |||
void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); | |||
std::atomic_bool thread_run_flag_; | |||
BlockingQueue<PreRunArgs> prerun_args_q_{}; | |||
BlockingQueue<RunArgs> run_args_q_{}; | |||
@@ -416,6 +448,16 @@ class GraphManager { | |||
std::mutex member_mutex_; | |||
std::mutex unload_model_mutex_; | |||
// avoid repeatively add same graph (owns same graph id) | |||
std::mutex add_graph_mutex_; | |||
std::mutex add_graph_cond_mutex_; | |||
std::condition_variable add_graph_cv_; | |||
std::map<GraphId, uint32_t> graph_id_to_add_graph_cond_; | |||
// use for multi-thread online-infer scenario | |||
std::set<GraphId> to_be_deleted_graphs_; | |||
std::map<GraphId, uint32_t> graph_count_; | |||
std::mutex graph_count_mutex_; | |||
}; | |||
} // namespace ge | |||
@@ -60,6 +60,15 @@ void GraphNode::Unlock() { | |||
sem_.Pop(unused); | |||
} | |||
void GraphNode::IncreaseLoadCount() { | |||
std::unique_lock<std::mutex> lock(load_count_mu_); | |||
if (load_record_ == kMaxLoadNum) { | |||
GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum); | |||
return; | |||
} | |||
++load_count_; | |||
} | |||
SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} | |||
SubGraphInfo::~SubGraphInfo() { | |||
@@ -55,6 +55,7 @@ using ConstGraphPtr = std::shared_ptr<const ge::Graph>; | |||
using GraphPtr = std::shared_ptr<ge::Graph>; | |||
const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; | |||
const uint32_t kMaxLoadNum = 8; | |||
struct ModelIdInfo { | |||
uint32_t model_id{INVALID_MODEL_ID}; | |||
@@ -162,6 +163,8 @@ class GraphNode { | |||
bool GetBuildFlag() const { return build_flag_; } | |||
void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } | |||
bool GetLoadFlag() const { return load_flag_; } | |||
// allow repeatively load graph owns same graph id | |||
void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; } | |||
void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | |||
void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | |||
GeModelPtr GetGeModel() const { return ge_model_; } | |||
@@ -172,6 +175,13 @@ class GraphNode { | |||
void Lock(); | |||
void Unlock(); | |||
void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } | |||
uint32_t GetLoadCount() const { return load_count_; } | |||
void IncreaseLoadCount(); | |||
void DecreaseLoadCount() { --load_count_; } | |||
void IncreaseLoadRecord() { ++load_record_; } | |||
// run graph asynchronous listener | |||
std::shared_ptr<RunAsyncListener> graph_run_async_listener_; | |||
@@ -184,11 +194,17 @@ class GraphNode { | |||
GraphPtr graph_; | |||
ComputeGraphPtr compute_graph_; | |||
bool build_flag_; | |||
// load_flag_ is true if more than 1 model were loaded | |||
bool load_flag_; | |||
bool async_; | |||
GeModelPtr ge_model_; | |||
GeRootModelPtr ge_root_model_; | |||
BlockingQueue<uint8_t> sem_; | |||
// consist with graph_count of same graph_id in graph_manager | |||
uint32_t load_count_ = 0; | |||
// total times of loading a graph with same graph_id. | |||
uint32_t load_record_ = 0; | |||
std::mutex load_count_mu_; | |||
}; | |||
using GraphNodePtr = std::shared_ptr<GraphNode>; | |||
@@ -144,8 +144,12 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||
GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); | |||
while (run_flag_) { | |||
// Model has not indeedly started running before received data | |||
SetRunningFlag(false); | |||
std::shared_ptr<InputDataWrapper> data_wrapper; | |||
Status ret = data_inputer_->Pop(data_wrapper); | |||
// Model indeedly start running | |||
SetRunningFlag(true); | |||
if (data_wrapper == nullptr || ret != SUCCESS) { | |||
GELOGI("data_wrapper is null!, ret = %u", ret); | |||
continue; | |||
@@ -185,7 +189,8 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||
RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); | |||
iterator_count_++; | |||
GELOGI("run iterator count is %lu", iterator_count_); | |||
SetRunningFlag(false); | |||
GELOGI("run iterator count is %lu, model_id:%u", iterator_count_, model_id_); | |||
} | |||
CsaInteract::GetInstance().WriteInternalErrorCode(); | |||
@@ -55,6 +55,12 @@ class HybridModelAsyncExecutor { | |||
Status EnqueueData(const std::shared_ptr<InputDataWrapper> &data); | |||
uint32_t GetDataInputerSize() { return data_inputer_->Size(); } | |||
bool GetRunningFlag() const { return running_flag_; } | |||
void SetRunningFlag(bool flag) { running_flag_ = flag; } | |||
private: | |||
Status InitInputDesc(); | |||
@@ -84,6 +90,8 @@ class HybridModelAsyncExecutor { | |||
uint32_t device_id_ = 0U; | |||
uint32_t model_id_ = 0U; | |||
std::atomic_bool run_flag_; | |||
// check whether model is running with data | |||
bool running_flag_ = false; | |||
std::unique_ptr<DataInputer> data_inputer_; | |||
std::unique_ptr<HybridModelExecutor> executor_; | |||
std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | |||
@@ -19,6 +19,7 @@ | |||
#include "hybrid/model/hybrid_model.h" | |||
#include "hybrid/executor/hybrid_model_async_executor.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
#include "graph/manager/graph_manager_utils.h" | |||
namespace ge { | |||
namespace hybrid { | |||
@@ -108,6 +109,17 @@ class HybridDavinciModel::Impl { | |||
model_.SetModelDescVersion(is_new_model_desc); | |||
} | |||
uint32_t GetDataInputerSize() { return executor_.GetDataInputerSize(); } | |||
bool GetRunningFlag() const { return executor_.GetRunningFlag(); } | |||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
auto listener = dynamic_cast<RunAsyncListener *>(listener_.get()); | |||
GE_CHECK_NOTNULL(listener); | |||
listener->SetCallback(callback); | |||
return SUCCESS; | |||
} | |||
private: | |||
std::shared_ptr<ModelListener> listener_; | |||
HybridModel model_; | |||
@@ -222,5 +234,16 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->GetSessionId(); | |||
} | |||
uint32_t HybridDavinciModel::GetDataInputerSize() { | |||
GE_CHECK_NOTNULL(impl_); | |||
return impl_->GetDataInputerSize(); | |||
} | |||
bool HybridDavinciModel::GetRunningFlag() const { return impl_->GetRunningFlag(); } | |||
Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
return impl_->SetRunAsyncListenerCallback(callback); | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -74,6 +74,12 @@ class HybridDavinciModel { | |||
void SetModelDescVersion(bool is_new_model_desc); | |||
uint32_t GetDataInputerSize(); | |||
bool GetRunningFlag() const; | |||
Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); | |||
private: | |||
HybridDavinciModel() = default; | |||
class Impl; | |||
@@ -68,6 +68,10 @@ uint64_t HybridDavinciModel::GetSessionId() { | |||
return 0; | |||
} | |||
uint32_t HybridDavinciModel::GetDataInputerSize() { | |||
return 0; | |||
} | |||
Status HybridDavinciModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) { | |||
return UNSUPPORTED; | |||
} | |||
@@ -87,5 +91,13 @@ Status HybridDavinciModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &i | |||
void HybridDavinciModel::SetModelDescVersion(bool is_new_model_desc) { | |||
} | |||
bool HybridDavinciModel::GetRunningFlag() const { | |||
return false; | |||
} | |||
Status HybridDavinciModel::SetRunAsyncListenerCallback(const RunAsyncCallback &callback) { | |||
return UNSUPPORTED; | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -85,4 +85,14 @@ ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; } | |||
ConstProtoAttrMapHelper GeModel::GetAttrMap() const { | |||
return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); | |||
} | |||
Status GeModel::GetSessionId(uint32_t model_id, uint64_t &session_id) const { | |||
auto it = model_id_to_session_id_map_.find(model_id); | |||
if (it != model_id_to_session_id_map_.end()) { | |||
session_id = it->second; | |||
return SUCCESS; | |||
} | |||
GELOGW("No session id were found with model id [%u].", model_id); | |||
return INTERNAL_ERROR; | |||
} | |||
} // namespace ge |
@@ -71,6 +71,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||
void SetModelId(uint32_t model_id) { model_id_ = model_id; } | |||
uint32_t GetModelId() const { return model_id_; } | |||
Status GetSessionId(uint32_t model_id, uint64_t &session_id) const; | |||
void InsertSessionMap(uint32_t model_id, uint64_t session_id) { | |||
model_id_to_session_id_map_.insert({model_id, session_id}); | |||
} | |||
protected: | |||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||
@@ -90,6 +95,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder | |||
std::string platform_version_; | |||
uint8_t platform_type_ = {0}; | |||
uint32_t model_id_ = INVALID_MODEL_ID; | |||
std::map<uint32_t, uint64_t> model_id_to_session_id_map_; | |||
}; | |||
} // namespace ge | |||
using GeModelPtr = std::shared_ptr<ge::GeModel>; | |||
@@ -32,15 +32,31 @@ class GeRootModel { | |||
return subgraph_instance_name_to_model_; | |||
}; | |||
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; | |||
void SetModelId(uint32_t model_id) { model_id_ = model_id; } | |||
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } | |||
void SetModelId(uint32_t model_id) { | |||
model_id_ = model_id; | |||
// cached for removement | |||
model_ids_.emplace_back(model_id); | |||
} | |||
uint32_t GetModelId() const { return model_id_; } | |||
std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | |||
Status CheckIsUnknownShape(bool &is_dynamic_shape); | |||
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | |||
void SetTrainFlag(bool flag) { train_flag_ = flag; } | |||
bool GetTrainFlag() const { return train_flag_; } | |||
private: | |||
ComputeGraphPtr root_graph_ = nullptr; | |||
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; | |||
uint32_t model_id_ = 0; | |||
// In multithread online secenario, same graph can owns different davinci_model for for concurrency | |||
std::vector<uint32_t> model_ids_; | |||
bool train_flag_ = false; | |||
}; | |||
} // namespace ge | |||
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | |||
@@ -1 +1 @@ | |||
Subproject commit 1e88df1d6bfe60faae0aa9fa2d87f273b793aeb0 | |||
Subproject commit fcebf37d7428caf4e0bd6e6c3a4f8143f6eac8b7 |
@@ -593,6 +593,7 @@ set(SINGLE_OP_SRC_FILES | |||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_executor.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_async_executor.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_execution_context.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/hybrid_model_pipeline_executor.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_context.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/subgraph_executor.cc" | |||
"${GE_CODE_DIR}/ge/hybrid/executor/worker/task_compile_engine.cc" | |||
@@ -780,10 +781,12 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/build/mem_assigner_unittest.cc" | |||
"graph/build/task_generator_unittest.cc" | |||
"graph/build/buffer_pool_mem_assigner_unittest.cc" | |||
"graph/execute/graph_execute_unittest.cc" | |||
"graph/preprocess/graph_preprocess_unittest.cc" | |||
"graph/manager/hcom_util_unittest.cc" | |||
"graph/manager/graph_caching_allocator_unittest.cc" | |||
"graph/partition/dynamic_shape_partition_unittest.cc" | |||
"graph/manager/graph_manager_unittest.cc" | |||
"session/omg_omg_unittest.cc" | |||
) | |||
@@ -0,0 +1,129 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <memory> | |||
#define protected public | |||
#define private public | |||
#include "graph/execute/graph_execute.h" | |||
#include "graph/load/model_manager/model_manager.h" | |||
#include "graph/load/model_manager/davinci_model.h" | |||
#include "omm/csa_interact.h" | |||
#undef private | |||
#undef public | |||
#include <pthread.h> | |||
#include <algorithm> | |||
#include <future> | |||
#include <set> | |||
#include <sstream> | |||
#include <string> | |||
#include <thread> | |||
#include <future> | |||
using namespace std; | |||
using namespace testing; | |||
using namespace ge; | |||
using namespace domi; | |||
namespace ge { | |||
namespace { | |||
const uint32_t kInvalidModelId = UINT32_MAX; | |||
} | |||
class UtestGraphExecuteTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_invalid) { | |||
GraphExecutor executor; | |||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
EXPECT_EQ(model_id, kInvalidModelId); | |||
} | |||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_1) { | |||
GraphExecutor executor; | |||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
auto model_manager = ModelManager::GetInstance(); | |||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||
davinci_model1->SetId(1); | |||
model_manager->InsertModel(1, davinci_model1); | |||
ge_root_model->SetModelId(1); | |||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
EXPECT_EQ(model_id, 1); | |||
} | |||
TEST_F(UtestGraphExecuteTest, get_execute_model_id_2) { | |||
GraphExecutor executor; | |||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
auto model_manager = ModelManager::GetInstance(); | |||
// model1 with 2 load | |||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, nullptr); | |||
davinci_model1->SetId(1); | |||
davinci_model1->data_inputer_ = new DataInputer(); | |||
auto data = MakeShared<InputDataWrapper>(); | |||
davinci_model1->data_inputer_->Push(data); | |||
davinci_model1->data_inputer_->Push(data); | |||
model_manager->InsertModel(1, davinci_model1); | |||
// model 2 with 3 load | |||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(1, nullptr); | |||
davinci_model2->SetId(2); | |||
davinci_model2->data_inputer_ = new DataInputer(); | |||
davinci_model2->data_inputer_->Push(data); | |||
davinci_model2->data_inputer_->Push(data); | |||
davinci_model2->data_inputer_->Push(data); | |||
model_manager->InsertModel(2, davinci_model2); | |||
// model 3 witH 1 load | |||
shared_ptr<DavinciModel> davinci_model3 = MakeShared<DavinciModel>(1, nullptr); | |||
davinci_model3->SetId(3); | |||
davinci_model3->data_inputer_ = new DataInputer(); | |||
davinci_model3->data_inputer_->Push(data); | |||
model_manager->InsertModel(3, davinci_model3); | |||
ge_root_model->SetModelId(1); | |||
ge_root_model->SetModelId(2); | |||
ge_root_model->SetModelId(3); | |||
auto model_id = executor.GetExecuteModelId(ge_root_model); | |||
// model 3 is picked for having least loads | |||
EXPECT_EQ(model_id, 3); | |||
} | |||
TEST_F(UtestGraphExecuteTest, test_set_callback) { | |||
GraphExecutor executor; | |||
ComputeGraphPtr graph = MakeShared<ComputeGraph>("test"); | |||
// is_unknown_shape_graph_ = false | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(graph); | |||
RunAsyncCallback callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
auto model_manager = ModelManager::GetInstance(); | |||
auto listener = MakeShared<RunAsyncListener>(); | |||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
davinci_model1->SetId(1); | |||
model_manager->InsertModel(1, davinci_model1); | |||
auto status = executor.SetCallback(1, ge_root_model, callback); | |||
EXPECT_EQ(status, SUCCESS); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,375 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <memory> | |||
#define protected public | |||
#define private public | |||
#include "graph/manager/graph_manager.h" | |||
#include "graph/load/model_manager/model_manager.h" | |||
#include "graph/load/model_manager/davinci_model.h" | |||
#define const | |||
#include "common/helper/model_cache_helper.h" | |||
#undef const | |||
#include "init/gelib.h" | |||
#undef private | |||
#undef public | |||
#include <pthread.h> | |||
#include <algorithm> | |||
#include <future> | |||
#include <set> | |||
#include <sstream> | |||
#include <string> | |||
#include <thread> | |||
#include <future> | |||
#include "common/math/math_util.h" | |||
#include "common/thread_pool.h" | |||
#include "common/dump/dump_manager.h" | |||
#include "analyzer/analyzer.h" | |||
#include "graph/common/ge_call_wrapper.h" | |||
#include "graph/common/local_context.h" | |||
#include "graph/common/transop_util.h" | |||
#include "graph/ge_context.h" | |||
#include "graph/ge_global_options.h" | |||
#include "graph/manager/util/rt_context_util.h" | |||
#include "graph/partition/dynamic_shape_partition.h" | |||
#include "graph/passes/enter_pass.h" | |||
#include "graph/partition/stage_partition.h" | |||
#include "graph/passes/addn_pass.h" | |||
#include "graph/passes/bitcast_pass.h" | |||
#include "graph/passes/assign_remove_pass.h" | |||
#include "graph/passes/inplace_support_check_pass.h" | |||
#include "graph/passes/atomic_addr_clean_pass.h" | |||
#include "graph/passes/attach_stream_label_pass.h" | |||
#include "graph/passes/cast_remove_pass.h" | |||
#include "graph/passes/common_subexpression_elimination_pass.h" | |||
#include "graph/passes/compile_nodes_pass.h" | |||
#include "graph/passes/cond_remove_pass.h" | |||
#include "graph/passes/constant_folding_pass.h" | |||
#include "graph/passes/constant_fuse_same_pass.h" | |||
#include "graph/passes/control_trigger_pass.h" | |||
#include "graph/passes/ctrl_edge_transfer_pass.h" | |||
#include "graph/passes/dimension_adjust_pass.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/flow_ctrl_pass.h" | |||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
#include "graph/passes/identity_pass.h" | |||
#include "graph/passes/input_output_connection_identify_pass.h" | |||
#include "graph/passes/iterator_op_pass.h" | |||
#include "graph/passes/link_gen_mask_nodes_pass.h" | |||
#include "graph/passes/mark_graph_unknown_status_pass.h" | |||
#include "graph/passes/merge_pass.h" | |||
#include "graph/passes/merge_input_memcpy_pass.h" | |||
#include "graph/passes/merge_to_stream_merge_pass.h" | |||
#include "graph/passes/multi_batch_pass.h" | |||
#include "graph/passes/next_iteration_pass.h" | |||
#include "graph/passes/permute_pass.h" | |||
#include "graph/passes/prune_pass.h" | |||
#include "graph/passes/ref_identity_delete_op_pass.h" | |||
#include "graph/passes/remove_same_const_pass.h" | |||
#include "graph/passes/reshape_recovery_pass.h" | |||
#include "graph/passes/reshape_remove_pass.h" | |||
#include "graph/passes/same_transdata_breadth_fusion_pass.h" | |||
#include "graph/passes/subgraph_pass.h" | |||
#include "graph/passes/switch_data_edges_bypass.h" | |||
#include "graph/passes/switch_dead_branch_elimination.h" | |||
#include "graph/passes/switch_logic_remove_pass.h" | |||
#include "graph/passes/switch_to_stream_switch_pass.h" | |||
#include "graph/passes/transop_breadth_fusion_pass.h" | |||
#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" | |||
#include "graph/passes/transop_symmetry_elimination_pass.h" | |||
#include "graph/passes/transop_without_reshape_fusion_pass.h" | |||
#include "graph/passes/transpose_transdata_pass.h" | |||
#include "graph/passes/useless_control_out_remove_pass.h" | |||
#include "graph/passes/variable_op_pass.h" | |||
#include "graph/passes/variable_ref_delete_op_pass.h" | |||
#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" | |||
#include "graph/passes/end_of_sequence_add_control_pass.h" | |||
#include "graph/passes/subexpression_migration_pass.h" | |||
#include "graph/passes/subgraph_const_migration_pass.h" | |||
#include "graph/passes/unused_args_clean_pass.h" | |||
#include "graph/passes/global_step_insert_pass.h" | |||
#include "graph/passes/memcpy_addr_async_pass.h" | |||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
#include "graph/build/label_allocator.h" | |||
#include "graph/utils/tensor_adapter.h" | |||
#include "inc/pass_manager.h" | |||
#include "ir_build/atc_ir_common.h" | |||
#include "graph/common/local_context.h" | |||
#include "graph/common/omg_util.h" | |||
#include "common/formats/utils/formats_trans_utils.h" | |||
#include "register/custom_pass_helper.h" | |||
#include "graph/ops_stub.h" | |||
using namespace std; | |||
using namespace testing; | |||
using namespace ge; | |||
using namespace domi; | |||
namespace { | |||
const uint32_t kNotAdded = 0; | |||
const uint32_t kStartAdd = 1; | |||
const uint32_t kDoneAdded = 2; | |||
} | |||
class UtestGraphManagerTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
void CreateGraph(Graph &graph) { | |||
TensorDesc desc(ge::Shape({1, 3, 224, 224})); | |||
uint32_t size = desc.GetShape().GetShapeSize(); | |||
desc.SetSize(size); | |||
auto data = op::Data("Data").set_attr_index(0); | |||
data.update_input_desc_data(desc); | |||
data.update_output_desc_out(desc); | |||
auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); | |||
std::vector<Operator> inputs{data}; | |||
std::vector<Operator> outputs{flatten}; | |||
std::vector<Operator> targets{flatten}; | |||
// Graph graph("test_graph"); | |||
graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); | |||
} | |||
TEST_F(UtestGraphManagerTest, set_and_get_add_graph_flag) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
graph_manager.SetAddGraphCondition(graph_id, 1); | |||
uint32_t res = graph_manager.GetAddGraphCondition(graph_id); | |||
EXPECT_EQ(res, 1); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_add_graph_1) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
// create graph | |||
Graph graph("test_graph"); | |||
CreateGraph(graph); | |||
std::map<std::string, std::string> options; | |||
OmgContext context; | |||
Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_add_graph_2) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_manager.AddGraphNode(graph_id, graph_node); | |||
graph_manager.SetAddGraphCondition(graph_id, kDoneAdded); | |||
Graph graph("test_graph"); | |||
CreateGraph(graph); | |||
std::map<std::string, std::string> options; | |||
OmgContext context; | |||
Status status = graph_manager.AddGraph(graph_id, graph, options, context); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_add_graph_3) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
Graph graph("test_graph"); | |||
CreateGraph(graph); | |||
std::map<std::string, std::string> options; | |||
OmgContext context; | |||
std::future<Status> fut1 = std::async(std::launch::async, | |||
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||
std::future<Status> fut2 = std::async(std::launch::async, | |||
&GraphManager::AddGraph, &graph_manager, graph_id, graph, options, context); | |||
fut1.wait(); | |||
fut2.wait(); | |||
Status status1 = fut1.get(); | |||
Status status2 = fut2.get(); | |||
EXPECT_EQ(status1, ge::SUCCESS); | |||
EXPECT_EQ(status2, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_remove_graph_1) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
Status status = graph_manager.RemoveGraph(graph_id); | |||
EXPECT_EQ(status, ge::GE_GRAPH_GRAPH_NOT_EXIST); | |||
graph_manager.AddGraphNode(graph_id, graph_node); | |||
graph_node->SetRunFlag(true); | |||
status = graph_manager.RemoveGraph(graph_id); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_remove_graph_2) { | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
Graph graph("test_graph"); | |||
CreateGraph(graph); | |||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
auto model_manager = ModelManager::GetInstance(); | |||
auto listener = MakeShared<RunAsyncListener>(); | |||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
davinci_model1->SetId(1); | |||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||
davinci_model1->SetId(2); | |||
model_manager->InsertModel(1, davinci_model1); | |||
model_manager->InsertModel(2, davinci_model2); | |||
ge_root_model->SetModelId(1); | |||
ge_root_model->SetModelId(2); | |||
graph_node->SetGeRootModel(ge_root_model); | |||
graph_node->SetLoadFlag(true); | |||
graph_manager.AddGraphNode(graph_id, graph_node); | |||
Status status = graph_manager.RemoveGraph(graph_id); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_pre_run_thread) { | |||
GraphManager graph_manager; | |||
graph_manager.thread_run_flag_ = true; | |||
GraphId graph_id = 1; | |||
std::vector<ge::InputTensorInfo> input_tensor; | |||
uint64_t session_id = 0; | |||
ErrorMessage::Context error_context; | |||
GEThreadLocalContext context; | |||
RunAsyncCallback callback; | |||
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
EXPECT_EQ(ret, true); | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_manager.AddGraphNode(graph_id, graph_node); | |||
graph_manager.PreRunThread(&graph_manager); | |||
// end with failed | |||
} | |||
TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) { | |||
GraphManager graph_manager; | |||
graph_manager.thread_run_flag_ = true; | |||
GraphId graph_id = 1; | |||
GraphNodePtr graph_node_1 = MakeShared<ge::GraphNode>(graph_id); | |||
graph_manager.AddGraphNode(graph_id, graph_node_1); | |||
graph_manager.IncreaseGraphCount(graph_id); | |||
graph_manager.IncreaseGraphCount(graph_id); | |||
graph_node_1->SetBuildFlag(true); | |||
std::vector<ge::InputTensorInfo> input_tensor; | |||
uint64_t session_id = 0; | |||
ErrorMessage::Context error_context; | |||
GEThreadLocalContext context; | |||
RunAsyncCallback callback; | |||
// PreRunArgs args{graph_id, input_tensor, session_id, error_context, context, callback}; | |||
bool ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
EXPECT_EQ(ret, true); | |||
graph_id = 2; | |||
GraphNodePtr graph_node_2 = MakeShared<ge::GraphNode>(graph_id); | |||
graph_manager.AddGraphNode(graph_id, graph_node_2); | |||
ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); | |||
EXPECT_EQ(ret, true); | |||
graph_manager.PreRunThread(&graph_manager); | |||
// end with failed | |||
} | |||
TEST_F(UtestGraphManagerTest, test_check_and_release_memory) { | |||
GraphManager graph_manager; | |||
GeModelPtr ge_model = make_shared<GeModel>(); | |||
int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; | |||
int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; | |||
uint64_t session_id = 0; | |||
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size); | |||
ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size); | |||
ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id); | |||
GraphId graph_id = 1; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_manager.AddGraphNode(graph_id, graph_node); | |||
graph_manager.IncreaseGraphCount(graph_id); | |||
graph_manager.IncreaseGraphCount(graph_id); | |||
auto model_manager = ModelManager::GetInstance(); | |||
auto listener = MakeShared<RunAsyncListener>(); | |||
shared_ptr<DavinciModel> davinci_model1 = MakeShared<DavinciModel>(1, listener); | |||
davinci_model1->SetId(1); | |||
shared_ptr<DavinciModel> davinci_model2 = MakeShared<DavinciModel>(2, listener); | |||
davinci_model1->SetId(2); | |||
model_manager->InsertModel(1, davinci_model1); | |||
model_manager->InsertModel(2, davinci_model2); | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
bool is_dynamic_shape = false; | |||
(void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
ge_root_model->SetModelId(1); | |||
ge_root_model->SetModelId(2); | |||
graph_node->SetGeRootModel(ge_root_model); | |||
graph_node->SetLoadFlag(true); | |||
Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { | |||
// no need to build | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
GraphManager::PreRunArgs arg; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_node->SetBuildFlag(true); | |||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) { | |||
// need build while buildflag is true, var format changed | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
GraphManager::PreRunArgs arg; | |||
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_node->SetBuildFlag(true); | |||
graph_node->Lock(); | |||
graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id); | |||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
EXPECT_EQ(status, ge::PARAM_INVALID); | |||
} | |||
TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { | |||
// need build while buildflag is false, var format unchanged | |||
GraphId graph_id = 1; | |||
GraphManager graph_manager; | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph"); | |||
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); | |||
GraphManager::PreRunArgs arg; | |||
arg.callback = [](Status, std::vector<ge::OutputTensorInfo> &) {}; | |||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||
graph_node->SetBuildFlag(false); | |||
graph_node->Lock(); | |||
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); | |||
EXPECT_NE(status, ge::SUCCESS); | |||
} |