You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

aicpu_constant_folding_pass.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/aicpu_constant_folding_pass.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "common/debug/log.h"
  20. #include "common/ge/ge_util.h"
  21. #include "common/types.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/attr_utils.h"
  25. #include "graph/utils/node_utils.h"
  26. #include "graph/utils/op_desc_utils.h"
  27. #include "graph/utils/type_utils.h"
  28. #include "init/gelib.h"
  29. namespace {
  30. const char *const kKernelLibName = "aicpu_tf_kernel";
  31. const char *const kNotSupported = "0";
  32. const uint64_t kReleaseFlag = 1;
  33. const uint64_t kOpsFlag = 1;
  34. const uint64_t kDouble = 2;
  35. } // namespace
  36. namespace ge {
  37. Status AicpuConstantFoldingPass::Run(ge::NodePtr &node) {
  38. GE_CHECK_NOTNULL(node);
  39. GELOGD("Start aicpu constant folding on node [%s]", node->GetName().c_str());
  40. if (IsSkipFold(node)) {
  41. return SUCCESS;
  42. }
  43. vector<ConstGeTensorPtr> weight_vec;
  44. bool flag = CheckInput(node, weight_vec);
  45. if (!flag) {
  46. return SUCCESS;
  47. }
  48. OpDescPtr node_desc = node->GetOpDesc(); // checked before
  49. vector<DataPtrInfo> data_vec;
  50. vector<AddrAndType> input_addrs;
  51. vector<uint64_t> output_addrs;
  52. Status ret = GetInputAddrs(weight_vec, input_addrs);
  53. if (ret != SUCCESS) {
  54. ReleaseMemory(input_addrs, output_addrs, data_vec);
  55. return SUCCESS;
  56. }
  57. ret = GetOutputAddrs(node_desc, output_addrs);
  58. if (ret != SUCCESS) {
  59. ReleaseMemory(input_addrs, output_addrs, data_vec);
  60. return SUCCESS;
  61. }
  62. ret = LaunchSingleOpRunTask(node, input_addrs, output_addrs);
  63. if (ret != SUCCESS) {
  64. ReleaseMemory(input_addrs, output_addrs, data_vec);
  65. return SUCCESS;
  66. }
  67. GELOGI("[Node:%s] Launch singleOpRunTask success", node->GetName().c_str());
  68. vector<uint64_t> data_infos;
  69. ret = GenerateDataPtrInfo(output_addrs, data_vec, data_infos);
  70. if (ret != SUCCESS) {
  71. ReleaseMemory(input_addrs, output_addrs, data_vec);
  72. return SUCCESS;
  73. }
  74. GELOGI("[Node:%s] Generate dataPtrInfo success", node->GetName().c_str());
  75. ret = LaunchMemCopyTask(data_infos);
  76. if (ret != SUCCESS) {
  77. ReleaseMemory(input_addrs, output_addrs, data_vec);
  78. return SUCCESS;
  79. }
  80. GELOGI("[Node:%s] Launch memCopyTask success", node->GetName().c_str());
  81. vector<GeTensorPtr> outputs;
  82. ret = GenerateGeTensor(node_desc, data_vec, outputs);
  83. if (ret != SUCCESS) {
  84. ReleaseMemory(input_addrs, output_addrs, data_vec);
  85. return SUCCESS;
  86. }
  87. ReleaseMemory(input_addrs, output_addrs, data_vec);
  88. GELOGI("[Node:%s] Generate geTensor success", node->GetName().c_str());
  89. return Folding(node, outputs);
  90. }
  91. bool AicpuConstantFoldingPass::CheckInput(const NodePtr &node, vector<ConstGeTensorPtr> &weight_vec) {
  92. OpDescPtr node_desc = node->GetOpDesc();
  93. if (node_desc == nullptr) {
  94. GELOGW("Opdesc of %s is null", node->GetName().c_str());
  95. return false;
  96. }
  97. DataType data_type = node_desc->GetOutputDesc(0).GetDataType();
  98. Format format = node_desc->GetOutputDesc(0).GetFormat();
  99. GELOGD("Current [node:%s, type:%s] info: format: %s, datatype:%s", node->GetName().c_str(), node->GetType().c_str(),
  100. TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str());
  101. auto input_nodes = OpDescUtils::GetConstInputNode(*node);
  102. if (input_nodes.empty() || input_nodes.size() != node_desc->GetInputsSize()) {
  103. GELOGD("Const input nodes size is %zu, and nodeDesc inputsSize is %zu, skip fold.", input_nodes.size(),
  104. node_desc->GetInputsSize());
  105. return false;
  106. }
  107. weight_vec = OpDescUtils::GetInputData(input_nodes);
  108. return true;
  109. }
  110. Status AicpuConstantFoldingPass::GetInputAddrs(const vector<ConstGeTensorPtr> &weight_vec,
  111. vector<AddrAndType> &input_addrs) {
  112. if (weight_vec.empty()) {
  113. GELOGE(FAILED, "Weight is null");
  114. return FAILED;
  115. }
  116. for (const ConstGeTensorPtr &weight : weight_vec) {
  117. void *input_addr = nullptr;
  118. GE_CHK_RT_RET(rtMalloc(&input_addr, weight->GetData().size(), RT_MEMORY_HBM));
  119. rtError_t rt_ret = rtMemcpy(input_addr, weight->GetData().size(), weight->GetData().data(),
  120. weight->GetData().size(), RT_MEMCPY_HOST_TO_DEVICE);
  121. if (rt_ret != RT_ERROR_NONE) {
  122. GELOGE(rt_ret, "rtMemcpy error");
  123. GE_CHK_RT(rtFree(input_addr));
  124. return FAILED;
  125. }
  126. AddrAndType input_info = {static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_addr)), kData};
  127. input_addrs.emplace_back(input_info);
  128. }
  129. return SUCCESS;
  130. }
  131. Status AicpuConstantFoldingPass::GetOutputAddrs(const OpDescPtr &node_desc, vector<uint64_t> &output_addrs) {
  132. if (node_desc->GetOutputsSize() == 0) {
  133. GELOGE(FAILED, "Output size is 0 ");
  134. return FAILED;
  135. }
  136. for (size_t i = 0; i < node_desc->GetOutputsSize(); ++i) {
  137. void *summary_addr = nullptr;
  138. GE_CHK_RT_RET(rtMalloc(&summary_addr, sizeof(aicpu::FWKAdapter::ResultSummary), RT_MEMORY_HBM));
  139. output_addrs.emplace_back(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(summary_addr)));
  140. }
  141. return SUCCESS;
  142. }
  143. Status AicpuConstantFoldingPass::GenerateDataPtrInfo(const vector<uint64_t> &output_addrs,
  144. vector<DataPtrInfo> &data_vec, vector<uint64_t> &data_infos) {
  145. for (uint64_t output_addr : output_addrs) {
  146. aicpu::FWKAdapter::ResultSummary result_summary;
  147. GE_CHK_RT_RET(rtMemcpy(&result_summary, sizeof(aicpu::FWKAdapter::ResultSummary),
  148. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(output_addr)),
  149. sizeof(aicpu::FWKAdapter::ResultSummary), RT_MEMCPY_DEVICE_TO_HOST));
  150. void *raw_data_addr = nullptr;
  151. GE_CHK_RT_RET(rtMalloc(&raw_data_addr, result_summary.raw_data_size, RT_MEMORY_HBM));
  152. void *shape_data_addr = nullptr;
  153. // shape_data_size = 0 means scalar
  154. if (result_summary.shape_data_size != 0) {
  155. rtError_t rt_ret = rtMalloc(&shape_data_addr, result_summary.shape_data_size, RT_MEMORY_HBM);
  156. if (rt_ret != RT_ERROR_NONE) {
  157. GELOGE(rt_ret, "rtMalloc error");
  158. GE_CHK_RT(rtFree(raw_data_addr));
  159. return FAILED;
  160. }
  161. }
  162. DataPtrInfo raw_data_info;
  163. raw_data_info.release_flag = kReleaseFlag;
  164. raw_data_info.data_size = result_summary.raw_data_size;
  165. raw_data_info.src_ptr = result_summary.raw_data_ptr;
  166. raw_data_info.dst_ptr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(raw_data_addr));
  167. data_vec.emplace_back(raw_data_info);
  168. DataPtrInfo shape_data_info;
  169. shape_data_info.release_flag = kReleaseFlag;
  170. shape_data_info.data_size = result_summary.shape_data_size;
  171. shape_data_info.src_ptr = result_summary.shape_data_ptr;
  172. shape_data_info.dst_ptr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(shape_data_addr));
  173. data_vec.emplace_back(shape_data_info);
  174. }
  175. for (const DataPtrInfo &data_info : data_vec) {
  176. data_infos.emplace_back(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(&data_info)));
  177. }
  178. return SUCCESS;
  179. }
  180. Status AicpuConstantFoldingPass::UpdateWorkSpaceAddr(string &task_info, STR_FWK_OP_KERNEL &task) {
  181. // Update the workspace_addr
  182. if (task_info.empty()) {
  183. GELOGE(FAILED, "task_info is empty ");
  184. return FAILED;
  185. }
  186. void *workspace_addr = nullptr;
  187. GE_CHK_RT_RET(rtMalloc(&workspace_addr, task_info.size(), RT_MEMORY_HBM));
  188. rtError_t rt_ret =
  189. rtMemcpy(workspace_addr, task_info.size(), task_info.data(), task_info.size(), RT_MEMCPY_HOST_TO_DEVICE);
  190. if (rt_ret != RT_ERROR_NONE) {
  191. GELOGE(rt_ret, "rtMemcpy error");
  192. GE_CHK_RT(rtFree(workspace_addr));
  193. return FAILED;
  194. }
  195. uint64_t workspace_base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_addr));
  196. task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = workspace_base_addr;
  197. return SUCCESS;
  198. }
  199. Status AicpuConstantFoldingPass::UpdateInputAndOutputAddr(const vector<uint64_t> &io_addrs, STR_FWK_OP_KERNEL &task) {
  200. auto addrs_size = sizeof(uint64_t) * (io_addrs.size());
  201. if (addrs_size <= 0) {
  202. GELOGE(FAILED, "addrs_size is less than 1 ");
  203. return FAILED;
  204. }
  205. void *input_output_addr = nullptr;
  206. GE_CHK_RT_RET(rtMalloc(&input_output_addr, addrs_size, RT_MEMORY_HBM));
  207. rtError_t rt_ret = rtMemcpy(input_output_addr, addrs_size, io_addrs.data(), addrs_size, RT_MEMCPY_HOST_TO_DEVICE);
  208. if (rt_ret != RT_ERROR_NONE) {
  209. GELOGE(rt_ret, "rtMemcpy error");
  210. GE_CHK_RT(rtFree(input_output_addr));
  211. return FAILED;
  212. }
  213. uint64_t in_out_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_output_addr));
  214. task.fwkKernelBase.fwk_kernel.inputOutputAddr = in_out_addr;
  215. return SUCCESS;
  216. }
  217. Status AicpuConstantFoldingPass::UpdateSingleOpAddr(string &task_info, const vector<AddrAndType> &input_addrs,
  218. const vector<uint64_t> &outputs_addr_vec, STR_FWK_OP_KERNEL &task) {
  219. // Build the SingleOpAddr
  220. vector<uint64_t> inputs_addr_vec;
  221. for (const auto &item : input_addrs) {
  222. inputs_addr_vec.push_back(item.input_addr);
  223. }
  224. vector<uint64_t> io_addrs;
  225. io_addrs.insert(io_addrs.end(), inputs_addr_vec.begin(), inputs_addr_vec.end());
  226. io_addrs.insert(io_addrs.end(), outputs_addr_vec.begin(), outputs_addr_vec.end());
  227. Status ret = UpdateInputAndOutputAddr(io_addrs, task);
  228. if (ret != SUCCESS) {
  229. GELOGE(ret, "UpdateInputAndOutputAddr error");
  230. return ret;
  231. }
  232. ret = UpdateWorkSpaceAddr(task_info, task);
  233. if (ret != SUCCESS) {
  234. GELOGE(ret, "UpdateWorkSpaceAddr error");
  235. return ret;
  236. }
  237. return SUCCESS;
  238. }
  239. Status AicpuConstantFoldingPass::UpdateMemCopyAddr(string &task_info, const vector<uint64_t> &data_infos,
  240. vector<uint64_t> &internal_addrs, STR_FWK_OP_KERNEL &task) {
  241. vector<uint64_t> release_flags;
  242. vector<uint64_t> data_sizes;
  243. vector<uint64_t> src_addrs;
  244. vector<uint64_t> dst_addrs;
  245. for (auto item : data_infos) {
  246. auto *data_info_ptr = reinterpret_cast<DataPtrInfo *>(reinterpret_cast<uintptr_t>(item)); // pointer cannot be null
  247. release_flags.push_back(data_info_ptr->release_flag);
  248. data_sizes.push_back(data_info_ptr->data_size);
  249. src_addrs.push_back(data_info_ptr->src_ptr);
  250. dst_addrs.push_back(data_info_ptr->dst_ptr);
  251. }
  252. vector<vector<uint64_t>> inputs = {release_flags, data_sizes, src_addrs, dst_addrs};
  253. auto data_size = sizeof(uint64_t) * (data_infos.size());
  254. vector<uint64_t> io_addrs;
  255. if (data_infos.size() > 0) {
  256. for (const auto &item : inputs) {
  257. void *input_addr_ptr = nullptr;
  258. GE_CHK_RT_RET(rtMalloc(&input_addr_ptr, data_size, RT_MEMORY_HBM));
  259. rtError_t rt_ret = rtMemcpy(input_addr_ptr, data_size, item.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE);
  260. if (rt_ret != RT_ERROR_NONE) {
  261. GELOGE(rt_ret, "rtMemcpy error");
  262. GE_CHK_RT(rtFree(input_addr_ptr));
  263. return FAILED;
  264. }
  265. uint64_t input_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_addr_ptr));
  266. io_addrs.push_back(input_addr);
  267. }
  268. }
  269. internal_addrs = io_addrs;
  270. Status ret = UpdateInputAndOutputAddr(io_addrs, task);
  271. if (ret != SUCCESS) {
  272. GELOGE(ret, "UpdateInputAndOutputAddr error");
  273. return ret;
  274. }
  275. ret = UpdateWorkSpaceAddr(task_info, task);
  276. if (ret != SUCCESS) {
  277. GELOGE(ret, "UpdateWorkSpaceAddr error");
  278. return ret;
  279. }
  280. return SUCCESS;
  281. }
  282. Status AicpuConstantFoldingPass::LaunchSingleOpRunTask(const NodePtr &node, const vector<AddrAndType> &input_addrs,
  283. const vector<uint64_t> &output_addrs) {
  284. void *task_buf = nullptr;
  285. auto instance_ptr = ge::GELib::GetInstance();
  286. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  287. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized");
  288. return GE_CLI_GE_NOT_INITIALIZED;
  289. }
  290. OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName);
  291. if (kernel_info == nullptr) {
  292. GELOGE(FAILED, "Get op kernel info store failed");
  293. return FAILED;
  294. }
  295. STR_FWK_OP_KERNEL aicpu_task;
  296. aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr = 0;
  297. aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = 0;
  298. aicpu_task.fwkKernelBase.fwk_kernel.extInfoAddr = 0;
  299. aicpu_task.fwkKernelBase.fwk_kernel.extInfoLen = 0;
  300. std::string task_info;
  301. Status ret = kernel_info->GenSingleOpRunTask(node, aicpu_task, task_info);
  302. if (ret != SUCCESS) {
  303. return ret;
  304. }
  305. std::function<void()> callback = [&]() {
  306. void *input_output_ptr =
  307. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr));
  308. if (input_output_ptr != nullptr) {
  309. GE_CHK_RT(rtFree(input_output_ptr));
  310. }
  311. void *workspace_addr_ptr =
  312. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr));
  313. if (workspace_addr_ptr != nullptr) {
  314. GE_CHK_RT(rtFree(workspace_addr_ptr));
  315. }
  316. };
  317. GE_MAKE_GUARD(release, callback);
  318. ret = UpdateSingleOpAddr(task_info, input_addrs, output_addrs, aicpu_task);
  319. if (ret != SUCCESS) {
  320. GELOGE(ret, "UpdateSingleOpAddr error");
  321. return ret;
  322. }
  323. ret = GenerateTaskForLaunch(aicpu_task, task_buf);
  324. if (ret != SUCCESS) {
  325. GELOGE(ret, "GenerateTaskForLaunch error");
  326. return ret;
  327. }
  328. ret = KernelLaunch(task_buf);
  329. if (ret != SUCCESS) {
  330. GELOGE(ret, "KernelLaunch error");
  331. return ret;
  332. }
  333. return SUCCESS;
  334. }
  335. Status AicpuConstantFoldingPass::LaunchMemCopyTask(const vector<uint64_t> &data_infos) {
  336. void *task_buf = nullptr;
  337. auto instance_ptr = ge::GELib::GetInstance();
  338. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  339. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized");
  340. return GE_CLI_GE_NOT_INITIALIZED;
  341. }
  342. OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName);
  343. if (kernel_info == nullptr) {
  344. GELOGE(FAILED, "Get op kernel info store failed");
  345. return FAILED;
  346. }
  347. STR_FWK_OP_KERNEL aicpu_task;
  348. aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr = 0;
  349. aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = 0;
  350. aicpu_task.fwkKernelBase.fwk_kernel.extInfoAddr = 0;
  351. aicpu_task.fwkKernelBase.fwk_kernel.extInfoLen = 0;
  352. std::string task_info;
  353. Status ret = kernel_info->GenMemCopyTask(data_infos.size(), aicpu_task, task_info);
  354. if (ret != SUCCESS) {
  355. return ret;
  356. }
  357. vector<uint64_t> internal_addrs;
  358. std::function<void()> callback = [&]() {
  359. for (auto item : internal_addrs) {
  360. GE_CHK_RT(rtFree(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(item)))); // pointer cannot be null
  361. }
  362. void *input_output_ptr =
  363. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr));
  364. if (input_output_ptr != nullptr) {
  365. GE_CHK_RT(rtFree(input_output_ptr));
  366. }
  367. void *workspace_addr_ptr =
  368. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr));
  369. if (workspace_addr_ptr != nullptr) {
  370. GE_CHK_RT(rtFree(workspace_addr_ptr));
  371. }
  372. };
  373. GE_MAKE_GUARD(release, callback);
  374. ret = UpdateMemCopyAddr(task_info, data_infos, internal_addrs, aicpu_task);
  375. if (ret != SUCCESS) {
  376. GELOGE(ret, "UpdateMemCopyAddr error");
  377. return ret;
  378. }
  379. ret = GenerateTaskForLaunch(aicpu_task, task_buf);
  380. if (ret != SUCCESS) {
  381. GELOGE(ret, "GenerateTaskForLaunch error");
  382. return ret;
  383. }
  384. ret = KernelLaunch(task_buf);
  385. if (ret != SUCCESS) {
  386. GELOGE(ret, "KernelLaunch error");
  387. return ret;
  388. }
  389. return SUCCESS;
  390. }
  391. Status AicpuConstantFoldingPass::GenerateTaskForLaunch(STR_FWK_OP_KERNEL &aicpu_task, void *&task_buf) {
  392. GE_CHK_RT_RET(rtMalloc(&task_buf, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM));
  393. rtError_t rt_ret = rtMemcpy(task_buf, sizeof(STR_FWK_OP_KERNEL), reinterpret_cast<void *>(&aicpu_task),
  394. sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE);
  395. if (rt_ret != RT_ERROR_NONE) {
  396. GELOGE(rt_ret, "rtMemcpy error");
  397. GE_CHK_RT(rtFree(task_buf));
  398. return FAILED;
  399. }
  400. return SUCCESS;
  401. }
  402. Status AicpuConstantFoldingPass::KernelLaunch(void *task_buf) {
  403. rtModel_t model = nullptr;
  404. rtStream_t stream = nullptr;
  405. rtStream_t stream_run = nullptr;
  406. std::function<void()> callback = [&]() {
  407. if (task_buf != nullptr) {
  408. GE_CHK_RT(rtFree(task_buf));
  409. }
  410. if (model != nullptr) {
  411. GE_CHK_RT(rtModelDestroy(model));
  412. }
  413. if (stream != nullptr) {
  414. GE_CHK_RT(rtStreamDestroy(stream));
  415. }
  416. if (stream_run != nullptr) {
  417. GE_CHK_RT(rtStreamDestroy(stream_run));
  418. }
  419. };
  420. GE_MAKE_GUARD(release, callback);
  421. rtError_t rt_ret = rtModelCreate(&model, 0);
  422. if (rt_ret != RT_ERROR_NONE) {
  423. GELOGE(rt_ret, "create model failed.");
  424. return FAILED;
  425. }
  426. rt_ret = rtStreamCreate(&stream, 0);
  427. if (rt_ret != RT_ERROR_NONE) {
  428. GELOGE(rt_ret, "create stream failed.");
  429. return FAILED;
  430. }
  431. rt_ret = rtModelBindStream(model, stream, 0);
  432. if (rt_ret != RT_ERROR_NONE) {
  433. GELOGE(rt_ret, "rtModelBindStream failed.");
  434. return FAILED;
  435. }
  436. rt_ret = rtKernelLaunchEx(task_buf, sizeof(STR_FWK_OP_KERNEL), 0, stream);
  437. if (rt_ret != RT_ERROR_NONE) {
  438. GELOGE(rt_ret, "rtKernelLaunchEx failed.");
  439. return FAILED;
  440. }
  441. rt_ret = rtModelLoadComplete(model);
  442. if (rt_ret != RT_ERROR_NONE) {
  443. GELOGE(rt_ret, "rtModelLoadComplete failed.");
  444. return FAILED;
  445. }
  446. rt_ret = rtStreamCreate(&stream_run, 0);
  447. if (rt_ret != RT_ERROR_NONE) {
  448. GELOGE(rt_ret, "create run stream failed.");
  449. return FAILED;
  450. }
  451. rt_ret = rtModelExecute(model, stream_run, 0);
  452. if (rt_ret != RT_ERROR_NONE) {
  453. GELOGE(rt_ret, "rtModelExecute failed.");
  454. return FAILED;
  455. }
  456. rt_ret = rtStreamSynchronize(stream_run);
  457. if (rt_ret != RT_ERROR_NONE) {
  458. GELOGE(rt_ret, "rtStreamSynchronize failed.");
  459. return FAILED;
  460. }
  461. return SUCCESS;
  462. }
  463. Status AicpuConstantFoldingPass::GenerateGeTensor(const OpDescPtr &node_desc, const vector<DataPtrInfo> &data_vec,
  464. vector<GeTensorPtr> &outputs) {
  465. if ((node_desc->GetOutputsSize() * kDouble) != data_vec.size()) {
  466. GELOGE(FAILED, "node[%s] something wrong with output size", node_desc->GetName().c_str());
  467. return FAILED;
  468. }
  469. for (size_t i = 0; i < node_desc->GetOutputsSize(); i++) {
  470. auto output_tensor_desc = node_desc->GetOutputDesc(static_cast<uint32_t>(i));
  471. GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
  472. if (output_ptr == nullptr) {
  473. GELOGE(FAILED, "node[%s] something wrong with construct GeTensor", node_desc->GetName().c_str());
  474. return FAILED;
  475. }
  476. const DataPtrInfo &raw_data_info = data_vec.at(i * kDouble);
  477. uint64_t raw_data_size = raw_data_info.data_size;
  478. std::unique_ptr<uint8_t[]> data_addr(new (std::nothrow) uint8_t[raw_data_size]());
  479. if (data_addr == nullptr) {
  480. GELOGE(MEMALLOC_FAILED, "new data_addr failed");
  481. return INTERNAL_ERROR;
  482. }
  483. GE_CHK_RT_RET(rtMemcpy(data_addr.get(), raw_data_size,
  484. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(raw_data_info.dst_ptr)), raw_data_size,
  485. RT_MEMCPY_DEVICE_TO_HOST));
  486. GE_IF_BOOL_EXEC(output_ptr->SetData(data_addr.get(), raw_data_size) != GRAPH_SUCCESS,
  487. GELOGE(FAILED, "set data failed");
  488. return FAILED);
  489. GELOGD("GenerateGeTensor: raw_data_size %lu", raw_data_size);
  490. const DataPtrInfo &shape_data_info = data_vec.at(i * kDouble + 1);
  491. uint64_t shape_data_size = shape_data_info.data_size;
  492. GELOGD("GenerateGeTensor: shape_data_size %lu", shape_data_size);
  493. if (shape_data_size == 0) {
  494. GELOGW("node[%s] outshape is scalar, skip copy shape", node_desc->GetName().c_str());
  495. output_ptr->MutableTensorDesc().SetShape(GeShape());
  496. outputs.emplace_back(output_ptr);
  497. continue;
  498. }
  499. uint64_t dim_num = shape_data_size / sizeof(uint64_t);
  500. std::unique_ptr<int64_t[]> shape_addr(new (std::nothrow) int64_t[dim_num]());
  501. if (shape_addr == nullptr) {
  502. GELOGE(MEMALLOC_FAILED, "new shape_addr failed");
  503. return INTERNAL_ERROR;
  504. }
  505. GE_CHK_RT_RET(rtMemcpy(shape_addr.get(), shape_data_size,
  506. reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(shape_data_info.dst_ptr)),
  507. shape_data_size, RT_MEMCPY_DEVICE_TO_HOST));
  508. std::vector<int64_t> shape_dims;
  509. for (size_t j = 0; j < dim_num; j++) {
  510. shape_dims.push_back(shape_addr[j]);
  511. GELOGD("GenerateGeTensor: dim %ld", shape_addr[j]);
  512. }
  513. output_ptr->MutableTensorDesc().SetShape(GeShape(shape_dims));
  514. outputs.emplace_back(output_ptr);
  515. }
  516. return SUCCESS;
  517. }
  518. void AicpuConstantFoldingPass::ReleaseMemory(const vector<AddrAndType> &input_addrs,
  519. const vector<uint64_t> &output_addrs,
  520. const vector<DataPtrInfo> &data_vec) {
  521. for (const auto &item : input_addrs) {
  522. GE_CHK_RT(rtFree(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(item.input_addr))));
  523. }
  524. for (auto item : output_addrs) {
  525. GE_CHK_RT(rtFree(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(item))));
  526. }
  527. for (const auto &item : data_vec) {
  528. auto dst_ptr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(item.dst_ptr));
  529. if (dst_ptr != nullptr) {
  530. GE_CHK_RT(rtFree(dst_ptr));
  531. }
  532. }
  533. }
  534. bool AicpuConstantFoldingPass::IsSkipFold(const ge::NodePtr &node) {
  535. GE_CHECK_NOTNULL(node);
  536. string type = node->GetType();
  537. if (type == ge::FRAMEWORKOP) {
  538. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) {
  539. GELOGW("Skip aicpu constant folding on frameworkop node [%s]", node->GetName().c_str());
  540. return true;
  541. }
  542. }
  543. auto instance_ptr = ge::GELib::GetInstance();
  544. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  545. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized");
  546. return true;
  547. }
  548. OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName);
  549. if (kernel_info == nullptr) {
  550. GELOGE(FAILED, "Get op kernel info store failed");
  551. return true;
  552. }
  553. std::string check_result;
  554. kernel_info->opsFlagCheck(*node, check_result);
  555. if (check_result.empty()) {
  556. GELOGE(FAILED, "Get op check_result failed");
  557. return true;
  558. }
  559. return check_result.substr(0, kOpsFlag) == kNotSupported;
  560. }
  561. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示