diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc index 1a97b6f8..d44568cc 100755 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -54,6 +54,7 @@ const int kSocVersionLen = 50; const int kDefaultDeviceIdForTrain = 0; const int kDefaultDeviceIdForInfer = -1; const char *const kGlobalOptionFpCeilingModeDefault = "2"; +const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; } // namespace static std::shared_ptr instancePtr_ = nullptr; @@ -76,6 +77,13 @@ Status GELib::Initialize(const map &options) { GELOGE(ret, "GeLib initial failed."); return ret; } + + ret = instancePtr_->SetAiCoreNum(new_options); + if (ret != SUCCESS) { + GELOGE(ret, "GeLib initial: SetAiCoreNum failed."); + return ret; + } + instancePtr_->SetDefaultPrecisionMode(new_options); if (new_options.find("ge.fpCeilingMode") == new_options.end()) { @@ -251,6 +259,24 @@ Status GELib::SetRTSocVersion(const map &options, map &options) { + // Already set or get AICORE_NUM from options in offline mode + if (options.find(AICORE_NUM) != options.end()) { + return SUCCESS; + } + + uint32_t aicore_num = 0; + rtError_t ret = rtGetAiCoreCount(&aicore_num); + if (ret == ACL_ERROR_RT_FEATURE_NOT_SUPPORT) { // offline without ATC Input of AiCoreNum + return SUCCESS; + } else if (ret == RT_ERROR_NONE) { // online-mode + options.emplace(std::make_pair(AICORE_NUM, std::to_string(aicore_num))); + return SUCCESS; + } + GELOGE(FAILED, "rtGetAiCoreCount failed."); + return FAILED; +} + void GELib::InitOptions(const map &options) { this->options_.session_id = 0; auto iter = options.find(OPTION_EXEC_SESSION_ID); diff --git a/ge/init/gelib.h b/ge/init/gelib.h index 885ae867..ed6fe5d4 100644 --- a/ge/init/gelib.h +++ b/ge/init/gelib.h @@ -81,6 +81,7 @@ class GE_FUNC_VISIBILITY GELib { Status InnerInitialize(const map &options); Status SystemInitialize(const map &options); Status SetRTSocVersion(const map &options, map &new_options); + Status SetAiCoreNum(map &options); void SetDefaultPrecisionMode(map &new_options); void RollbackInit(); void InitOptions(const map &options); diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 1a170167..1323a76a 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -354,6 +354,11 @@ rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) return RT_ERROR_NONE; } +rtError_t rtGetAiCoreCount(uint32_t *aiCoreCnt) +{ + return RT_ERROR_NONE; +} + rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) { return RT_ERROR_NONE;