Browse Source

!1955 Detach SessionManager from GELib

Merge pull request !1955 from 张晓昆/master
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
2bfa1ba528
14 changed files with 303 additions and 273 deletions
  1. +6
    -3
      ge/CMakeLists.txt
  2. +133
    -195
      ge/client/ge_api.cc
  3. +7
    -2
      ge/graph/execute/model_executor.cc
  4. +2
    -1
      ge/graph/execute/model_executor.h
  5. +3
    -1
      ge/graph/manager/graph_manager.cc
  6. +2
    -2
      ge/hybrid/model/node_item.cc
  7. +0
    -21
      ge/init/gelib.cc
  8. +7
    -5
      ge/init/gelib.h
  9. +1
    -3
      ge/session/inner_session.cc
  10. +0
    -5
      ge/session/session_manager.cc
  11. +19
    -21
      ge/session/session_manager.h
  12. +10
    -10
      tests/ut/ge/graph/execute/model_executor_unittest.cc
  13. +3
    -2
      tests/ut/ge/graph/load/model_manager_unittest.cc
  14. +110
    -2
      tests/ut/ge/session/ge_api_unittest.cc

+ 6
- 3
ge/CMakeLists.txt View File

@@ -474,9 +474,6 @@ set(INFER_SRC_LIST
"common/ge/plugin_manager.cc" "common/ge/plugin_manager.cc"
"common/ge/op_tiling_manager.cc" "common/ge/op_tiling_manager.cc"
"init/gelib.cc" "init/gelib.cc"
"session/inner_session.cc"
"session/session_manager.cc"
"graph/execute/model_executor.cc"
"engine_manager/dnnengine_manager.cc" "engine_manager/dnnengine_manager.cc"
"opskernel_manager/ops_kernel_manager.cc" "opskernel_manager/ops_kernel_manager.cc"
"opskernel_manager/ops_kernel_builder_manager.cc" "opskernel_manager/ops_kernel_builder_manager.cc"
@@ -721,6 +718,12 @@ set(INFER_SRC_LIST
"ge_opt_info/ge_opt_info.cc" "ge_opt_info/ge_opt_info.cc"
) )


set(RUNNER_SRC_LIST
"client/ge_api.cc"
"session/inner_session.cc"
"session/session_manager.cc"
)

if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
message("CMAKE_CXX_COMPILER_VERSION = ${CMAKE_CXX_COMPILER_VERSION}") message("CMAKE_CXX_COMPILER_VERSION = ${CMAKE_CXX_COMPILER_VERSION}")
############ libge_runner.so ############ ############ libge_runner.so ############


+ 133
- 195
ge/client/ge_api.cc View File

@@ -47,6 +47,7 @@ const int32_t kMaxStrLen = 128;


static bool g_ge_initialized = false; static bool g_ge_initialized = false;
static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use
static std::shared_ptr<ge::SessionManager> g_session_manager;


