| @@ -21,6 +21,13 @@ set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | ||||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | ||||
| set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) | |||||
| set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) | |||||
| set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) | |||||
| set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) | |||||
| set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) | |||||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
| if (ENABLE_OPEN_SRC) | if (ENABLE_OPEN_SRC) | ||||
| @@ -129,14 +136,6 @@ if (ENABLE_OPEN_SRC) | |||||
| #add_subdirectory(metadef/graph) | #add_subdirectory(metadef/graph) | ||||
| #add_subdirectory(metadef/register) | #add_subdirectory(metadef/register) | ||||
| elseif (ENABLE_D OR ENABLE_ACL) | elseif (ENABLE_D OR ENABLE_ACL) | ||||
| set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) | |||||
| set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) | |||||
| set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) | |||||
| set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) | |||||
| set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) | |||||
| # compiling with MindSpore | # compiling with MindSpore | ||||
| include(cmake/external_libs/protobuf_static.cmake) | include(cmake/external_libs/protobuf_static.cmake) | ||||
| include(cmake/external_libs/protoc.cmake) | include(cmake/external_libs/protoc.cmake) | ||||
| @@ -158,11 +157,18 @@ elseif (ENABLE_D OR ENABLE_ACL) | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | ||||
| add_subdirectory(metadef) | add_subdirectory(metadef) | ||||
| elseif(ENABLE_MS_TESTCASE) | |||||
| elseif(ENABLE_MS_TESTCASES) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | include(cmake/external_libs/protobuf_static.cmake) | ||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | include(cmake/external_libs/securec.cmake) | ||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | include(cmake/intf_pub_linux.cmake) | ||||
| # common libraries | |||||
| find_module(slog libslog.so ${ASCEND_MS_DRIVER_PATH}) | |||||
| find_module(error_manager liberror_manager.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | ||||
| add_subdirectory(metadef) | add_subdirectory(metadef) | ||||
| else() | else() | ||||
| @@ -42,7 +42,7 @@ include(GNUInstallDirs) | |||||
| add_library(ascend_protobuf_static_lib STATIC IMPORTED) | add_library(ascend_protobuf_static_lib STATIC IMPORTED) | ||||
| set_target_properties(ascend_protobuf_static_lib PROPERTIES | set_target_properties(ascend_protobuf_static_lib PROPERTIES | ||||
| IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/lib64/libascend_protobuf.a | |||||
| IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.a | |||||
| ) | ) | ||||
| add_library(ascend_protobuf_static INTERFACE) | add_library(ascend_protobuf_static INTERFACE) | ||||
| @@ -76,7 +76,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
| } | } | ||||
| // Initialize GE, prepare for execution, call GELib::Initialize | // Initialize GE, prepare for execution, call GELib::Initialize | ||||
| Status GEInitialize(const std::map<string, string> &options) { | |||||
| Status GEInitializeImpl(const std::map<string, string> &options) { | |||||
| GELOGT(TRACE_INIT, "GEInitialize start"); | GELOGT(TRACE_INIT, "GEInitialize start"); | ||||
| // 0.check init status | // 0.check init status | ||||
| if (g_ge_initialized) { | if (g_ge_initialized) { | ||||
| @@ -127,6 +127,26 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Initialize GE, prepare for execution, call GELib::Initialize | |||||
| Status GEInitialize(const std::map<string, string> &options) { | |||||
| return GEInitializeImpl(options); | |||||
| } | |||||
| Status GEInitialize(const std::map<AscendString, AscendString> &options) { | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto & option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "GEInitialize options is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| return GEInitializeImpl(str_options); | |||||
| } | |||||
| // GE finalize, releasing all resources | // GE finalize, releasing all resources | ||||
| Status GEFinalize() { | Status GEFinalize() { | ||||
| GELOGT(TRACE_INIT, "GEFinalize start"); | GELOGT(TRACE_INIT, "GEFinalize start"); | ||||
| @@ -202,6 +222,46 @@ Session::Session(const std::map<string, string> &options) { | |||||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | GELOGT(TRACE_STOP, "Session Constructor finished"); | ||||
| } | } | ||||
| Session::Session(const std::map<AscendString, AscendString> &options) { | |||||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||||
| // check init status | |||||
| sessionId_ = 0; | |||||
| if (!g_ge_initialized) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized."); | |||||
| return; | |||||
| } | |||||
| // call Initialize | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session Constructor failed"); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Creating session"); | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "Session options is nullptr."); | |||||
| return; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| uint64_t session_id = 0; | |||||
| Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); | |||||
| GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||||
| // check return status, return, update session id if success | |||||
| if (ret == SUCCESS) { | |||||
| sessionId_ = session_id; | |||||
| } else { | |||||
| GELOGE(ret, "Session constructor failed, session Id not initialized"); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | |||||
| } | |||||
| // session destructor | // session destructor | ||||
| Session::~Session() { | Session::~Session() { | ||||
| GELOGT(TRACE_INIT, "Session Destructor start"); | GELOGT(TRACE_INIT, "Session Destructor start"); | ||||
| @@ -260,6 +320,34 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status Session::AddGraph(uint32_t graph_id, const Graph &graph, | |||||
| const std::map<AscendString, AscendString> &options) { | |||||
| GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Adding graph to session"); | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "AddGraph options is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "AddGraph failed in Session."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("AddGraph finished in Session."); | |||||
| return ret; | |||||
| } | |||||
| Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | ||||
| std::map<AscendString, AscendString> options; | std::map<AscendString, AscendString> options; | ||||
| return AddGraphWithCopy(graph_id, graph, options); | return AddGraphWithCopy(graph_id, graph, options); | ||||
| @@ -387,6 +475,14 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | ||||
| } | } | ||||
| Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { | |||||
| std::string str_key; | |||||
| if (key != nullptr) { | |||||
| str_key = key; | |||||
| } | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); | |||||
| } | |||||
| Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | ||||
| @@ -436,6 +532,29 @@ Status Session::GetVariables(const std::vector<std::string> &var_names, std::vec | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) { | |||||
| auto instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Get Variables"); | |||||
| std::vector<ge::string> str_var_names; | |||||
| for (auto &var_name : var_names) { | |||||
| if (var_name.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "GetVariables name is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| str_var_names.emplace_back(var_name.GetString()); | |||||
| } | |||||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "SessionManager RunGraphAsync failed"); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | ||||
| return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | ||||
| } | } | ||||
| @@ -548,7 +548,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | ||||
| compute_graph->GetGraphID(), subgraph, session_id, GetThreadLocalContext()); | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext()); | |||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -563,7 +563,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | ||||
| compute_graph->GetGraphID(), subgraph, session_id, | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, | |||||
| GetThreadLocalContext()); | GetThreadLocalContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| @@ -1865,12 +1865,30 @@ Status GraphManager::RegisterCallBackFunc( | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::RegisterCallBackFunc( | |||||
| const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
| std::lock_guard<std::mutex> lock(member_mutex_); | |||||
| GELOGI("[GraphManager] RegisterCallBackFunc, key=%s.", key.c_str()); | |||||
| callback_map_[key] = callback; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, | Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, | ||||
| const std::map<std::string, ge::Tensor> &summary_data) { | const std::map<std::string, ge::Tensor> &summary_data) { | ||||
| std::lock_guard<std::mutex> lock(member_mutex_); | std::lock_guard<std::mutex> lock(member_mutex_); | ||||
| GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); | GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); | ||||
| auto itr = me_callback_map_.find(kSummary); | auto itr = me_callback_map_.find(kSummary); | ||||
| if (itr == me_callback_map_.end()) { | if (itr == me_callback_map_.end()) { | ||||
| auto iter = callback_map_.find(kSummary); | |||||
| if (iter != callback_map_.end()) { | |||||
| std::map<AscendString, ge::Tensor> tmp_summary_data; | |||||
| for (auto &data : summary_data) { | |||||
| AscendString tmp(data.first.c_str()); | |||||
| tmp_summary_data[tmp] = data.second; | |||||
| } | |||||
| return iter->second(graph_id, tmp_summary_data); | |||||
| } | |||||
| GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); | GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1882,6 +1900,15 @@ Status GraphManager::PushSaveData2ME(const GraphId &graph_id, const std::map<std | |||||
| GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size()); | GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size()); | ||||
| auto itr = me_callback_map_.find(kSave); | auto itr = me_callback_map_.find(kSave); | ||||
| if (itr == me_callback_map_.end()) { | if (itr == me_callback_map_.end()) { | ||||
| auto iter = callback_map_.find(kSave); | |||||
| if (iter != callback_map_.end()) { | |||||
| std::map<AscendString, ge::Tensor> tmp_save_data; | |||||
| for (auto &data : save_data) { | |||||
| AscendString tmp(data.first.c_str()); | |||||
| tmp_save_data[tmp] = data.second; | |||||
| } | |||||
| return iter->second(graph_id, tmp_save_data); | |||||
| } | |||||
| GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); | GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -2478,7 +2505,8 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
| } | } | ||||
| Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | ||||
| const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id, | |||||
| const SubGraphInfoPtr &sub_graph_info_ptr, | |||||
| const ComputeGraphPtr &compute_graph, uint64_t session_id, | |||||
| const GEThreadLocalContext &ge_context) { | const GEThreadLocalContext &ge_context) { | ||||
| if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | ||||
| GetContext().SetSessionId(session_id); | GetContext().SetSessionId(session_id); | ||||
| @@ -2494,6 +2522,7 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| GE_CHECK_NOTNULL(compute_graph_tmp); | GE_CHECK_NOTNULL(compute_graph_tmp); | ||||
| compute_graph_tmp->SetSessionID(session_id); | compute_graph_tmp->SetSessionID(session_id); | ||||
| Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | ||||
| compute_graph, | |||||
| engine_name); | engine_name); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str()); | GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str()); | ||||
| @@ -163,6 +163,10 @@ class GraphManager { | |||||
| const std::string &key, | const std::string &key, | ||||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
| Status RegisterCallBackFunc( | |||||
| const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
| const bool GetTrainFlag() const { return options_.train_graph_flag; } | const bool GetTrainFlag() const { return options_.train_graph_flag; } | ||||
| bool IsGraphNeedRebuild(uint32_t graph_id); | bool IsGraphNeedRebuild(uint32_t graph_id); | ||||
| @@ -214,7 +218,8 @@ class GraphManager { | |||||
| std::shared_ptr<GraphModelListener> GetModelListener() const { return graph_run_listener_; } | std::shared_ptr<GraphModelListener> GetModelListener() const { return graph_run_listener_; } | ||||
| static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | ||||
| const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id, | |||||
| const SubGraphInfoPtr &sub_graph_info_ptr, | |||||
| const ComputeGraphPtr &compute_graph, uint64_t session_id, | |||||
| const GEThreadLocalContext &ge_context); | const GEThreadLocalContext &ge_context); | ||||
| Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); | Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); | ||||
| Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes, | ||||
| @@ -390,6 +395,8 @@ class GraphManager { | |||||
| // summary and checkpoint callback function list for ME, key is summary or checkpoint | // summary and checkpoint callback function list for ME, key is summary or checkpoint | ||||
| std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | ||||
| std::map<std::string, std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)>> callback_map_; | |||||
| bool init_flag_; | bool init_flag_; | ||||
| GraphManagerOptions options_; | GraphManagerOptions options_; | ||||
| @@ -76,7 +76,8 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) { | |||||
| } | } | ||||
| } | } | ||||
| Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name) { | |||||
| Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph, | |||||
| const std::string &engine_name) { | |||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr."); | GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr."); | ||||
| return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | ||||
| @@ -105,6 +106,10 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std | |||||
| for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | ||||
| Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph)); | Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph)); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| auto root_graph = ge::GraphUtils::FindRootGraph(parent_graph); | |||||
| if (root_graph != nullptr) { | |||||
| ErrorManager::GetInstance().SaveMstuneCompileFailedMsg(root_graph->GetName()); | |||||
| } | |||||
| GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret); | GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -42,7 +42,8 @@ class GraphOptimize { | |||||
| ~GraphOptimize() = default; | ~GraphOptimize() = default; | ||||
| // subgraph optimize | // subgraph optimize | ||||
| Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name); | |||||
| Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph, | |||||
| const std::string &engine_name); | |||||
| // original graph optimize | // original graph optimize | ||||
| Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | ||||
| @@ -382,11 +382,18 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr | |||||
| GELOGW("SetInt anchorIndex failed");) | GELOGW("SetInt anchorIndex failed");) | ||||
| GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node), | GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node), | ||||
| GELOGW("SetPldExtAttr parentNode failed");) | GELOGW("SetPldExtAttr parentNode failed");) | ||||
| OpDescPtr src_node_op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(src_node_op_desc); | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME, | GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME, | ||||
| src_node_op_desc->GetOpEngineName()), GELOGW("SetStr frontNodeEngineName failed");) | |||||
| src_node_opdesc->GetOpEngineName()), GELOGW("SetStr frontNodeEngineName failed");) | |||||
| std::string l2_info_attr; | |||||
| if (AttrUtils::GetStr(src_node_opdesc, "_task_L2FusionInfo", l2_info_attr)) { | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_task_L2FusionInfo", l2_info_attr), | |||||
| GELOGW("SetStr l2_info_attr failed");) | |||||
| } | |||||
| int64_t anchor_index_for_lxfusion; | |||||
| if (AttrUtils::GetInt(src_node_opdesc, "_data_anchor_index_for_lxfusion", anchor_index_for_lxfusion)) { | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "_data_anchor_index_for_lxfusion", anchor_index_for_lxfusion), | |||||
| GELOGW("SetInt anchor_index_for_lxfusion failed");) | |||||
| } | |||||
| // do not care over flow | // do not care over flow | ||||
| graph_info_.num_of_pld_end_++; | graph_info_.num_of_pld_end_++; | ||||
| // replace output_desc of pld with input node's output desc | // replace output_desc of pld with input node's output desc | ||||
| @@ -30,6 +30,11 @@ | |||||
| namespace ge { | namespace ge { | ||||
| const int kValueIndexOutputIndex = 1; | const int kValueIndexOutputIndex = 1; | ||||
| bool IsEmptyTensor(const GeShape &shape) { | |||||
| const auto &dims = shape.GetDims(); | |||||
| return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; }); | |||||
| } | |||||
| Status MergePass::Run(NodePtr &node) { | Status MergePass::Run(NodePtr &node) { | ||||
| GELOGD("MergePass running"); | GELOGD("MergePass running"); | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| @@ -48,6 +53,11 @@ Status MergePass::Run(NodePtr &node) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (OptimizeEmptyTensorInput(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| const auto &in_data_nodes = node->GetInDataNodes(); | const auto &in_data_nodes = node->GetInDataNodes(); | ||||
| switch (in_data_nodes.size()) { | switch (in_data_nodes.size()) { | ||||
| case 0: { | case 0: { | ||||
| @@ -197,4 +207,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) { | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if ((peer_data_anchor->GetOwnerNode() == nullptr) || | |||||
| (peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) { | |||||
| continue; | |||||
| } | |||||
| const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc(); | |||||
| if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) { | |||||
| if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", | |||||
| op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), | |||||
| node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Remove data edge %s:%d->%s:%d", | |||||
| op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), | |||||
| node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -29,6 +29,7 @@ class MergePass : public BaseNodePass { | |||||
| Status ChangeIndexToConstant(NodePtr &node, int &value_index); | Status ChangeIndexToConstant(NodePtr &node, int &value_index); | ||||
| Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | ||||
| bool IsMergeInputNeedOptimized(NodePtr &node) const; | bool IsMergeInputNeedOptimized(NodePtr &node) const; | ||||
| static Status OptimizeEmptyTensorInput(const NodePtr &node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ | #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ | ||||
| @@ -610,11 +610,17 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| /// | /// | ||||
| Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | ||||
| auto func_desc = case_node_->GetOpDesc(); | auto func_desc = case_node_->GetOpDesc(); | ||||
| domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; | |||||
| auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | ||||
| if (post_func == nullptr) { | if (post_func == nullptr) { | ||||
| GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | ||||
| case_node_->GetType().c_str()); | case_node_->GetType().c_str()); | ||||
| return FAILED; | |||||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || | |||||
| parse_func_v2 == nullptr) { | |||||
| GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), | |||||
| case_node_->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | ||||
| @@ -629,7 +635,12 @@ Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||||
| "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | ||||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | ||||
| auto ret = post_func(subgraph_name, graph); | |||||
| Status ret = FAILED; | |||||
| if (post_func != nullptr) { | |||||
| ret = post_func(subgraph_name, graph); | |||||
| } else if (parse_func_v2 != nullptr) { | |||||
| ret = parse_func_v2(subgraph_name.c_str(), graph); | |||||
| } | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | ||||
| case_node_->GetName().c_str(), case_node_->GetType().c_str()); | case_node_->GetName().c_str(), case_node_->GetType().c_str()); | ||||
| @@ -141,7 +141,7 @@ static void LoadOpsProto() { | |||||
| (void)manager->Initialize(option_tmp); | (void)manager->Initialize(option_tmp); | ||||
| } | } | ||||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||||
| graphStatus aclgrphBuildInitializeImpl(std::map<std::string, std::string> &global_options) { | |||||
| GELOGD("Enter aclgrphInitialize start!"); | GELOGD("Enter aclgrphInitialize start!"); | ||||
| // check global options | // check global options | ||||
| if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { | if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { | ||||
| @@ -164,9 +164,34 @@ graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_opt | |||||
| } | } | ||||
| } | } | ||||
| GELOGW("gelib has been initialized!"); | GELOGW("gelib has been initialized!"); | ||||
| std::string path_base = ge::GELib::GetPath(); | |||||
| int ret = ErrorManager::GetInstance().Init(path_base); | |||||
| if (ret != 0) { | |||||
| DOMI_LOGE("ErrorManager init fail !"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||||
| return aclgrphBuildInitializeImpl(global_options); | |||||
| } | |||||
| graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options) { | |||||
| std::map<std::string, std::string> tmp_global_options; | |||||
| for (auto &option : global_options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| tmp_global_options[key] = val; | |||||
| } | |||||
| return aclgrphBuildInitializeImpl(tmp_global_options); | |||||
| } | |||||
| void aclgrphBuildFinalize() { | void aclgrphBuildFinalize() { | ||||
| if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { | if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { | ||||
| (void)ge::GELib::GetInstance()->Finalize(); | (void)ge::GELib::GetInstance()->Finalize(); | ||||
| @@ -453,6 +478,24 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string | |||||
| return builder.BuildModel(graph, build_options, model); | return builder.BuildModel(graph, build_options, model); | ||||
| } | } | ||||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model) { | |||||
| GELOGD("Enter aclmdlBuildModel process!"); | |||||
| std::map<std::string, std::string> tmp_build_options; | |||||
| for (auto &option : build_options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| tmp_build_options[key] = val; | |||||
| } | |||||
| Impl builder; | |||||
| return builder.BuildModel(graph, tmp_build_options, model); | |||||
| } | |||||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { | graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { | ||||
| GELOGD("Enter aclmdlSaveModel process!"); | GELOGD("Enter aclmdlSaveModel process!"); | ||||
| if (model.data.get() == nullptr || model.length == 0) { | if (model.data.get() == nullptr || model.length == 0) { | ||||
| @@ -463,6 +506,21 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m | |||||
| static_cast<uint32_t>(model.length)); | static_cast<uint32_t>(model.length)); | ||||
| } | } | ||||
| graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model) { | |||||
| GELOGD("Enter aclmdlSaveModel process!"); | |||||
| if (model.data.get() == nullptr || model.length == 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Input model is illegal"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (output_file == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Output file is nullptr."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| std::string str_output_file = output_file; | |||||
| return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()), | |||||
| static_cast<uint32_t>(model.length)); | |||||
| } | |||||
| graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { | graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { | ||||
| GELOGD("Enter aclgrphGetIRVersion process!"); | GELOGD("Enter aclgrphGetIRVersion process!"); | ||||
| GE_CHECK_NOTNULL(major_version); | GE_CHECK_NOTNULL(major_version); | ||||
| @@ -254,6 +254,25 @@ Status InnerSession::RegisterCallBackFunc( | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status InnerSession::RegisterCallBackFunc( | |||||
| const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
| std::lock_guard<std::mutex> lock(resource_mutex_); | |||||
| if (!init_flag_) { | |||||
| GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); | |||||
| return GE_SESS_INIT_FAILED; | |||||
| } | |||||
| UpdateThreadContext(std::map<std::string, std::string>{}); | |||||
| Status ret = graph_manager_.RegisterCallBackFunc(key, callback); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[InnerSession:%lu] register %s callback function failed.", session_id_, key.c_str()); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("[InnerSession:%lu] register %s callback function success.", session_id_, key.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
| UpdateThreadContext(graph_id); | UpdateThreadContext(graph_id); | ||||
| GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); | GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); | ||||
| @@ -62,6 +62,10 @@ class InnerSession { | |||||
| const std::string &key, | const std::string &key, | ||||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
| Status RegisterCallBackFunc( | |||||
| const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
| const GraphManager &getGraphManagerObj() const; | const GraphManager &getGraphManagerObj() const; | ||||
| bool IsGraphNeedRebuild(uint32_t graph_id); | bool IsGraphNeedRebuild(uint32_t graph_id); | ||||
| @@ -276,6 +276,26 @@ Status SessionManager::RegisterCallBackFunc( | |||||
| return innerSession->RegisterCallBackFunc(key, callback); | return innerSession->RegisterCallBackFunc(key, callback); | ||||
| } | } | ||||
| Status SessionManager::RegisterCallBackFunc( | |||||
| SessionId session_id, const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
| if (!init_flag_) { | |||||
| GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | |||||
| return GE_SESSION_MANAGER_NOT_INIT; | |||||
| } | |||||
| SessionPtr innerSession = nullptr; | |||||
| { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id); | |||||
| if (it == session_manager_map_.end()) { | |||||
| return GE_SESSION_NOT_EXIST; | |||||
| } else { | |||||
| innerSession = it->second; | |||||
| } | |||||
| } | |||||
| return innerSession->RegisterCallBackFunc(key, callback); | |||||
| } | |||||
| Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
| if (!init_flag_) { | if (!init_flag_) { | ||||
| GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | ||||
| @@ -158,6 +158,9 @@ class SessionManager { | |||||
| Status RegisterCallBackFunc( | Status RegisterCallBackFunc( | ||||
| SessionId session_id, const std::string &key, | SessionId session_id, const std::string &key, | ||||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
| Status RegisterCallBackFunc( | |||||
| SessionId session_id, const std::string &key, | |||||
| const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
| bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); | bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); | ||||
| @@ -29,16 +29,26 @@ | |||||
| namespace ge { | namespace ge { | ||||
| typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> ¶ms_list); | typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> ¶ms_list); | ||||
| namespace session { | |||||
| typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<AscendString, ge::Tensor> ¶ms_list); | |||||
| } | |||||
| // Initialize GE | // Initialize GE | ||||
| ATTRIBUTED_DEPRECATED(Status GEInitialize(const std::map<AscendString, AscendString> &)) | |||||
| Status GEInitialize(const std::map<std::string, std::string> &options); | Status GEInitialize(const std::map<std::string, std::string> &options); | ||||
| Status GEInitialize(const std::map<AscendString, AscendString> &options); | |||||
| // Finalize GE, release all resources | // Finalize GE, release all resources | ||||
| Status GEFinalize(); | Status GEFinalize(); | ||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | ||||
| public: | public: | ||||
| ATTRIBUTED_DEPRECATED(Session(const std::map<AscendString, AscendString> &)) | |||||
| explicit Session(const std::map<std::string, std::string> &options); | explicit Session(const std::map<std::string, std::string> &options); | ||||
| explicit Session(const std::map<AscendString, AscendString> &options); | |||||
| ~Session(); | ~Session(); | ||||
| /// | /// | ||||
| @@ -57,8 +67,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
| /// @param [in] options graph options | /// @param [in] options graph options | ||||
| /// @return Status result of function | /// @return Status result of function | ||||
| /// | /// | ||||
| ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map<AscendString, AscendString> &)) | |||||
| Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | ||||
| /// | |||||
| /// @ingroup client | |||||
| /// @brief add a graph with a specific graphId and graphOptions | |||||
| /// @param [in] graphId graph id | |||||
| /// @param [in] graph the graph | |||||
| /// @param [in] options graph options | |||||
| /// @return Status result of function | |||||
| /// | |||||
| Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<AscendString, AscendString> &options); | |||||
| /// | /// | ||||
| /// @ingroup client | /// @ingroup client | ||||
| /// @brief add a copy graph with a specific graphId | /// @brief add a copy graph with a specific graphId | ||||
| @@ -124,8 +145,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
| /// @param [out] var_values: variable values | /// @param [out] var_values: variable values | ||||
| /// @return Status result of function | /// @return Status result of function | ||||
| /// | /// | ||||
| ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector<std::string> &, std::vector<Tensor> &)) | |||||
| Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values); | Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values); | ||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief get variables in the session with specific session id | |||||
| /// @param [in] var_names: variable names | |||||
| /// @param [out] var_values: variable values | |||||
| /// @return Status result of function | |||||
| /// | |||||
| Status GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values); | |||||
| /// | /// | ||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| /// @brief register callback func with specific summary or checkpoint by users | /// @brief register callback func with specific summary or checkpoint by users | ||||
| @@ -135,8 +166,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
| /// Please ensure that the implementation of the function is trusted. | /// Please ensure that the implementation of the function is trusted. | ||||
| /// @return Status result of function | /// @return Status result of function | ||||
| /// | /// | ||||
| ATTRIBUTED_DEPRECATED(Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) | |||||
| Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); | Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); | ||||
| Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); | |||||
| bool IsGraphNeedRebuild(uint32_t graphId); | bool IsGraphNeedRebuild(uint32_t graphId); | ||||
| private: | private: | ||||
| @@ -22,6 +22,12 @@ | |||||
| #include "ge_error_codes.h" | #include "ge_error_codes.h" | ||||
| namespace ge { | namespace ge { | ||||
| #ifdef __GNUC__ | |||||
| #define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) | |||||
| #else | |||||
| #define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) | |||||
| #endif | |||||
| class StatusFactory { | class StatusFactory { | ||||
| public: | public: | ||||
| static StatusFactory *Instance() { | static StatusFactory *Instance() { | ||||
| @@ -37,6 +43,17 @@ class StatusFactory { | |||||
| err_desc_[err] = desc; | err_desc_[err] = desc; | ||||
| } | } | ||||
| void RegisterErrorNo(uint32_t err, const char *desc) { | |||||
| if (desc == nullptr) { | |||||
| return; | |||||
| } | |||||
| std::string error_desc = desc; | |||||
| if (err_desc_.find(err) != err_desc_.end()) { | |||||
| return; | |||||
| } | |||||
| err_desc_[err] = error_desc; | |||||
| } | |||||
| std::string GetErrDesc(uint32_t err) { | std::string GetErrDesc(uint32_t err) { | ||||
| auto iter_find = err_desc_.find(err); | auto iter_find = err_desc_.find(err); | ||||
| if (iter_find == err_desc_.end()) { | if (iter_find == err_desc_.end()) { | ||||
| @@ -56,6 +73,7 @@ class StatusFactory { | |||||
| class ErrorNoRegisterar { | class ErrorNoRegisterar { | ||||
| public: | public: | ||||
| ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | ||||
| ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | |||||
| ~ErrorNoRegisterar() {} | ~ErrorNoRegisterar() {} | ||||
| }; | }; | ||||
| @@ -65,7 +65,47 @@ const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOp | |||||
| // Option key: memory init | // Option key: memory init | ||||
| const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | ||||
| const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | ||||
| namespace configure_option { | |||||
| const char *const STREAM_NUM = "ge.streamNum"; | |||||
| const char *const HEAD_STREAM = "ge.headStream"; | |||||
| const char *const PERF_LEVEL = "ge.perfLevel"; | |||||
| const char *const ENCRYPT_MODE = "ge.encryptMode"; | |||||
| const char *const EK_FILE = "ge.ekFile"; | |||||
| const char *const CERT_FILE = "ge.certFile"; | |||||
| const char *const HW_KEY_FILE = "ge.hwKeyFile"; | |||||
| const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; | |||||
| const char *const FRAMEWORK_TYPE = "ge.frameworkType"; | |||||
| const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; | |||||
| const char *const INSERT_OP_FILE = "ge.insertOpFile"; | |||||
| const char *const OUTPUT_NODE_NAME = "ge.outputNodeName"; | |||||
| const char *const COMPRESS_FLAG = "ge.compressFlag"; | |||||
| const char *const PRECISION_MODE = "ge.exec.precision_mode"; | |||||
| const char *const SINGLE_OP_FLAG = "ge.exec.single_op"; | |||||
| const char *const TRAIN_FLAG = "ge.trainFlag"; | |||||
| const char *const RUN_FLAG = "ge.runFlag"; | |||||
| const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; | |||||
| const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; | |||||
| const char *const DDK_VERSION_FLAG = "ge.DDK_version"; | |||||
| const char *const GE_FE_FLAG = "ge.feFlag"; | |||||
| const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||||
| const char *const OUTPUT_DATATYPE = "ge.outputDatatype"; | |||||
| const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; | |||||
| const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; | |||||
| const char *const HCOM_PARALLEL = "ge.hcomParallel"; | |||||
| const char *const AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||||
| const char *const SOC_VERSION = "ge.socVersion"; | |||||
| const char *const CORE_TYPE = "ge.engineType"; | |||||
| const char *const AICORE_NUM = "ge.aicoreNum"; | |||||
| const char *const L1_FUSION = "ge.l1Fusion"; | |||||
| const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||||
| const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||||
| const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||||
| const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||||
| const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||||
| const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | |||||
| const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
| const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||||
| } // namespace configure_option | |||||
| // Configure stream num by Session constructor options param, | // Configure stream num by Session constructor options param, | ||||
| // its value should be int32_t type, default value is "1" | // its value should be int32_t type, default value is "1" | ||||
| const std::string STREAM_NUM = "ge.streamNum"; | const std::string STREAM_NUM = "ge.streamNum"; | ||||
| @@ -324,8 +364,8 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c | |||||
| static const char *const DEBUG_DIR = ge::DEBUG_DIR; | static const char *const DEBUG_DIR = ge::DEBUG_DIR; | ||||
| static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | ||||
| static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | ||||
| static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str(); | |||||
| static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str(); | |||||
| static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); | |||||
| static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); | |||||
| static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | ||||
| // for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
| @@ -347,8 +387,8 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||||
| DEBUG_DIR, | DEBUG_DIR, | ||||
| OP_COMPILER_CACHE_DIR, | OP_COMPILER_CACHE_DIR, | ||||
| OP_COMPILER_CACHE_MODE, | OP_COMPILER_CACHE_MODE, | ||||
| MDL_BANK_PATH_FLAG, | |||||
| OP_BANK_PATH_FLAG}; | |||||
| MDL_BANK_PATH, | |||||
| OP_BANK_PATH}; | |||||
| // for interface: aclgrphParse | // for interface: aclgrphParse | ||||
| const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | ||||
| @@ -44,8 +44,11 @@ struct ModelBufferData { | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &)) | |||||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | ||||
| graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options); | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| * @brief build model.Notice the model is stored in buffer | * @brief build model.Notice the model is stored in buffer | ||||
| @@ -63,9 +66,14 @@ void aclgrphBuildFinalize(); | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &, | |||||
| ModelBufferData &)) | |||||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | ||||
| ModelBufferData &model); | ModelBufferData &model); | ||||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model); | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| * @brief save model buffer to file | * @brief save model buffer to file | ||||
| @@ -75,8 +83,11 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphSaveModel(const char *, const ModelBufferData &)) | |||||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | ||||
| graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model); | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| * @brief query IR interface version | * @brief query IR interface version | ||||
| @@ -110,6 +121,5 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph); | |||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); | graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); | ||||
| }; // namespace ge | |||||
| }; // namespace ge | |||||
| #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 57e72aac24a35e40799e342fdacca362a66395c4 | |||||
| Subproject commit 0f5ddb10ce79ea2c01b8b9cab5ec3102879610bb | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit bb6424dc6d9252a3ac70650cde2f547761237681 | |||||
| Subproject commit cf60b0c02d1a6e844fcec4202d18a069e9502b0f | |||||