You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 35 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include <utility>
  18. #include "common/l2_cache_optimize.h"
  19. #include "common/types.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "framework/common/debug/log.h"
  22. #include "ge/ge_api_types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/manager/graph_mem_allocator.h"
  25. #include "graph/manager/trans_var_data_utils.h"
  26. #include "graph/utils/attr_utils.h"
  27. #include "graph/utils/type_utils.h"
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  33. VarResource::~VarResource() {
  34. var_offset_set_.clear();
  35. var_addr_mgr_map_.clear();
  36. cur_var_tensor_desc_map_.clear();
  37. var_broad_cast_info_.clear();
  38. }
  39. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  40. rtMemType_t &memory_type) {
  41. if (dev_ptr == nullptr) {
  42. GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!");
  43. return FAILED;
  44. }
  45. std::string var_key = VarKey(var_name, tensor_desc);
  46. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  47. auto iter = var_addr_mgr_map_.find(var_key);
  48. if (iter == var_addr_mgr_map_.end()) {
  49. GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str());
  50. return FAILED;
  51. }
  52. *dev_ptr = iter->second.address;
  53. memory_type = iter->second.memory_type;
  54. return SUCCESS;
  55. }
  56. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  57. var_addr_mgr_map = var_addr_mgr_map_;
  58. }
  59. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  60. rtMemType_t memory_type) {
  61. std::string var_key = VarKey(var_name, tensor_desc);
  62. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  63. if (var_addr_mgr_map_.count(var_key) == 0) {
  64. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  65. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  66. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  67. VarAddrMgr var_addr_mgr;
  68. var_addr_mgr.address = dev_ptr;
  69. var_addr_mgr.tensor_desc = tensor_desc;
  70. var_addr_mgr_map_[var_key] = var_addr_mgr;
  71. }
  72. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  73. }
  74. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  75. rtMemType_t memory_type) {
  76. std::string var_key = VarKey(var_name, tensor_desc);
  77. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  78. if (var_addr_mgr_map_.count(var_key) == 0) {
  79. uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() +
  80. static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  81. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  82. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  83. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  84. VarAddrMgr var_addr_mgr;
  85. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  86. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  87. var_addr_mgr.tensor_desc = tensor_desc;
  88. var_addr_mgr.memory_type = memory_type;
  89. var_addr_mgr_map_[var_key] = var_addr_mgr;
  90. var_offset_set_.insert(logic_address);
  91. return SUCCESS;
  92. }
  93. GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str());
  94. return FAILED;
  95. }
  96. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  97. std::string var_key = VarKey(var_name, tensor_desc);
  98. return var_addr_mgr_map_.count(var_key) != 0;
  99. }
  100. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  101. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  102. std::string var_key(var_name);
  103. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  104. .append("_")
  105. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  106. return var_key;
  107. }
  108. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  109. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  110. return FAILED;
  111. }
  112. tensor_desc = cur_var_tensor_desc_map_[var_name];
  113. return SUCCESS;
  114. }
  115. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  116. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  117. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  118. return SUCCESS;
  119. }
  120. if (op_desc == nullptr) {
  121. GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!");
  122. return FAILED;
  123. }
  124. ge::GeTensorDesc curr_desc;
  125. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  126. if (ret != SUCCESS) {
  127. GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!");
  128. return FAILED;
  129. }
  130. std::string key = VarKey(var_name, curr_desc);
  131. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  132. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  133. cur_var_tensor_desc_map_[var_name] = curr_desc;
  134. auto iter = var_addr_mgr_map_.find(key);
  135. if (iter == var_addr_mgr_map_.end()) {
  136. GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str());
  137. return FAILED;
  138. }
  139. auto val = iter->second;
  140. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  141. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  142. var_addr_mgr_map_.erase(iter);
  143. key = VarKey(var_name, curr_desc);
  144. var_addr_mgr_map_[key] = val;
  145. return SUCCESS;
  146. }
  147. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  148. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  149. }
  150. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  151. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  152. return FAILED;
  153. }
  154. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  155. return SUCCESS;
  156. }
  157. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  158. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  159. GE_CHECK_NOTNULL(base_ptr);
  160. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  161. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  162. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  163. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  164. var_broadcast_info.input_size, session_id_);
  165. }
  166. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  167. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  168. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  169. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  170. // subgraph base_ptr could be nullptr, task it as base 0
  171. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  172. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  173. var_tensor_desc, session_id_);
  174. }
  175. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  176. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  177. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  178. }
  179. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_set_.count(offset) > 0; }
  180. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  181. auto iter = var_to_trans_road_.find(var_name);
  182. if (iter == var_to_trans_road_.end()) {
  183. return nullptr;
  184. } else {
  185. return &(iter->second);
  186. }
  187. }
  188. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  189. auto iter = var_names_to_changed_graph_id_.find(var_name);
  190. if (iter == var_names_to_changed_graph_id_.end()) {
  191. return FAILED;
  192. } else {
  193. graph_id = iter->second;
  194. return SUCCESS;
  195. }
  196. }
  197. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  198. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  199. if (iter == var_names_to_allocated_graph_id_.end()) {
  200. return FAILED;
  201. } else {
  202. graph_id = iter->second;
  203. return SUCCESS;
  204. }
  205. }
  206. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  207. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  208. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  209. return SUCCESS;
  210. }
  211. var_names_to_allocated_graph_id_[var_name] = graph_id;
  212. return SUCCESS;
  213. }
  214. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  215. Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) {
  216. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  217. uint64_t real_size = size;
  218. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  219. if (total_size_ < var_mem_size_) {
  220. GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_);
  221. return PARAM_INVALID;
  222. }
  223. uint64_t free_size = total_size_ - var_mem_size_;
  224. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  225. GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
  226. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  227. return PARAM_INVALID;
  228. }
  229. mem_offset = var_mem_size_;
  230. // offset for next, align 512 BYTE
  231. size = size + kSessionMemAlignSize;
  232. var_mem_size_ = var_mem_size_ + size;
  233. // align 512 BYTE
  234. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  235. GELOGI(
  236. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  237. "offset to [%zu] size[%lu] realsize[%lu].",
  238. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  239. return SUCCESS;
  240. }
  241. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  242. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  243. VarManager::VarManager(uint64_t session_id)
  244. : version_(SessionVersion::OTHER_VERSION),
  245. session_id_(session_id),
  246. device_id_(0),
  247. job_id_(0),
  248. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  249. var_mem_max_size_(kMemoryVarManagerMallocSize),
  250. var_mem_logic_base_(kMemoryVarLogicBase),
  251. use_max_mem_size_(kUseMaxMemorySize) {}
  252. VarManager *VarManager::Instance(uint64_t session_id) {
  253. GELOGD("VarManager::Instance, session id = %lu", session_id);
  254. return VarManagerPool::Instance().GetVarManager(session_id);
  255. }
  256. void VarManager::Destory() {
  257. std::lock_guard<std::recursive_mutex> lock(mutex_);
  258. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  259. version_ = SessionVersion::OTHER_VERSION;
  260. device_id_ = 0;
  261. session_id_ = 0;
  262. for (auto &memory_resource : mem_resource_map_) {
  263. if (memory_resource.second != nullptr) {
  264. delete memory_resource.second;
  265. memory_resource.second = nullptr;
  266. }
  267. }
  268. mem_resource_map_.clear();
  269. }
  270. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  271. const uint64_t &job_id) {
  272. std::lock_guard<std::recursive_mutex> lock(mutex_);
  273. GELOGI("VarManager::Init, session id = %lu.", session_id);
  274. version_ = version;
  275. device_id_ = device_id;
  276. session_id_ = session_id;
  277. job_id_ = job_id;
  278. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  279. if (var_resource_ == nullptr) {
  280. GELOGW("VarManager has not been init.");
  281. return ge::INTERNAL_ERROR;
  282. }
  283. return SUCCESS;
  284. }
  285. const uint64_t &VarManager::SessionId() const {
  286. std::lock_guard<std::recursive_mutex> lock(mutex_);
  287. return session_id_;
  288. }
  289. const uint32_t &VarManager::DeviceId() const {
  290. std::lock_guard<std::recursive_mutex> lock(mutex_);
  291. return device_id_;
  292. }
  293. const uint64_t &VarManager::JobId() const {
  294. std::lock_guard<std::recursive_mutex> lock(mutex_);
  295. return job_id_;
  296. }
  297. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  298. rtMemType_t memory_type) {
  299. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  300. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  301. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  302. std::lock_guard<std::recursive_mutex> lock(mutex_);
  303. if (var_resource_ == nullptr) {
  304. GELOGW("VarManager has not been init.");
  305. return ge::INTERNAL_ERROR;
  306. }
  307. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  308. return ge::SUCCESS;
  309. }
  310. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  311. rtMemType_t memory_type) {
  312. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  313. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  314. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  315. std::lock_guard<std::recursive_mutex> lock(mutex_);
  316. if (var_resource_ == nullptr) {
  317. GELOGW("VarManager has not been init.");
  318. return ge::INTERNAL_ERROR;
  319. }
  320. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  321. return ge::SUCCESS;
  322. }
  323. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  324. rtMemType_t &memory_type) {
  325. std::lock_guard<std::recursive_mutex> lock(mutex_);
  326. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  327. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  328. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  329. if (var_resource_ == nullptr) {
  330. GELOGW("VarManager has not been init.");
  331. return ge::INTERNAL_ERROR;
  332. }
  333. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  334. if (ret != SUCCESS) {
  335. GELOGW("GetVarAddr fail.");
  336. return ge::INTERNAL_ERROR;
  337. }
  338. return SUCCESS;
  339. }
  340. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  341. std::lock_guard<std::recursive_mutex> lock(mutex_);
  342. rtMemType_t memory_type = RT_MEMORY_HBM;
  343. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  344. }
  345. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  346. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  347. }
  348. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  349. std::lock_guard<std::recursive_mutex> lock(mutex_);
  350. MemResource *mem_resource = nullptr;
  351. auto iter = mem_resource_map_.find(memory_type);
  352. if (iter == mem_resource_map_.end()) {
  353. return 0;
  354. } else {
  355. mem_resource = iter->second;
  356. }
  357. if (mem_resource == nullptr) {
  358. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  359. return 0;
  360. }
  361. return mem_resource->GetVarMemSize();
  362. }
  363. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  364. std::lock_guard<std::recursive_mutex> lock(mutex_);
  365. MemResource *mem_resource = nullptr;
  366. auto iter = mem_resource_map_.find(memory_type);
  367. if (iter == mem_resource_map_.end()) {
  368. mem_resource = new (std::nothrow) MemResource();
  369. if (mem_resource == nullptr) {
  370. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  371. return ge::INTERNAL_ERROR;
  372. } else {
  373. mem_resource_map_[memory_type] = mem_resource;
  374. }
  375. } else {
  376. mem_resource = iter->second;
  377. }
  378. if (mem_resource == nullptr) {
  379. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  380. return FAILED;
  381. }
  382. mem_resource->UpdateVarMemSize(mem_size);
  383. return SUCCESS;
  384. }
  385. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  386. rtMemType_t memory_type) {
  387. std::lock_guard<std::recursive_mutex> lock(mutex_);
  388. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  389. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  390. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  391. int64_t tensor_desc_size = 0;
  392. size_t mem_offset = 0;
  393. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  394. if (result != ge::SUCCESS) {
  395. GELOGE(result, "get size from TensorDesc failed");
  396. return result;
  397. }
  398. MemResource *mem_resource = nullptr;
  399. auto it = mem_resource_map_.find(memory_type);
  400. if (it == mem_resource_map_.end()) {
  401. mem_resource = new (std::nothrow) MemResource();
  402. if (mem_resource == nullptr) {
  403. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  404. return ge::INTERNAL_ERROR;
  405. } else {
  406. mem_resource_map_[memory_type] = mem_resource;
  407. }
  408. } else {
  409. mem_resource = it->second;
  410. }
  411. if (mem_resource == nullptr) {
  412. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type);
  413. return ge::INTERNAL_ERROR;
  414. }
  415. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  416. if (result != SUCCESS) {
  417. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  418. return ge::INTERNAL_ERROR;
  419. }
  420. if (var_resource_ == nullptr) {
  421. GELOGW("VarManager has not been init.");
  422. return ge::INTERNAL_ERROR;
  423. }
  424. result = var_resource_->SaveVarAddr(
  425. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  426. if (result != SUCCESS) {
  427. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  428. return ge::INTERNAL_ERROR;
  429. }
  430. result = var_resource_->GetVarAddr(
  431. var_name, tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  432. if (result != SUCCESS) {
  433. GELOGE(ge::INTERNAL_ERROR, "GetVarAddr by offset failed.");
  434. return ge::INTERNAL_ERROR;
  435. }
  436. ge::GeTensorDesc cur_tensor_desc;
  437. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  438. if (result != SUCCESS) {
  439. var_resource_->SetVarAddr(var_name, tensor_desc,
  440. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  441. return SUCCESS;
  442. }
  443. if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  444. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  445. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) {
  446. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  447. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  448. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  449. tensor_desc.GetShape().GetDims().size(),
  450. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  451. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  452. cur_tensor_desc.GetShape().GetDims().size());
  453. var_resource_->SetVarAddr(var_name, tensor_desc,
  454. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  455. }
  456. return SUCCESS;
  457. }
  458. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  459. std::lock_guard<std::recursive_mutex> lock(mutex_);
  460. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  461. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  462. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  463. if (var_resource_ == nullptr) {
  464. GELOGW("VarManager has not been init.");
  465. return false;
  466. }
  467. return var_resource_->IsVarExist(var_name, tensor_desc);
  468. }
  469. bool VarManager::IsVarExist(const std::string &var_name) {
  470. std::lock_guard<std::recursive_mutex> lock(mutex_);
  471. if (var_resource_ == nullptr) {
  472. GELOGW("VarManager has not been init.");
  473. return false;
  474. }
  475. return var_resource_->IsVarExist(var_name);
  476. }
  477. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  478. uint8_t *base_ptr) {
  479. std::lock_guard<std::recursive_mutex> lock(mutex_);
  480. if (var_resource_ == nullptr) {
  481. GELOGW("VarManager has not been init.");
  482. return ge::INTERNAL_ERROR;
  483. }
  484. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  485. }
  486. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  487. std::lock_guard<std::recursive_mutex> lock(mutex_);
  488. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  489. if (var_resource_ == nullptr) {
  490. GELOGW("VarManager has not been init.");
  491. return ge::INTERNAL_ERROR;
  492. }
  493. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  494. }
  495. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  496. std::lock_guard<std::recursive_mutex> lock(mutex_);
  497. GELOGI(
  498. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  499. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  500. "output_size = %lu",
  501. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  502. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  503. broad_cast_info.output_size);
  504. if (var_resource_ == nullptr) {
  505. GELOGW("VarManager has not been init.");
  506. return ge::INTERNAL_ERROR;
  507. }
  508. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  509. return SUCCESS;
  510. }
  511. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  512. std::lock_guard<std::recursive_mutex> lock(mutex_);
  513. if (var_resource_ == nullptr) {
  514. GELOGW("VarManager has not been init.");
  515. return ge::INTERNAL_ERROR;
  516. }
  517. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  518. }
  519. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  520. std::lock_guard<std::recursive_mutex> lock(mutex_);
  521. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  522. if (var_resource_ == nullptr) {
  523. GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init.");
  524. return ge::INTERNAL_ERROR;
  525. }
  526. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  527. }
  528. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  529. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  530. std::lock_guard<std::recursive_mutex> lock(mutex_);
  531. if (var_resource_ == nullptr) {
  532. GELOGW("VarManager has not been init.");
  533. return ge::INTERNAL_ERROR;
  534. }
  535. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  536. }
  537. bool VarManager::IsVarAddr(const int64_t &offset) {
  538. std::lock_guard<std::recursive_mutex> lock(mutex_);
  539. if (var_resource_ == nullptr) {
  540. GELOGD("VarManager has not been init.");
  541. return false;
  542. }
  543. return var_resource_->IsVarAddr(offset);
  544. }
  545. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  546. std::lock_guard<std::recursive_mutex> lock(mutex_);
  547. uint8_t *var_mem_base = nullptr;
  548. string memory_key = std::to_string(session_id_);
  549. // malloc variable memory
  550. size_t var_memory_size = memory_size;
  551. // align 512 BYTE
  552. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  553. const string purpose("variables and constant op memory in training network.");
  554. var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size);
  555. if (var_mem_base == nullptr) {
  556. GELOGE(ge::INTERNAL_ERROR,
  557. "VarManager::MallocVarMemory failed "
  558. "session_id = %s",
  559. memory_key.c_str());
  560. return ge::INTERNAL_ERROR;
  561. }
  562. return SUCCESS;
  563. }
  564. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  565. std::lock_guard<std::recursive_mutex> lock(mutex_);
  566. string memory_key = std::to_string(session_id_);
  567. return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key);
  568. }
  569. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  570. std::lock_guard<std::recursive_mutex> lock(mutex_);
  571. string mem_key = std::to_string(session_id_);
  572. uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key);
  573. if (mem_base == nullptr) {
  574. return nullptr;
  575. }
  576. uint8_t *mem_addr =
  577. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  578. return mem_addr;
  579. }
  580. ge::Status VarManager::FreeVarMemory() {
  581. std::lock_guard<std::recursive_mutex> lock(mutex_);
  582. string memory_key = std::to_string(SessionId());
  583. return MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key);
  584. }
  585. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  586. std::lock_guard<std::recursive_mutex> lock(mutex_);
  587. if (var_resource_ == nullptr) {
  588. GELOGW("VarManager has not been init.");
  589. return ge::INTERNAL_ERROR;
  590. }
  591. return var_resource_->SetTransRoad(var_name, trans_road);
  592. }
  593. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  594. std::lock_guard<std::recursive_mutex> lock(mutex_);
  595. if (var_resource_ == nullptr) {
  596. GELOGW("VarManager has not been init.");
  597. return nullptr;
  598. }
  599. return var_resource_->GetTransRoad(var_name);
  600. }
  601. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  602. std::lock_guard<std::recursive_mutex> lock(mutex_);
  603. if (var_resource_ == nullptr) {
  604. GELOGW("VarManager has not been init.");
  605. return INTERNAL_ERROR;
  606. }
  607. return var_resource_->SetChangedGraphId(var_name, graph_id);
  608. }
  609. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  610. std::lock_guard<std::recursive_mutex> lock(mutex_);
  611. if (var_resource_ == nullptr) {
  612. GELOGW("VarManager has not been init.");
  613. return INTERNAL_ERROR;
  614. }
  615. return var_resource_->GetChangedGraphId(var_name, graph_id);
  616. }
  617. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  618. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  619. if (it == options.end()) {
  620. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  621. } else {
  622. string graph_memory_manager_malloc_max_size = it->second;
  623. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  624. if (ret != SUCCESS) {
  625. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed.");
  626. return ge::GE_GRAPH_OPTIONS_INVALID;
  627. }
  628. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  629. }
  630. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  631. if (it == options.end()) {
  632. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  633. } else {
  634. string memory_var_manager_malloc_size = it->second;
  635. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  636. if (ret != SUCCESS) {
  637. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed.");
  638. return ge::GE_GRAPH_OPTIONS_INVALID;
  639. }
  640. }
  641. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  642. if (var_mem_logic_base_ > kMaxMemorySize) {
  643. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.",
  644. var_mem_logic_base_, kMaxMemorySize);
  645. return ge::GE_GRAPH_OPTIONS_INVALID;
  646. }
  647. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  648. if (use_max_mem_size_ > kMaxMemorySize) {
  649. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.",
  650. use_max_mem_size_, kMaxMemorySize);
  651. return ge::GE_GRAPH_OPTIONS_INVALID;
  652. }
  653. GELOGI("Set memory malloc size successfully");
  654. return SUCCESS;
  655. }
  656. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  657. if (memory_size.empty()) {
  658. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty.");
  659. return GE_GRAPH_OPTIONS_INVALID;
  660. }
  661. // split string by '*'
  662. vector<string> splits;
  663. std::istringstream str(memory_size);
  664. string str_split;
  665. while (getline(str, str_split, '*')) {
  666. splits.emplace_back(str_split);
  667. }
  668. result = 1;
  669. for (string split : splits) {
  670. // Trim
  671. auto it = split.find_first_not_of(" ");
  672. if (it != string::npos) {
  673. split.erase(0, it);
  674. }
  675. it = split.find_last_not_of(" ");
  676. if (it != string::npos) {
  677. split.erase(it + 1);
  678. }
  679. for (char c : split) {
  680. if (!isdigit(c)) {
  681. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit.");
  682. return GE_GRAPH_OPTIONS_INVALID;
  683. }
  684. }
  685. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  686. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  687. GELOGE(FAILED, "Input memory size is out of range.");
  688. return FAILED);
  689. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  690. GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize);
  691. return FAILED;
  692. }
  693. result *= static_cast<size_t>(num);
  694. }
  695. return SUCCESS;
  696. }
  697. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  698. std::lock_guard<std::recursive_mutex> lock(mutex_);
  699. if (var_resource_ == nullptr) {
  700. GELOGW("VarManager has not been init.");
  701. return;
  702. }
  703. var_resource_->RemoveChangedGraphId(var_name);
  704. }
  705. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  706. std::lock_guard<std::recursive_mutex> lock(mutex_);
  707. if (var_resource_ == nullptr) {
  708. GELOGW("VarManager has not been init.");
  709. return INTERNAL_ERROR;
  710. }
  711. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  712. }
  713. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  714. std::lock_guard<std::recursive_mutex> lock(mutex_);
  715. if (var_resource_ == nullptr) {
  716. GELOGW("VarManager has not been init.");
  717. return INTERNAL_ERROR;
  718. }
  719. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  720. }
  721. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  722. std::lock_guard<std::recursive_mutex> lock(mutex_);
  723. if (var_resource_ == nullptr) {
  724. GELOGW("VarManager has not been init.");
  725. return;
  726. }
  727. var_resource_->RemoveAllocatedGraphId(var_name);
  728. }
  729. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  730. std::lock_guard<std::recursive_mutex> lock(mutex_);
  731. if (var_resource_ == nullptr) {
  732. GELOGW("VarManager has not been inited.");
  733. return INTERNAL_ERROR;
  734. }
  735. auto new_variable_desc = var_resource_->GetAllVarDesc();
  736. if (new_variable_desc.size() == 0) {
  737. GELOGW("VarManager don't have variables.");
  738. return INTERNAL_ERROR;
  739. }
  740. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  741. auto trans_road = var_resource_->GetTransRoad(iter->first);
  742. if (trans_road == nullptr || trans_road->empty()) {
  743. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  744. all_variables[iter->first] = iter->second;
  745. continue;
  746. }
  747. // get origin trans info : the first trans node info
  748. auto origin_trans_node_info = trans_road->at(0);
  749. all_variables[iter->first] = origin_trans_node_info.input;
  750. }
  751. return SUCCESS;
  752. }
  753. VarManagerPool::~VarManagerPool() { Destory(); }
  754. VarManagerPool &VarManagerPool::Instance() {
  755. static VarManagerPool var_manager_pool;
  756. return var_manager_pool;
  757. }
  758. void VarManagerPool::Destory() noexcept {
  759. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  760. for (auto &it : var_manager_map_) {
  761. VarManager *var_manager = it.second;
  762. if (var_manager != nullptr) {
  763. var_manager->Destory();
  764. delete var_manager;
  765. var_manager = nullptr;
  766. }
  767. }
  768. var_manager_map_.clear();
  769. }
  770. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  771. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  772. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  773. auto it = var_manager_map_.find(session_id);
  774. if (it != var_manager_map_.end()) {
  775. GELOGD("VarManagerPool::GetVarManager");
  776. return it->second;
  777. }
  778. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  779. if (var_manager == nullptr) {
  780. GELOGE(INTERNAL_ERROR,
  781. "VarManager::Instance find session by "
  782. "session_id[%lu] failed.",
  783. session_id);
  784. static VarManager new_var_manager(0);
  785. return &new_var_manager;
  786. }
  787. var_manager_map_[session_id] = var_manager;
  788. return var_manager;
  789. }
  790. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  791. VarManager *var_manager = nullptr;
  792. {
  793. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  794. auto it = var_manager_map_.find(session_id);
  795. if (it != var_manager_map_.end()) {
  796. var_manager = it->second;
  797. var_manager_map_.erase(it);
  798. }
  799. }
  800. if (var_manager != nullptr) {
  801. var_manager->Destory();
  802. delete var_manager;
  803. var_manager = nullptr;
  804. }
  805. }
  806. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示