You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 39 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/trans_var_data_utils.h"
  19. #include "graph/utils/type_utils.h"
  20. #include "graph/ge_context.h"
  21. using std::map;
  22. using std::string;
  23. using std::vector;
  24. namespace ge {
  25. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  26. VarResource::~VarResource() {
  27. var_offset_map_.clear();
  28. var_addr_mgr_map_.clear();
  29. cur_var_tensor_desc_map_.clear();
  30. var_broad_cast_info_.clear();
  31. }
  32. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  33. rtMemType_t &memory_type) {
  34. if (dev_ptr == nullptr) {
  35. REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, "
  36. "check invalid", var_name.c_str(), session_id_);
  37. GELOGE(FAILED, "[Check][Param] Param dev_ptr is nullptr, var_name:%s, session_id:%lu",
  38. var_name.c_str(), session_id_);
  39. return FAILED;
  40. }
  41. std::string var_key = VarKey(var_name, tensor_desc);
  42. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  43. auto iter = var_addr_mgr_map_.find(var_key);
  44. if (iter == var_addr_mgr_map_.end()) {
  45. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  46. "check invalid", var_key.c_str(), var_name.c_str(),
  47. session_id_);
  48. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  49. var_key.c_str(), var_name.c_str(), session_id_);
  50. return FAILED;
  51. }
  52. *dev_ptr = iter->second.address;
  53. memory_type = iter->second.memory_type;
  54. return SUCCESS;
  55. }
  56. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  57. var_addr_mgr_map = var_addr_mgr_map_;
  58. }
  59. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  60. rtMemType_t memory_type) {
  61. std::string var_key = VarKey(var_name, tensor_desc);
  62. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  63. if (var_addr_mgr_map_.count(var_key) == 0) {
  64. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  65. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  66. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  67. VarAddrMgr var_addr_mgr;
  68. var_addr_mgr.address = dev_ptr;
  69. var_addr_mgr.tensor_desc = tensor_desc;
  70. var_addr_mgr_map_[var_key] = var_addr_mgr;
  71. }
  72. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  73. }
  74. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  75. rtMemType_t memory_type) {
  76. std::string var_key = VarKey(var_name, tensor_desc);
  77. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  78. if (var_addr_mgr_map_.count(var_key) == 0) {
  79. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  80. if (memory_type != RT_MEMORY_RDMA_HBM) {
  81. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  82. }
  83. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  84. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  85. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  86. VarAddrMgr var_addr_mgr;
  87. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  88. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  89. var_addr_mgr.tensor_desc = tensor_desc;
  90. var_addr_mgr.memory_type = memory_type;
  91. var_addr_mgr_map_[var_key] = var_addr_mgr;
  92. var_offset_map_[logic_address] = memory_type;
  93. return SUCCESS;
  94. }
  95. REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  96. "check invalid", var_key.c_str(), var_name.c_str(),
  97. session_id_);
  98. GELOGE(FAILED, "[Check][Param] var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  99. var_key.c_str(), var_name.c_str(), session_id_);
  100. return FAILED;
  101. }
  102. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  103. std::string var_key = VarKey(var_name, tensor_desc);
  104. return var_addr_mgr_map_.count(var_key) != 0;
  105. }
  106. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  107. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  108. std::string var_key(var_name);
  109. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  110. .append("_")
  111. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  112. return var_key;
  113. }
  114. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  115. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  116. return FAILED;
  117. }
  118. tensor_desc = cur_var_tensor_desc_map_[var_name];
  119. return SUCCESS;
  120. }
  121. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  122. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  123. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  124. return SUCCESS;
  125. }
  126. if (op_desc == nullptr) {
  127. REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid",
  128. var_name.c_str(), session_id_);
  129. GELOGE(FAILED, "[Check][Param] input opdesc is nullptr, var_name:%s, session_id:%lu",
  130. var_name.c_str(), session_id_);
  131. return FAILED;
  132. }
  133. ge::GeTensorDesc curr_desc;
  134. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  135. if (ret != SUCCESS) {
  136. GELOGE(FAILED, "[Get][CurVarDesc] fail, var_name:%s, session_id:%lu", var_name.c_str(), session_id_);
  137. return FAILED;
  138. }
  139. std::string key = VarKey(var_name, curr_desc);
  140. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  141. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  142. cur_var_tensor_desc_map_[var_name] = curr_desc;
  143. auto iter = var_addr_mgr_map_.find(key);
  144. if (iter == var_addr_mgr_map_.end()) {
  145. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), "
  146. "check invalid", key.c_str(), var_name.c_str(),
  147. session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  148. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s)",
  149. key.c_str(), var_name.c_str(), session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  150. return FAILED;
  151. }
  152. auto val = iter->second;
  153. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  154. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  155. var_addr_mgr_map_.erase(iter);
  156. key = VarKey(var_name, curr_desc);
  157. var_addr_mgr_map_[key] = val;
  158. return SUCCESS;
  159. }
  160. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  161. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  162. }
  163. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  164. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  165. return FAILED;
  166. }
  167. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  168. return SUCCESS;
  169. }
  170. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  171. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  172. if (var_offset_map_.count(offset) > 0) {
  173. return var_offset_map_[offset];
  174. }
  175. return RT_MEMORY_RESERVED;
  176. }
  177. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  178. auto iter = var_to_trans_road_.find(var_name);
  179. if (iter == var_to_trans_road_.end()) {
  180. return nullptr;
  181. } else {
  182. return &(iter->second);
  183. }
  184. }
  185. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  186. auto iter = var_names_to_changed_graph_id_.find(var_name);
  187. if (iter == var_names_to_changed_graph_id_.end()) {
  188. return FAILED;
  189. } else {
  190. graph_id = iter->second;
  191. return SUCCESS;
  192. }
  193. }
  194. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  195. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  196. if (iter == var_names_to_allocated_graph_id_.end()) {
  197. return FAILED;
  198. } else {
  199. graph_id = iter->second;
  200. return SUCCESS;
  201. }
  202. }
  203. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  204. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  205. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  206. return SUCCESS;
  207. }
  208. var_names_to_allocated_graph_id_[var_name] = graph_id;
  209. return SUCCESS;
  210. }
  211. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  212. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  213. switch (mem_type) {
  214. case RT_MEMORY_HBM:
  215. return new (std::nothrow) HbmMemResource();
  216. case RT_MEMORY_RDMA_HBM:
  217. return new (std::nothrow) RdmaMemResource();
  218. default:
  219. return nullptr;
  220. }
  221. }
  222. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  223. size_t &mem_offset) {
  224. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  225. uint64_t real_size = size;
  226. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  227. if (total_size_ < var_mem_size_) {
  228. REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid"
  229. "", total_size_, var_mem_size_, size, var_name.c_str());
  230. GELOGE(PARAM_INVALID, "[Check][Param] total_size_:%lu is smaller than var_mem_size_:%lu, var_name:%s",
  231. total_size_, var_mem_size_, var_name.c_str());
  232. return PARAM_INVALID;
  233. }
  234. uint64_t free_size = total_size_ - var_mem_size_;
  235. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  236. REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid",
  237. free_size, size, var_name.c_str());
  238. GELOGE(PARAM_INVALID, "[Check][Param] Out of memory: current var size[%lu] exceeds total var size[%lu]",
  239. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  240. return PARAM_INVALID;
  241. }
  242. mem_offset = var_mem_size_;
  243. // offset for next, align 512 BYTE
  244. size = size + kSessionMemAlignSize;
  245. var_mem_size_ = var_mem_size_ + size;
  246. // align 512 BYTE
  247. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  248. GELOGI(
  249. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  250. "offset to [%zu] size[%lu] realsize[%lu].",
  251. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  252. return SUCCESS;
  253. }
  254. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  255. uint8_t *buffer = VarManager::Instance(session_id)->GetPoolMemory(RT_MEMORY_HBM, size);
  256. if (buffer == nullptr) {
  257. REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s",
  258. size, var_name.c_str());
  259. GELOGE(MEMALLOC_FAILED, "[Malloc][RdmaMemory] for node %s failed, size = %lu", var_name.c_str(), size);
  260. return MEMALLOC_FAILED;
  261. }
  262. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  263. var_mem_size_ += size;
  264. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  265. session_id, var_name.c_str(), 0, buffer, size);
  266. return SUCCESS;
  267. }
  268. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  269. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  270. VarManager::VarManager(uint64_t session_id)
  271. : version_(SessionVersion::OTHER_VERSION),
  272. session_id_(session_id),
  273. device_id_(0),
  274. job_id_(0),
  275. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  276. var_mem_max_size_(kMemoryVarManagerMallocSize),
  277. var_mem_logic_base_(kMemoryVarLogicBase),
  278. use_max_mem_size_(kUseMaxMemorySize) {}
  279. VarManager *VarManager::Instance(uint64_t session_id) {
  280. GELOGD("VarManager::Instance, session id = %lu", session_id);
  281. return VarManagerPool::Instance().GetVarManager(session_id);
  282. }
  283. void VarManager::Destory() {
  284. std::lock_guard<std::recursive_mutex> lock(mutex_);
  285. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  286. version_ = SessionVersion::OTHER_VERSION;
  287. device_id_ = 0;
  288. session_id_ = 0;
  289. for (auto &memory_resource : mem_resource_map_) {
  290. if (memory_resource.second != nullptr) {
  291. delete memory_resource.second;
  292. memory_resource.second = nullptr;
  293. }
  294. }
  295. mem_resource_map_.clear();
  296. }
  297. Status VarManager::Init(uint32_t version, uint64_t session_id, uint32_t device_id, uint64_t job_id) {
  298. std::lock_guard<std::recursive_mutex> lock(mutex_);
  299. GELOGI("VarManager::Init, session id = %lu.", session_id);
  300. if (var_resource_ == nullptr) {
  301. version_ = version;
  302. device_id_ = device_id;
  303. session_id_ = session_id;
  304. job_id_ = job_id;
  305. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  306. if (var_resource_ == nullptr) {
  307. GELOGW("VarManager init failed session id = %lu.", session_id);
  308. return ge::INTERNAL_ERROR;
  309. }
  310. } else {
  311. GELOGW("VarManager::has been inited, session id = %lu.", session_id);
  312. }
  313. return SUCCESS;
  314. }
  315. const uint64_t &VarManager::SessionId() const {
  316. std::lock_guard<std::recursive_mutex> lock(mutex_);
  317. return session_id_;
  318. }
  319. const uint32_t &VarManager::DeviceId() const {
  320. std::lock_guard<std::recursive_mutex> lock(mutex_);
  321. return device_id_;
  322. }
  323. const uint64_t &VarManager::JobId() const {
  324. std::lock_guard<std::recursive_mutex> lock(mutex_);
  325. return job_id_;
  326. }
  327. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  328. rtMemType_t memory_type) {
  329. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  330. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  331. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  332. std::lock_guard<std::recursive_mutex> lock(mutex_);
  333. if (var_resource_ == nullptr) {
  334. GELOGW("VarManager has not been init.");
  335. return ge::INTERNAL_ERROR;
  336. }
  337. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  338. return ge::SUCCESS;
  339. }
  340. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  341. rtMemType_t &memory_type) {
  342. std::lock_guard<std::recursive_mutex> lock(mutex_);
  343. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  344. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  345. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  346. if (var_resource_ == nullptr) {
  347. GELOGW("VarManager has not been init.");
  348. return ge::INTERNAL_ERROR;
  349. }
  350. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  351. if (ret != SUCCESS) {
  352. GELOGW("GetVarAddr fail.");
  353. return ge::INTERNAL_ERROR;
  354. }
  355. return SUCCESS;
  356. }
  357. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  358. std::lock_guard<std::recursive_mutex> lock(mutex_);
  359. rtMemType_t memory_type = RT_MEMORY_HBM;
  360. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  361. }
  362. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  363. std::lock_guard<std::recursive_mutex> lock(mutex_);
  364. MemResource *mem_resource = nullptr;
  365. auto iter = mem_resource_map_.find(memory_type);
  366. if (iter == mem_resource_map_.end()) {
  367. return 0;
  368. } else {
  369. mem_resource = iter->second;
  370. }
  371. if (mem_resource == nullptr) {
  372. REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu",
  373. memory_type, session_id_);
  374. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%d, session_id:%lu",
  375. memory_type, session_id_);
  376. return 0;
  377. }
  378. return mem_resource->GetVarMemSize();
  379. }
  380. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  381. rtMemType_t memory_type) {
  382. std::lock_guard<std::recursive_mutex> lock(mutex_);
  383. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  384. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  385. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  386. int64_t tensor_desc_size = 0;
  387. size_t mem_offset = 0;
  388. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  389. if (result != ge::SUCCESS) {
  390. REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu",
  391. var_name.c_str(), memory_type, session_id_);
  392. GELOGE(result, "[Get][Size] from tensor fail, var_name:%s, memory_type:%u, session_id:%lu",
  393. var_name.c_str(), memory_type, session_id_);
  394. return result;
  395. }
  396. MemResource *mem_resource = nullptr;
  397. auto it = mem_resource_map_.find(memory_type);
  398. if (it == mem_resource_map_.end()) {
  399. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  400. if (mem_resource == nullptr) {
  401. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  402. memory_type, session_id_);
  403. GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu.",
  404. memory_type, session_id_);
  405. return ge::INTERNAL_ERROR;
  406. } else {
  407. mem_resource_map_[memory_type] = mem_resource;
  408. }
  409. } else {
  410. mem_resource = it->second;
  411. }
  412. if (mem_resource == nullptr) {
  413. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  414. memory_type, session_id_);
  415. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu.",
  416. memory_type, session_id_);
  417. return ge::INTERNAL_ERROR;
  418. }
  419. if (var_resource_ == nullptr) {
  420. REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, "
  421. "check invalid", memory_type, session_id_);
  422. GELOGW("VarManager has not been init.");
  423. return ge::INTERNAL_ERROR;
  424. }
  425. ge::GeTensorDesc cur_tensor_desc;
  426. int64_t cur_tensor_desc_size = 0;
  427. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  428. // reuse old format variable memory
  429. if (result == SUCCESS) {
  430. result = var_resource_->GetVarAddr(
  431. var_name, cur_tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  432. if (result == SUCCESS) {
  433. result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size);
  434. GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size,
  435. cur_tensor_desc_size, mem_offset);
  436. }
  437. }
  438. bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size);
  439. if (can_not_reuse_old_memory) {
  440. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  441. if (result != SUCCESS) {
  442. GELOGE(ge::INTERNAL_ERROR, "[Assign][VarMem] by offset failed, session_id:%lu.", session_id_);
  443. return ge::INTERNAL_ERROR;
  444. }
  445. result = var_resource_->SaveVarAddr(
  446. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  447. if (result != SUCCESS) {
  448. GELOGE(ge::INTERNAL_ERROR, "[Save][VarAddr] by offset failed, memory type:%u, session_id:%lu.",
  449. memory_type, session_id_);
  450. return ge::INTERNAL_ERROR;
  451. }
  452. }
  453. // old not exist only save new tensor
  454. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  455. if (result != SUCCESS) {
  456. var_resource_->SetVarAddr(var_name, tensor_desc,
  457. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  458. return SUCCESS;
  459. }
  460. bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  461. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  462. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims();
  463. if (format_changed) {
  464. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  465. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  466. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  467. tensor_desc.GetShape().GetDims().size(),
  468. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  469. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  470. cur_tensor_desc.GetShape().GetDims().size());
  471. var_resource_->SetVarAddr(var_name, tensor_desc,
  472. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  473. }
  474. return SUCCESS;
  475. }
  476. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  477. std::lock_guard<std::recursive_mutex> lock(mutex_);
  478. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  479. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  480. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  481. if (var_resource_ == nullptr) {
  482. GELOGW("VarManager has not been init.");
  483. return false;
  484. }
  485. return var_resource_->IsVarExist(var_name, tensor_desc);
  486. }
  487. bool VarManager::IsVarExist(const std::string &var_name) {
  488. std::lock_guard<std::recursive_mutex> lock(mutex_);
  489. if (var_resource_ == nullptr) {
  490. GELOGW("VarManager has not been init.");
  491. return false;
  492. }
  493. return var_resource_->IsVarExist(var_name);
  494. }
  495. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  496. std::lock_guard<std::recursive_mutex> lock(mutex_);
  497. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  498. if (var_resource_ == nullptr) {
  499. GELOGW("VarManager has not been init.");
  500. return ge::INTERNAL_ERROR;
  501. }
  502. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  503. }
  504. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  505. std::lock_guard<std::recursive_mutex> lock(mutex_);
  506. GELOGI(
  507. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  508. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  509. "output_size = %lu",
  510. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  511. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  512. broad_cast_info.output_size);
  513. if (var_resource_ == nullptr) {
  514. GELOGW("VarManager has not been init.");
  515. return ge::INTERNAL_ERROR;
  516. }
  517. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  518. return SUCCESS;
  519. }
  520. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  521. std::lock_guard<std::recursive_mutex> lock(mutex_);
  522. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  523. if (var_resource_ == nullptr) {
  524. REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid",
  525. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  526. session_id_);
  527. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] VarManager has not been init, op:%s(%s), session_id:%lu",
  528. op_desc->GetName().c_str(), op_desc->GetType().c_str(), session_id_);
  529. return ge::INTERNAL_ERROR;
  530. }
  531. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  532. }
  533. bool VarManager::IsVarAddr(const int64_t &offset) {
  534. std::lock_guard<std::recursive_mutex> lock(mutex_);
  535. if (var_resource_ == nullptr) {
  536. GELOGD("VarManager has not been init.");
  537. return false;
  538. }
  539. return var_resource_->IsVarAddr(offset);
  540. }
  541. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  542. std::lock_guard<std::recursive_mutex> lock(mutex_);
  543. if (var_resource_ == nullptr) {
  544. GELOGW("VarManager has not been init.");
  545. return RT_MEMORY_RESERVED;
  546. }
  547. return var_resource_->GetVarMemType(offset);
  548. }
  549. void VarManager::SetMemManager(MemoryManager *mem_manager) {
  550. // Better use shared_ptr instead, reconsitution later.
  551. GELOGI("Set MemManager to VarManager.");
  552. std::lock_guard<std::recursive_mutex> lock(mutex_);
  553. mem_manager_ = mem_manager;
  554. }
  555. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  556. std::lock_guard<std::recursive_mutex> lock(mutex_);
  557. if (mem_manager_ == nullptr) {
  558. GELOGE(FAILED, "MemManager has not been init.");
  559. REPORT_INNER_ERROR("E19999", "MemManager has not been init, session_id: %lu", session_id_);
  560. return FAILED;
  561. }
  562. uint8_t *var_mem_base = nullptr;
  563. string memory_key = std::to_string(session_id_);
  564. // malloc variable memory
  565. size_t var_memory_size = memory_size;
  566. // align 512 BYTE
  567. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  568. const string purpose("variables and constant op memory in training network.");
  569. var_mem_base = mem_manager_->MallocMemory(RT_MEMORY_HBM, purpose, memory_key, var_memory_size, device_id_);
  570. if (var_mem_base == nullptr) {
  571. GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s",
  572. var_memory_size, memory_key.c_str());
  573. return ge::INTERNAL_ERROR;
  574. }
  575. return SUCCESS;
  576. }
  577. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  578. std::lock_guard<std::recursive_mutex> lock(mutex_);
  579. if (mem_manager_ == nullptr) {
  580. GELOGE(FAILED, "MemManager has not been init.");
  581. REPORT_INNER_ERROR("E19999", "MemManager has not been init, session_id: %lu", session_id_);
  582. return nullptr;
  583. }
  584. string memory_key = std::to_string(session_id_);
  585. return mem_manager_->GetMemoryBase(memory_type, memory_key, device_id_);
  586. }
  587. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  588. std::lock_guard<std::recursive_mutex> lock(mutex_);
  589. if (mem_manager_ == nullptr) {
  590. GELOGE(FAILED, "MemManager has not been init.");
  591. REPORT_INNER_ERROR("E19999", "MemManager has not been init, session_id: %lu", session_id_);
  592. return nullptr;
  593. }
  594. if (memory_type == RT_MEMORY_RDMA_HBM) {
  595. return logic_addr;
  596. }
  597. string mem_key = std::to_string(session_id_);
  598. uint8_t *mem_base = mem_manager_->GetMemoryAddr(memory_type, mem_key, device_id_);
  599. if (mem_base == nullptr) {
  600. return nullptr;
  601. }
  602. uint8_t *mem_addr =
  603. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  604. return mem_addr;
  605. }
  606. ge::Status VarManager::FreeVarMemory() {
  607. std::lock_guard<std::recursive_mutex> lock(mutex_);
  608. if (mem_manager_ == nullptr) {
  609. GELOGE(FAILED, "MemManager has not been init.");
  610. REPORT_INNER_ERROR("E19999", "MemManager has not been init, session_id: %lu", session_id_);
  611. return FAILED;
  612. }
  613. string memory_key = std::to_string(SessionId());
  614. return mem_manager_->FreeMemory(RT_MEMORY_HBM, memory_key, device_id_);
  615. }
  616. uint8_t *VarManager::GetPoolMemory(rtMemType_t memory_type, size_t mem_size) {
  617. std::lock_guard<std::recursive_mutex> lock(mutex_);
  618. if (mem_manager_ == nullptr) {
  619. GELOGE(FAILED, "MemManager has not been init.");
  620. REPORT_INNER_ERROR("E19999", "MemManager has not been init, session_id: %lu", session_id_);
  621. return nullptr;
  622. }
  623. return mem_manager_->GetPoolMemory(memory_type, mem_size, device_id_);
  624. }
  625. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  626. std::lock_guard<std::recursive_mutex> lock(mutex_);
  627. if (var_resource_ == nullptr) {
  628. GELOGW("VarManager has not been init.");
  629. return ge::INTERNAL_ERROR;
  630. }
  631. return var_resource_->SetTransRoad(var_name, trans_road);
  632. }
  633. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  634. std::lock_guard<std::recursive_mutex> lock(mutex_);
  635. if (var_resource_ == nullptr) {
  636. GELOGW("VarManager has not been init.");
  637. return nullptr;
  638. }
  639. return var_resource_->GetTransRoad(var_name);
  640. }
  641. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  642. std::lock_guard<std::recursive_mutex> lock(mutex_);
  643. if (var_resource_ == nullptr) {
  644. GELOGW("VarManager has not been init.");
  645. return INTERNAL_ERROR;
  646. }
  647. return var_resource_->SetChangedGraphId(var_name, graph_id);
  648. }
  649. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  650. std::lock_guard<std::recursive_mutex> lock(mutex_);
  651. if (var_resource_ == nullptr) {
  652. GELOGW("VarManager has not been init.");
  653. return INTERNAL_ERROR;
  654. }
  655. return var_resource_->GetChangedGraphId(var_name, graph_id);
  656. }
  657. Status VarManager::SetMemoryMallocSize(const map<string, string> &options, size_t total_mem_size) {
  658. GEEVENT("Total memory size is %zu", total_mem_size);
  659. graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio);
  660. var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio);
  661. auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE);
  662. if (it1 != options.end()) {
  663. string graph_memory_manager_malloc_max_size = it1->second;
  664. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  665. if (ret != SUCCESS) {
  666. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  667. return ge::GE_GRAPH_OPTIONS_INVALID;
  668. }
  669. }
  670. auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE);
  671. if (it2 != options.end()) {
  672. string memory_var_manager_malloc_size = it2->second;
  673. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  674. if (ret != SUCCESS) {
  675. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  676. return ge::GE_GRAPH_OPTIONS_INVALID;
  677. }
  678. }
  679. GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_);
  680. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  681. if (var_mem_logic_base_ > kMaxMemorySize) {
  682. REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  683. var_mem_logic_base_, kMaxMemorySize, session_id_);
  684. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kMemoryVarLogicBase:%zu can not exceed "
  685. "max memory size:%zu, session_id:%lu.", var_mem_logic_base_, kMaxMemorySize, session_id_);
  686. return ge::GE_GRAPH_OPTIONS_INVALID;
  687. }
  688. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  689. if (use_max_mem_size_ > kMaxMemorySize) {
  690. REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  691. use_max_mem_size_, kMaxMemorySize, session_id_);
  692. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kUseMaxMemorySize:%zu can not exceed "
  693. "max memory size:%zu, session_id:%lu.", use_max_mem_size_, kMaxMemorySize, session_id_);
  694. return ge::GE_GRAPH_OPTIONS_INVALID;
  695. }
  696. GELOGI("Set memory malloc size successfully");
  697. return SUCCESS;
  698. }
  699. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  700. if (memory_size.empty()) {
  701. REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid",
  702. session_id_);
  703. GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Check][Param] Memory malloc size input is empty, session_id:%lu.", session_id_);
  704. return GE_GRAPH_OPTIONS_INVALID;
  705. }
  706. // split string by '*'
  707. vector<string> splits;
  708. std::istringstream str(memory_size);
  709. string str_split;
  710. while (getline(str, str_split, '*')) {
  711. splits.emplace_back(str_split);
  712. }
  713. result = 1;
  714. for (string split : splits) {
  715. // Trim
  716. auto it = split.find_first_not_of(" ");
  717. if (it != string::npos) {
  718. split.erase(0, it);
  719. }
  720. it = split.find_last_not_of(" ");
  721. if (it != string::npos) {
  722. split.erase(it + 1);
  723. }
  724. for (char c : split) {
  725. if (!isdigit(c)) {
  726. REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid",
  727. memory_size.c_str(), session_id_);
  728. GELOGE(GE_GRAPH_OPTIONS_INVALID,
  729. "[Check][Param] Memory malloc size:%s input contains non digit, session_id:%lu.",
  730. memory_size.c_str(), session_id_);
  731. return GE_GRAPH_OPTIONS_INVALID;
  732. }
  733. }
  734. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  735. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  736. REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, "
  737. "check invalid", memory_size.c_str(),
  738. session_id_);
  739. GELOGE(FAILED, "[Check][Param] Param memory_size:%s will overflow after multi all, session_id:%lu",
  740. memory_size.c_str(), session_id_);
  741. return FAILED);
  742. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  743. REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, "
  744. "check invalid", memory_size.c_str(), kMaxMemorySize,
  745. session_id_);
  746. GELOGE(FAILED, "[Check][Param] Input memory size can not exceed max memory size:%zu, session_id:%lu.",
  747. kMaxMemorySize, session_id_);
  748. return FAILED;
  749. }
  750. result *= static_cast<size_t>(num);
  751. }
  752. return SUCCESS;
  753. }
  754. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  755. std::lock_guard<std::recursive_mutex> lock(mutex_);
  756. if (var_resource_ == nullptr) {
  757. GELOGW("VarManager has not been init.");
  758. return;
  759. }
  760. var_resource_->RemoveChangedGraphId(var_name);
  761. }
  762. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  763. std::lock_guard<std::recursive_mutex> lock(mutex_);
  764. if (var_resource_ == nullptr) {
  765. GELOGW("VarManager has not been init.");
  766. return INTERNAL_ERROR;
  767. }
  768. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  769. }
  770. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  771. std::lock_guard<std::recursive_mutex> lock(mutex_);
  772. if (var_resource_ == nullptr) {
  773. GELOGW("VarManager has not been init.");
  774. return INTERNAL_ERROR;
  775. }
  776. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  777. }
  778. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  779. std::lock_guard<std::recursive_mutex> lock(mutex_);
  780. if (var_resource_ == nullptr) {
  781. GELOGW("VarManager has not been init.");
  782. return;
  783. }
  784. var_resource_->RemoveAllocatedGraphId(var_name);
  785. }
  786. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  787. std::lock_guard<std::recursive_mutex> lock(mutex_);
  788. if (var_resource_ == nullptr) {
  789. GELOGW("VarManager has not been inited.");
  790. return INTERNAL_ERROR;
  791. }
  792. auto new_variable_desc = var_resource_->GetAllVarDesc();
  793. if (new_variable_desc.size() == 0) {
  794. GELOGW("VarManager don't have variables.");
  795. return INTERNAL_ERROR;
  796. }
  797. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  798. auto trans_road = var_resource_->GetTransRoad(iter->first);
  799. if (trans_road == nullptr || trans_road->empty()) {
  800. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  801. all_variables[iter->first] = iter->second;
  802. continue;
  803. }
  804. // get origin trans info : the first trans node info
  805. auto origin_trans_node_info = trans_road->at(0);
  806. all_variables[iter->first] = origin_trans_node_info.input;
  807. }
  808. return SUCCESS;
  809. }
  810. VarManagerPool::~VarManagerPool() { Destory(); }
  811. VarManagerPool &VarManagerPool::Instance() {
  812. static VarManagerPool var_manager_pool;
  813. return var_manager_pool;
  814. }
  815. void VarManagerPool::Destory() noexcept {
  816. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  817. for (auto &it : var_manager_map_) {
  818. VarManager *var_manager = it.second;
  819. if (var_manager != nullptr) {
  820. var_manager->Destory();
  821. delete var_manager;
  822. var_manager = nullptr;
  823. }
  824. }
  825. var_manager_map_.clear();
  826. }
  827. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  828. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  829. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  830. auto it = var_manager_map_.find(session_id);
  831. if (it != var_manager_map_.end()) {
  832. GELOGD("VarManagerPool::GetVarManager");
  833. return it->second;
  834. }
  835. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  836. if (var_manager == nullptr) {
  837. REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id);
  838. GELOGE(INTERNAL_ERROR, "[New][VarManager] fail, session_id:%lu", session_id);
  839. static VarManager new_var_manager(0);
  840. return &new_var_manager;
  841. }
  842. var_manager_map_[session_id] = var_manager;
  843. return var_manager;
  844. }
  845. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  846. VarManager *var_manager = nullptr;
  847. {
  848. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  849. auto it = var_manager_map_.find(session_id);
  850. if (it != var_manager_map_.end()) {
  851. var_manager = it->second;
  852. var_manager_map_.erase(it);
  853. }
  854. }
  855. if (var_manager != nullptr) {
  856. var_manager->Destory();
  857. delete var_manager;
  858. var_manager = nullptr;
  859. }
  860. }
  861. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示