You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 43 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/graph_mem_manager.h"
  19. #include "graph/manager/trans_var_data_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. using std::map;
  22. using std::string;
  23. using std::vector;
  24. namespace ge {
  25. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  26. VarResource::~VarResource() {
  27. var_offset_map_.clear();
  28. var_addr_mgr_map_.clear();
  29. cur_var_tensor_desc_map_.clear();
  30. var_broad_cast_info_.clear();
  31. }
  32. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  33. rtMemType_t &memory_type) {
  34. if (dev_ptr == nullptr) {
  35. REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, "
  36. "check invalid", var_name.c_str(), session_id_);
  37. GELOGE(FAILED, "[Check][Param] Param dev_ptr is nullptr, var_name:%s, session_id:%lu",
  38. var_name.c_str(), session_id_);
  39. return FAILED;
  40. }
  41. std::string var_key = VarKey(var_name, tensor_desc);
  42. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  43. auto iter = var_addr_mgr_map_.find(var_key);
  44. if (iter == var_addr_mgr_map_.end()) {
  45. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  46. "check invalid", var_key.c_str(), var_name.c_str(),
  47. session_id_);
  48. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  49. var_key.c_str(), var_name.c_str(), session_id_);
  50. return FAILED;
  51. }
  52. *dev_ptr = iter->second.address;
  53. memory_type = iter->second.memory_type;
  54. return SUCCESS;
  55. }
  56. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  57. var_addr_mgr_map = var_addr_mgr_map_;
  58. }
  59. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  60. rtMemType_t memory_type) {
  61. std::string var_key = VarKey(var_name, tensor_desc);
  62. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  63. if (var_addr_mgr_map_.count(var_key) == 0) {
  64. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  65. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  66. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  67. VarAddrMgr var_addr_mgr;
  68. var_addr_mgr.address = dev_ptr;
  69. var_addr_mgr.tensor_desc = tensor_desc;
  70. var_addr_mgr_map_[var_key] = var_addr_mgr;
  71. }
  72. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  73. }
  74. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  75. rtMemType_t memory_type) {
  76. std::string var_key = VarKey(var_name, tensor_desc);
  77. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  78. if (var_addr_mgr_map_.count(var_key) == 0) {
  79. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  80. if (memory_type != RT_MEMORY_RDMA_HBM) {
  81. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  82. }
  83. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  84. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  85. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  86. VarAddrMgr var_addr_mgr;
  87. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  88. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  89. var_addr_mgr.tensor_desc = tensor_desc;
  90. var_addr_mgr.memory_type = memory_type;
  91. var_addr_mgr_map_[var_key] = var_addr_mgr;
  92. var_offset_map_[logic_address] = memory_type;
  93. return SUCCESS;
  94. }
  95. REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  96. "check invalid", var_key.c_str(), var_name.c_str(),
  97. session_id_);
  98. GELOGE(FAILED, "[Check][Param] var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  99. var_key.c_str(), var_name.c_str(), session_id_);
  100. return FAILED;
  101. }
  102. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  103. std::string var_key = VarKey(var_name, tensor_desc);
  104. return var_addr_mgr_map_.count(var_key) != 0;
  105. }
  106. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  107. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  108. std::string var_key(var_name);
  109. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  110. .append("_")
  111. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  112. return var_key;
  113. }
  114. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  115. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  116. return FAILED;
  117. }
  118. tensor_desc = cur_var_tensor_desc_map_[var_name];
  119. return SUCCESS;
  120. }
  121. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  122. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  123. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  124. return SUCCESS;
  125. }
  126. if (op_desc == nullptr) {
  127. REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid",
  128. var_name.c_str(), session_id_);
  129. GELOGE(FAILED, "[Check][Param] input opdesc is nullptr, var_name:%s, session_id:%lu",
  130. var_name.c_str(), session_id_);
  131. return FAILED;
  132. }
  133. ge::GeTensorDesc curr_desc;
  134. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  135. if (ret != SUCCESS) {
  136. GELOGE(FAILED, "[Get][CurVarDesc] fail, var_name:%s, session_id:%lu", var_name.c_str(), session_id_);
  137. return FAILED;
  138. }
  139. std::string key = VarKey(var_name, curr_desc);
  140. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  141. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  142. cur_var_tensor_desc_map_[var_name] = curr_desc;
  143. auto iter = var_addr_mgr_map_.find(key);
  144. if (iter == var_addr_mgr_map_.end()) {
  145. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), "
  146. "check invalid", key.c_str(), var_name.c_str(),
  147. session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  148. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s)",
  149. key.c_str(), var_name.c_str(), session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  150. return FAILED;
  151. }
  152. auto val = iter->second;
  153. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  154. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  155. var_addr_mgr_map_.erase(iter);
  156. key = VarKey(var_name, curr_desc);
  157. var_addr_mgr_map_[key] = val;
  158. return SUCCESS;
  159. }
  160. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  161. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  162. }
  163. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  164. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  165. return FAILED;
  166. }
  167. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  168. return SUCCESS;
  169. }
  170. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  171. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  172. GE_CHECK_NOTNULL(base_ptr);
  173. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  174. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  175. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  176. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  177. var_broadcast_info.input_size, session_id_);
  178. }
  179. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  180. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  181. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  182. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  183. // subgraph base_ptr could be nullptr, task it as base 0
  184. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  185. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  186. var_tensor_desc, session_id_);
  187. }
  188. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  189. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  190. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  191. }
  192. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  193. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  194. if (var_offset_map_.count(offset) > 0) {
  195. return var_offset_map_[offset];
  196. }
  197. return RT_MEMORY_RESERVED;
  198. }
  199. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  200. auto iter = var_to_trans_road_.find(var_name);
  201. if (iter == var_to_trans_road_.end()) {
  202. return nullptr;
  203. } else {
  204. return &(iter->second);
  205. }
  206. }
  207. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  208. auto iter = var_names_to_changed_graph_id_.find(var_name);
  209. if (iter == var_names_to_changed_graph_id_.end()) {
  210. return FAILED;
  211. } else {
  212. graph_id = iter->second;
  213. return SUCCESS;
  214. }
  215. }
  216. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  217. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  218. if (iter == var_names_to_allocated_graph_id_.end()) {
  219. return FAILED;
  220. } else {
  221. graph_id = iter->second;
  222. return SUCCESS;
  223. }
  224. }
  225. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  226. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  227. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  228. return SUCCESS;
  229. }
  230. var_names_to_allocated_graph_id_[var_name] = graph_id;
  231. return SUCCESS;
  232. }
  233. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  234. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  235. switch (mem_type) {
  236. case RT_MEMORY_HBM:
  237. return new (std::nothrow) HbmMemResource();
  238. case RT_MEMORY_RDMA_HBM:
  239. return new (std::nothrow) RdmaMemResource();
  240. default:
  241. return nullptr;
  242. }
  243. }
  244. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  245. size_t &mem_offset) {
  246. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  247. uint64_t real_size = size;
  248. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  249. if (total_size_ < var_mem_size_) {
  250. REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid"
  251. "", total_size_, var_mem_size_, size, var_name.c_str());
  252. GELOGE(PARAM_INVALID, "[Check][Param] total_size_:%lu is smaller than var_mem_size_:%lu, var_name:%s",
  253. total_size_, var_mem_size_, var_name.c_str());
  254. return PARAM_INVALID;
  255. }
  256. uint64_t free_size = total_size_ - var_mem_size_;
  257. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  258. REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid",
  259. free_size, size, var_name.c_str());
  260. GELOGE(PARAM_INVALID, "[Check][Param] Out of memory: current var size[%lu] exceeds total var size[%lu]",
  261. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  262. return PARAM_INVALID;
  263. }
  264. mem_offset = var_mem_size_;
  265. // offset for next, align 512 BYTE
  266. size = size + kSessionMemAlignSize;
  267. var_mem_size_ = var_mem_size_ + size;
  268. // align 512 BYTE
  269. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  270. GELOGI(
  271. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  272. "offset to [%zu] size[%lu] realsize[%lu].",
  273. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  274. return SUCCESS;
  275. }
  276. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  277. uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
  278. if (buffer == nullptr) {
  279. REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s",
  280. size, var_name.c_str());
  281. GELOGE(MEMALLOC_FAILED, "[Malloc][RdmaMemory] for node %s failed, size = %lu", var_name.c_str(), size);
  282. return MEMALLOC_FAILED;
  283. }
  284. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  285. var_mem_size_ += size;
  286. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  287. session_id, var_name.c_str(), 0, buffer, size);
  288. return SUCCESS;
  289. }
  290. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  291. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  292. VarManager::VarManager(uint64_t session_id)
  293. : version_(SessionVersion::OTHER_VERSION),
  294. session_id_(session_id),
  295. device_id_(0),
  296. job_id_(0),
  297. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  298. var_mem_max_size_(kMemoryVarManagerMallocSize),
  299. var_mem_logic_base_(kMemoryVarLogicBase),
  300. use_max_mem_size_(kUseMaxMemorySize) {}
  301. VarManager *VarManager::Instance(uint64_t session_id) {
  302. GELOGD("VarManager::Instance, session id = %lu", session_id);
  303. return VarManagerPool::Instance().GetVarManager(session_id);
  304. }
  305. void VarManager::Destory() {
  306. std::lock_guard<std::recursive_mutex> lock(mutex_);
  307. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  308. version_ = SessionVersion::OTHER_VERSION;
  309. device_id_ = 0;
  310. session_id_ = 0;
  311. for (auto &memory_resource : mem_resource_map_) {
  312. if (memory_resource.second != nullptr) {
  313. delete memory_resource.second;
  314. memory_resource.second = nullptr;
  315. }
  316. }
  317. mem_resource_map_.clear();
  318. }
  319. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  320. const uint64_t &job_id) {
  321. std::lock_guard<std::recursive_mutex> lock(mutex_);
  322. GELOGI("VarManager::Init, session id = %lu.", session_id);
  323. if (var_resource_ == nullptr) {
  324. version_ = version;
  325. device_id_ = device_id;
  326. session_id_ = session_id;
  327. job_id_ = job_id;
  328. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  329. if (var_resource_ == nullptr) {
  330. GELOGW("VarManager init failed session id = %lu.", session_id);
  331. return ge::INTERNAL_ERROR;
  332. }
  333. } else {
  334. GELOGW("VarManager::has been inited, session id = %lu.", session_id);
  335. }
  336. return SUCCESS;
  337. }
  338. const uint64_t &VarManager::SessionId() const {
  339. std::lock_guard<std::recursive_mutex> lock(mutex_);
  340. return session_id_;
  341. }
  342. const uint32_t &VarManager::DeviceId() const {
  343. std::lock_guard<std::recursive_mutex> lock(mutex_);
  344. return device_id_;
  345. }
  346. const uint64_t &VarManager::JobId() const {
  347. std::lock_guard<std::recursive_mutex> lock(mutex_);
  348. return job_id_;
  349. }
  350. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  351. rtMemType_t memory_type) {
  352. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  353. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  354. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  355. std::lock_guard<std::recursive_mutex> lock(mutex_);
  356. if (var_resource_ == nullptr) {
  357. GELOGW("VarManager has not been init.");
  358. return ge::INTERNAL_ERROR;
  359. }
  360. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  361. return ge::SUCCESS;
  362. }
  363. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  364. rtMemType_t memory_type) {
  365. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  366. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  367. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  368. std::lock_guard<std::recursive_mutex> lock(mutex_);
  369. if (var_resource_ == nullptr) {
  370. GELOGW("VarManager has not been init.");
  371. return ge::INTERNAL_ERROR;
  372. }
  373. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  374. return ge::SUCCESS;
  375. }
  376. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  377. rtMemType_t &memory_type) {
  378. std::lock_guard<std::recursive_mutex> lock(mutex_);
  379. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  380. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  381. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  382. if (var_resource_ == nullptr) {
  383. GELOGW("VarManager has not been init.");
  384. return ge::INTERNAL_ERROR;
  385. }
  386. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  387. if (ret != SUCCESS) {
  388. GELOGW("GetVarAddr fail.");
  389. return ge::INTERNAL_ERROR;
  390. }
  391. return SUCCESS;
  392. }
  393. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  394. std::lock_guard<std::recursive_mutex> lock(mutex_);
  395. rtMemType_t memory_type = RT_MEMORY_HBM;
  396. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  397. }
  398. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  399. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  400. }
  401. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  402. std::lock_guard<std::recursive_mutex> lock(mutex_);
  403. MemResource *mem_resource = nullptr;
  404. auto iter = mem_resource_map_.find(memory_type);
  405. if (iter == mem_resource_map_.end()) {
  406. return 0;
  407. } else {
  408. mem_resource = iter->second;
  409. }
  410. if (mem_resource == nullptr) {
  411. REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu",
  412. memory_type, session_id_);
  413. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%d, session_id:%lu",
  414. memory_type, session_id_);
  415. return 0;
  416. }
  417. return mem_resource->GetVarMemSize();
  418. }
  419. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  420. std::lock_guard<std::recursive_mutex> lock(mutex_);
  421. MemResource *mem_resource = nullptr;
  422. auto iter = mem_resource_map_.find(memory_type);
  423. if (iter == mem_resource_map_.end()) {
  424. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  425. if (mem_resource == nullptr) {
  426. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  427. memory_type, session_id_);
  428. GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu",
  429. memory_type, session_id_);
  430. return ge::INTERNAL_ERROR;
  431. } else {
  432. mem_resource_map_[memory_type] = mem_resource;
  433. }
  434. } else {
  435. mem_resource = iter->second;
  436. }
  437. if (mem_resource == nullptr) {
  438. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  439. memory_type, session_id_);
  440. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu",
  441. memory_type, session_id_);
  442. return FAILED;
  443. }
  444. mem_resource->UpdateVarMemSize(mem_size);
  445. return SUCCESS;
  446. }
  447. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  448. rtMemType_t memory_type) {
  449. std::lock_guard<std::recursive_mutex> lock(mutex_);
  450. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  451. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  452. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  453. int64_t tensor_desc_size = 0;
  454. size_t mem_offset = 0;
  455. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  456. if (result != ge::SUCCESS) {
  457. REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu",
  458. var_name.c_str(), memory_type, session_id_);
  459. GELOGE(result, "[Get][Size] from tensor fail, var_name:%s, memory_type:%u, session_id:%lu",
  460. var_name.c_str(), memory_type, session_id_);
  461. return result;
  462. }
  463. MemResource *mem_resource = nullptr;
  464. auto it = mem_resource_map_.find(memory_type);
  465. if (it == mem_resource_map_.end()) {
  466. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  467. if (mem_resource == nullptr) {
  468. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  469. memory_type, session_id_);
  470. GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu.",
  471. memory_type, session_id_);
  472. return ge::INTERNAL_ERROR;
  473. } else {
  474. mem_resource_map_[memory_type] = mem_resource;
  475. }
  476. } else {
  477. mem_resource = it->second;
  478. }
  479. if (mem_resource == nullptr) {
  480. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  481. memory_type, session_id_);
  482. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu.",
  483. memory_type, session_id_);
  484. return ge::INTERNAL_ERROR;
  485. }
  486. if (var_resource_ == nullptr) {
  487. REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, "
  488. "check invalid", memory_type, session_id_);
  489. GELOGW("VarManager has not been init.");
  490. return ge::INTERNAL_ERROR;
  491. }
  492. ge::GeTensorDesc cur_tensor_desc;
  493. int64_t cur_tensor_desc_size = 0;
  494. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  495. // reuse old format variable memory
  496. if (result == SUCCESS) {
  497. result = var_resource_->GetVarAddr(
  498. var_name, cur_tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  499. if (result == SUCCESS) {
  500. result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size);
  501. GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size,
  502. cur_tensor_desc_size, mem_offset);
  503. }
  504. }
  505. bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size);
  506. if (can_not_reuse_old_memory) {
  507. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  508. if (result != SUCCESS) {
  509. GELOGE(ge::INTERNAL_ERROR, "[Assign][VarMem] by offset failed, session_id:%lu.", session_id_);
  510. return ge::INTERNAL_ERROR;
  511. }
  512. result = var_resource_->SaveVarAddr(
  513. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  514. if (result != SUCCESS) {
  515. GELOGE(ge::INTERNAL_ERROR, "[Save][VarAddr] by offset failed, memory type:%u, session_id:%lu.",
  516. memory_type, session_id_);
  517. return ge::INTERNAL_ERROR;
  518. }
  519. }
  520. // old not exist only save new tensor
  521. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  522. if (result != SUCCESS) {
  523. var_resource_->SetVarAddr(var_name, tensor_desc,
  524. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  525. return SUCCESS;
  526. }
  527. bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  528. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  529. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims();
  530. if (format_changed) {
  531. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  532. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  533. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  534. tensor_desc.GetShape().GetDims().size(),
  535. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  536. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  537. cur_tensor_desc.GetShape().GetDims().size());
  538. var_resource_->SetVarAddr(var_name, tensor_desc,
  539. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  540. }
  541. return SUCCESS;
  542. }
  543. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  544. std::lock_guard<std::recursive_mutex> lock(mutex_);
  545. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  546. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  547. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  548. if (var_resource_ == nullptr) {
  549. GELOGW("VarManager has not been init.");
  550. return false;
  551. }
  552. return var_resource_->IsVarExist(var_name, tensor_desc);
  553. }
  554. bool VarManager::IsVarExist(const std::string &var_name) {
  555. std::lock_guard<std::recursive_mutex> lock(mutex_);
  556. if (var_resource_ == nullptr) {
  557. GELOGW("VarManager has not been init.");
  558. return false;
  559. }
  560. return var_resource_->IsVarExist(var_name);
  561. }
  562. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  563. uint8_t *base_ptr) {
  564. std::lock_guard<std::recursive_mutex> lock(mutex_);
  565. if (var_resource_ == nullptr) {
  566. GELOGW("VarManager has not been init.");
  567. return ge::INTERNAL_ERROR;
  568. }
  569. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  570. }
  571. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  572. std::lock_guard<std::recursive_mutex> lock(mutex_);
  573. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  574. if (var_resource_ == nullptr) {
  575. GELOGW("VarManager has not been init.");
  576. return ge::INTERNAL_ERROR;
  577. }
  578. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  579. }
  580. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  581. std::lock_guard<std::recursive_mutex> lock(mutex_);
  582. GELOGI(
  583. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  584. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  585. "output_size = %lu",
  586. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  587. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  588. broad_cast_info.output_size);
  589. if (var_resource_ == nullptr) {
  590. GELOGW("VarManager has not been init.");
  591. return ge::INTERNAL_ERROR;
  592. }
  593. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  594. return SUCCESS;
  595. }
  596. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  597. std::lock_guard<std::recursive_mutex> lock(mutex_);
  598. if (var_resource_ == nullptr) {
  599. GELOGW("VarManager has not been init.");
  600. return ge::INTERNAL_ERROR;
  601. }
  602. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  603. }
  604. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  605. std::lock_guard<std::recursive_mutex> lock(mutex_);
  606. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  607. if (var_resource_ == nullptr) {
  608. REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid",
  609. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  610. session_id_);
  611. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] VarManager has not been init, op:%s(%s), session_id:%lu",
  612. op_desc->GetName().c_str(), op_desc->GetType().c_str(), session_id_);
  613. return ge::INTERNAL_ERROR;
  614. }
  615. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  616. }
  617. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  618. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  619. std::lock_guard<std::recursive_mutex> lock(mutex_);
  620. if (var_resource_ == nullptr) {
  621. GELOGW("VarManager has not been init.");
  622. return ge::INTERNAL_ERROR;
  623. }
  624. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  625. }
  626. bool VarManager::IsVarAddr(const int64_t &offset) {
  627. std::lock_guard<std::recursive_mutex> lock(mutex_);
  628. if (var_resource_ == nullptr) {
  629. GELOGD("VarManager has not been init.");
  630. return false;
  631. }
  632. return var_resource_->IsVarAddr(offset);
  633. }
  634. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  635. std::lock_guard<std::recursive_mutex> lock(mutex_);
  636. if (var_resource_ == nullptr) {
  637. GELOGW("VarManager has not been init.");
  638. return RT_MEMORY_RESERVED;
  639. }
  640. return var_resource_->GetVarMemType(offset);
  641. }
  642. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  643. std::lock_guard<std::recursive_mutex> lock(mutex_);
  644. uint8_t *var_mem_base = nullptr;
  645. string memory_key = std::to_string(session_id_);
  646. // malloc variable memory
  647. size_t var_memory_size = memory_size;
  648. // align 512 BYTE
  649. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  650. const string purpose("variables and constant op memory in training network.");
  651. var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size);
  652. if (var_mem_base == nullptr) {
  653. GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s",
  654. var_memory_size, memory_key.c_str());
  655. return ge::INTERNAL_ERROR;
  656. }
  657. return SUCCESS;
  658. }
  659. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  660. std::lock_guard<std::recursive_mutex> lock(mutex_);
  661. if (memory_type == RT_MEMORY_RDMA_HBM) {
  662. return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
  663. }
  664. string memory_key = std::to_string(session_id_);
  665. return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key);
  666. }
  667. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  668. std::lock_guard<std::recursive_mutex> lock(mutex_);
  669. if (memory_type == RT_MEMORY_RDMA_HBM) {
  670. return logic_addr;
  671. }
  672. string mem_key = std::to_string(session_id_);
  673. uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key);
  674. if (mem_base == nullptr) {
  675. return nullptr;
  676. }
  677. uint8_t *mem_addr =
  678. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  679. return mem_addr;
  680. }
  681. ge::Status VarManager::FreeVarMemory() {
  682. std::lock_guard<std::recursive_mutex> lock(mutex_);
  683. string memory_key = std::to_string(SessionId());
  684. return MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(memory_key);
  685. }
  686. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  687. std::lock_guard<std::recursive_mutex> lock(mutex_);
  688. if (var_resource_ == nullptr) {
  689. GELOGW("VarManager has not been init.");
  690. return ge::INTERNAL_ERROR;
  691. }
  692. return var_resource_->SetTransRoad(var_name, trans_road);
  693. }
  694. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  695. std::lock_guard<std::recursive_mutex> lock(mutex_);
  696. if (var_resource_ == nullptr) {
  697. GELOGW("VarManager has not been init.");
  698. return nullptr;
  699. }
  700. return var_resource_->GetTransRoad(var_name);
  701. }
  702. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  703. std::lock_guard<std::recursive_mutex> lock(mutex_);
  704. if (var_resource_ == nullptr) {
  705. GELOGW("VarManager has not been init.");
  706. return INTERNAL_ERROR;
  707. }
  708. return var_resource_->SetChangedGraphId(var_name, graph_id);
  709. }
  710. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  711. std::lock_guard<std::recursive_mutex> lock(mutex_);
  712. if (var_resource_ == nullptr) {
  713. GELOGW("VarManager has not been init.");
  714. return INTERNAL_ERROR;
  715. }
  716. return var_resource_->GetChangedGraphId(var_name, graph_id);
  717. }
  718. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  719. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  720. if (it == options.end()) {
  721. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  722. } else {
  723. string graph_memory_manager_malloc_max_size = it->second;
  724. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  725. if (ret != SUCCESS) {
  726. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  727. return ge::GE_GRAPH_OPTIONS_INVALID;
  728. }
  729. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  730. }
  731. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  732. if (it == options.end()) {
  733. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  734. } else {
  735. string memory_var_manager_malloc_size = it->second;
  736. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  737. if (ret != SUCCESS) {
  738. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  739. return ge::GE_GRAPH_OPTIONS_INVALID;
  740. }
  741. }
  742. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  743. if (var_mem_logic_base_ > kMaxMemorySize) {
  744. REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  745. var_mem_logic_base_, kMaxMemorySize, session_id_);
  746. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kMemoryVarLogicBase:%zu can not exceed "
  747. "max memory size:%zu, session_id:%lu.", var_mem_logic_base_, kMaxMemorySize, session_id_);
  748. return ge::GE_GRAPH_OPTIONS_INVALID;
  749. }
  750. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  751. if (use_max_mem_size_ > kMaxMemorySize) {
  752. REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  753. use_max_mem_size_, kMaxMemorySize, session_id_);
  754. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kUseMaxMemorySize:%zu can not exceed "
  755. "max memory size:%zu, session_id:%lu.", use_max_mem_size_, kMaxMemorySize, session_id_);
  756. return ge::GE_GRAPH_OPTIONS_INVALID;
  757. }
  758. GELOGI("Set memory malloc size successfully");
  759. return SUCCESS;
  760. }
  761. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  762. if (memory_size.empty()) {
  763. REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid",
  764. session_id_);
  765. GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Check][Param] Memory malloc size input is empty, session_id:%lu.", session_id_);
  766. return GE_GRAPH_OPTIONS_INVALID;
  767. }
  768. // split string by '*'
  769. vector<string> splits;
  770. std::istringstream str(memory_size);
  771. string str_split;
  772. while (getline(str, str_split, '*')) {
  773. splits.emplace_back(str_split);
  774. }
  775. result = 1;
  776. for (string split : splits) {
  777. // Trim
  778. auto it = split.find_first_not_of(" ");
  779. if (it != string::npos) {
  780. split.erase(0, it);
  781. }
  782. it = split.find_last_not_of(" ");
  783. if (it != string::npos) {
  784. split.erase(it + 1);
  785. }
  786. for (char c : split) {
  787. if (!isdigit(c)) {
  788. REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid",
  789. memory_size.c_str(), session_id_);
  790. GELOGE(GE_GRAPH_OPTIONS_INVALID,
  791. "[Check][Param] Memory malloc size:%s input contains non digit, session_id:%lu.",
  792. memory_size.c_str(), session_id_);
  793. return GE_GRAPH_OPTIONS_INVALID;
  794. }
  795. }
  796. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  797. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  798. REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, "
  799. "check invalid", memory_size.c_str(),
  800. session_id_);
  801. GELOGE(FAILED, "[Check][Param] Param memory_size:%s will overflow after multi all, session_id:%lu",
  802. memory_size.c_str(), session_id_);
  803. return FAILED);
  804. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  805. REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, "
  806. "check invalid", memory_size.c_str(), kMaxMemorySize,
  807. session_id_);
  808. GELOGE(FAILED, "[Check][Param] Input memory size can not exceed max memory size:%zu, session_id:%lu.",
  809. kMaxMemorySize, session_id_);
  810. return FAILED;
  811. }
  812. result *= static_cast<size_t>(num);
  813. }
  814. return SUCCESS;
  815. }
  816. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  817. std::lock_guard<std::recursive_mutex> lock(mutex_);
  818. if (var_resource_ == nullptr) {
  819. GELOGW("VarManager has not been init.");
  820. return;
  821. }
  822. var_resource_->RemoveChangedGraphId(var_name);
  823. }
  824. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  825. std::lock_guard<std::recursive_mutex> lock(mutex_);
  826. if (var_resource_ == nullptr) {
  827. GELOGW("VarManager has not been init.");
  828. return INTERNAL_ERROR;
  829. }
  830. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  831. }
  832. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  833. std::lock_guard<std::recursive_mutex> lock(mutex_);
  834. if (var_resource_ == nullptr) {
  835. GELOGW("VarManager has not been init.");
  836. return INTERNAL_ERROR;
  837. }
  838. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  839. }
  840. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  841. std::lock_guard<std::recursive_mutex> lock(mutex_);
  842. if (var_resource_ == nullptr) {
  843. GELOGW("VarManager has not been init.");
  844. return;
  845. }
  846. var_resource_->RemoveAllocatedGraphId(var_name);
  847. }
  848. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  849. std::lock_guard<std::recursive_mutex> lock(mutex_);
  850. if (var_resource_ == nullptr) {
  851. GELOGW("VarManager has not been inited.");
  852. return INTERNAL_ERROR;
  853. }
  854. auto new_variable_desc = var_resource_->GetAllVarDesc();
  855. if (new_variable_desc.size() == 0) {
  856. GELOGW("VarManager don't have variables.");
  857. return INTERNAL_ERROR;
  858. }
  859. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  860. auto trans_road = var_resource_->GetTransRoad(iter->first);
  861. if (trans_road == nullptr || trans_road->empty()) {
  862. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  863. all_variables[iter->first] = iter->second;
  864. continue;
  865. }
  866. // get origin trans info : the first trans node info
  867. auto origin_trans_node_info = trans_road->at(0);
  868. all_variables[iter->first] = origin_trans_node_info.input;
  869. }
  870. return SUCCESS;
  871. }
  872. VarManagerPool::~VarManagerPool() { Destory(); }
  873. VarManagerPool &VarManagerPool::Instance() {
  874. static VarManagerPool var_manager_pool;
  875. return var_manager_pool;
  876. }
  877. void VarManagerPool::Destory() noexcept {
  878. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  879. for (auto &it : var_manager_map_) {
  880. VarManager *var_manager = it.second;
  881. if (var_manager != nullptr) {
  882. var_manager->Destory();
  883. delete var_manager;
  884. var_manager = nullptr;
  885. }
  886. }
  887. var_manager_map_.clear();
  888. }
  889. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  890. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  891. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  892. auto it = var_manager_map_.find(session_id);
  893. if (it != var_manager_map_.end()) {
  894. GELOGD("VarManagerPool::GetVarManager");
  895. return it->second;
  896. }
  897. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  898. if (var_manager == nullptr) {
  899. REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id);
  900. GELOGE(INTERNAL_ERROR, "[New][VarManager] fail, session_id:%lu", session_id);
  901. static VarManager new_var_manager(0);
  902. return &new_var_manager;
  903. }
  904. var_manager_map_[session_id] = var_manager;
  905. return var_manager;
  906. }
  907. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  908. VarManager *var_manager = nullptr;
  909. {
  910. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  911. auto it = var_manager_map_.find(session_id);
  912. if (it != var_manager_map_.end()) {
  913. var_manager = it->second;
  914. var_manager_map_.erase(it);
  915. }
  916. }
  917. if (var_manager != nullptr) {
  918. var_manager->Destory();
  919. delete var_manager;
  920. var_manager = nullptr;
  921. }
  922. }
  923. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示