namespace ge { namespace ge {
void GetOpsProtoPath(std::string &opsproto_path) { void GetOpsProtoPath(std::string &opsproto_path) {
@@ -70,8 +71,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); auto job_id_iter = options.find(OPTION_EXEC_JOB_ID);
if (job_id_iter != options.end()) { if (job_id_iter != options.end()) {
if (job_id_iter->second.length() > kMaxStrLen) { if (job_id_iter->second.length() > kMaxStrLen) {
GELOGE(PARAM_INVALID, "[Check][JobId]Failed,"
"the job_id [%s] string length: %zu > max string length: %d",
GELOGE(PARAM_INVALID, "[Check][JobId]Failed, the job_id [%s] string length: %zu > max string length: %d",
job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen);
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id", "length"}), REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id", "length"}),
std::vector<std::string>({job_id_iter->second, std::vector<std::string>({job_id_iter->second,
@@ -95,8 +95,7 @@ Status GEInitializeImpl(const std::map<string, string> &options) {
std::string path_base = ge::GELib::GetPath(); std::string path_base = ge::GELib::GetPath();
auto ret = ErrorManager::GetInstance().Init(path_base); auto ret = ErrorManager::GetInstance().Init(path_base);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(GE_CLI_INIT_FAILED,
"[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str());
GELOGE(GE_CLI_INIT_FAILED, "[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str());
REPORT_CALL_ERROR("E19999", "Init failed when pass param path_base:%s", path_base.c_str()); REPORT_CALL_ERROR("E19999", "Init failed when pass param path_base:%s", path_base.c_str());
return ret; return ret;
} }
@@ -117,11 +116,9 @@ Status GEInitializeImpl(const std::map<string, string> &options) {
bool is_proto_init = manager->Initialize(option_tmp); bool is_proto_init = manager->Initialize(option_tmp);
GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize");
if (!is_proto_init) { if (!is_proto_init) {
GELOGE(GE_CLI_INIT_FAILED,
"[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.",
GELOGE(GE_CLI_INIT_FAILED, "[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.",
opsproto_path.c_str()); opsproto_path.c_str());
REPORT_CALL_ERROR("E19999", "Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid",
opsproto_path.c_str());
REPORT_CALL_ERROR("E19999", "Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid", opsproto_path.c_str());
return FAILED; return FAILED;
} }


@@ -148,6 +145,22 @@ Status GEInitializeImpl(const std::map<string, string> &options) {
return FAILED; return FAILED;
} }


ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther);
GELOGI("sessionManager initial.");
GE_TIMESTAMP_START(SessionManagerInitialize);
g_session_manager = MakeShared<ge::SessionManager>();
if (g_session_manager == nullptr) {
GELOGE(GE_CLI_INIT_FAILED, "[Init][Create]SessionManager failed");
return FAILED;
}
ret = g_session_manager->Initialize(options);
GE_TIMESTAMP_END(SessionManagerInitialize, "InnerInitialize::SessionManagerInitialize");
if (ret != SUCCESS) {
GELOGE(ret, "[Init][SessionManager] GE session manager initial failed.");
REPORT_CALL_ERROR("E19999", "SessionManager initialize failed.");
return ret;
}

// 7.check return status, return // 7.check return status, return
if (!g_ge_initialized) { if (!g_ge_initialized) {
// Initialize success, first time calling initialize // Initialize success, first time calling initialize
@@ -173,8 +186,7 @@ Status GEInitialize(const std::map<AscendString, AscendString> &options) {
for (auto &option : options) { for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "[Check][Param]Options invalid, first or second option is nullptr."); GELOGE(FAILED, "[Check][Param]Options invalid, first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Check parameter's options invalid,"
"the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Check parameter's options invalid, the first or second option is nullptr.");
return FAILED; return FAILED;
} }
std::string key = option.first.GetString(); std::string key = option.first.GetString();
@@ -217,6 +229,12 @@ Status GEFinalize() {
ret = middle_ret; ret = middle_ret;
} }
} }

GELOGI("SessionManager finalization.");
if (g_session_manager != nullptr) {
(void)g_session_manager->Finalize(); // always success.
}

middle_ret = TBEPluginManager::Instance().Finalize(); middle_ret = TBEPluginManager::Instance().Finalize();
if (middle_ret != SUCCESS) { if (middle_ret != SUCCESS) {
ret = middle_ret; ret = middle_ret;
@@ -251,28 +269,18 @@ std::string GEGetWarningMsg() {
Session::Session(const std::map<string, string> &options) { Session::Session(const std::map<string, string> &options) {
ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther);
GELOGT(TRACE_INIT, "Start to construct session."); GELOGT(TRACE_INIT, "Start to construct session.");

ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
// check init status // check init status
sessionId_ = 0; sessionId_ = 0;
if (!g_ge_initialized) { if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999",
"Creating session failed because lack GEInitialize call before.");
return;
}
// call Initialize
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed, GELib instance is nullptr or it is not InitFlag");
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return; return;
} }


GELOGT(TRACE_RUNNING, "Creating session"); GELOGT(TRACE_RUNNING, "Creating session");
uint64_t session_id = 0; uint64_t session_id = 0;
Status ret = instance_ptr->SessionManagerObj().CreateSession(options, session_id);
Status ret = g_session_manager->CreateSession(options, session_id);
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);


// check return status, return, update session id if success // check return status, return, update session id if success
@@ -288,32 +296,21 @@ Session::Session(const std::map<string, string> &options) {
Session::Session(const std::map<AscendString, AscendString> &options) { Session::Session(const std::map<AscendString, AscendString> &options) {
ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther);
GELOGT(TRACE_INIT, "Session Constructor start"); GELOGT(TRACE_INIT, "Session Constructor start");

ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
// check init status // check init status
sessionId_ = 0; sessionId_ = 0;
if (!g_ge_initialized) { if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999",
"Creating session failed because lack GEInitialize call before.");
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return; return;
} }
// call Initialize // call Initialize
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed, the GELib instance is nullptr or is not InitFlag");
return;
}

GELOGT(TRACE_RUNNING, "Creating session"); GELOGT(TRACE_RUNNING, "Creating session");
std::map<std::string, std::string> str_options; std::map<std::string, std::string> str_options;
for (auto &option : options) { for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "[Construct][Session]Failed, the first or second option is nullptr."); GELOGE(FAILED, "[Construct][Session]Failed, the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Creating session's options invalid,"
"the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Creating session's options invalid, the first or second option is nullptr.");
return; return;
} }
std::string key = option.first.GetString(); std::string key = option.first.GetString();
@@ -321,7 +318,7 @@ Session::Session(const std::map<AscendString, AscendString> &options) {
str_options[key] = val; str_options[key] = val;
} }
uint64_t session_id = 0; uint64_t session_id = 0;
Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id);
Status ret = g_session_manager->CreateSession(str_options, session_id);
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);


// check return status, return, update session id if success // check return status, return, update session id if success
@@ -350,19 +347,12 @@ Session::~Session() {
try { try {
uint64_t session_id = sessionId_; uint64_t session_id = sessionId_;
// call DestroySession // call DestroySession
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGW("GE is not yet initialized or is finalized.");
return;
}
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);

GELOGT(TRACE_RUNNING, "Destroying session"); GELOGT(TRACE_RUNNING, "Destroying session");


ret = instance_ptr->SessionManagerObj().DestroySession(session_id);
ret = g_session_manager->DestroySession(session_id);
} catch (google::protobuf::FatalException &e) { } catch (google::protobuf::FatalException &e) {
GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed "
"because get fatalException.");
GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed because get fatalException.");
REPORT_CALL_ERROR("E19999", "Destruct session failed, get fatal exception"); REPORT_CALL_ERROR("E19999", "Destruct session failed, get fatal exception");
} }


@@ -377,9 +367,7 @@ Session::~Session() {


// Add Graph // Add Graph
Status Session::AddGraph(uint32_t graph_id, const Graph &graph) { Status Session::AddGraph(uint32_t graph_id, const Graph &graph) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
std::map<std::string, std::string> options; std::map<std::string, std::string> options;
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
return AddGraph(graph_id, graph, options); return AddGraph(graph_id, graph, options);
} }


@@ -388,20 +376,16 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed because GELib instance is nullptr or it is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or it is not InitFlag.");
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGD("Adding graph to session"); GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, options);
Status ret = g_session_manager->AddGraph(sessionId_, graph_id, graph, options);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
GELOGD("AddGraph finished in Session."); GELOGD("AddGraph finished in Session.");
@@ -409,37 +393,31 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s
} }


