You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 36 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/graph_mem_allocator.h"
  19. #include "graph/manager/rdma_pool_allocator.h"
  20. #include "graph/manager/trans_var_data_utils.h"
  21. #include "graph/utils/type_utils.h"
  22. using std::map;
  23. using std::string;
  24. using std::vector;
  25. namespace ge {
  26. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  27. VarResource::~VarResource() {
  28. var_offset_map_.clear();
  29. var_addr_mgr_map_.clear();
  30. cur_var_tensor_desc_map_.clear();
  31. var_broad_cast_info_.clear();
  32. }
  33. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  34. rtMemType_t &memory_type) {
  35. if (dev_ptr == nullptr) {
  36. GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!");
  37. return FAILED;
  38. }
  39. std::string var_key = VarKey(var_name, tensor_desc);
  40. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  41. auto iter = var_addr_mgr_map_.find(var_key);
  42. if (iter == var_addr_mgr_map_.end()) {
  43. GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str());
  44. return FAILED;
  45. }
  46. *dev_ptr = iter->second.address;
  47. memory_type = iter->second.memory_type;
  48. return SUCCESS;
  49. }
  50. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  51. var_addr_mgr_map = var_addr_mgr_map_;
  52. }
  53. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  54. rtMemType_t memory_type) {
  55. std::string var_key = VarKey(var_name, tensor_desc);
  56. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  57. if (var_addr_mgr_map_.count(var_key) == 0) {
  58. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  59. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  60. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  61. VarAddrMgr var_addr_mgr;
  62. var_addr_mgr.address = dev_ptr;
  63. var_addr_mgr.tensor_desc = tensor_desc;
  64. var_addr_mgr_map_[var_key] = var_addr_mgr;
  65. }
  66. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  67. }
  68. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  69. rtMemType_t memory_type) {
  70. std::string var_key = VarKey(var_name, tensor_desc);
  71. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  72. if (var_addr_mgr_map_.count(var_key) == 0) {
  73. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  74. if (memory_type != RT_MEMORY_RDMA_HBM) {
  75. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  76. }
  77. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  78. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  79. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  80. VarAddrMgr var_addr_mgr;
  81. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  82. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  83. var_addr_mgr.tensor_desc = tensor_desc;
  84. var_addr_mgr.memory_type = memory_type;
  85. var_addr_mgr_map_[var_key] = var_addr_mgr;
  86. var_offset_map_[logic_address] = memory_type;
  87. return SUCCESS;
  88. }
  89. GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str());
  90. return FAILED;
  91. }
  92. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  93. std::string var_key = VarKey(var_name, tensor_desc);
  94. return var_addr_mgr_map_.count(var_key) != 0;
  95. }
  96. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  97. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  98. std::string var_key(var_name);
  99. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  100. .append("_")
  101. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  102. return var_key;
  103. }
  104. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  105. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  106. return FAILED;
  107. }
  108. tensor_desc = cur_var_tensor_desc_map_[var_name];
  109. return SUCCESS;
  110. }
  111. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  112. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  113. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  114. return SUCCESS;
  115. }
  116. if (op_desc == nullptr) {
  117. GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!");
  118. return FAILED;
  119. }
  120. ge::GeTensorDesc curr_desc;
  121. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  122. if (ret != SUCCESS) {
  123. GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!");
  124. return FAILED;
  125. }
  126. std::string key = VarKey(var_name, curr_desc);
  127. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  128. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  129. cur_var_tensor_desc_map_[var_name] = curr_desc;
  130. auto iter = var_addr_mgr_map_.find(key);
  131. if (iter == var_addr_mgr_map_.end()) {
  132. GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str());
  133. return FAILED;
  134. }
  135. auto val = iter->second;
  136. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  137. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  138. var_addr_mgr_map_.erase(iter);
  139. key = VarKey(var_name, curr_desc);
  140. var_addr_mgr_map_[key] = val;
  141. return SUCCESS;
  142. }
  143. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  144. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  145. }
  146. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  147. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  148. return FAILED;
  149. }
  150. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  151. return SUCCESS;
  152. }
  153. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  154. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  155. GE_CHECK_NOTNULL(base_ptr);
  156. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  157. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  158. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  159. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  160. var_broadcast_info.input_size, session_id_);
  161. }
  162. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  163. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  164. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  165. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  166. // subgraph base_ptr could be nullptr, task it as base 0
  167. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  168. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  169. var_tensor_desc, session_id_);
  170. }
  171. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  172. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  173. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  174. }
  175. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  176. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  177. if (var_offset_map_.count(offset) > 0) {
  178. return var_offset_map_[offset];
  179. }
  180. return RT_MEMORY_RESERVED;
  181. }
  182. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  183. auto iter = var_to_trans_road_.find(var_name);
  184. if (iter == var_to_trans_road_.end()) {
  185. return nullptr;
  186. } else {
  187. return &(iter->second);
  188. }
  189. }
  190. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  191. auto iter = var_names_to_changed_graph_id_.find(var_name);
  192. if (iter == var_names_to_changed_graph_id_.end()) {
  193. return FAILED;
  194. } else {
  195. graph_id = iter->second;
  196. return SUCCESS;
  197. }
  198. }
  199. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  200. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  201. if (iter == var_names_to_allocated_graph_id_.end()) {
  202. return FAILED;
  203. } else {
  204. graph_id = iter->second;
  205. return SUCCESS;
  206. }
  207. }
  208. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  209. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  210. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  211. return SUCCESS;
  212. }
  213. var_names_to_allocated_graph_id_[var_name] = graph_id;
  214. return SUCCESS;
  215. }
  216. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  217. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  218. switch (mem_type) {
  219. case RT_MEMORY_HBM:
  220. return new (std::nothrow) HbmMemResource();
  221. case RT_MEMORY_RDMA_HBM:
  222. return new (std::nothrow) RdmaMemResource();
  223. default:
  224. return nullptr;
  225. }
  226. }
  227. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  228. size_t &mem_offset) {
  229. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  230. uint64_t real_size = size;
  231. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  232. if (total_size_ < var_mem_size_) {
  233. GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_);
  234. return PARAM_INVALID;
  235. }
  236. uint64_t free_size = total_size_ - var_mem_size_;
  237. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  238. GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
  239. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  240. return PARAM_INVALID;
  241. }
  242. mem_offset = var_mem_size_;
  243. // offset for next, align 512 BYTE
  244. size = size + kSessionMemAlignSize;
  245. var_mem_size_ = var_mem_size_ + size;
  246. // align 512 BYTE
  247. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  248. GELOGI(
  249. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  250. "offset to [%zu] size[%lu] realsize[%lu].",
  251. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  252. return SUCCESS;
  253. }
  254. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  255. uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
  256. if (buffer == nullptr) {
  257. GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %lu", var_name.c_str(), size);
  258. return MEMALLOC_FAILED;
  259. }
  260. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  261. var_mem_size_ += size;
  262. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  263. session_id, var_name.c_str(), 0, buffer, size);
  264. return SUCCESS;
  265. }
  266. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  267. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  268. VarManager::VarManager(uint64_t session_id)
  269. : version_(SessionVersion::OTHER_VERSION),
  270. session_id_(session_id),
  271. device_id_(0),
  272. job_id_(0),
  273. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  274. var_mem_max_size_(kMemoryVarManagerMallocSize),
  275. var_mem_logic_base_(kMemoryVarLogicBase),
  276. use_max_mem_size_(kUseMaxMemorySize) {}
  277. VarManager *VarManager::Instance(uint64_t session_id) {
  278. GELOGD("VarManager::Instance, session id = %lu", session_id);
  279. return VarManagerPool::Instance().GetVarManager(session_id);
  280. }
  281. void VarManager::Destory() {
  282. std::lock_guard<std::recursive_mutex> lock(mutex_);
  283. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  284. version_ = SessionVersion::OTHER_VERSION;
  285. device_id_ = 0;
  286. session_id_ = 0;
  287. for (auto &memory_resource : mem_resource_map_) {
  288. if (memory_resource.second != nullptr) {
  289. delete memory_resource.second;
  290. memory_resource.second = nullptr;
  291. }
  292. }
  293. mem_resource_map_.clear();
  294. }
  295. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  296. const uint64_t &job_id) {
  297. std::lock_guard<std::recursive_mutex> lock(mutex_);
  298. GELOGI("VarManager::Init, session id = %lu.", session_id);
  299. version_ = version;
  300. device_id_ = device_id;
  301. session_id_ = session_id;
  302. job_id_ = job_id;
  303. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  304. if (var_resource_ == nullptr) {
  305. GELOGW("VarManager has not been init.");
  306. return ge::INTERNAL_ERROR;
  307. }
  308. return SUCCESS;
  309. }
  310. const uint64_t &VarManager::SessionId() const {
  311. std::lock_guard<std::recursive_mutex> lock(mutex_);
  312. return session_id_;
  313. }
  314. const uint32_t &VarManager::DeviceId() const {
  315. std::lock_guard<std::recursive_mutex> lock(mutex_);
  316. return device_id_;
  317. }
  318. const uint64_t &VarManager::JobId() const {
  319. std::lock_guard<std::recursive_mutex> lock(mutex_);
  320. return job_id_;
  321. }
  322. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  323. rtMemType_t memory_type) {
  324. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  325. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  326. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  327. std::lock_guard<std::recursive_mutex> lock(mutex_);
  328. if (var_resource_ == nullptr) {
  329. GELOGW("VarManager has not been init.");
  330. return ge::INTERNAL_ERROR;
  331. }
  332. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  333. return ge::SUCCESS;
  334. }
  335. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  336. rtMemType_t memory_type) {
  337. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  338. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  339. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  340. std::lock_guard<std::recursive_mutex> lock(mutex_);
  341. if (var_resource_ == nullptr) {
  342. GELOGW("VarManager has not been init.");
  343. return ge::INTERNAL_ERROR;
  344. }
  345. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  346. return ge::SUCCESS;
  347. }
  348. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  349. rtMemType_t &memory_type) {
  350. std::lock_guard<std::recursive_mutex> lock(mutex_);
  351. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  352. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  353. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  354. if (var_resource_ == nullptr) {
  355. GELOGW("VarManager has not been init.");
  356. return ge::INTERNAL_ERROR;
  357. }
  358. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  359. if (ret != SUCCESS) {
  360. GELOGW("GetVarAddr fail.");
  361. return ge::INTERNAL_ERROR;
  362. }
  363. return SUCCESS;
  364. }
  365. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  366. std::lock_guard<std::recursive_mutex> lock(mutex_);
  367. rtMemType_t memory_type = RT_MEMORY_HBM;
  368. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  369. }
  370. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  371. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  372. }
  373. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  374. std::lock_guard<std::recursive_mutex> lock(mutex_);
  375. MemResource *mem_resource = nullptr;
  376. auto iter = mem_resource_map_.find(memory_type);
  377. if (iter == mem_resource_map_.end()) {
  378. return 0;
  379. } else {
  380. mem_resource = iter->second;
  381. }
  382. if (mem_resource == nullptr) {
  383. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  384. return 0;
  385. }
  386. return mem_resource->GetVarMemSize();
  387. }
  388. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  389. std::lock_guard<std::recursive_mutex> lock(mutex_);
  390. MemResource *mem_resource = nullptr;
  391. auto iter = mem_resource_map_.find(memory_type);
  392. if (iter == mem_resource_map_.end()) {
  393. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  394. if (mem_resource == nullptr) {
  395. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  396. return ge::INTERNAL_ERROR;
  397. } else {
  398. mem_resource_map_[memory_type] = mem_resource;
  399. }
  400. } else {
  401. mem_resource = iter->second;
  402. }
  403. if (mem_resource == nullptr) {
  404. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  405. return FAILED;
  406. }
  407. mem_resource->UpdateVarMemSize(mem_size);
  408. return SUCCESS;
  409. }
  410. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  411. rtMemType_t memory_type) {
  412. std::lock_guard<std::recursive_mutex> lock(mutex_);
  413. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  414. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  415. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  416. int64_t tensor_desc_size = 0;
  417. size_t mem_offset = 0;
  418. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  419. if (result != ge::SUCCESS) {
  420. GELOGE(result, "get size from TensorDesc failed");
  421. return result;
  422. }
  423. MemResource *mem_resource = nullptr;
  424. auto it = mem_resource_map_.find(memory_type);
  425. if (it == mem_resource_map_.end()) {
  426. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  427. if (mem_resource == nullptr) {
  428. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  429. return ge::INTERNAL_ERROR;
  430. } else {
  431. mem_resource_map_[memory_type] = mem_resource;
  432. }
  433. } else {
  434. mem_resource = it->second;
  435. }
  436. if (mem_resource == nullptr) {
  437. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type);
  438. return ge::INTERNAL_ERROR;
  439. }
  440. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  441. if (result != SUCCESS) {
  442. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  443. return ge::INTERNAL_ERROR;
  444. }
  445. if (var_resource_ == nullptr) {
  446. GELOGW("VarManager has not been init.");
  447. return ge::INTERNAL_ERROR;
  448. }
  449. result = var_resource_->SaveVarAddr(
  450. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  451. if (result != SUCCESS) {
  452. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  453. return ge::INTERNAL_ERROR;
  454. }
  455. result = var_resource_->GetVarAddr(
  456. var_name, tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  457. if (result != SUCCESS) {
  458. GELOGE(ge::INTERNAL_ERROR, "GetVarAddr by offset failed.");
  459. return ge::INTERNAL_ERROR;
  460. }
  461. ge::GeTensorDesc cur_tensor_desc;
  462. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  463. if (result != SUCCESS) {
  464. var_resource_->SetVarAddr(var_name, tensor_desc,
  465. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  466. return SUCCESS;
  467. }
  468. if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  469. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  470. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) {
  471. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  472. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  473. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  474. tensor_desc.GetShape().GetDims().size(),
  475. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  476. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  477. cur_tensor_desc.GetShape().GetDims().size());
  478. var_resource_->SetVarAddr(var_name, tensor_desc,
  479. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  480. }
  481. return SUCCESS;
  482. }
  483. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  484. std::lock_guard<std::recursive_mutex> lock(mutex_);
  485. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  486. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  487. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  488. if (var_resource_ == nullptr) {
  489. GELOGW("VarManager has not been init.");
  490. return false;
  491. }
  492. return var_resource_->IsVarExist(var_name, tensor_desc);
  493. }
  494. bool VarManager::IsVarExist(const std::string &var_name) {
  495. std::lock_guard<std::recursive_mutex> lock(mutex_);
  496. if (var_resource_ == nullptr) {
  497. GELOGW("VarManager has not been init.");
  498. return false;
  499. }
  500. return var_resource_->IsVarExist(var_name);
  501. }
  502. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  503. uint8_t *base_ptr) {
  504. std::lock_guard<std::recursive_mutex> lock(mutex_);
  505. if (var_resource_ == nullptr) {
  506. GELOGW("VarManager has not been init.");
  507. return ge::INTERNAL_ERROR;
  508. }
  509. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  510. }
  511. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  512. std::lock_guard<std::recursive_mutex> lock(mutex_);
  513. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  514. if (var_resource_ == nullptr) {
  515. GELOGW("VarManager has not been init.");
  516. return ge::INTERNAL_ERROR;
  517. }
  518. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  519. }
  520. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  521. std::lock_guard<std::recursive_mutex> lock(mutex_);
  522. GELOGI(
  523. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  524. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  525. "output_size = %lu",
  526. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  527. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  528. broad_cast_info.output_size);
  529. if (var_resource_ == nullptr) {
  530. GELOGW("VarManager has not been init.");
  531. return ge::INTERNAL_ERROR;
  532. }
  533. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  534. return SUCCESS;
  535. }
  536. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  537. std::lock_guard<std::recursive_mutex> lock(mutex_);
  538. if (var_resource_ == nullptr) {
  539. GELOGW("VarManager has not been init.");
  540. return ge::INTERNAL_ERROR;
  541. }
  542. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  543. }
  544. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  545. std::lock_guard<std::recursive_mutex> lock(mutex_);
  546. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  547. if (var_resource_ == nullptr) {
  548. GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init.");
  549. return ge::INTERNAL_ERROR;
  550. }
  551. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  552. }
  553. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  554. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  555. std::lock_guard<std::recursive_mutex> lock(mutex_);
  556. if (var_resource_ == nullptr) {
  557. GELOGW("VarManager has not been init.");
  558. return ge::INTERNAL_ERROR;
  559. }
  560. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  561. }
  562. bool VarManager::IsVarAddr(const int64_t &offset) {
  563. std::lock_guard<std::recursive_mutex> lock(mutex_);
  564. if (var_resource_ == nullptr) {
  565. GELOGD("VarManager has not been init.");
  566. return false;
  567. }
  568. return var_resource_->IsVarAddr(offset);
  569. }
  570. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  571. std::lock_guard<std::recursive_mutex> lock(mutex_);
  572. if (var_resource_ == nullptr) {
  573. GELOGW("VarManager has not been init.");
  574. return RT_MEMORY_RESERVED;
  575. }
  576. return var_resource_->GetVarMemType(offset);
  577. }
  578. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  579. std::lock_guard<std::recursive_mutex> lock(mutex_);
  580. uint8_t *var_mem_base = nullptr;
  581. string memory_key = std::to_string(session_id_);
  582. // malloc variable memory
  583. size_t var_memory_size = memory_size;
  584. // align 512 BYTE
  585. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  586. const string purpose("variables and constant op memory in training network.");
  587. var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size);
  588. if (var_mem_base == nullptr) {
  589. GELOGE(ge::INTERNAL_ERROR,
  590. "VarManager::MallocVarMemory failed "
  591. "session_id = %s",
  592. memory_key.c_str());
  593. return ge::INTERNAL_ERROR;
  594. }
  595. return SUCCESS;
  596. }
  597. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  598. std::lock_guard<std::recursive_mutex> lock(mutex_);
  599. if (memory_type == RT_MEMORY_RDMA_HBM) {
  600. return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
  601. }
  602. string memory_key = std::to_string(session_id_);
  603. return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key);
  604. }
  605. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  606. std::lock_guard<std::recursive_mutex> lock(mutex_);
  607. if (memory_type == RT_MEMORY_RDMA_HBM) {
  608. return logic_addr;
  609. }
  610. string mem_key = std::to_string(session_id_);
  611. uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key);
  612. if (mem_base == nullptr) {
  613. return nullptr;
  614. }
  615. uint8_t *mem_addr =
  616. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  617. return mem_addr;
  618. }
  619. ge::Status VarManager::FreeVarMemory() {
  620. std::lock_guard<std::recursive_mutex> lock(mutex_);
  621. string memory_key = std::to_string(SessionId());
  622. return MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key);
  623. }
  624. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  625. std::lock_guard<std::recursive_mutex> lock(mutex_);
  626. if (var_resource_ == nullptr) {
  627. GELOGW("VarManager has not been init.");
  628. return ge::INTERNAL_ERROR;
  629. }
  630. return var_resource_->SetTransRoad(var_name, trans_road);
  631. }
  632. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  633. std::lock_guard<std::recursive_mutex> lock(mutex_);
  634. if (var_resource_ == nullptr) {
  635. GELOGW("VarManager has not been init.");
  636. return nullptr;
  637. }
  638. return var_resource_->GetTransRoad(var_name);
  639. }
  640. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  641. std::lock_guard<std::recursive_mutex> lock(mutex_);
  642. if (var_resource_ == nullptr) {
  643. GELOGW("VarManager has not been init.");
  644. return INTERNAL_ERROR;
  645. }
  646. return var_resource_->SetChangedGraphId(var_name, graph_id);
  647. }
  648. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  649. std::lock_guard<std::recursive_mutex> lock(mutex_);
  650. if (var_resource_ == nullptr) {
  651. GELOGW("VarManager has not been init.");
  652. return INTERNAL_ERROR;
  653. }
  654. return var_resource_->GetChangedGraphId(var_name, graph_id);
  655. }
  656. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  657. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  658. if (it == options.end()) {
  659. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  660. } else {
  661. string graph_memory_manager_malloc_max_size = it->second;
  662. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  663. if (ret != SUCCESS) {
  664. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed.");
  665. return ge::GE_GRAPH_OPTIONS_INVALID;
  666. }
  667. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  668. }
  669. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  670. if (it == options.end()) {
  671. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  672. } else {
  673. string memory_var_manager_malloc_size = it->second;
  674. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  675. if (ret != SUCCESS) {
  676. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed.");
  677. return ge::GE_GRAPH_OPTIONS_INVALID;
  678. }
  679. }
  680. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  681. if (var_mem_logic_base_ > kMaxMemorySize) {
  682. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.",
  683. var_mem_logic_base_, kMaxMemorySize);
  684. return ge::GE_GRAPH_OPTIONS_INVALID;
  685. }
  686. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  687. if (use_max_mem_size_ > kMaxMemorySize) {
  688. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.",
  689. use_max_mem_size_, kMaxMemorySize);
  690. return ge::GE_GRAPH_OPTIONS_INVALID;
  691. }
  692. GELOGI("Set memory malloc size successfully");
  693. return SUCCESS;
  694. }
  695. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  696. if (memory_size.empty()) {
  697. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty.");
  698. return GE_GRAPH_OPTIONS_INVALID;
  699. }
  700. // split string by '*'
  701. vector<string> splits;
  702. std::istringstream str(memory_size);
  703. string str_split;
  704. while (getline(str, str_split, '*')) {
  705. splits.emplace_back(str_split);
  706. }
  707. result = 1;
  708. for (string split : splits) {
  709. // Trim
  710. auto it = split.find_first_not_of(" ");
  711. if (it != string::npos) {
  712. split.erase(0, it);
  713. }
  714. it = split.find_last_not_of(" ");
  715. if (it != string::npos) {
  716. split.erase(it + 1);
  717. }
  718. for (char c : split) {
  719. if (!isdigit(c)) {
  720. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit.");
  721. return GE_GRAPH_OPTIONS_INVALID;
  722. }
  723. }
  724. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  725. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  726. GELOGE(FAILED, "Input memory size is out of range.");
  727. return FAILED);
  728. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  729. GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize);
  730. return FAILED;
  731. }
  732. result *= static_cast<size_t>(num);
  733. }
  734. return SUCCESS;
  735. }
  736. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  737. std::lock_guard<std::recursive_mutex> lock(mutex_);
  738. if (var_resource_ == nullptr) {
  739. GELOGW("VarManager has not been init.");
  740. return;
  741. }
  742. var_resource_->RemoveChangedGraphId(var_name);
  743. }
  744. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  745. std::lock_guard<std::recursive_mutex> lock(mutex_);
  746. if (var_resource_ == nullptr) {
  747. GELOGW("VarManager has not been init.");
  748. return INTERNAL_ERROR;
  749. }
  750. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  751. }
  752. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  753. std::lock_guard<std::recursive_mutex> lock(mutex_);
  754. if (var_resource_ == nullptr) {
  755. GELOGW("VarManager has not been init.");
  756. return INTERNAL_ERROR;
  757. }
  758. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  759. }
  760. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  761. std::lock_guard<std::recursive_mutex> lock(mutex_);
  762. if (var_resource_ == nullptr) {
  763. GELOGW("VarManager has not been init.");
  764. return;
  765. }
  766. var_resource_->RemoveAllocatedGraphId(var_name);
  767. }
  768. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  769. std::lock_guard<std::recursive_mutex> lock(mutex_);
  770. if (var_resource_ == nullptr) {
  771. GELOGW("VarManager has not been inited.");
  772. return INTERNAL_ERROR;
  773. }
  774. auto new_variable_desc = var_resource_->GetAllVarDesc();
  775. if (new_variable_desc.size() == 0) {
  776. GELOGW("VarManager don't have variables.");
  777. return INTERNAL_ERROR;
  778. }
  779. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  780. auto trans_road = var_resource_->GetTransRoad(iter->first);
  781. if (trans_road == nullptr || trans_road->empty()) {
  782. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  783. all_variables[iter->first] = iter->second;
  784. continue;
  785. }
  786. // get origin trans info : the first trans node info
  787. auto origin_trans_node_info = trans_road->at(0);
  788. all_variables[iter->first] = origin_trans_node_info.input;
  789. }
  790. return SUCCESS;
  791. }
  792. VarManagerPool::~VarManagerPool() { Destory(); }
  793. VarManagerPool &VarManagerPool::Instance() {
  794. static VarManagerPool var_manager_pool;
  795. return var_manager_pool;
  796. }
  797. void VarManagerPool::Destory() noexcept {
  798. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  799. for (auto &it : var_manager_map_) {
  800. VarManager *var_manager = it.second;
  801. if (var_manager != nullptr) {
  802. var_manager->Destory();
  803. delete var_manager;
  804. var_manager = nullptr;
  805. }
  806. }
  807. var_manager_map_.clear();
  808. }
  809. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  810. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  811. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  812. auto it = var_manager_map_.find(session_id);
  813. if (it != var_manager_map_.end()) {
  814. GELOGD("VarManagerPool::GetVarManager");
  815. return it->second;
  816. }
  817. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  818. if (var_manager == nullptr) {
  819. GELOGE(INTERNAL_ERROR,
  820. "VarManager::Instance find session by "
  821. "session_id[%lu] failed.",
  822. session_id);
  823. static VarManager new_var_manager(0);
  824. return &new_var_manager;
  825. }
  826. var_manager_map_[session_id] = var_manager;
  827. return var_manager;
  828. }
  829. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  830. VarManager *var_manager = nullptr;
  831. {
  832. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  833. auto it = var_manager_map_.find(session_id);
  834. if (it != var_manager_map_.end()) {
  835. var_manager = it->second;
  836. var_manager_map_.erase(it);
  837. }
  838. }
  839. if (var_manager != nullptr) {
  840. var_manager->Destory();
  841. delete var_manager;
  842. var_manager = nullptr;
  843. }
  844. }
  845. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示