You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 41 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/graph_mem_allocator.h"
  19. #include "graph/manager/rdma_pool_allocator.h"
  20. #include "graph/manager/trans_var_data_utils.h"
  21. #include "graph/utils/type_utils.h"
  22. using std::map;
  23. using std::string;
  24. using std::vector;
  25. namespace ge {
  26. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  27. VarResource::~VarResource() {
  28. var_offset_map_.clear();
  29. var_addr_mgr_map_.clear();
  30. cur_var_tensor_desc_map_.clear();
  31. var_broad_cast_info_.clear();
  32. }
  33. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  34. rtMemType_t &memory_type) {
  35. if (dev_ptr == nullptr) {
  36. REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, "
  37. "check invalid", var_name.c_str(), session_id_);
  38. GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!");
  39. return FAILED;
  40. }
  41. std::string var_key = VarKey(var_name, tensor_desc);
  42. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  43. auto iter = var_addr_mgr_map_.find(var_key);
  44. if (iter == var_addr_mgr_map_.end()) {
  45. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  46. "check invalid", var_key.c_str(), var_name.c_str(),
  47. session_id_);
  48. GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str());
  49. return FAILED;
  50. }
  51. *dev_ptr = iter->second.address;
  52. memory_type = iter->second.memory_type;
  53. return SUCCESS;
  54. }
  55. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  56. var_addr_mgr_map = var_addr_mgr_map_;
  57. }
  58. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  59. rtMemType_t memory_type) {
  60. std::string var_key = VarKey(var_name, tensor_desc);
  61. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  62. if (var_addr_mgr_map_.count(var_key) == 0) {
  63. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  64. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  65. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  66. VarAddrMgr var_addr_mgr;
  67. var_addr_mgr.address = dev_ptr;
  68. var_addr_mgr.tensor_desc = tensor_desc;
  69. var_addr_mgr_map_[var_key] = var_addr_mgr;
  70. }
  71. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  72. }
  73. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  74. rtMemType_t memory_type) {
  75. std::string var_key = VarKey(var_name, tensor_desc);
  76. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  77. if (var_addr_mgr_map_.count(var_key) == 0) {
  78. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  79. if (memory_type != RT_MEMORY_RDMA_HBM) {
  80. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  81. }
  82. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  83. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  84. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  85. VarAddrMgr var_addr_mgr;
  86. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  87. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  88. var_addr_mgr.tensor_desc = tensor_desc;
  89. var_addr_mgr.memory_type = memory_type;
  90. var_addr_mgr_map_[var_key] = var_addr_mgr;
  91. var_offset_map_[logic_address] = memory_type;
  92. return SUCCESS;
  93. }
  94. REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  95. "check invalid", var_key.c_str(), var_name.c_str(),
  96. session_id_);
  97. GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str());
  98. return FAILED;
  99. }
  100. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  101. std::string var_key = VarKey(var_name, tensor_desc);
  102. return var_addr_mgr_map_.count(var_key) != 0;
  103. }
  104. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  105. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  106. std::string var_key(var_name);
  107. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  108. .append("_")
  109. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  110. return var_key;
  111. }
  112. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  113. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  114. return FAILED;
  115. }
  116. tensor_desc = cur_var_tensor_desc_map_[var_name];
  117. return SUCCESS;
  118. }
  119. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  120. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  121. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  122. return SUCCESS;
  123. }
  124. if (op_desc == nullptr) {
  125. REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid",
  126. var_name.c_str(), session_id_);
  127. GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!");
  128. return FAILED;
  129. }
  130. ge::GeTensorDesc curr_desc;
  131. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  132. if (ret != SUCCESS) {
  133. GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!");
  134. return FAILED;
  135. }
  136. std::string key = VarKey(var_name, curr_desc);
  137. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  138. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  139. cur_var_tensor_desc_map_[var_name] = curr_desc;
  140. auto iter = var_addr_mgr_map_.find(key);
  141. if (iter == var_addr_mgr_map_.end()) {
  142. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), "
  143. "check invalid", key.c_str(), var_name.c_str(),
  144. session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  145. GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str());
  146. return FAILED;
  147. }
  148. auto val = iter->second;
  149. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  150. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  151. var_addr_mgr_map_.erase(iter);
  152. key = VarKey(var_name, curr_desc);
  153. var_addr_mgr_map_[key] = val;
  154. return SUCCESS;
  155. }
  156. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  157. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  158. }
  159. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  160. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  161. return FAILED;
  162. }
  163. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  164. return SUCCESS;
  165. }
  166. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  167. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  168. GE_CHECK_NOTNULL(base_ptr);
  169. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  170. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  171. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  172. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  173. var_broadcast_info.input_size, session_id_);
  174. }
  175. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  176. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  177. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  178. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  179. // subgraph base_ptr could be nullptr, task it as base 0
  180. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  181. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  182. var_tensor_desc, session_id_);
  183. }
  184. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  185. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  186. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  187. }
  188. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  189. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  190. if (var_offset_map_.count(offset) > 0) {
  191. return var_offset_map_[offset];
  192. }
  193. return RT_MEMORY_RESERVED;
  194. }
  195. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  196. auto iter = var_to_trans_road_.find(var_name);
  197. if (iter == var_to_trans_road_.end()) {
  198. return nullptr;
  199. } else {
  200. return &(iter->second);
  201. }
  202. }
  203. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  204. auto iter = var_names_to_changed_graph_id_.find(var_name);
  205. if (iter == var_names_to_changed_graph_id_.end()) {
  206. return FAILED;
  207. } else {
  208. graph_id = iter->second;
  209. return SUCCESS;
  210. }
  211. }
  212. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  213. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  214. if (iter == var_names_to_allocated_graph_id_.end()) {
  215. return FAILED;
  216. } else {
  217. graph_id = iter->second;
  218. return SUCCESS;
  219. }
  220. }
  221. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  222. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  223. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  224. return SUCCESS;
  225. }
  226. var_names_to_allocated_graph_id_[var_name] = graph_id;
  227. return SUCCESS;
  228. }
  229. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  230. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  231. switch (mem_type) {
  232. case RT_MEMORY_HBM:
  233. return new (std::nothrow) HbmMemResource();
  234. case RT_MEMORY_RDMA_HBM:
  235. return new (std::nothrow) RdmaMemResource();
  236. default:
  237. return nullptr;
  238. }
  239. }
  240. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  241. size_t &mem_offset) {
  242. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  243. uint64_t real_size = size;
  244. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  245. if (total_size_ < var_mem_size_) {
  246. REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid"
  247. "", total_size_, var_mem_size_, size, var_name.c_str());
  248. GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_);
  249. return PARAM_INVALID;
  250. }
  251. uint64_t free_size = total_size_ - var_mem_size_;
  252. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  253. REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid",
  254. free_size, size, var_name.c_str());
  255. GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
  256. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  257. return PARAM_INVALID;
  258. }
  259. mem_offset = var_mem_size_;
  260. // offset for next, align 512 BYTE
  261. size = size + kSessionMemAlignSize;
  262. var_mem_size_ = var_mem_size_ + size;
  263. // align 512 BYTE
  264. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  265. GELOGI(
  266. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  267. "offset to [%zu] size[%lu] realsize[%lu].",
  268. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  269. return SUCCESS;
  270. }
  271. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  272. uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
  273. if (buffer == nullptr) {
  274. REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s",
  275. size, var_name.c_str());
  276. GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %lu", var_name.c_str(), size);
  277. return MEMALLOC_FAILED;
  278. }
  279. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  280. var_mem_size_ += size;
  281. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  282. session_id, var_name.c_str(), 0, buffer, size);
  283. return SUCCESS;
  284. }
  285. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  286. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  287. VarManager::VarManager(uint64_t session_id)
  288. : version_(SessionVersion::OTHER_VERSION),
  289. session_id_(session_id),
  290. device_id_(0),
  291. job_id_(0),
  292. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  293. var_mem_max_size_(kMemoryVarManagerMallocSize),
  294. var_mem_logic_base_(kMemoryVarLogicBase),
  295. use_max_mem_size_(kUseMaxMemorySize) {}
  296. VarManager *VarManager::Instance(uint64_t session_id) {
  297. GELOGD("VarManager::Instance, session id = %lu", session_id);
  298. return VarManagerPool::Instance().GetVarManager(session_id);
  299. }
  300. void VarManager::Destory() {
  301. std::lock_guard<std::recursive_mutex> lock(mutex_);
  302. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  303. version_ = SessionVersion::OTHER_VERSION;
  304. device_id_ = 0;
  305. session_id_ = 0;
  306. for (auto &memory_resource : mem_resource_map_) {
  307. if (memory_resource.second != nullptr) {
  308. delete memory_resource.second;
  309. memory_resource.second = nullptr;
  310. }
  311. }
  312. mem_resource_map_.clear();
  313. }
  314. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  315. const uint64_t &job_id) {
  316. std::lock_guard<std::recursive_mutex> lock(mutex_);
  317. GELOGI("VarManager::Init, session id = %lu.", session_id);
  318. if (var_resource_ == nullptr) {
  319. version_ = version;
  320. device_id_ = device_id;
  321. session_id_ = session_id;
  322. job_id_ = job_id;
  323. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  324. if (var_resource_ == nullptr) {
  325. GELOGW("VarManager init failed session id = %lu.", session_id);
  326. return ge::INTERNAL_ERROR;
  327. }
  328. } else {
  329. GELOGW("VarManager::has been inited, session id = %lu.", session_id);
  330. }
  331. return SUCCESS;
  332. }
  333. const uint64_t &VarManager::SessionId() const {
  334. std::lock_guard<std::recursive_mutex> lock(mutex_);
  335. return session_id_;
  336. }
  337. const uint32_t &VarManager::DeviceId() const {
  338. std::lock_guard<std::recursive_mutex> lock(mutex_);
  339. return device_id_;
  340. }
  341. const uint64_t &VarManager::JobId() const {
  342. std::lock_guard<std::recursive_mutex> lock(mutex_);
  343. return job_id_;
  344. }
  345. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  346. rtMemType_t memory_type) {
  347. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  348. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  349. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  350. std::lock_guard<std::recursive_mutex> lock(mutex_);
  351. if (var_resource_ == nullptr) {
  352. GELOGW("VarManager has not been init.");
  353. return ge::INTERNAL_ERROR;
  354. }
  355. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  356. return ge::SUCCESS;
  357. }
  358. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  359. rtMemType_t memory_type) {
  360. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  361. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  362. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  363. std::lock_guard<std::recursive_mutex> lock(mutex_);
  364. if (var_resource_ == nullptr) {
  365. GELOGW("VarManager has not been init.");
  366. return ge::INTERNAL_ERROR;
  367. }
  368. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  369. return ge::SUCCESS;
  370. }
  371. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  372. rtMemType_t &memory_type) {
  373. std::lock_guard<std::recursive_mutex> lock(mutex_);
  374. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  375. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  376. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  377. if (var_resource_ == nullptr) {
  378. GELOGW("VarManager has not been init.");
  379. return ge::INTERNAL_ERROR;
  380. }
  381. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  382. if (ret != SUCCESS) {
  383. GELOGW("GetVarAddr fail.");
  384. return ge::INTERNAL_ERROR;
  385. }
  386. return SUCCESS;
  387. }
  388. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  389. std::lock_guard<std::recursive_mutex> lock(mutex_);
  390. rtMemType_t memory_type = RT_MEMORY_HBM;
  391. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  392. }
  393. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  394. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  395. }
  396. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  397. std::lock_guard<std::recursive_mutex> lock(mutex_);
  398. MemResource *mem_resource = nullptr;
  399. auto iter = mem_resource_map_.find(memory_type);
  400. if (iter == mem_resource_map_.end()) {
  401. return 0;
  402. } else {
  403. mem_resource = iter->second;
  404. }
  405. if (mem_resource == nullptr) {
  406. REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu",
  407. memory_type, session_id_);
  408. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  409. return 0;
  410. }
  411. return mem_resource->GetVarMemSize();
  412. }
  413. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  414. std::lock_guard<std::recursive_mutex> lock(mutex_);
  415. MemResource *mem_resource = nullptr;
  416. auto iter = mem_resource_map_.find(memory_type);
  417. if (iter == mem_resource_map_.end()) {
  418. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  419. if (mem_resource == nullptr) {
  420. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  421. memory_type, session_id_);
  422. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  423. return ge::INTERNAL_ERROR;
  424. } else {
  425. mem_resource_map_[memory_type] = mem_resource;
  426. }
  427. } else {
  428. mem_resource = iter->second;
  429. }
  430. if (mem_resource == nullptr) {
  431. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  432. memory_type, session_id_);
  433. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  434. return FAILED;
  435. }
  436. mem_resource->UpdateVarMemSize(mem_size);
  437. return SUCCESS;
  438. }
  439. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  440. rtMemType_t memory_type) {
  441. std::lock_guard<std::recursive_mutex> lock(mutex_);
  442. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  443. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  444. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  445. int64_t tensor_desc_size = 0;
  446. size_t mem_offset = 0;
  447. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  448. if (result != ge::SUCCESS) {
  449. REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu",
  450. var_name.c_str(), memory_type, session_id_);
  451. GELOGE(result, "get size from TensorDesc failed");
  452. return result;
  453. }
  454. MemResource *mem_resource = nullptr;
  455. auto it = mem_resource_map_.find(memory_type);
  456. if (it == mem_resource_map_.end()) {
  457. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  458. if (mem_resource == nullptr) {
  459. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  460. memory_type, session_id_);
  461. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  462. return ge::INTERNAL_ERROR;
  463. } else {
  464. mem_resource_map_[memory_type] = mem_resource;
  465. }
  466. } else {
  467. mem_resource = it->second;
  468. }
  469. if (mem_resource == nullptr) {
  470. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  471. memory_type, session_id_);
  472. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type);
  473. return ge::INTERNAL_ERROR;
  474. }
  475. if (var_resource_ == nullptr) {
  476. REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, "
  477. "check invalid", memory_type, session_id_);
  478. GELOGW("VarManager has not been init.");
  479. return ge::INTERNAL_ERROR;
  480. }
  481. ge::GeTensorDesc cur_tensor_desc;
  482. int64_t cur_tensor_desc_size = 0;
  483. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  484. // reuse old format variable memory
  485. if (result == SUCCESS) {
  486. result = var_resource_->GetVarAddr(
  487. var_name, cur_tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  488. if (result == SUCCESS) {
  489. result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size);
  490. GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size,
  491. cur_tensor_desc_size, mem_offset);
  492. }
  493. }
  494. bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size);
  495. if (can_not_reuse_old_memory) {
  496. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  497. if (result != SUCCESS) {
  498. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  499. return ge::INTERNAL_ERROR;
  500. }
  501. result = var_resource_->SaveVarAddr(
  502. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  503. if (result != SUCCESS) {
  504. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  505. return ge::INTERNAL_ERROR;
  506. }
  507. }
  508. // old not exist only save new tensor
  509. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  510. if (result != SUCCESS) {
  511. var_resource_->SetVarAddr(var_name, tensor_desc,
  512. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  513. return SUCCESS;
  514. }
  515. bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  516. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  517. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims();
  518. if (format_changed) {
  519. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  520. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  521. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  522. tensor_desc.GetShape().GetDims().size(),
  523. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  524. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  525. cur_tensor_desc.GetShape().GetDims().size());
  526. var_resource_->SetVarAddr(var_name, tensor_desc,
  527. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  528. }
  529. return SUCCESS;
  530. }
  531. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  532. std::lock_guard<std::recursive_mutex> lock(mutex_);
  533. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  534. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  535. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  536. if (var_resource_ == nullptr) {
  537. GELOGW("VarManager has not been init.");
  538. return false;
  539. }
  540. return var_resource_->IsVarExist(var_name, tensor_desc);
  541. }
  542. bool VarManager::IsVarExist(const std::string &var_name) {
  543. std::lock_guard<std::recursive_mutex> lock(mutex_);
  544. if (var_resource_ == nullptr) {
  545. GELOGW("VarManager has not been init.");
  546. return false;
  547. }
  548. return var_resource_->IsVarExist(var_name);
  549. }
  550. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  551. uint8_t *base_ptr) {
  552. std::lock_guard<std::recursive_mutex> lock(mutex_);
  553. if (var_resource_ == nullptr) {
  554. GELOGW("VarManager has not been init.");
  555. return ge::INTERNAL_ERROR;
  556. }
  557. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  558. }
  559. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  560. std::lock_guard<std::recursive_mutex> lock(mutex_);
  561. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  562. if (var_resource_ == nullptr) {
  563. GELOGW("VarManager has not been init.");
  564. return ge::INTERNAL_ERROR;
  565. }
  566. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  567. }
  568. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  569. std::lock_guard<std::recursive_mutex> lock(mutex_);
  570. GELOGI(
  571. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  572. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  573. "output_size = %lu",
  574. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  575. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  576. broad_cast_info.output_size);
  577. if (var_resource_ == nullptr) {
  578. GELOGW("VarManager has not been init.");
  579. return ge::INTERNAL_ERROR;
  580. }
  581. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  582. return SUCCESS;
  583. }
  584. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  585. std::lock_guard<std::recursive_mutex> lock(mutex_);
  586. if (var_resource_ == nullptr) {
  587. GELOGW("VarManager has not been init.");
  588. return ge::INTERNAL_ERROR;
  589. }
  590. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  591. }
  592. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  593. std::lock_guard<std::recursive_mutex> lock(mutex_);
  594. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  595. if (var_resource_ == nullptr) {
  596. REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid",
  597. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  598. session_id_);
  599. GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init.");
  600. return ge::INTERNAL_ERROR;
  601. }
  602. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  603. }
  604. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  605. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  606. std::lock_guard<std::recursive_mutex> lock(mutex_);
  607. if (var_resource_ == nullptr) {
  608. GELOGW("VarManager has not been init.");
  609. return ge::INTERNAL_ERROR;
  610. }
  611. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  612. }
  613. bool VarManager::IsVarAddr(const int64_t &offset) {
  614. std::lock_guard<std::recursive_mutex> lock(mutex_);
  615. if (var_resource_ == nullptr) {
  616. GELOGD("VarManager has not been init.");
  617. return false;
  618. }
  619. return var_resource_->IsVarAddr(offset);
  620. }
  621. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  622. std::lock_guard<std::recursive_mutex> lock(mutex_);
  623. if (var_resource_ == nullptr) {
  624. GELOGW("VarManager has not been init.");
  625. return RT_MEMORY_RESERVED;
  626. }
  627. return var_resource_->GetVarMemType(offset);
  628. }
  629. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  630. std::lock_guard<std::recursive_mutex> lock(mutex_);
  631. uint8_t *var_mem_base = nullptr;
  632. string memory_key = std::to_string(session_id_);
  633. // malloc variable memory
  634. size_t var_memory_size = memory_size;
  635. // align 512 BYTE
  636. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  637. const string purpose("variables and constant op memory in training network.");
  638. var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size);
  639. if (var_mem_base == nullptr) {
  640. GELOGE(ge::INTERNAL_ERROR,
  641. "VarManager::MallocVarMemory failed "
  642. "session_id = %s",
  643. memory_key.c_str());
  644. return ge::INTERNAL_ERROR;
  645. }
  646. return SUCCESS;
  647. }
  648. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  649. std::lock_guard<std::recursive_mutex> lock(mutex_);
  650. if (memory_type == RT_MEMORY_RDMA_HBM) {
  651. return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
  652. }
  653. string memory_key = std::to_string(session_id_);
  654. return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key);
  655. }
  656. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  657. std::lock_guard<std::recursive_mutex> lock(mutex_);
  658. if (memory_type == RT_MEMORY_RDMA_HBM) {
  659. return logic_addr;
  660. }
  661. string mem_key = std::to_string(session_id_);
  662. uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key);
  663. if (mem_base == nullptr) {
  664. return nullptr;
  665. }
  666. uint8_t *mem_addr =
  667. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  668. return mem_addr;
  669. }
  670. ge::Status VarManager::FreeVarMemory() {
  671. std::lock_guard<std::recursive_mutex> lock(mutex_);
  672. string memory_key = std::to_string(SessionId());
  673. return MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key);
  674. }
  675. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  676. std::lock_guard<std::recursive_mutex> lock(mutex_);
  677. if (var_resource_ == nullptr) {
  678. GELOGW("VarManager has not been init.");
  679. return ge::INTERNAL_ERROR;
  680. }
  681. return var_resource_->SetTransRoad(var_name, trans_road);
  682. }
  683. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  684. std::lock_guard<std::recursive_mutex> lock(mutex_);
  685. if (var_resource_ == nullptr) {
  686. GELOGW("VarManager has not been init.");
  687. return nullptr;
  688. }
  689. return var_resource_->GetTransRoad(var_name);
  690. }
  691. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  692. std::lock_guard<std::recursive_mutex> lock(mutex_);
  693. if (var_resource_ == nullptr) {
  694. GELOGW("VarManager has not been init.");
  695. return INTERNAL_ERROR;
  696. }
  697. return var_resource_->SetChangedGraphId(var_name, graph_id);
  698. }
  699. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  700. std::lock_guard<std::recursive_mutex> lock(mutex_);
  701. if (var_resource_ == nullptr) {
  702. GELOGW("VarManager has not been init.");
  703. return INTERNAL_ERROR;
  704. }
  705. return var_resource_->GetChangedGraphId(var_name, graph_id);
  706. }
  707. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  708. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  709. if (it == options.end()) {
  710. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  711. } else {
  712. string graph_memory_manager_malloc_max_size = it->second;
  713. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  714. if (ret != SUCCESS) {
  715. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed.");
  716. return ge::GE_GRAPH_OPTIONS_INVALID;
  717. }
  718. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  719. }
  720. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  721. if (it == options.end()) {
  722. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  723. } else {
  724. string memory_var_manager_malloc_size = it->second;
  725. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  726. if (ret != SUCCESS) {
  727. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed.");
  728. return ge::GE_GRAPH_OPTIONS_INVALID;
  729. }
  730. }
  731. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  732. if (var_mem_logic_base_ > kMaxMemorySize) {
  733. REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  734. var_mem_logic_base_, kMaxMemorySize, session_id_);
  735. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.",
  736. var_mem_logic_base_, kMaxMemorySize);
  737. return ge::GE_GRAPH_OPTIONS_INVALID;
  738. }
  739. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  740. if (use_max_mem_size_ > kMaxMemorySize) {
  741. REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  742. use_max_mem_size_, kMaxMemorySize, session_id_);
  743. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.",
  744. use_max_mem_size_, kMaxMemorySize);
  745. return ge::GE_GRAPH_OPTIONS_INVALID;
  746. }
  747. GELOGI("Set memory malloc size successfully");
  748. return SUCCESS;
  749. }
  750. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  751. if (memory_size.empty()) {
  752. REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid",
  753. session_id_);
  754. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty.");
  755. return GE_GRAPH_OPTIONS_INVALID;
  756. }
  757. // split string by '*'
  758. vector<string> splits;
  759. std::istringstream str(memory_size);
  760. string str_split;
  761. while (getline(str, str_split, '*')) {
  762. splits.emplace_back(str_split);
  763. }
  764. result = 1;
  765. for (string split : splits) {
  766. // Trim
  767. auto it = split.find_first_not_of(" ");
  768. if (it != string::npos) {
  769. split.erase(0, it);
  770. }
  771. it = split.find_last_not_of(" ");
  772. if (it != string::npos) {
  773. split.erase(it + 1);
  774. }
  775. for (char c : split) {
  776. if (!isdigit(c)) {
  777. REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid",
  778. memory_size.c_str(), session_id_);
  779. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit.");
  780. return GE_GRAPH_OPTIONS_INVALID;
  781. }
  782. }
  783. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  784. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  785. REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, "
  786. "check invalid", memory_size.c_str(),
  787. session_id_);
  788. GELOGE(FAILED, "Input memory size is out of range.");
  789. return FAILED);
  790. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  791. REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, "
  792. "check invalid", memory_size.c_str(), kMaxMemorySize,
  793. session_id_);
  794. GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize);
  795. return FAILED;
  796. }
  797. result *= static_cast<size_t>(num);
  798. }
  799. return SUCCESS;
  800. }
  801. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  802. std::lock_guard<std::recursive_mutex> lock(mutex_);
  803. if (var_resource_ == nullptr) {
  804. GELOGW("VarManager has not been init.");
  805. return;
  806. }
  807. var_resource_->RemoveChangedGraphId(var_name);
  808. }
  809. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  810. std::lock_guard<std::recursive_mutex> lock(mutex_);
  811. if (var_resource_ == nullptr) {
  812. GELOGW("VarManager has not been init.");
  813. return INTERNAL_ERROR;
  814. }
  815. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  816. }
  817. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  818. std::lock_guard<std::recursive_mutex> lock(mutex_);
  819. if (var_resource_ == nullptr) {
  820. GELOGW("VarManager has not been init.");
  821. return INTERNAL_ERROR;
  822. }
  823. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  824. }
  825. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  826. std::lock_guard<std::recursive_mutex> lock(mutex_);
  827. if (var_resource_ == nullptr) {
  828. GELOGW("VarManager has not been init.");
  829. return;
  830. }
  831. var_resource_->RemoveAllocatedGraphId(var_name);
  832. }
  833. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  834. std::lock_guard<std::recursive_mutex> lock(mutex_);
  835. if (var_resource_ == nullptr) {
  836. GELOGW("VarManager has not been inited.");
  837. return INTERNAL_ERROR;
  838. }
  839. auto new_variable_desc = var_resource_->GetAllVarDesc();
  840. if (new_variable_desc.size() == 0) {
  841. GELOGW("VarManager don't have variables.");
  842. return INTERNAL_ERROR;
  843. }
  844. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  845. auto trans_road = var_resource_->GetTransRoad(iter->first);
  846. if (trans_road == nullptr || trans_road->empty()) {
  847. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  848. all_variables[iter->first] = iter->second;
  849. continue;
  850. }
  851. // get origin trans info : the first trans node info
  852. auto origin_trans_node_info = trans_road->at(0);
  853. all_variables[iter->first] = origin_trans_node_info.input;
  854. }
  855. return SUCCESS;
  856. }
  857. VarManagerPool::~VarManagerPool() { Destory(); }
  858. VarManagerPool &VarManagerPool::Instance() {
  859. static VarManagerPool var_manager_pool;
  860. return var_manager_pool;
  861. }
  862. void VarManagerPool::Destory() noexcept {
  863. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  864. for (auto &it : var_manager_map_) {
  865. VarManager *var_manager = it.second;
  866. if (var_manager != nullptr) {
  867. var_manager->Destory();
  868. delete var_manager;
  869. var_manager = nullptr;
  870. }
  871. }
  872. var_manager_map_.clear();
  873. }
  874. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  875. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  876. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  877. auto it = var_manager_map_.find(session_id);
  878. if (it != var_manager_map_.end()) {
  879. GELOGD("VarManagerPool::GetVarManager");
  880. return it->second;
  881. }
  882. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  883. if (var_manager == nullptr) {
  884. REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id);
  885. GELOGE(INTERNAL_ERROR,
  886. "VarManager::Instance find session by "
  887. "session_id[%lu] failed.",
  888. session_id);
  889. static VarManager new_var_manager(0);
  890. return &new_var_manager;
  891. }
  892. var_manager_map_[session_id] = var_manager;
  893. return var_manager;
  894. }
  895. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  896. VarManager *var_manager = nullptr;
  897. {
  898. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  899. auto it = var_manager_map_.find(session_id);
  900. if (it != var_manager_map_.end()) {
  901. var_manager = it->second;
  902. var_manager_map_.erase(it);
  903. }
  904. }
  905. if (var_manager != nullptr) {
  906. var_manager->Destory();
  907. delete var_manager;
  908. var_manager = nullptr;
  909. }
  910. }
  911. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示