//Add Graph //Add Graph
Status Session::AddGraph(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<AscendString, AscendString> &options) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or it is not InitFlag.");
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGD("Adding graph to session"); GELOGD("Adding graph to session");
std::map<std::string, std::string> str_options; std::map<std::string, std::string> str_options;
for (auto &option : options) { for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "[Add][Graph]Failed, the first or second option is nullptr."); GELOGE(FAILED, "[Add][Graph]Failed, the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999",
"Add Graph Failed, the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Add Graph Failed, the first or second option is nullptr.");
return FAILED; return FAILED;
} }
std::string key = option.first.GetString(); std::string key = option.first.GetString();
std::string val = option.second.GetString(); std::string val = option.second.GetString();
str_options[key] = val; str_options[key] = val;
} }
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options);
Status ret = g_session_manager->AddGraph(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
GELOGD("AddGraph finished in Session."); GELOGD("AddGraph finished in Session.");
@@ -447,8 +425,6 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph,
} }


Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::map<AscendString, AscendString> options; std::map<AscendString, AscendString> options;
return AddGraphWithCopy(graph_id, graph, options); return AddGraphWithCopy(graph_id, graph, options);
} }
@@ -459,24 +435,20 @@ Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or is not InitFlag.");
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

std::map<std::string, std::string> str_options; std::map<std::string, std::string> str_options;
for (auto it = options.begin(); it != options.end(); ++it) { for (auto it = options.begin(); it != options.end(); ++it) {
str_options.insert({it->first.GetString(), it->second.GetString()}); str_options.insert({it->first.GetString(), it->second.GetString()});
} }
GELOGD("Adding graph to session"); GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options);
Status ret = g_session_manager->AddGraphWithCopy(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
GELOGD("AddGraph finished in Session."); GELOGD("AddGraph finished in Session.");
@@ -487,29 +459,21 @@ Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
Status Session::RemoveGraph(uint32_t graph_id) { Status Session::RemoveGraph(uint32_t graph_id) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Session RemoveGraph start"); GELOGT(TRACE_INIT, "Session RemoveGraph start");

ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
// call RemoveGraph // call RemoveGraph
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (!instance_ptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Remove][Graph]Failed, GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
REPORT_INNER_ERROR("E19999",
"RemoveGraph Failed, GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }


GELOGT(TRACE_RUNNING, "Removing Graph from session"); GELOGT(TRACE_RUNNING, "Removing Graph from session");
Status ret = instance_ptr->SessionManagerObj().RemoveGraph(sessionId_, graph_id);
Status ret = g_session_manager->RemoveGraph(sessionId_, graph_id);
// check return status, return // check return status, return
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, "
"session_id:%lu, graph_id:%u", ret, sessionId_, graph_id);
GELOGE(ret, "[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, session_id:%lu, graph_id:%u",
ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
GELOGT(TRACE_STOP, "Session RemoveGraph finished"); GELOGT(TRACE_STOP, "Session RemoveGraph finished");
@@ -568,29 +532,21 @@ void PrintOutputResult(std::vector<Tensor> &outputs) {
Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) { Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Session RunGraph start"); GELOGT(TRACE_INIT, "Session RunGraph start");

ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::vector<Tensor> graph_inputs = inputs;
// call RunGraph
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]Failed, GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
REPORT_INNER_ERROR("E19999",
"RunGraph Failed, GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

// call RunGraph
GELOGT(TRACE_RUNNING, "Running Graph"); GELOGT(TRACE_RUNNING, "Running Graph");
Status ret = instance_ptr->SessionManagerObj().RunGraph(sessionId_, graph_id, graph_inputs, outputs);
Status ret = g_session_manager->RunGraph(sessionId_, graph_id, inputs, outputs);
// check return status // check return status
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, "
"session_id:%lu, graph_id:%u", ret, sessionId_, graph_id);
GELOGE(ret, "[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, session_id:%lu, graph_id:%u",
ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }


@@ -609,30 +565,15 @@ Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const s
std::vector<Tensor> &outputs) { std::vector<Tensor> &outputs) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
GELOGT(TRACE_INIT, "Start to run graph with stream async."); GELOGT(TRACE_INIT, "Start to run graph with stream async.");

ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]Run graph with stream async failed, the GELib instance is nullptr,"
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream);
REPORT_INNER_ERROR("E19999",
"Run graph with stream async failed, the GELib instance is nullptr"
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream);
return FAILED;
}
if (!instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]Run graph with stream asyn failed, the GELib instance is not init,"
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream);
REPORT_INNER_ERROR("E19999",
"Run graph with stream asyn failed, the GELib instance is not init,"
"session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Run Graph Run graph with stream asyn."); GELOGT(TRACE_RUNNING, "Run Graph Run graph with stream asyn.");
Status ret = instance_ptr->SessionManagerObj().RunGraphWithStreamAsync(sessionId_, graph_id, stream, inputs,
outputs);
Status ret = g_session_manager->RunGraphWithStreamAsync(sessionId_, graph_id, stream, inputs, outputs);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Run][Graph]Run graph with stream asyn Failed," GELOGE(ret, "[Run][Graph]Run graph with stream asyn Failed,"
"error code = %u, session id = %lu, graph id = %u, stream = %p.", ret, sessionId_, graph_id, stream); "error code = %u, session id = %lu, graph id = %u, stream = %p.", ret, sessionId_, graph_id, stream);
@@ -648,40 +589,46 @@ Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const s
// Register Call Back // Register Call Back
Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) {
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED;
}

return g_session_manager->RegisterCallBackFunc(sessionId_, key, callback);
} }


Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) {
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED;
}

std::string str_key; std::string str_key;
if (key != nullptr) { if (key != nullptr) {
str_key = key; str_key = key;
} }
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback);
return g_session_manager->RegisterCallBackFunc(sessionId_, str_key, callback);
} }


// Build Graph // Build Graph
Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
REPORT_INNER_ERROR("E19999",
"Build graph failed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Building Graph"); GELOGT(TRACE_RUNNING, "Building Graph");
Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs);
Status ret = g_session_manager->BuildGraph(sessionId_, graph_id, inputs);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, "
"session_id:%lu, graph_id:%u", ret, sessionId_, graph_id);
GELOGE(ret, "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, session_id:%lu, graph_id:%u",
ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
return SUCCESS; return SUCCESS;
@@ -691,24 +638,18 @@ Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo>
Status Session::BuildGraph(uint32_t graph_id, const std::vector<ge::Tensor> &inputs) { Status Session::BuildGraph(uint32_t graph_id, const std::vector<ge::Tensor> &inputs) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
REPORT_INNER_ERROR("E19999",
"Build graph failed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Building Graph"); GELOGT(TRACE_RUNNING, "Building Graph");
Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs);
Status ret = g_session_manager->BuildGraph(sessionId_, graph_id, inputs);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret,
"[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, "
"session_id:%lu, graph_id:%u", ret, sessionId_, graph_id);
GELOGE(ret, "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, session_id:%lu, graph_id:%u",
ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
return SUCCESS; return SUCCESS;
@@ -719,26 +660,22 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<ge::Tensor> &
RunAsyncCallback callback) { RunAsyncCallback callback) {
ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]RunGraphAsyncFailed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
REPORT_INNER_ERROR("E19999",
"RunGraphAsync Failed, the GELib instance is nullptr or is not InitFlag, "
"session_id %lu, graph_id %u", sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Run Graph Asynchronously"); GELOGT(TRACE_RUNNING, "Run Graph Asynchronously");
GELOGW( GELOGW(
"The callback function will not be checked. Please ensure that the implementation of the function is trusted."); "The callback function will not be checked. Please ensure that the implementation of the function is trusted.");


Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback);
Status ret = g_session_manager->RunGraphAsync(sessionId_, graph_id, inputs, callback);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Run][Graph]RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u.", GELOGE(ret, "[Run][Graph]RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id); ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "RunGraphAsync Failed, error code:%u, session_id:%lu, "
"graph_id:%u", ret, sessionId_, graph_id);
REPORT_CALL_ERROR("E19999", "RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u",
ret, sessionId_, graph_id);
return FAILED; return FAILED;
} }
return SUCCESS; return SUCCESS;
@@ -748,16 +685,14 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<ge::Tensor> &
Status Session::GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values) { Status Session::GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values) {
ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"GetVariables failed, the GELib instance is nullptr or is not InitFlag.");
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Get Variables"); GELOGT(TRACE_RUNNING, "Get Variables");
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, var_names, var_values);
Status ret = g_session_manager->GetVariables(sessionId_, var_names, var_values);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_);
return FAILED; return FAILED;
@@ -769,14 +704,12 @@ Status Session::GetVariables(const std::vector<std::string> &var_names, std::vec
Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) { Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) {
ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"GetVariables failed, the GELib instance is nullptr or is not InitFlag.");
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return FAILED; return FAILED;
} }

GELOGT(TRACE_RUNNING, "Get Variables"); GELOGT(TRACE_RUNNING, "Get Variables");
std::vector<ge::string> str_var_names; std::vector<ge::string> str_var_names;
for (auto &var_name : var_names) { for (auto &var_name : var_names) {
@@ -787,17 +720,22 @@ Status Session::GetVariables(const std::vector<AscendString> &var_names, std::ve
} }
str_var_names.emplace_back(var_name.GetString()); str_var_names.emplace_back(var_name.GetString());
} }
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values);
Status ret = g_session_manager->GetVariables(sessionId_, str_var_names, var_values);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_);
REPORT_CALL_ERROR("E19999", "Get variables failed, error code:%u, session_id:%lu.",
ret, sessionId_);
REPORT_CALL_ERROR("E19999", "Get variables failed, error code:%u, session_id:%lu.", ret, sessionId_);
return FAILED; return FAILED;
} }
return SUCCESS; return SUCCESS;
} }


bool Session::IsGraphNeedRebuild(uint32_t graph_id) { bool Session::IsGraphNeedRebuild(uint32_t graph_id) {
return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id);
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before.");
REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before.");
return false;
}

return g_session_manager->IsGraphNeedRebuild(sessionId_, graph_id);
} }
} // namespace ge } // namespace ge

+ 7
- 2
ge/graph/execute/model_executor.cc View File

@@ -23,6 +23,7 @@
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/load/graph_loader.h" #include "graph/load/graph_loader.h"
#include "graph/load/model_manager/model_manager.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"


