You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_var_data_utils.cc 28 kB

5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/trans_var_data_utils.h"
  17. #include "common/debug/log.h"
  18. #include "common/debug/memory_dumper.h"
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/types.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "common/thread_pool.h"
  27. #include <algorithm>
  28. namespace ge {
  29. namespace {
  30. class RtContextSwitchGuard {
  31. public:
  32. RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) {
  33. auto ret = rtCtxGetCurrent(&last_);
  34. if (ret != RT_ERROR_NONE) {
  35. REPORT_CALL_ERROR("E19999", "Call rtCtxGetCurrent failed, device_id:%u, ret:0x%X, when %s",
  36. device_id, ret, __FUNCTION__);
  37. GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret);
  38. return;
  39. }
  40. ret = rtCtxCreate(&current_, mode, static_cast<int32_t>(device_id));
  41. if (ret != RT_ERROR_NONE) {
  42. REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, device_id:%u, ret:0x%X, when %s",
  43. device_id, ret, __FUNCTION__);
  44. GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret);
  45. return;
  46. }
  47. ret = rtCtxSetCurrent(current_);
  48. if (ret != RT_ERROR_NONE) {
  49. REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, device_id:%u, ret:0x%X, when %s",
  50. device_id, ret, __FUNCTION__);
  51. GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id);
  52. return;
  53. }
  54. GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_);
  55. }
  56. ~RtContextSwitchGuard() {
  57. if (current_ != nullptr) {
  58. auto ret = rtCtxDestroy(current_);
  59. GELOGD("Destory current context %p result %d", current_, ret);
  60. }
  61. if (last_ != nullptr) {
  62. auto ret = rtCtxSetCurrent(last_);
  63. GELOGD("Recovery last context %p result %d.", last_, ret);
  64. }
  65. }
  66. private:
  67. rtContext_t last_;
  68. rtContext_t current_;
  69. };
  70. int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
  71. int64_t var_size = GetSizeByDataType(desc.GetDataType());
  72. if (var_size <= 0) {
  73. REPORT_INNER_ERROR("E19999", "Data type:%s in desc, it's size:%ld < 0, check invalid when %s",
  74. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size, __FUNCTION__);
  75. GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s",
  76. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
  77. return -1;
  78. }
  79. auto shape = desc.GetShape();
  80. auto dim_num = shape.GetDimNum();
  81. for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) {
  82. var_size *= shape.GetDim(dim_index);
  83. }
  84. return var_size;
  85. }
  86. Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) {
  87. GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length);
  88. auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()),
  89. trans_result.length, RT_MEMCPY_HOST_TO_DEVICE);
  90. if (ret != RT_ERROR_NONE) {
  91. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, op:%s(%s), size:%lu, ret:0x%X, when %s", var->GetName().c_str(),
  92. var->GetType().c_str(), trans_result.length, ret, __FUNCTION__);
  93. GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length);
  94. return RT_FAILED;
  95. }
  96. return SUCCESS;
  97. }
  98. Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data,
  99. const GeTensorDesc &input_desc) {
  100. uint8_t *var_logic = nullptr;
  101. GE_CHECK_NOTNULL(var);
  102. auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic);
  103. if (ret != SUCCESS) {
  104. GELOGE(INTERNAL_ERROR,
  105. "Failed to copy var %s from device, can not find it"
  106. " from var manager %u",
  107. var->GetName().c_str(), ret);
  108. return INTERNAL_ERROR;
  109. }
  110. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  111. if (var_addr == nullptr) {
  112. REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, op:%s(%s), session_id:%lu, when %s",
  113. RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id, __FUNCTION__);
  114. GELOGE(INTERNAL_ERROR,
  115. "Failed to copy var %s from device, cant not get "
  116. "var addr from logic addr %p",
  117. var->GetName().c_str(), var_logic);
  118. return INTERNAL_ERROR;
  119. }
  120. int64_t var_size_bytes = CalcVarSizeInBytes(input_desc);
  121. if (var_size_bytes <= 0) {
  122. return INTERNAL_ERROR;
  123. }
  124. std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]);
  125. if (var_host == nullptr) {
  126. REPORT_CALL_ERROR("E19999", "New host memory failed, size:%ld, op:%s(%s), session_id:%lu, when %s",
  127. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, __FUNCTION__);
  128. GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes);
  129. return OUT_OF_MEMORY;
  130. }
  131. ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr),
  132. var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST);
  133. if (ret != RT_ERROR_NONE) {
  134. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X when %s",
  135. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret, __FUNCTION__);
  136. GELOGE(RT_FAILED,
  137. "Failed to copy var memory from device, var %s, size %ld,"
  138. " rt-error-code %u",
  139. var->GetName().c_str(), var_size_bytes, ret);
  140. return RT_FAILED;
  141. }
  142. GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes);
  143. var_data.swap(var_host);
  144. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  145. return SUCCESS;
  146. }
  147. Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) {
  148. formats::TransResult result_last_time{};
  149. bool use_init_data = true;
  150. for (const auto &trans_info : trans_road) {
  151. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  152. GELOGD("Skip to trans variable data on the reshape/reformat node");
  153. continue;
  154. }
  155. uint8_t *src_data = nullptr;
  156. if (use_init_data) {
  157. src_data = var_data;
  158. use_init_data = false;
  159. } else {
  160. src_data = result_last_time.data.get();
  161. }
  162. formats::TransResult tmp_result{};
  163. if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  164. auto src_format = trans_info.input.GetFormat();
  165. auto src_shape = trans_info.input.GetShape().GetDims();
  166. auto dst_format = trans_info.output.GetFormat();
  167. auto dst_shape = trans_info.output.GetShape().GetDims();
  168. auto data_type = trans_info.input.GetDataType();
  169. GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s",
  170. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  171. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  172. TypeUtils::DataTypeToSerialString(data_type).c_str());
  173. auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result);
  174. if (ret != SUCCESS) {
  175. REPORT_CALL_ERROR("E19999", "Trans format from %s to %s, shape %s to %s failed, data type:%s, ret:%u, when %s",
  176. TypeUtils::FormatToSerialString(src_format).c_str(),
  177. TypeUtils::FormatToSerialString(dst_format).c_str(),
  178. formats::ShapeToString(src_shape).c_str(),
  179. formats::ShapeToString(dst_shape).c_str(),
  180. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret, __FUNCTION__);
  181. GELOGE(INTERNAL_ERROR,
  182. "Failed to trans format from %s to %s, shape %s to %s, "
  183. "data type %s error code %u",
  184. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  185. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  186. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
  187. return ret;
  188. }
  189. } else if (trans_info.node_type == CAST) {
  190. auto input_shape = trans_info.input.GetShape();
  191. auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize();
  192. auto src_data_type = trans_info.input.GetDataType();
  193. auto dst_data_type = trans_info.output.GetDataType();
  194. GELOGD("Trans data type from %s to %s, input shape %s, data size %ld",
  195. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  196. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  197. src_data_size);
  198. auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type},
  199. tmp_result);
  200. if (ret != SUCCESS) {
  201. REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, input shape %s, data size %ld, ret:%u, "
  202. "when %s", TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  203. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(),
  204. formats::ShapeToString(input_shape).c_str(), src_data_size, ret, __FUNCTION__);
  205. GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u",
  206. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  207. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  208. src_data_size, ret);
  209. return ret;
  210. }
  211. } else {
  212. REPORT_INNER_ERROR("E19999", "Trans var data failed, the trans type %s does not supported, check invalid when %s",
  213. trans_info.node_type.c_str(), __FUNCTION__);
  214. GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported",
  215. trans_info.node_type.c_str());
  216. return UNSUPPORTED;
  217. }
  218. result_last_time = tmp_result;
  219. }
  220. result = result_last_time;
  221. return SUCCESS;
  222. }
  223. /// re-alloc var memory on device using var-manager
  224. /// free origin var memory(var manager does not support now)
  225. /// @param session_id
  226. /// @param var
  227. /// @param var_size_bytes
  228. /// @param var_device
  229. /// @return
  230. Status ReAssignVarAddr(uint64_t session_id,
  231. const std::string &var_name,
  232. const GeTensorDesc &tensor_desc,
  233. void **var_device) {
  234. uint8_t *var_logic = nullptr;
  235. Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic);
  236. if (ret != SUCCESS) {
  237. GELOGE(INTERNAL_ERROR,
  238. "Failed to get var %s device addr, can not find it"
  239. " from var manager %u",
  240. var_name.c_str(), ret);
  241. return INTERNAL_ERROR;
  242. }
  243. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  244. if (var_addr == nullptr) {
  245. REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, var_name:%s, session_id:%lu, when %s",
  246. RT_MEMORY_HBM, var_name.c_str(), session_id, __FUNCTION__);
  247. GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str());
  248. return INTERNAL_ERROR;
  249. }
  250. *var_device = var_addr;
  251. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  252. return SUCCESS;
  253. }
  254. Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) {
  255. // do not need to do anything if only all reshape/reformat node on the trans_road
  256. GE_CHECK_NOTNULL(var);
  257. bool need_trans = false;
  258. for (auto &road : trans_road) {
  259. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  260. need_trans = true;
  261. break;
  262. }
  263. }
  264. if (!need_trans) {
  265. return SUCCESS;
  266. }
  267. // Sync var data from device
  268. std::unique_ptr<uint8_t[]> var_data;
  269. if (trans_road.empty()) {
  270. REPORT_INNER_ERROR("E19999", "Param trans_road is empty, session_id:%lu, check invalid when %s",
  271. session_id, __FUNCTION__);
  272. GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty.");
  273. return INTERNAL_ERROR;
  274. }
  275. const GeTensorDesc &input_desc = trans_road.begin()->input;
  276. auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc);
  277. if (ret != SUCCESS) {
  278. return ret;
  279. }
  280. formats::TransResult trans_result{};
  281. ret = TransVarOnHost(var_data.get(), trans_road, trans_result);
  282. if (ret != SUCCESS) {
  283. GELOGE(ret, "Failed to trans var data on host, error code %u", ret);
  284. return ret;
  285. }
  286. void *var_device = nullptr;
  287. /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager
  288. /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the
  289. /// size of the converted variable. To complete the final solution, the dependency of the variable manager on
  290. /// TensorDesc needs to be removed. This change is large and needs to be performed step by step.
  291. ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device);
  292. if (ret != SUCCESS) {
  293. GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length);
  294. return ret;
  295. }
  296. // sync new data to device
  297. ret = CopyVarToDevice(var, trans_result, var_device);
  298. if (ret != SUCCESS) {
  299. GELOGE(ret, "Failed to send var data to device");
  300. return ret;
  301. }
  302. return SUCCESS;
  303. }
  304. Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) {
  305. GE_CHECK_NOTNULL(var_src);
  306. GE_CHECK_NOTNULL(var_src->GetOpDesc());
  307. GE_CHECK_NOTNULL(var_dst);
  308. GE_CHECK_NOTNULL(var_dst->GetOpDesc());
  309. auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize();
  310. auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType();
  311. auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType();
  312. GE_IF_BOOL_EXEC(
  313. src_data_datatype != dst_data_datatype,
  314. auto ret = formats::TransDataType(
  315. {var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
  316. if (ret != SUCCESS) {
  317. REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, data size %ld, ret:%u, "
  318. "when %s", TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
  319. TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(),
  320. src_data_shape_size, ret, __FUNCTION__);
  321. GELOGE(INTERNAL_ERROR, "trans var data on host failed");
  322. return ret;
  323. });
  324. return SUCCESS;
  325. }
  326. Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
  327. const NodePtr &var_dst,
  328. uint64_t session_id,
  329. uint32_t device_id) {
  330. /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is
  331. /// var_fp16(new).
  332. /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node.
  333. /// need copy value from var_fp32 to var_fp16.
  334. /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr]
  335. GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr,
  336. REPORT_INNER_ERROR("E19999", "Param var_src or var_dst is empty, session_id:%lu, device_id:%u, "
  337. "check invalid when %s", session_id, device_id, __FUNCTION__);
  338. GELOGE(FAILED, "node var is nullptr"); return FAILED);
  339. // src_node output_desc (fp32)
  340. GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0);
  341. auto src_data_type = output_desc.GetDataType();
  342. auto src_shape = output_desc.GetShape();
  343. auto src_format = output_desc.GetFormat();
  344. GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(),
  345. TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(),
  346. TypeUtils::DataTypeToSerialString(src_data_type).c_str());
  347. // dst_node output_desc (fp16)
  348. GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0);
  349. auto data_type = dst_tensor_desc.GetDataType();
  350. auto data_shape = dst_tensor_desc.GetShape();
  351. auto data_format = dst_tensor_desc.GetFormat();
  352. GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(),
  353. TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(),
  354. TypeUtils::DataTypeToSerialString(data_type).c_str());
  355. // Sync var data from device
  356. std::unique_ptr<uint8_t[]> var_src_data;
  357. RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id);
  358. // copy from src_node
  359. auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc);
  360. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret);
  361. // trans dtype
  362. formats::TransResult trans_result{};
  363. ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result);
  364. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret);
  365. // reset src value.
  366. void *var_device = nullptr;
  367. ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device);
  368. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret);
  369. // copy to device
  370. ret = CopyVarToDevice(var_dst, trans_result, var_device);
  371. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret);
  372. return SUCCESS;
  373. }
  374. } // namespace
  375. Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  376. uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) {
  377. GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. ");
  378. uint8_t *src_host_addr = nullptr;
  379. int64_t src_addr_size = 0;
  380. GE_MAKE_GUARD_RTMEM(src_host_addr);
  381. GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id));
  382. GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size);
  383. GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, "var data size is not equal broadcast ");
  384. GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  385. return SUCCESS;
  386. }
  387. Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
  388. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  389. GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "src addr is null. ");
  390. uint8_t *host_addr = nullptr;
  391. GE_MAKE_GUARD_RTMEM(host_addr);
  392. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size));
  393. GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST));
  394. GE_CHK_STATUS_RET(
  395. SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id));
  396. return SUCCESS;
  397. }
  398. Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  399. uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) {
  400. GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "get size from TensorDesc failed");
  401. uint8_t *src_addr = nullptr;
  402. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr));
  403. uint8_t *mem_addr =
  404. src_addr -
  405. static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  406. static_cast<int64_t>(
  407. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  408. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size));
  409. GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST));
  410. GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size);
  411. return SUCCESS;
  412. }
  413. Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
  414. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  415. uint8_t *dst_addr = nullptr;
  416. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr));
  417. uint8_t *mem_addr =
  418. dst_addr -
  419. static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  420. static_cast<int64_t>(
  421. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  422. GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  423. GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size);
  424. return SUCCESS;
  425. }
  426. Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
  427. uint64_t session_id,
  428. rtContext_t context,
  429. uint32_t graph_id,
  430. uint32_t thread_num) {
  431. ThreadPool executor(thread_num);
  432. std::vector<std::future<Status>> vector_future;
  433. for (auto &node : variable_nodes) {
  434. if (node == nullptr) {
  435. continue;
  436. }
  437. if (node->GetType() != VARIABLE) {
  438. continue;
  439. }
  440. std::future<Status> f = executor.commit(
  441. [](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id,
  442. const struct ErrorMessage::Context &error_context) -> Status {
  443. ErrorManager::GetInstance().SetErrorContext(error_context);
  444. rtError_t rt_ret = rtCtxSetCurrent(ctx);
  445. if (rt_ret != RT_ERROR_NONE) {
  446. REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, session_id:%lu, graph_id:%u, ret:0x%X, when %s",
  447. session_id, graph_id, rt_ret, __FUNCTION__);
  448. GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret);
  449. return RT_ERROR_TO_GE_STATUS(rt_ret);
  450. }
  451. uint32_t allocated_graph_id = 0;
  452. Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id);
  453. if (ret != SUCCESS) {
  454. REPORT_CALL_ERROR("E19999", "Get allocated GraphId failed, session_id:%lu, graph_id:%u, ret:0x%X, when %s",
  455. session_id, graph_id, ret, __FUNCTION__);
  456. GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(),
  457. graph_id);
  458. return INTERNAL_ERROR;
  459. }
  460. uint32_t changed_graph_id = 0;
  461. ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id);
  462. bool call_trans_var =
  463. (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id);
  464. if (call_trans_var) {
  465. GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  466. VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName());
  467. if (trans_road == nullptr) {
  468. GELOGI("The variable %s does not have any trans road", node->GetName().c_str());
  469. return SUCCESS;
  470. }
  471. ret = TransVarData(node, *trans_road, session_id);
  472. if (ret != SUCCESS) {
  473. GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  474. return INTERNAL_ERROR;
  475. }
  476. VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName());
  477. }
  478. return SUCCESS;
  479. },
  480. node, session_id, context, graph_id, ErrorManager::GetInstance().GetErrorContext());
  481. if (!f.valid()) {
  482. GELOGE(FAILED, "Future is invalid");
  483. return FAILED;
  484. }
  485. vector_future.push_back(std::move(f));
  486. }
  487. Status ret_status;
  488. for (size_t i = 0; i < vector_future.size(); ++i) {
  489. ret_status = vector_future[i].get();
  490. if (ret_status != SUCCESS) {
  491. GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i);
  492. return ret_status;
  493. }
  494. }
  495. return SUCCESS;
  496. }
  497. Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) {
  498. GELOGD("CopyVarData start: session_id:%lu.", session_id);
  499. if (compute_graph == nullptr) {
  500. REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, session_id:%lu, device_id:%u, check invalid when %s",
  501. session_id, device_id, __FUNCTION__);
  502. GELOGE(FAILED, "compute_graph is nullptr");
  503. return FAILED;
  504. }
  505. string cp_from_node;
  506. bool copy_value = false;
  507. for (auto &node : compute_graph->GetAllNodes()) {
  508. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue);
  509. GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node),
  510. GELOGI("Get original type of cp_from_node"));
  511. if (cp_from_node.length() != 0) {
  512. (void) ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value
  513. if (!copy_value) {
  514. auto src_node = compute_graph->FindNode(cp_from_node);
  515. GE_CHECK_NOTNULL(src_node);
  516. GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(),
  517. src_node->GetName().c_str());
  518. auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id);
  519. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED);
  520. // only copy once
  521. (void) ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value
  522. }
  523. }
  524. }
  525. return SUCCESS;
  526. }
  527. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示