You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_utils.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/tensor_utils.h"
  17. #include <cmath>
  18. #include "debug/ge_log.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "common/util/error_manager/error_manager.h"
  21. #include "graph/ge_tensor.h"
  22. #include "graph/types.h"
  23. #include "graph/utils/type_utils.h"
  24. #include "mmpa/mmpa_api.h"
  25. namespace ge {
  26. namespace {
  27. // When nc1hwc0 dim size = 5, calc element count directly.
  28. const uint32_t kNc1hwc0CalcByDimsSize = 5;
  29. // Unknown shape element num
  30. const int64_t kElementCntUnknownShape = -1;
  31. // Unknown shape mem size
  32. const int64_t kMemSizeUnknownShape = -1;
  33. // Nchw and nhwc dim size must be 4
  34. const uint32_t kDimSize4d = 4;
  35. // C1HWNCoC0 dim size must be 6
  36. const uint32_t kDimSizeC1hwncoc0 = 6;
  37. // Cube size is 16
  38. const uint32_t kTheCubeSize = 16;
  39. // Default c0 size equals cube size.
  40. const uint32_t kC0SizeDefault = kTheCubeSize;
  41. // Size equals int8 cube size is 32
  42. const uint32_t kC0SizeInt8 = 32;
  43. // NCHW dim N index
  44. const int32_t kNchwDimIdxN = 0;
  45. // NCHW dim C index
  46. const int32_t kNchwDimIdxC = 1;
  47. // NCHW dim H index
  48. const int32_t kNchwDimIdxH = 2;
  49. // NCHW dim W index
  50. const int32_t kNchwDimIdxW = 3;
  51. const int kDataMemAlignSize = 32;
  52. const int kNum2 = 2;
  53. } // namespace
  54. ///
  55. /// Check if a * b overflow.
  56. /// @param a multiplier
  57. /// @param b Multiplicand
  58. /// @return true: overflow
  59. /// false: not overflow
  60. ///
  61. static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) {
  62. if (a > 0) {
  63. if (b > 0) {
  64. if (a > (INT64_MAX / b)) {
  65. return true;
  66. }
  67. } else {
  68. if (b < (INT64_MIN / a)) {
  69. return true;
  70. }
  71. }
  72. } else {
  73. if (b > 0) {
  74. if (a < (INT64_MIN / b)) {
  75. return true;
  76. }
  77. } else {
  78. if ((a != 0) && (b < (INT64_MAX / a))) {
  79. return true;
  80. }
  81. }
  82. }
  83. return false;
  84. }
  85. ///
  86. /// Calculate element num by dims directly.
  87. /// @param dims dim info
  88. /// @param element_cnt element count
  89. /// @return GRAPH_SUCCESS:success
  90. /// other:failed
  91. ///
  92. static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_t &element_cnt) {
  93. element_cnt = 1;
  94. for (int64_t dim : dims) {
  95. if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
  96. ErrorManager::GetInstance().ATCReportErrMessage(
  97. "E19013", {"function", "var1", "var2"},
  98. {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
  99. GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
  100. return GRAPH_FAILED;
  101. }
  102. element_cnt *= dim;
  103. }
  104. return GRAPH_SUCCESS;
  105. }
  106. ///
  107. /// Calculate fixed dims element num.
  108. /// @param dims dim info
  109. /// @param fixed_dim_size fixed dim size
  110. /// @param element_cnt element count
  111. /// @return GRAPH_SUCCESS:success
  112. /// other:failed
  113. ///
  114. static graphStatus CalcElementCntOfFixedDims(const std::vector<int64_t> &dims, Format format, uint32_t fixed_dim_size,
  115. int64_t &element_cnt) {
  116. if (dims.size() != fixed_dim_size) {
  117. GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.",
  118. format, TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size());
  119. }
  120. return CalcElementCntByDims(dims, element_cnt);
  121. }
  122. ///
  123. /// Get dim c0 size by type
  124. /// @param data_type data type
  125. /// @return c0 size
  126. ///
  127. static uint32_t GetDimC0(DataType &data_type) {
  128. bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) ||
  129. (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8);
  130. return is_int8_size ? kC0SizeInt8 : kC0SizeDefault;
  131. }
  132. ///
  133. /// Calculate nc1hwc0 element num.
  134. /// @param dims dim info
  135. /// @param data_type data type
  136. /// @param element_cnt element count
  137. /// @return GRAPH_SUCCESS:success
  138. /// other:failed
  139. ///
  140. static graphStatus CalcElementCntOfNc1hwc0(const std::vector<int64_t> &dims, DataType data_type, int64_t &element_cnt) {
  141. // When nc1hwc0 dims size = 5, no need split dim c
  142. if (dims.size() == kNc1hwc0CalcByDimsSize) {
  143. return CalcElementCntByDims(dims, element_cnt);
  144. } else if (dims.size() != kDimSize4d) {
  145. ErrorManager::GetInstance().ATCReportErrMessage(
  146. "E19012", {"function", "reason"},
  147. {"CalcElementCntOfNc1hwc0", "dims.size[" + std::to_string(dims.size()) + "] is not " +
  148. std::to_string(kDimSize4d) + " or " + std::to_string(kNc1hwc0CalcByDimsSize)});
  149. GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d,
  150. kNc1hwc0CalcByDimsSize);
  151. return GRAPH_FAILED;
  152. }
  153. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  154. // Nc1hwc0 dims is according to nchw, dim c index is 1.
  155. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  156. // Store dims is split c to c1 and c0.
  157. std::vector<int64_t> store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  158. return CalcElementCntByDims(store_dims, element_cnt);
  159. }
  160. ///
  161. /// Calculate FractalZ element num.
  162. /// @param dims dim info
  163. /// @param data_type data type
  164. /// @param element_cnt element count
  165. /// @return GRAPH_SUCCESS:success
  166. /// other:failed
  167. ///
  168. static graphStatus CalcElementCntOfFractalZ(const std::vector<int64_t> &dims, DataType data_type,
  169. int64_t &element_cnt) {
  170. static char parser_priority[MMPA_MAX_PATH] = { 0x00 };
  171. INT32 res = mmGetEnv("PARSER_PRIORITY", parser_priority, MMPA_MAX_PATH);
  172. if (res == EN_OK && string(parser_priority) == "cce") {
  173. if (dims.size() != kDimSize4d) {
  174. ErrorManager::GetInstance().ATCReportErrMessage(
  175. "E19012", {"function", "reason"},
  176. {"CalcElementCntOfFractalZ",
  177. "dims.size[" + std::to_string(dims.size()) + "] is not " + std::to_string(kDimSize4d)});
  178. GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d);
  179. return GRAPH_FAILED;
  180. }
  181. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  182. // FractalZ dims is according to nchw, dim c index is 1.
  183. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  184. // Spread NC1HWC0 as a two dimension array, n as column dimension,
  185. // C1HWC0 as row dimension
  186. std::vector<int64_t> r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  187. int64_t r_count = 1;
  188. graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count);
  189. if (graph_status != GRAPH_SUCCESS) {
  190. GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.",
  191. c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0);
  192. return graph_status;
  193. }
  194. // Cube count in n
  195. auto nc_cnt = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize));
  196. // Cube count in vertical direction(C1HWC0)
  197. int64_t vc_cnt = r_count / c0;
  198. // Element count in each cube
  199. int64_t cube_elem_cnt = c0 * kTheCubeSize;
  200. if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) {
  201. ErrorManager::GetInstance().ATCReportErrMessage(
  202. "E19013", {"function", "var1", "var2"},
  203. {"CheckMultiplyOverflowInt64", std::to_string(nc_cnt), std::to_string(vc_cnt)});
  204. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt);
  205. return GRAPH_FAILED;
  206. }
  207. // Read data times needed by cube
  208. int64_t c_cnt = nc_cnt * vc_cnt;
  209. if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) {
  210. ErrorManager::GetInstance().ATCReportErrMessage(
  211. "E19013", {"function", "var1", "var2"},
  212. {"CheckMultiplyOverflowInt64", std::to_string(c_cnt), std::to_string(cube_elem_cnt)});
  213. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt);
  214. return GRAPH_FAILED;
  215. }
  216. // Element count after fractal arrangement
  217. element_cnt = c_cnt * cube_elem_cnt;
  218. return GRAPH_SUCCESS;
  219. } else {
  220. return CalcElementCntByDims(dims, element_cnt);
  221. }
  222. }
  223. ///
  224. /// Calculate tensor element num.
  225. /// @param dims dim info
  226. /// @param format tensor format
  227. /// @param data_type data type
  228. /// @param element_cnt element count
  229. /// @return GRAPH_SUCCESS:success
  230. /// other:failed
  231. ///
  232. static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format format, DataType data_type,
  233. int64_t &element_cnt) {
  234. const string format_str = TypeUtils::FormatToSerialString(format);
  235. // Check dims
  236. for (size_t i = 0; i < dims.size(); ++i) {
  237. int64_t dim = dims[i];
  238. if (dim < 0) {
  239. GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str());
  240. element_cnt = kElementCntUnknownShape;
  241. return GRAPH_SUCCESS;
  242. } else if (dim == 0) {
  243. GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str());
  244. element_cnt = 0;
  245. return GRAPH_SUCCESS;
  246. }
  247. }
  248. graphStatus graph_status;
  249. switch (format) {
  250. case FORMAT_ND:
  251. case FORMAT_MD:
  252. graph_status = CalcElementCntByDims(dims, element_cnt);
  253. break;
  254. case FORMAT_NCHW:
  255. case FORMAT_HWCN:
  256. case FORMAT_NHWC:
  257. case FORMAT_CHWN:
  258. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt);
  259. break;
  260. case FORMAT_C1HWNCoC0:
  261. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt);
  262. break;
  263. case FORMAT_NC1HWC0:
  264. graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt);
  265. break;
  266. case FORMAT_FRACTAL_Z:
  267. graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
  268. break;
  269. case FORMAT_FRACTAL_NZ:
  270. case FORMAT_FRACTAL_ZZ:
  271. case FORMAT_NDHWC:
  272. case FORMAT_NCDHW:
  273. case FORMAT_DHWCN:
  274. case FORMAT_DHWNC:
  275. case FORMAT_FRACTAL_Z_3D:
  276. case FORMAT_FRACTAL_Z_3D_TRANSPOSE:
  277. case FORMAT_NDC1HWC0:
  278. case FORMAT_FRACTAL_Z_C04:
  279. case FORMAT_FRACTAL_ZN_LSTM:
  280. case FORMAT_NC1HWC0_C04:
  281. graph_status = CalcElementCntByDims(dims, element_cnt);
  282. break;
  283. default:
  284. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  285. {"CalcTensorElementCnt", "format[" + format_str + "] is not support"});
  286. GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str());
  287. graph_status = GRAPH_FAILED;
  288. break;
  289. }
  290. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  291. if (graph_status == GRAPH_SUCCESS) {
  292. GELOGD(
  293. "CalcTensorElementCnt end, format=%d(%s),"
  294. " data_type=%d(%s), element_cnt=%ld.",
  295. format, format_str.c_str(), data_type, type_str.c_str(), element_cnt);
  296. } else {
  297. GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).",
  298. format, format_str.c_str(), data_type, type_str.c_str());
  299. }
  300. return graph_status;
  301. }
  302. ///
  303. /// Calculate tensor mem size.
  304. /// @param shape tensor shape
  305. /// @param format tensor format
  306. /// @param data_type tensor data type
  307. /// @param mem_size -1 means unknown shape,other means mem size
  308. /// @return GRAPH_SUCCESS:success, other:failed
  309. ///
  310. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape,
  311. Format format,
  312. DataType data_type,
  313. int64_t &mem_size) {
  314. const string format_str = TypeUtils::FormatToSerialString(format);
  315. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  316. uint32_t type_size = 0;
  317. bool result = TypeUtils::GetDataTypeLength(data_type, type_size);
  318. if (!result) {
  319. GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str());
  320. return GRAPH_FAILED;
  321. }
  322. std::vector<int64_t> dims = shape.GetDims();
  323. int64_t element_cnt = 0;
  324. graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt);
  325. if (status != GRAPH_SUCCESS) {
  326. GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).",
  327. status, format, format_str.c_str(), data_type, type_str.c_str());
  328. return status;
  329. }
  330. // Support unknown shape
  331. if (element_cnt < 0) {
  332. mem_size = kMemSizeUnknownShape;
  333. GELOGD(
  334. "element_cnt is unknown. "
  335. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  336. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  337. return GRAPH_SUCCESS;
  338. }
  339. auto type_size_int64 = static_cast<int64_t>(type_size);
  340. if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) {
  341. ErrorManager::GetInstance().ATCReportErrMessage(
  342. "E19013", {"function", "var1", "var2"},
  343. {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(type_size_int64)});
  344. GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).",
  345. element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str());
  346. return GRAPH_FAILED;
  347. }
  348. mem_size = element_cnt * type_size_int64;
  349. GELOGD(
  350. "CalcTensorMemSize end, "
  351. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  352. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  353. return GRAPH_SUCCESS;
  354. }
  355. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  356. TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  357. graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp);
  358. if (graph_status != GRAPH_SUCCESS) {
  359. return GRAPH_FAILED;
  360. }
  361. // 64-byte alignment, if size is 0, align to 32 bytes
  362. if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) {
  363. GELOGW("The updated mem size %ld is bigger than INT64_MAX",size_temp);
  364. } else {
  365. size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
  366. }
  367. return GRAPH_SUCCESS;
  368. }
  369. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  370. TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  371. GeShape output_shape = desc_temp.GetShape();
  372. Format format = desc_temp.GetFormat();
  373. DataType data_type = desc_temp.GetDataType();
  374. int64_t output_mem_size = 0;
  375. graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size);
  376. if (graph_status != GRAPH_SUCCESS) {
  377. GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!");
  378. return GRAPH_FAILED;
  379. }
  380. if (output_mem_size < 0) {
  381. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  382. {"GetTensorSizeInBytes", "output_mem_size is out of data range [0, " + std::to_string(INT64_MAX) + "]"});
  383. GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]",
  384. output_mem_size, INT64_MAX);
  385. return GRAPH_FAILED;
  386. }
  387. size_temp = output_mem_size;
  388. return GRAPH_SUCCESS;
  389. }
  390. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示