@@ -38,7 +39,7 @@ namespace ge {
/// @param [in] options user config params /// @param [in] options user config params
/// @return Status result of function /// @return Status result of function
/// ///
Status ModelExecutor::Initialize(const map<string, string> &options) {
Status ModelExecutor::Initialize(const map<string, string> &options, uint64_t session_id) {
graph_run_listener_ = MakeShared<GraphModelListener>(sync_run_mutex_, condition_); graph_run_listener_ = MakeShared<GraphModelListener>(sync_run_mutex_, condition_);
if (graph_run_listener_ == nullptr) { if (graph_run_listener_ == nullptr) {
REPORT_CALL_ERROR("E19999", "New GraphModelListener fail"); REPORT_CALL_ERROR("E19999", "New GraphModelListener fail");
@@ -46,6 +47,7 @@ Status ModelExecutor::Initialize(const map<string, string> &options) {
return MEMALLOC_FAILED; return MEMALLOC_FAILED;
} }


session_id_ = session_id;
train_graph_flag_ = ParseTrainGraphFlag(); train_graph_flag_ = ParseTrainGraphFlag();
thread_run_flag_.store(true); thread_run_flag_.store(true);
run_thread_ = std::thread(&ModelExecutor::RunThread, this); run_thread_ = std::thread(&ModelExecutor::RunThread, this);
@@ -74,6 +76,7 @@ Status ModelExecutor::Finalize() {
GELOGW("Graph executor FreeExecuteMemory failed, resources may not be released correctly."); GELOGW("Graph executor FreeExecuteMemory failed, resources may not be released correctly.");
} }


ModelManager::GetInstance()->DestroyAicpuSession(session_id_);
return SUCCESS; return SUCCESS;
} }


@@ -168,7 +171,9 @@ void ModelExecutor::ReturnError(RunAsyncCallback callback, Status ret, const str
StopQueue(); StopQueue();
GELOGE(ret, "%s.", log.c_str()); GELOGE(ret, "%s.", log.c_str());
std::vector<ge::Tensor> outputs; std::vector<ge::Tensor> outputs;
callback(ret, outputs);
if (callback != nullptr) {
callback(ret, outputs);
}
} }


void ModelExecutor::UpdateLocalOmeContext(const GraphNodePtr &graph_node) { void ModelExecutor::UpdateLocalOmeContext(const GraphNodePtr &graph_node) {


+ 2
- 1
ge/graph/execute/model_executor.h View File

@@ -30,7 +30,7 @@ class ModelExecutor : public Executor {
/// @param [in] options user config params /// @param [in] options user config params
/// @return Status result of function /// @return Status result of function
/// ///
Status Initialize(const map<string, string> &options);
Status Initialize(const map<string, string> &options, uint64_t session_id);


/// ///
/// @ingroup ge /// @ingroup ge
@@ -120,6 +120,7 @@ class ModelExecutor : public Executor {


bool init_flag_{false}; bool init_flag_{false};
bool train_graph_flag_{false}; bool train_graph_flag_{false};
uint64_t session_id_{0};
GraphExecutor graph_executor_; GraphExecutor graph_executor_;


std::mutex mutex_; std::mutex mutex_;


+ 3
- 1
ge/graph/manager/graph_manager.cc View File

@@ -2939,7 +2939,9 @@ void GraphManager::ReturnError(RunAsyncCallback callback, Status ret, const stri
StopQueue(); StopQueue();
GELOGE(ret, "%s.", log.c_str()); GELOGE(ret, "%s.", log.c_str());
std::vector<ge::Tensor> outputs; std::vector<ge::Tensor> outputs;
callback(ret, outputs);
if (callback != nullptr) {
callback(ret, outputs);
}
} }


bool GraphManager::IsGraphNeedRebuild(uint32_t graph_id) { bool GraphManager::IsGraphNeedRebuild(uint32_t graph_id) {


+ 2
- 2
ge/hybrid/model/node_item.cc View File

@@ -25,7 +25,7 @@ namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
const uint8_t kMaxTransCount = 3; const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const uint8_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal"; const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{ const std::set<std::string> kControlOpTypes{
@@ -47,7 +47,7 @@ bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> TransData -> Cast -> node // For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) { for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str());
GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str());
return true; return true;
} }




+ 0
- 21
ge/init/gelib.cc View File

@@ -160,18 +160,6 @@ Status GELib::InnerInitialize(const map<string, string> &options) {
return initOpsBuilderStatus; return initOpsBuilderStatus;
} }


ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther);
GELOGI("sessionManager initial.");
GE_TIMESTAMP_START(SessionManagerInitialize);
Status initSmStatus = sessionManager_.Initialize(options);
GE_TIMESTAMP_END(SessionManagerInitialize, "InnerInitialize::SessionManagerInitialize");
if (initSmStatus != SUCCESS) {
GELOGE(initSmStatus, "[Init][SessionManager] GE session manager initial failed.");
REPORT_CALL_ERROR("E19999", "SessionManager initialize failed.");
RollbackInit();
return initSmStatus;
}

GELOGI("Start to initialize HostCpuEngine"); GELOGI("Start to initialize HostCpuEngine");
GE_TIMESTAMP_START(HostCpuEngineInitialize); GE_TIMESTAMP_START(HostCpuEngineInitialize);
Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize();
@@ -454,12 +442,6 @@ Status GELib::Finalize() {
GELOGW("engineManager finalize failed"); GELOGW("engineManager finalize failed");
final_state = mid_state; final_state = mid_state;
} }
GELOGI("sessionManager finalization.");
mid_state = sessionManager_.Finalize();
if (mid_state != SUCCESS) {
GELOGW("sessionManager finalize failed");
final_state = mid_state;
}


GELOGI("opsBuilderManager finalization."); GELOGI("opsBuilderManager finalization.");
mid_state = OpsKernelBuilderManager::Instance().Finalize(); mid_state = OpsKernelBuilderManager::Instance().Finalize();
@@ -539,9 +521,6 @@ void GELib::RollbackInit() {
if (opsManager_.init_flag_) { if (opsManager_.init_flag_) {
(void)opsManager_.Finalize(); (void)opsManager_.Finalize();
} }
if (sessionManager_.init_flag_) {
(void)sessionManager_.Finalize();
}
MemManager::Instance().Finalize(); MemManager::Instance().Finalize();
HostMemManager::Instance().Finalize(); HostMemManager::Instance().Finalize();
VarManagerPool::Instance().Destory(); VarManagerPool::Instance().Destory();


+ 7
- 5
ge/init/gelib.h View File

@@ -22,7 +22,13 @@
#include <vector> #include <vector>
#include "engine_manager/dnnengine_manager.h" #include "engine_manager/dnnengine_manager.h"
#include "opskernel_manager/ops_kernel_manager.h" #include "opskernel_manager/ops_kernel_manager.h"
#include "session/session_manager.h"
#include "graph/tuning_utils.h"
#include "graph/operator_factory.h"
#include "graph/ge_local_context.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/anchor_utils.h"
#include "graph/manager/graph_var_manager.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "framework/common/ge_types.h" #include "framework/common/ge_types.h"


@@ -53,9 +59,6 @@ class GE_FUNC_VISIBILITY GELib {
// get OpsKernelManager object // get OpsKernelManager object
OpsKernelManager &OpsKernelManagerObj() { return opsManager_; } OpsKernelManager &OpsKernelManagerObj() { return opsManager_; }


// get SessionManager object
SessionManager &SessionManagerObj() { return sessionManager_; }

// get Initial flag // get Initial flag
bool InitFlag() const { return init_flag_; } bool InitFlag() const { return init_flag_; }


@@ -90,7 +93,6 @@ class GE_FUNC_VISIBILITY GELib {


DNNEngineManager engineManager_; DNNEngineManager engineManager_;
OpsKernelManager opsManager_; OpsKernelManager opsManager_;
SessionManager sessionManager_;
std::mutex status_mutex_; std::mutex status_mutex_;
bool init_flag_ = false; bool init_flag_ = false;
Options options_; Options options_;


+ 1
- 3
ge/session/inner_session.cc View File

@@ -30,7 +30,6 @@
#include "graph/ge_global_options.h" #include "graph/ge_global_options.h"
#include "graph/ge_local_context.h" #include "graph/ge_local_context.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
#include "graph/load/model_manager/model_manager.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph/manager/graph_mem_manager.h" #include "graph/manager/graph_mem_manager.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
@@ -169,7 +168,6 @@ Status InnerSession::Finalize() {
REPORT_CALL_ERROR("E19999", "GraphManager Finalize failed, InnerSession:%lu.", session_id_); REPORT_CALL_ERROR("E19999", "GraphManager Finalize failed, InnerSession:%lu.", session_id_);
} }


ModelManager::GetInstance()->DestroyAicpuSession(session_id_);
init_flag_ = false; init_flag_ = false;
// release var memory // release var memory
GELOGI("VarManager free var memory."); GELOGI("VarManager free var memory.");
@@ -189,7 +187,7 @@ Status InnerSession::Finalize() {
} }


Status InnerSession::InnerInitialize() { Status InnerSession::InnerInitialize() {
Status ret = model_executor_.Initialize(options_);
Status ret = model_executor_.Initialize(options_, session_id_);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Init][GraphExecutor] failed, InnerSession:%lu.", session_id_); GELOGE(ret, "[Init][GraphExecutor] failed, InnerSession:%lu.", session_id_);
REPORT_CALL_ERROR("E19999", "GraphExecutor initialize failed, InnerSession:%lu.", session_id_); REPORT_CALL_ERROR("E19999", "GraphExecutor initialize failed, InnerSession:%lu.", session_id_);


+ 0
- 5
ge/session/session_manager.cc View File

@@ -20,7 +20,6 @@
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "graph/load/model_manager/model_manager.h"
#include "graph/manager/util/rt_context_util.h" #include "graph/manager/util/rt_context_util.h"


using std::map; using std::map;
@@ -105,10 +104,6 @@ Status SessionManager::DestroySession(SessionId session_id) {
return GE_SESSION_NOT_EXIST; return GE_SESSION_NOT_EXIST;
} }


if (ModelManager::GetInstance() != nullptr) {
ModelManager::GetInstance()->DestroyAicpuSession(session_id);
}

// Unified destruct rt_context // Unified destruct rt_context
RtContextUtil::GetInstance().DestroyRtContexts(session_id); RtContextUtil::GetInstance().DestroyRtContexts(session_id);




+ 19
- 21
ge/session/session_manager.h View File

@@ -31,9 +31,26 @@ namespace ge {
using SessionPtr = std::shared_ptr<InnerSession>; using SessionPtr = std::shared_ptr<InnerSession>;


class SessionManager { class SessionManager {
friend class GELib;

public: public:
SessionManager() = default;

~SessionManager() = default;

///
/// @ingroup ge_session
/// @brief initialize session manager
/// @param [in] options session manager config options
/// @return Status result of function
///
Status Initialize(const std::map<std::string, std::string> &options);

///
/// @ingroup ge_session
/// @brief finalize session manager
/// @return Status result of function
///
Status Finalize();

/// ///
/// @ingroup ge_session /// @ingroup ge_session
/// @brief create session /// @brief create session
@@ -181,25 +198,6 @@ class SessionManager {
bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id);


private: private:
SessionManager() = default;

~SessionManager() = default;

///
/// @ingroup ge_session
/// @brief initialize session manager
/// @param [in] options session manager config options
/// @return Status result of function
///
Status Initialize(const std::map<std::string, std::string> &options);

///
/// @ingroup ge_session
/// @brief finalize session manager
/// @return Status result of function
///
Status Finalize();

bool HasSession(SessionId session_id); bool HasSession(SessionId session_id);


Status GetNextSessionId(SessionId &next_session_id); Status GetNextSessionId(SessionId &next_session_id);


+ 10
- 10
tests/ut/ge/graph/execute/model_executor_unittest.cc View File

@@ -63,7 +63,7 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string


TEST_F(UtestModelExecutorTest, test_load_graph_sync) { TEST_F(UtestModelExecutorTest, test_load_graph_sync) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


auto compute_graph = MakeShared<ComputeGraph>("test_graph"); auto compute_graph = MakeShared<ComputeGraph>("test_graph");
GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph); GeRootModelPtr ge_root_model = MakeShared<GeRootModel>(compute_graph);
@@ -86,7 +86,7 @@ TEST_F(UtestModelExecutorTest, test_load_graph_sync) {


TEST_F(UtestModelExecutorTest, test_load_graph_async) { TEST_F(UtestModelExecutorTest, test_load_graph_async) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


Graph graph("test_graph"); Graph graph("test_graph");
auto compute_graph = MakeShared<ComputeGraph>("test_graph"); auto compute_graph = MakeShared<ComputeGraph>("test_graph");
@@ -111,7 +111,7 @@ TEST_F(UtestModelExecutorTest, test_load_graph_async) {


TEST_F(UtestModelExecutorTest, test_load_graph_failed) { TEST_F(UtestModelExecutorTest, test_load_graph_failed) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


Graph graph("test_graph"); Graph graph("test_graph");
auto compute_graph = MakeShared<ComputeGraph>("test_graph"); auto compute_graph = MakeShared<ComputeGraph>("test_graph");
@@ -144,7 +144,7 @@ TEST_F(UtestModelExecutorTest, test_check_and_release_memory) {
} }


ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


GeModelPtr ge_model = make_shared<GeModel>(); GeModelPtr ge_model = make_shared<GeModel>();
int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL;
@@ -171,7 +171,7 @@ TEST_F(UtestModelExecutorTest, test_check_and_release_memory) {


TEST_F(UtestModelExecutorTest, parse_inputs_dims_data) { TEST_F(UtestModelExecutorTest, parse_inputs_dims_data) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


OmeContext context; OmeContext context;
SetLocalOmeContext(context); SetLocalOmeContext(context);
@@ -195,7 +195,7 @@ TEST_F(UtestModelExecutorTest, parse_inputs_dims_data) {


TEST_F(UtestModelExecutorTest, parse_inputs_dims_getnext) { TEST_F(UtestModelExecutorTest, parse_inputs_dims_getnext) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


OmeContext context; OmeContext context;
SetLocalOmeContext(context); SetLocalOmeContext(context);
@@ -223,7 +223,7 @@ TEST_F(UtestModelExecutorTest, parse_inputs_dims_getnext) {


TEST_F(UtestModelExecutorTest, test_run_thread) { TEST_F(UtestModelExecutorTest, test_run_thread) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


GraphId graph_id = 1; GraphId graph_id = 1;
uint64_t session_id = 0; uint64_t session_id = 0;
@@ -281,7 +281,7 @@ static void test_run_graph(ModelExecutor &model_executor) {
TEST_F(UtestModelExecutorTest, test_run_graph_train) { TEST_F(UtestModelExecutorTest, test_run_graph_train) {
GetThreadLocalContext().SetGlobalOption({{OPTION_GRAPH_RUN_MODE, "1"}}); GetThreadLocalContext().SetGlobalOption({{OPTION_GRAPH_RUN_MODE, "1"}});
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);
test_run_graph(model_executor); test_run_graph(model_executor);
EXPECT_EQ(model_executor.Finalize(), SUCCESS); EXPECT_EQ(model_executor.Finalize(), SUCCESS);
} }
@@ -291,14 +291,14 @@ TEST_F(UtestModelExecutorTest, test_run_graph_infer) {
GetThreadLocalContext().SetSessionOption({}); GetThreadLocalContext().SetSessionOption({});
GetThreadLocalContext().SetGraphOption({}); GetThreadLocalContext().SetGraphOption({});
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);
test_run_graph(model_executor); test_run_graph(model_executor);
EXPECT_EQ(model_executor.Finalize(), SUCCESS); EXPECT_EQ(model_executor.Finalize(), SUCCESS);
} }


TEST_F(UtestModelExecutorTest, test_run_graph_with_stream) { TEST_F(UtestModelExecutorTest, test_run_graph_with_stream) {
ModelExecutor model_executor; ModelExecutor model_executor;
EXPECT_EQ(model_executor.Initialize({}), SUCCESS);
EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS);


GraphId graph_id = 1; GraphId graph_id = 1;
auto compute_graph = MakeShared<ComputeGraph>("test_graph"); auto compute_graph = MakeShared<ComputeGraph>("test_graph");


+ 3
- 2
tests/ut/ge/graph/load/model_manager_unittest.cc View File

@@ -78,7 +78,7 @@ class UtestModelManagerModelManager : public testing::Test {
const int model_len = 10; const int model_len = 10;
data.model_len = sizeof(ModelFileHeader) + model_len; data.model_len = sizeof(ModelFileHeader) + model_len;
data.model_data = new uint8_t[data.model_len]; data.model_data = new uint8_t[data.model_len];
memset((uint8_t *)data.model_data + sizeof(ModelFileHeader), 10, model_len);
memset((uint8_t *)data.model_data + sizeof(ModelFileHeader), 0, model_len);


ModelFileHeader *header = (ModelFileHeader *)data.model_data; ModelFileHeader *header = (ModelFileHeader *)data.model_data;
header->magic = MODEL_FILE_MAGIC_NUM; header->magic = MODEL_FILE_MAGIC_NUM;
@@ -93,7 +93,7 @@ class UtestModelManagerModelManager : public testing::Test {
data.key = ENC_KEY; data.key = ENC_KEY;
data.model_data = new uint8_t[data.model_len]; data.model_data = new uint8_t[data.model_len];
uint8_t data_ori[model_len]; uint8_t data_ori[model_len];
memset(data_ori, 10, model_len);
memset(data_ori, 0, model_len);
ModelFileHeader *header = (ModelFileHeader *)data.model_data; ModelFileHeader *header = (ModelFileHeader *)data.model_data;
header->magic = MODEL_FILE_MAGIC_NUM; header->magic = MODEL_FILE_MAGIC_NUM;
header->version = MODEL_VERSION; header->version = MODEL_VERSION;
@@ -224,6 +224,7 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) {
ModelFileHeader *header = (ModelFileHeader *)data.model_data; ModelFileHeader *header = (ModelFileHeader *)data.model_data;
header->is_encrypt = 255; header->is_encrypt = 255;
uint32_t model_id = 1; uint32_t model_id = 1;
// Error for: LoadModelPartitionTable: Invalid partition_table->num:0
EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID); EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
delete[](uint8_t *) data.model_data; delete[](uint8_t *) data.model_data;
} }


+ 110
- 2
tests/ut/ge/session/ge_api_unittest.cc View File

@@ -26,8 +26,6 @@
#include "proto/ge_ir.pb.h" #include "proto/ge_ir.pb.h"
#include "inc/external/ge/ge_api.h" #include "inc/external/ge/ge_api.h"
#include "session/session_manager.h" #include "session/session_manager.h"
#undef protected
#undef private
using namespace std; using namespace std;
@@ -71,4 +69,114 @@ TEST_F(UtestGeApi, ge_initialize_modify_mixlist) {
auto ret = GEInitialize(options); auto ret = GEInitialize(options);
ASSERT_NE(ret, SUCCESS); ASSERT_NE(ret, SUCCESS);
} }
TEST_F(UtestGeApi, ge_not_initialized) {
EXPECT_EQ(GEFinalize(), SUCCESS);
std::map<std::string, std::string> options;
std::map<AscendString, AscendString> ascend_options;
Session session(options);
GraphId graph_id = 1;
const auto compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
EXPECT_EQ(session.AddGraph(graph_id, graph), FAILED);
EXPECT_EQ(session.AddGraph(graph_id, graph, ascend_options), FAILED);
EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph), FAILED);
EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph, ascend_options), FAILED);
vector<Tensor> inputs;
vector<InputTensorInfo> tensors;
EXPECT_EQ(session.BuildGraph(graph_id, inputs), FAILED);
EXPECT_EQ(session.BuildGraph(graph_id, tensors), FAILED);
vector<Tensor> outputs;
EXPECT_EQ(session.RunGraph(graph_id, inputs, outputs), FAILED);
EXPECT_EQ(session.RunGraphWithStreamAsync(graph_id, nullptr, inputs, outputs), FAILED);
EXPECT_EQ(session.RunGraphAsync(graph_id, inputs, nullptr), FAILED);
vector<string> var_inputs;
EXPECT_EQ(session.GetVariables(var_inputs, outputs), FAILED);
vector<AscendString> var_names;
EXPECT_EQ(session.GetVariables(var_names, outputs), FAILED);
std::string key;
pCallBackFunc ge_callback;
EXPECT_EQ(session.RegisterCallBackFunc(key, ge_callback), FAILED);
session::pCallBackFunc session_callback;
EXPECT_EQ(session.RegisterCallBackFunc(key.c_str(), session_callback), FAILED);
EXPECT_FALSE(session.IsGraphNeedRebuild(graph_id));
EXPECT_EQ(session.RemoveGraph(graph_id), FAILED);
EXPECT_EQ(GEFinalize(), SUCCESS);
}
TEST_F(UtestGeApi, ge_session_ascend_string) {
std::map<AscendString, AscendString> options;
EXPECT_EQ(GEInitialize(options), SUCCESS);
Session session(options);
GraphId graph_id = 1;
const auto compute_graph = MakeShared<ComputeGraph>("test_graph");
EXPECT_EQ(session.AddGraph(graph_id, GraphUtils::CreateGraphFromComputeGraph(compute_graph)), SUCCESS);
EXPECT_TRUE(session.IsGraphNeedRebuild(graph_id));
EXPECT_EQ(session.RemoveGraph(graph_id), SUCCESS);
EXPECT_EQ(GEFinalize(), SUCCESS);
}
TEST_F(UtestGeApi, ge_session_test) {
std::map<std::string, std::string> options;
EXPECT_EQ(GEInitialize(options), SUCCESS);
std::map<AscendString, AscendString> ascend_options;
Session session(options);
GraphId graph_id = 1;
const auto compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
EXPECT_EQ(session.AddGraph(graph_id, graph), SUCCESS);
EXPECT_EQ(session.AddGraph(graph_id, graph, ascend_options), SUCCESS);
EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph), FAILED);
EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph, ascend_options), FAILED);
vector<Tensor> inputs;
vector<InputTensorInfo> tensors;
EXPECT_EQ(session.BuildGraph(graph_id, inputs), FAILED);
EXPECT_EQ(session.BuildGraph(graph_id, tensors), FAILED);
vector<Tensor> outputs;
EXPECT_EQ(session.RunGraph(graph_id, inputs, outputs), FAILED);
EXPECT_EQ(session.RunGraphWithStreamAsync(graph_id, nullptr, inputs, outputs), FAILED);
EXPECT_EQ(session.RunGraphAsync(graph_id, inputs, nullptr), SUCCESS); // Push to queue.
vector<string> var_inputs;
EXPECT_EQ(session.GetVariables(var_inputs, outputs), FAILED);
vector<AscendString> var_names;
EXPECT_EQ(session.GetVariables(var_names, outputs), FAILED);
std::string key;
pCallBackFunc ge_callback;
EXPECT_EQ(session.RegisterCallBackFunc(key, ge_callback), SUCCESS);
session::pCallBackFunc session_callback;
EXPECT_EQ(session.RegisterCallBackFunc(key.c_str(), session_callback), SUCCESS);
EXPECT_TRUE(session.IsGraphNeedRebuild(graph_id));
EXPECT_EQ(session.RemoveGraph(graph_id), SUCCESS);
EXPECT_EQ(GEFinalize(), SUCCESS);
}
} // namespace ge } // namespace ge

Loading…
Cancel
Save