diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 6114467c..cdb8e4bb 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -341,11 +341,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr davinci_model = MakeShared(0, listener); - if (davinci_model == nullptr) { - REPORT_CALL_ERROR("E19999", "New DavinciModel fail, model_id:%u", model_id); - GELOGE(FAILED, "davinci_model is nullptr"); - return FAILED; - } + GE_CHECK_NOTNULL(davinci_model); davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetId(model_id); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index f2b4211d..1315376c 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -125,6 +125,7 @@ const uint32_t kInitGraphCount = 1; const uint32_t kNotAdded = 0; const uint32_t kStartAdd = 1; const uint32_t kDoneAdded = 2; +const uint32_t kNeverLoaded = 0; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -2748,6 +2749,15 @@ void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); } graph_node->SetLoadFlag(false); + // Allow model to be loaded agagin without adding graph again + graph_node->SetLoadCount(graph_node->GetLoadRecord()); + graph_node->SetLoadRecord(kNeverLoaded); + GeRootModelPtr ge_root_model = graph_node->GetGeRootModel(); + if (ge_root_model == nullptr) { + GELOGW("ge_root_model is null, graph_id:%u", graph_id); + return; + } + ge_root_model->ClearAllModelId(); rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, when GraphManager %s", diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index ffbc20cf..bebba93e 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -178,9 +178,12 @@ class GraphNode { void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); } uint32_t GetLoadCount() const { return load_count_; } + void SetLoadCount(uint32_t count) { load_count_ = count; } + uint32_t GetLoadRecord() const { return load_record_; } + void SetLoadRecord(uint32_t record) { load_record_ = record; } + void IncreaseLoadRecord() { ++load_record_; } void IncreaseLoadCount(); void DecreaseLoadCount() { --load_count_; } - void IncreaseLoadRecord() { ++load_record_; } // run graph asynchronous listener std::shared_ptr graph_run_async_listener_; diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index f3f1e1f5..1fed16a5 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -150,10 +150,8 @@ Status HybridModelAsyncExecutor::RunInternal() { Status ret = data_inputer_->Pop(data_wrapper); // Model indeedly start running SetRunningFlag(true); - if (data_wrapper == nullptr || ret != SUCCESS) { - GELOGI("data_wrapper is null!, ret = %u", ret); - continue; - } + GE_IF_BOOL_EXEC(data_wrapper == nullptr || ret != SUCCESS, GELOGI("data_wrapper is null!, ret = %u", ret); + continue); GELOGI("Getting the input data, model_id:%u", model_id_); GE_IF_BOOL_EXEC(!run_flag_, break); diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index 8c44272d..b8ff7b7a 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -40,12 +40,14 @@ class GeRootModel { } uint32_t GetModelId() const { return model_id_; } - std::vector GetAllModelId() const { return model_ids_; } - void SetModelName(const std::string &model_name) { model_name_ = model_name; } - + const std::string &GetModelName() const { return model_name_; } - + + std::vector GetAllModelId() const { return model_ids_; } + + void ClearAllModelId() { model_ids_.clear(); } + Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }