You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_var_data_utils.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/trans_var_data_utils.h"
  17. #include "common/debug/log.h"
  18. #include "common/debug/memory_dumper.h"
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/types.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "common/thread_pool.h"
  27. #include <algorithm>
  28. namespace ge {
  29. namespace {
  30. class RtContextSwitchGuard {
  31. public:
  32. RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) {
  33. auto ret = rtCtxGetCurrent(&last_);
  34. if (ret != RT_ERROR_NONE) {
  35. GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret);
  36. return;
  37. }
  38. ret = rtCtxCreate(&current_, mode, static_cast<int32_t>(device_id));
  39. if (ret != RT_ERROR_NONE) {
  40. GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret);
  41. return;
  42. }
  43. ret = rtCtxSetCurrent(current_);
  44. if (ret != RT_ERROR_NONE) {
  45. GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id);
  46. return;
  47. }
  48. GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_);
  49. }
  50. ~RtContextSwitchGuard() {
  51. if (current_ != nullptr) {
  52. auto ret = rtCtxDestroy(current_);
  53. GELOGD("Destory current context %p result %d", current_, ret);
  54. }
  55. if (last_ != nullptr) {
  56. auto ret = rtCtxSetCurrent(last_);
  57. GELOGD("Recovery last context %p result %d.", last_, ret);
  58. }
  59. }
  60. private:
  61. rtContext_t last_;
  62. rtContext_t current_;
  63. };
  64. int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
  65. int64_t var_size = GetSizeByDataType(desc.GetDataType());
  66. if (var_size <= 0) {
  67. GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s",
  68. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
  69. return -1;
  70. }
  71. auto shape = desc.GetShape();
  72. auto dim_num = shape.GetDimNum();
  73. for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) {
  74. var_size *= shape.GetDim(dim_index);
  75. }
  76. return var_size;
  77. }
  78. Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) {
  79. GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length);
  80. auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()),
  81. trans_result.length, RT_MEMCPY_HOST_TO_DEVICE);
  82. if (ret != RT_ERROR_NONE) {
  83. GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length);
  84. return RT_FAILED;
  85. }
  86. return SUCCESS;
  87. }
  88. Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data,
  89. const GeTensorDesc &input_desc) {
  90. uint8_t *var_logic = nullptr;
  91. GE_CHECK_NOTNULL(var);
  92. auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic);
  93. if (ret != SUCCESS) {
  94. GELOGE(INTERNAL_ERROR,
  95. "Failed to copy var %s from device, can not find it"
  96. " from var manager %u",
  97. var->GetName().c_str(), ret);
  98. return INTERNAL_ERROR;
  99. }
  100. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  101. if (var_addr == nullptr) {
  102. GELOGE(INTERNAL_ERROR,
  103. "Failed to copy var %s from device, cant not get "
  104. "var addr from logic addr %p",
  105. var->GetName().c_str(), var_logic);
  106. return INTERNAL_ERROR;
  107. }
  108. int64_t var_size_bytes = CalcVarSizeInBytes(input_desc);
  109. if (var_size_bytes <= 0) {
  110. return INTERNAL_ERROR;
  111. }
  112. std::unique_ptr<uint8_t[]> var_host(new (std::nothrow) uint8_t[var_size_bytes]);
  113. if (var_host == nullptr) {
  114. GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes);
  115. return OUT_OF_MEMORY;
  116. }
  117. ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr),
  118. var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST);
  119. if (ret != RT_ERROR_NONE) {
  120. GELOGE(RT_FAILED,
  121. "Failed to copy var memory from device, var %s, size %ld,"
  122. " rt-error-code %u",
  123. var->GetName().c_str(), var_size_bytes, ret);
  124. return RT_FAILED;
  125. }
  126. GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes);
  127. var_data.swap(var_host);
  128. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  129. return SUCCESS;
  130. }
  131. Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) {
  132. formats::TransResult result_last_time{};
  133. bool use_init_data = true;
  134. for (const auto &trans_info : trans_road) {
  135. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  136. GELOGD("Skip to trans variable data on the reshape/reformat node");
  137. continue;
  138. }
  139. uint8_t *src_data = nullptr;
  140. if (use_init_data) {
  141. src_data = var_data;
  142. use_init_data = false;
  143. } else {
  144. src_data = result_last_time.data.get();
  145. }
  146. formats::TransResult tmp_result{};
  147. if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  148. auto src_format = trans_info.input.GetFormat();
  149. auto src_shape = trans_info.input.GetShape().GetDims();
  150. auto dst_format = trans_info.output.GetFormat();
  151. auto dst_shape = trans_info.output.GetShape().GetDims();
  152. auto data_type = trans_info.input.GetDataType();
  153. GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s",
  154. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  155. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  156. TypeUtils::DataTypeToSerialString(data_type).c_str());
  157. auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result);
  158. if (ret != SUCCESS) {
  159. GELOGE(INTERNAL_ERROR,
  160. "Failed to trans format from %s to %s, shape %s to %s, "
  161. "data type %s error code %u",
  162. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  163. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  164. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
  165. return ret;
  166. }
  167. } else if (trans_info.node_type == CAST) {
  168. auto input_shape = trans_info.input.GetShape();
  169. auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize();
  170. auto src_data_type = trans_info.input.GetDataType();
  171. auto dst_data_type = trans_info.output.GetDataType();
  172. GELOGD("Trans data type from %s to %s, input shape %s, data size %ld",
  173. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  174. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  175. src_data_size);
  176. auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type},
  177. tmp_result);
  178. if (ret != SUCCESS) {
  179. GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u",
  180. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  181. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  182. src_data_size, ret);
  183. return ret;
  184. }
  185. } else {
  186. GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported",
  187. trans_info.node_type.c_str());
  188. return UNSUPPORTED;
  189. }
  190. result_last_time = tmp_result;
  191. }
  192. result = result_last_time;
  193. return SUCCESS;
  194. }
  195. /// re-alloc var memory on device using var-manager
  196. /// free origin var memory(var manager does not support now)
  197. /// @param session_id
  198. /// @param var
  199. /// @param var_size_bytes
  200. /// @param var_device
  201. /// @return
  202. Status ReAssignVarAddr(uint64_t session_id, const std::string &var_name, const GeTensorDesc &tensor_desc,
  203. void **var_device) {
  204. uint8_t *var_logic = nullptr;
  205. Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic);
  206. if (ret != SUCCESS) {
  207. GELOGE(INTERNAL_ERROR,
  208. "Failed to get var %s device addr, can not find it"
  209. " from var manager %u",
  210. var_name.c_str(), ret);
  211. return INTERNAL_ERROR;
  212. }
  213. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  214. if (var_addr == nullptr) {
  215. GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str());
  216. return INTERNAL_ERROR;
  217. }
  218. *var_device = var_addr;
  219. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  220. return SUCCESS;
  221. }
  222. Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) {
  223. // do not need to do anything if only all reshape/reformat node on the trans_road
  224. GE_CHECK_NOTNULL(var);
  225. bool need_trans = false;
  226. for (auto &road : trans_road) {
  227. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  228. need_trans = true;
  229. break;
  230. }
  231. }
  232. if (!need_trans) {
  233. return SUCCESS;
  234. }
  235. // Sync var data from device
  236. std::unique_ptr<uint8_t[]> var_data;
  237. if (trans_road.empty()) {
  238. GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty.");
  239. return INTERNAL_ERROR;
  240. }
  241. const GeTensorDesc &input_desc = trans_road.begin()->input;
  242. auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc);
  243. if (ret != SUCCESS) {
  244. return ret;
  245. }
  246. formats::TransResult trans_result{};
  247. ret = TransVarOnHost(var_data.get(), trans_road, trans_result);
  248. if (ret != SUCCESS) {
  249. GELOGE(ret, "Failed to trans var data on host, error code %u", ret);
  250. return ret;
  251. }
  252. void *var_device = nullptr;
  253. /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager
  254. /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the
  255. /// size of the converted variable. To complete the final solution, the dependency of the variable manager on
  256. /// TensorDesc needs to be removed. This change is large and needs to be performed step by step.
  257. ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device);
  258. if (ret != SUCCESS) {
  259. GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length);
  260. return ret;
  261. }
  262. // sync new data to device
  263. ret = CopyVarToDevice(var, trans_result, var_device);
  264. if (ret != SUCCESS) {
  265. GELOGE(ret, "Failed to send var data to device");
  266. return ret;
  267. }
  268. return SUCCESS;
  269. }
  270. Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) {
  271. GE_CHECK_NOTNULL(var_src);
  272. GE_CHECK_NOTNULL(var_src->GetOpDesc());
  273. GE_CHECK_NOTNULL(var_dst);
  274. GE_CHECK_NOTNULL(var_dst->GetOpDesc());
  275. auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize();
  276. auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType();
  277. auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType();
  278. GE_IF_BOOL_EXEC(
  279. src_data_datatype != dst_data_datatype,
  280. auto ret = formats::TransDataType(
  281. {var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
  282. if (ret != SUCCESS) {
  283. GELOGE(INTERNAL_ERROR, "trans var data on host failed");
  284. return ret;
  285. });
  286. return SUCCESS;
  287. }
  288. Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst, uint64_t session_id,
  289. uint32_t device_id) {
  290. /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is
  291. /// var_fp16(new).
  292. /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node.
  293. /// need copy value from var_fp32 to var_fp16.
  294. /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr]
  295. GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED);
  296. // src_node output_desc (fp32)
  297. GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0);
  298. auto src_data_type = output_desc.GetDataType();
  299. auto src_shape = output_desc.GetShape();
  300. auto src_format = output_desc.GetFormat();
  301. GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(),
  302. TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(),
  303. TypeUtils::DataTypeToSerialString(src_data_type).c_str());
  304. // dst_node output_desc (fp16)
  305. GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0);
  306. auto data_type = dst_tensor_desc.GetDataType();
  307. auto data_shape = dst_tensor_desc.GetShape();
  308. auto data_format = dst_tensor_desc.GetFormat();
  309. GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(),
  310. TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(),
  311. TypeUtils::DataTypeToSerialString(data_type).c_str());
  312. // Sync var data from device
  313. std::unique_ptr<uint8_t[]> var_src_data;
  314. RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id);
  315. // copy from src_node
  316. auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc);
  317. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret);
  318. // trans dtype
  319. formats::TransResult trans_result{};
  320. ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result);
  321. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret);
  322. // reset src value.
  323. void *var_device = nullptr;
  324. ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device);
  325. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret);
  326. // copy to device
  327. ret = CopyVarToDevice(var_dst, trans_result, var_device);
  328. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret);
  329. return SUCCESS;
  330. }
  331. } // namespace
  332. Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  333. uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) {
  334. GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. ");
  335. uint8_t *src_host_addr = nullptr;
  336. int64_t src_addr_size = 0;
  337. GE_MAKE_GUARD_RTMEM(src_host_addr);
  338. GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id));
  339. GELOGI("src_addr_size: %u, dst_addr_size: %u", src_addr_size, dst_addr_size);
  340. GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, "var data size is not equal broadcast ");
  341. GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  342. return SUCCESS;
  343. }
  344. Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
  345. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  346. GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "src addr is null. ");
  347. uint8_t *host_addr = nullptr;
  348. GE_MAKE_GUARD_RTMEM(host_addr);
  349. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size));
  350. GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST));
  351. GE_CHK_STATUS_RET(
  352. SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id));
  353. return SUCCESS;
  354. }
  355. Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  356. uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) {
  357. GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "get size from TensorDesc failed");
  358. uint8_t *src_addr = nullptr;
  359. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr));
  360. uint8_t *mem_addr =
  361. src_addr -
  362. static_cast<int64_t>(reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  363. static_cast<int64_t>(
  364. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  365. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size));
  366. GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST));
  367. GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size);
  368. return SUCCESS;
  369. }
  370. Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
  371. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  372. uint8_t *dst_addr = nullptr;
  373. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr));
  374. uint8_t *mem_addr =
  375. dst_addr -
  376. static_cast<int64_t>(reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  377. static_cast<int64_t>(
  378. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  379. GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  380. GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size);
  381. return SUCCESS;
  382. }
  383. Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, uint64_t session_id,
  384. rtContext_t context, uint32_t graph_id, uint32_t thread_num) {
  385. ThreadPool executor(thread_num);
  386. std::vector<std::future<Status>> vector_future;
  387. for (auto &node : variable_nodes) {
  388. if (node == nullptr) {
  389. continue;
  390. }
  391. if (node->GetType() != VARIABLE) {
  392. continue;
  393. }
  394. std::future<Status> f = executor.commit(
  395. [](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id) -> Status {
  396. rtError_t rt_ret = rtCtxSetCurrent(ctx);
  397. if (rt_ret != RT_ERROR_NONE) {
  398. GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret);
  399. return RT_ERROR_TO_GE_STATUS(rt_ret);
  400. }
  401. uint32_t allocated_graph_id = 0;
  402. Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id);
  403. if (ret != SUCCESS) {
  404. GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(),
  405. graph_id);
  406. return INTERNAL_ERROR;
  407. }
  408. uint32_t changed_graph_id = 0;
  409. ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id);
  410. bool call_trans_var =
  411. (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id);
  412. if (call_trans_var) {
  413. GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  414. VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName());
  415. if (trans_road == nullptr) {
  416. GELOGI("The variable %s does not have any trans road", node->GetName().c_str());
  417. return SUCCESS;
  418. }
  419. ret = TransVarData(node, *trans_road, session_id);
  420. if (ret != SUCCESS) {
  421. GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  422. return INTERNAL_ERROR;
  423. }
  424. VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName());
  425. }
  426. return SUCCESS;
  427. },
  428. node, session_id, context, graph_id);
  429. if (!f.valid()) {
  430. GELOGE(FAILED, "Future is invalid");
  431. return FAILED;
  432. }
  433. vector_future.push_back(std::move(f));
  434. }
  435. Status ret_status;
  436. for (size_t i = 0; i < vector_future.size(); ++i) {
  437. ret_status = vector_future[i].get();
  438. if (ret_status != SUCCESS) {
  439. GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i);
  440. return ret_status;
  441. }
  442. }
  443. return SUCCESS;
  444. }
  445. Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) {
  446. GELOGI("CopyVarData start: session_id:%lu.", session_id);
  447. if (compute_graph == nullptr) {
  448. GELOGE(FAILED, "compute_graph is nullptr");
  449. return FAILED;
  450. }
  451. string cp_from_node;
  452. bool copy_value = false;
  453. for (auto &node : compute_graph->GetAllNodes()) {
  454. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue);
  455. GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node),
  456. GELOGI("Get original type of cp_from_node"));
  457. if (cp_from_node.length() != 0) {
  458. (void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value
  459. if (!copy_value) {
  460. auto src_node = compute_graph->FindNode(cp_from_node);
  461. GE_CHECK_NOTNULL(src_node);
  462. GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(),
  463. src_node->GetName().c_str());
  464. auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id);
  465. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED);
  466. // only copy once
  467. (void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value
  468. }
  469. }
  470. }
  471. return SUCCESS;
  472. }
  473. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示