| @@ -76,7 +76,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| } | |||
| // Initialize GE, prepare for execution, call GELib::Initialize | |||
| Status GEInitialize(const std::map<string, string> &options) { | |||
| Status GEInitializeImpl(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_INIT, "GEInitialize start"); | |||
| // 0.check init status | |||
| if (g_ge_initialized) { | |||
| @@ -127,6 +127,26 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| return ret; | |||
| } | |||
| // Initialize GE, prepare for execution, call GELib::Initialize | |||
| Status GEInitialize(const std::map<string, string> &options) { | |||
| return GEInitializeImpl(options); | |||
| } | |||
| Status GEInitialize(const std::map<ge::AscendString, ge::AscendString> &options) { | |||
| std::map<std::string, std::string> str_options; | |||
| for (auto & option : options) { | |||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||
| GELOGE(FAILED, "GEInitialize options is nullptr."); | |||
| return FAILED; | |||
| } | |||
| std::string key = option.first.GetString(); | |||
| std::string val = option.second.GetString(); | |||
| str_options[key] = val; | |||
| } | |||
| return GEInitializeImpl(str_options); | |||
| } | |||
| // GE finalize, releasing all resources | |||
| Status GEFinalize() { | |||
| GELOGT(TRACE_INIT, "GEFinalize start"); | |||
| @@ -202,6 +222,46 @@ Session::Session(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | |||
| } | |||
| Session::Session(const std::map<ge::AscendString, ge::AscendString> &options) { | |||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||
| // check init status | |||
| sessionId_ = 0; | |||
| if (!g_ge_initialized) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED); | |||
| return; | |||
| } | |||
| // call Initialize | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session Constructor failed"); | |||
| return; | |||
| } | |||
| GELOGT(TRACE_RUNNING, "Creating session"); | |||
| std::map<std::string, std::string> str_options; | |||
| for (auto &option : options) { | |||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||
| GELOGE(FAILED, "Session options is nullptr."); | |||
| return; | |||
| } | |||
| std::string key = option.first.GetString(); | |||
| std::string val = option.second.GetString(); | |||
| str_options[key] = val; | |||
| } | |||
| uint64_t session_id = 0; | |||
| Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); | |||
| GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||
| // check return status, return, update session id if success | |||
| if (ret == SUCCESS) { | |||
| sessionId_ = session_id; | |||
| } else { | |||
| GELOGE(ret, "Session constructor failed, session Id not initialized"); | |||
| return; | |||
| } | |||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | |||
| } | |||
| // session destructor | |||
| Session::~Session() { | |||
| GELOGT(TRACE_INIT, "Session Destructor start"); | |||
| @@ -260,6 +320,34 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s | |||
| return ret; | |||
| } | |||
| Status Session::AddGraph(uint32_t graph_id, const Graph &graph, | |||
| const std::map<ge::AscendString, ge::AscendString> &options) { | |||
| GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); | |||
| return FAILED; | |||
| } | |||
| GELOGD("Adding graph to session"); | |||
| std::map<std::string, std::string> str_options; | |||
| for (auto &option : options) { | |||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||
| GELOGE(FAILED, "AddGraph options is nullptr."); | |||
| return FAILED; | |||
| } | |||
| std::string key = option.first.GetString(); | |||
| std::string val = option.second.GetString(); | |||
| str_options[key] = val; | |||
| } | |||
| Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "AddGraph failed in Session."); | |||
| return FAILED; | |||
| } | |||
| GELOGD("AddGraph finished in Session."); | |||
| return ret; | |||
| } | |||
| Status Session::RemoveGraph(uint32_t graph_id) { | |||
| GELOGT(TRACE_INIT, "Session RemoveGraph start"); | |||
| @@ -360,6 +448,14 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||
| } | |||
| Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { | |||
| std::string str_key; | |||
| if (key != nullptr) { | |||
| str_key = key; | |||
| } | |||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); | |||
| } | |||
| Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| @@ -409,6 +505,29 @@ Status Session::GetVariables(const std::vector<std::string> &var_names, std::vec | |||
| return SUCCESS; | |||
| } | |||
| Status Session::GetVariables(const std::vector<ge::AscendString> &var_names, std::vector<Tensor> &var_values) { | |||
| auto instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | |||
| return FAILED; | |||
| } | |||
| GELOGT(TRACE_RUNNING, "Get Variables"); | |||
| std::vector<ge::string> str_var_names; | |||
| for (auto &var_name : var_names) { | |||
| if (var_name.GetString() == nullptr) { | |||
| GELOGE(FAILED, "GetVariables name is nullptr."); | |||
| return FAILED; | |||
| } | |||
| str_var_names.emplace_back(var_name.GetString()); | |||
| } | |||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "SessionManager RunGraphAsync failed"); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | |||
| return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | |||
| } | |||
| @@ -77,6 +77,7 @@ target_compile_options(ge_common PRIVATE | |||
| -fvisibility=hidden | |||
| -O2 | |||
| -Werror | |||
| -wno-deprecated-declarations | |||
| ) | |||
| target_include_directories(ge_common PRIVATE | |||
| @@ -131,6 +132,7 @@ target_compile_options(ge_common_static PRIVATE | |||
| -fvisibility=hidden | |||
| -O2 | |||
| -Werror | |||
| -wno-deprecated-declarations | |||
| ) | |||
| target_include_directories(ge_common_static PRIVATE | |||
| @@ -82,8 +82,9 @@ include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_common | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| else | |||
| @@ -123,8 +124,9 @@ include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_common | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| else | |||
| @@ -169,8 +171,9 @@ include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_common | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| endif | |||
| @@ -211,8 +214,9 @@ include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_common | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP | |||
| LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| endif | |||
| @@ -80,6 +80,7 @@ add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) | |||
| target_compile_options(ge_executor PRIVATE | |||
| -Werror | |||
| -O2 | |||
| -Wno-deprecated-declarations | |||
| ) | |||
| target_compile_definitions(ge_executor PRIVATE | |||
| @@ -104,7 +104,7 @@ local_ge_executor_ldflags := -lrt -ldl \ | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_executor | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private | |||
| LOCAL_SRC_FILES := $(local_ge_executor_src_files) | |||
| @@ -130,7 +130,7 @@ include $(BUILD_SHARED_LIBRARY) | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_executor | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| @@ -166,7 +166,7 @@ include $(BUILD_HOST_SHARED_LIBRARY) | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_executor | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| @@ -199,7 +199,7 @@ include $(BUILD_HOST_STATIC_LIBRARY) | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := libge_executor | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| @@ -1752,12 +1752,30 @@ Status GraphManager::RegisterCallBackFunc( | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::RegisterCallBackFunc( | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback) { | |||
| std::lock_guard<std::mutex> lock(member_mutex_); | |||
| GELOGI("[GraphManager] RegisterCallBackFunc, key=%s.", key.c_str()); | |||
| callback_map_[key] = callback; | |||
| return SUCCESS; | |||
| } | |||
| Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, | |||
| const std::map<std::string, ge::Tensor> &summary_data) { | |||
| std::lock_guard<std::mutex> lock(member_mutex_); | |||
| GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); | |||
| auto itr = me_callback_map_.find(kSummary); | |||
| if (itr == me_callback_map_.end()) { | |||
| auto iter = callback_map_.find(kSummary); | |||
| if (iter != callback_map_.end()) { | |||
| std::map<ge::AscendString, ge::Tensor> tmp_summary_data; | |||
| for (auto &data : summary_data) { | |||
| AscendString tmp(data.first.c_str()); | |||
| tmp_summary_data[tmp] = data.second; | |||
| } | |||
| return iter->second(graph_id, tmp_summary_data); | |||
| } | |||
| GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); | |||
| return FAILED; | |||
| } | |||
| @@ -1769,6 +1787,15 @@ Status GraphManager::PushSaveData2ME(const GraphId &graph_id, const std::map<std | |||
| GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size()); | |||
| auto itr = me_callback_map_.find(kSave); | |||
| if (itr == me_callback_map_.end()) { | |||
| auto iter = callback_map_.find(kSave); | |||
| if (iter != callback_map_.end()) { | |||
| std::map<ge::AscendString, ge::Tensor> tmp_save_data; | |||
| for (auto &data : save_data) { | |||
| AscendString tmp(data.first.c_str()); | |||
| tmp_save_data[tmp] = data.second; | |||
| } | |||
| return iter->second(graph_id, tmp_save_data); | |||
| } | |||
| GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); | |||
| return FAILED; | |||
| } | |||
| @@ -152,6 +152,10 @@ class GraphManager { | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | |||
| Status RegisterCallBackFunc( | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback); | |||
| const bool GetTrainFlag() const { return options_.train_graph_flag; } | |||
| bool IsGraphNeedRebuild(uint32_t graph_id); | |||
| @@ -373,6 +377,8 @@ class GraphManager { | |||
| // summary and checkpoint callback function list for ME, key is summary or checkpoint | |||
| std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | |||
| std::map<std::string, std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)>> callback_map_; | |||
| bool init_flag_; | |||
| GraphManagerOptions options_; | |||
| @@ -62,16 +62,26 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { | |||
| node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); | |||
| } | |||
| domi::ParseSubgraphFuncV1 parse_subgraph = nullptr; | |||
| auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType()); | |||
| if (post_func == nullptr) { | |||
| GELOGW("The subgraph post func for node %s type %s is null.", | |||
| parent_node->GetName().c_str(), parent_node->GetType().c_str()); | |||
| return SUCCESS; | |||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType(), parse_subgraph) != SUCCESS || | |||
| parse_subgraph == nullptr) { | |||
| GELOGW("The subgraph new post func for node[%s] type [%s] is null", | |||
| parent_node->GetName().c_str(), parent_node->GetType().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| } | |||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto ret = post_func(subgraph_name, graph); | |||
| Status ret = FAILED; | |||
| if (post_func != nullptr) { | |||
| ret = post_func(subgraph_name, graph); | |||
| } else if (parse_subgraph != nullptr) { | |||
| ret = parse_subgraph(subgraph_name.c_str(), graph); | |||
| } | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", | |||
| graph.GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); | |||
| @@ -610,11 +610,17 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||
| /// | |||
| Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||
| auto func_desc = case_node_->GetOpDesc(); | |||
| domi::ParseSubgraphFuncV1 parse_subgraph = nullptr; | |||
| auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | |||
| if (post_func == nullptr) { | |||
| GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | |||
| case_node_->GetType().c_str()); | |||
| return FAILED; | |||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_subgraph) != SUCCESS || | |||
| parse_subgraph == nullptr) { | |||
| GELOGW("The subgraph new post func for node[%s] type [%s] is null", case_node_->GetName().c_str(), | |||
| case_node_->GetType().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | |||
| @@ -629,7 +635,12 @@ Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||
| "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | |||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||
| auto ret = post_func(subgraph_name, graph); | |||
| Status ret = FAILED; | |||
| if (post_func != nullptr) { | |||
| ret = post_func(subgraph_name, graph); | |||
| } else if (parse_subgraph != nullptr) { | |||
| ret = parse_subgraph(subgraph_name.c_str(), graph); | |||
| } | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | |||
| case_node_->GetName().c_str(), case_node_->GetType().c_str()); | |||
| @@ -109,7 +109,7 @@ static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||
| graphStatus aclgrphBuildInitializeImpl(std::map<std::string, std::string> &global_options) { | |||
| GELOGD("Enter aclgrphInitialize start!"); | |||
| // check global options | |||
| if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { | |||
| @@ -132,6 +132,24 @@ graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_opt | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||
| return aclgrphBuildInitializeImpl(global_options); | |||
| } | |||
| graphStatus aclgrphBuildInitialize(std::map<ge::AscendString, ge::AscendString> &global_options) { | |||
| std::map<std::string, std::string> tmp_global_options; | |||
| for (auto &option : global_options) { | |||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string key = option.first.GetString(); | |||
| std::string val = option.second.GetString(); | |||
| tmp_global_options[key] = val; | |||
| } | |||
| return aclgrphBuildInitializeImpl(tmp_global_options); | |||
| } | |||
| void aclgrphBuildFinalize() { | |||
| if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { | |||
| (void)ge::GELib::GetInstance()->Finalize(); | |||
| @@ -417,6 +435,24 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string | |||
| return builder.BuildModel(graph, build_options, model); | |||
| } | |||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<ge::AscendString,ge::AscendString> &build_options, | |||
| ModelBufferData &model) { | |||
| GELOGD("Enter aclmdlBuildModel process!"); | |||
| std::map<std::string, std::string> tmp_build_options; | |||
| for (auto &option : build_options) { | |||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string key = option.first.GetString(); | |||
| std::string val = option.second.GetString(); | |||
| tmp_build_options[key] = val; | |||
| } | |||
| Impl builder; | |||
| return builder.BuildModel(graph, tmp_build_options, model); | |||
| } | |||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { | |||
| GELOGD("Enter aclmdlSaveModel process!"); | |||
| if (model.data.get() == nullptr || model.length == 0) { | |||
| @@ -427,6 +463,21 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m | |||
| static_cast<uint32_t>(model.length)); | |||
| } | |||
| graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model) { | |||
| GELOGD("Enter aclmdlSaveModel process!"); | |||
| if (model.data.get() == nullptr || model.length == 0) { | |||
| GELOGE(GRAPH_PARAM_INVALID, "Input model is illegal"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| if (output_file == nullptr) { | |||
| GELOGE(GRAPH_PARAM_INVALID, "Output file is nullptr."); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| std::string str_output_file = output_file; | |||
| return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()), | |||
| static_cast<uint32_t>(model.length)); | |||
| } | |||
| graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { | |||
| GELOGD("Enter aclgrphGetIRVersion process!"); | |||
| GE_CHECK_NOTNULL(major_version); | |||
| @@ -20,6 +20,7 @@ add_executable(atc ${SRC_LIST} ${PROTO_HDRS}) | |||
| target_compile_options(atc PRIVATE | |||
| -Werror | |||
| -O2 | |||
| -Wno-deprecated-declarations | |||
| ) | |||
| target_compile_definitions(atc PRIVATE | |||
| @@ -5,7 +5,7 @@ include $(CLEAR_VARS) | |||
| LOCAL_MODULE := atc | |||
| LOCAL_CFLAGS += -Werror | |||
| LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private | |||
| LOCAL_SRC_FILES := \ | |||
| @@ -236,6 +236,25 @@ Status InnerSession::RegisterCallBackFunc( | |||
| return SUCCESS; | |||
| } | |||
| Status InnerSession::RegisterCallBackFunc( | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback) { | |||
| std::lock_guard<std::mutex> lock(resource_mutex_); | |||
| if (!init_flag_) { | |||
| GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); | |||
| return GE_SESS_INIT_FAILED; | |||
| } | |||
| UpdateThreadContext(std::map<std::string, std::string>{}); | |||
| Status ret = graph_manager_.RegisterCallBackFunc(key, callback); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[InnerSession:%lu] register %s callback function failed.", session_id_, key.c_str()); | |||
| return ret; | |||
| } | |||
| GELOGI("[InnerSession:%lu] register %s callback function success.", session_id_, key.c_str()); | |||
| return SUCCESS; | |||
| } | |||
| Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | |||
| UpdateThreadContext(graph_id); | |||
| GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); | |||
| @@ -60,6 +60,10 @@ class InnerSession { | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | |||
| Status RegisterCallBackFunc( | |||
| const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback); | |||
| const GraphManager &getGraphManagerObj() const; | |||
| bool IsGraphNeedRebuild(uint32_t graph_id); | |||
| @@ -246,6 +246,26 @@ Status SessionManager::RegisterCallBackFunc( | |||
| return innerSession->RegisterCallBackFunc(key, callback); | |||
| } | |||
| Status SessionManager::RegisterCallBackFunc( | |||
| SessionId session_id, const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback) { | |||
| if (!init_flag_) { | |||
| GELOGE(GE_SESSION_MANAGER_NOT_INIT); | |||
| return GE_SESSION_MANAGER_NOT_INIT; | |||
| } | |||
| SessionPtr innerSession = nullptr; | |||
| { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id); | |||
| if (it == session_manager_map_.end()) { | |||
| return GE_SESSION_NOT_EXIST; | |||
| } else { | |||
| innerSession = it->second; | |||
| } | |||
| } | |||
| return innerSession->RegisterCallBackFunc(key, callback); | |||
| } | |||
| Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | |||
| if (!init_flag_) { | |||
| GELOGE(GE_SESSION_MANAGER_NOT_INIT); | |||
| @@ -146,6 +146,9 @@ class SessionManager { | |||
| Status RegisterCallBackFunc( | |||
| SessionId session_id, const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | |||
| Status RegisterCallBackFunc( | |||
| SessionId session_id, const std::string &key, | |||
| const std::function<Status(uint32_t, const std::map<ge::AscendString, ge::Tensor> &)> &callback); | |||
| bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); | |||
| @@ -29,16 +29,26 @@ | |||
| namespace ge { | |||
| typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> ¶ms_list); | |||
| namespace session { | |||
| typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<ge::AscendString, ge::Tensor> ¶ms_list); | |||
| } | |||
| // Initialize GE | |||
| ATTRIBUTED_DEPRECATED(Status GEInitialize(const std::map<ge::AscendString, ge::AscendString> &)) | |||
| Status GEInitialize(const std::map<std::string, std::string> &options); | |||
| Status GEInitialize(const std::map<ge::AscendString, ge::AscendString> &options); | |||
| // Finalize GE, release all resources | |||
| Status GEFinalize(); | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
| public: | |||
| ATTRIBUTED_DEPRECATED(Session(const std::map<ge::AscendString, ge::AscendString> &)) | |||
| explicit Session(const std::map<std::string, std::string> &options); | |||
| explicit Session(const std::map<ge::AscendString, ge::AscendString> &options); | |||
| ~Session(); | |||
| /// | |||
| @@ -57,8 +67,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
| /// @param [in] options graph options | |||
| /// @return Status result of function | |||
| /// | |||
| ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map<ge::AscendString, ge::AscendString> &)) | |||
| Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | |||
| /// | |||
| /// @ingroup client | |||
| /// @brief add a graph with a specific graphId and graphOptions | |||
| /// @param [in] graphId graph id | |||
| /// @param [in] graph the graph | |||
| /// @param [in] options graph options | |||
| /// @return Status result of function | |||
| /// | |||
| Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<ge::AscendString, ge::AscendString> &options); | |||
| /// | |||
| /// @ingroup ge_graph | |||
| /// @brief remove a graph of the session with specific session id | |||
| @@ -105,8 +126,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
| /// @param [out] var_values: variable values | |||
| /// @return Status result of function | |||
| /// | |||
| ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector<std::string> &, std::vector<Tensor> &)) | |||
| Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values); | |||
| /// | |||
| /// @ingroup ge_graph | |||
| /// @brief get variables in the session with specific session id | |||
| /// @param [in] var_names: variable names | |||
| /// @param [out] var_values: variable values | |||
| /// @return Status result of function | |||
| /// | |||
| Status GetVariables(const std::vector<ge::AscendString> &var_names, std::vector<Tensor> &var_values); | |||
| /// | |||
| /// @ingroup ge_graph | |||
| /// @brief register callback func with specific summary or checkpoint by users | |||
| @@ -116,8 +147,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
| /// Please ensure that the implementation of the function is trusted. | |||
| /// @return Status result of function | |||
| /// | |||
| ATTRIBUTED_DEPRECATED( Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) | |||
| Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); | |||
| Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); | |||
| bool IsGraphNeedRebuild(uint32_t graphId); | |||
| private: | |||
| @@ -19,8 +19,15 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "graph/ascend_string.h" | |||
| namespace ge { | |||
| #ifdef __GNUC__ | |||
| #define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) | |||
| #else | |||
| #define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) | |||
| #endif | |||
| class StatusFactory { | |||
| public: | |||
| static StatusFactory *Instance() { | |||
| @@ -36,6 +43,17 @@ class StatusFactory { | |||
| err_desc_[err] = desc; | |||
| } | |||
| void RegisterErrorNo(uint32_t err, const char *desc) { | |||
| if (desc == nullptr) { | |||
| return; | |||
| } | |||
| std::string error_desc = desc; | |||
| if (err_desc_.find(err) != err_desc_.end()) { | |||
| return; | |||
| } | |||
| err_desc_[err] = error_desc; | |||
| } | |||
| std::string GetErrDesc(uint32_t err) { | |||
| auto iter_find = err_desc_.find(err); | |||
| if (iter_find == err_desc_.end()) { | |||
| @@ -44,6 +62,13 @@ class StatusFactory { | |||
| return iter_find->second; | |||
| } | |||
| void GetErrDesc(uint32_t err, ge::AscendString &err_desc) { | |||
| auto iter_find = err_desc_.find(err); | |||
| if (iter_find != err_desc_.end()) { | |||
| err_desc = ge::AscendString((iter_find->second).c_str()); | |||
| } | |||
| } | |||
| protected: | |||
| StatusFactory() {} | |||
| ~StatusFactory() {} | |||
| @@ -55,6 +80,7 @@ class StatusFactory { | |||
| class ErrorNoRegisterar { | |||
| public: | |||
| ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | |||
| ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | |||
| ~ErrorNoRegisterar() {} | |||
| }; | |||
| @@ -65,7 +65,47 @@ const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOp | |||
| // Option key: memory init | |||
| const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | |||
| const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | |||
| namespace configure_option { | |||
| const char *const STREAM_NUM = "ge.streamNum"; | |||
| const char *const HEAD_STREAM = "ge.headStream"; | |||
| const char *const PERF_LEVEL = "ge.perfLevel"; | |||
| const char *const ENCRYPT_MODE = "ge.encryptMode"; | |||
| const char *const EK_FILE = "ge.ekFile"; | |||
| const char *const CERT_FILE = "ge.certFile"; | |||
| const char *const HW_KEY_FILE = "ge.hwKeyFile"; | |||
| const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; | |||
| const char *const FRAMEWORK_TYPE = "ge.frameworkType"; | |||
| const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; | |||
| const char *const INSERT_OP_FILE = "ge.insertOpFile"; | |||
| const char *const OUTPUT_NODE_NAME = "ge.outputNodeName"; | |||
| const char *const COMPRESS_FLAG = "ge.compressFlag"; | |||
| const char *const PRECISION_MODE = "ge.exec.precision_mode"; | |||
| const char *const SINGLE_OP_FLAG = "ge.exec.single_op"; | |||
| const char *const TRAIN_FLAG = "ge.trainFlag"; | |||
| const char *const RUN_FLAG = "ge.runFlag"; | |||
| const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; | |||
| const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; | |||
| const char *const DDK_VERSION_FLAG = "ge.DDK_version"; | |||
| const char *const GE_FE_FLAG = "ge.feFlag"; | |||
| const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||
| const char *const OUTPUT_DATATYPE = "ge.outputDatatype"; | |||
| const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; | |||
| const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; | |||
| const char *const HCOM_PARALLEL = "ge.hcomParallel"; | |||
| const char *const AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||
| const char *const SOC_VERSION = "ge.socVersion"; | |||
| const char *const CORE_TYPE = "ge.engineType"; | |||
| const char *const AICORE_NUM = "ge.aicoreNum"; | |||
| const char *const L1_FUSION = "ge.l1Fusion"; | |||
| const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||
| const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||
| const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||
| const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||
| const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||
| const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | |||
| const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||
| const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||
| } // namespace configure_option | |||
| // Configure stream num by Session constructor options param, | |||
| // its value should be int32_t type, default value is "1" | |||
| const std::string STREAM_NUM = "ge.streamNum"; | |||
| @@ -45,8 +45,11 @@ struct ModelBufferData | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildInitialize(std::map<ge::AscendString, ge::AscendString> &)) | |||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | |||
| graphStatus aclgrphBuildInitialize(std::map<ge::AscendString, ge::AscendString> &global_options); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief build model.Notice the model is stored in buffer | |||
| @@ -64,8 +67,13 @@ void aclgrphBuildFinalize(); | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<ge::AscendString,ge::AscendString> &, | |||
| ModelBufferData&)) | |||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, ModelBufferData& model); | |||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<ge::AscendString,ge::AscendString> &build_options, | |||
| ModelBufferData& model); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief save model buffer to file | |||
| @@ -75,8 +83,11 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| ATTRIBUTED_DEPRECATED(graphStatus aclgrphSaveModel(const char *, const ModelBufferData&)) | |||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData& model); | |||
| graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData& model); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief query IR interface version | |||