You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_utils.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/tensor_utils.h"
  17. #include <cmath>
  18. #include "debug/ge_log.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "common/util/error_manager/error_manager.h"
  21. #include "graph/ge_tensor.h"
  22. #include "graph/types.h"
  23. #include "graph/utils/type_utils.h"
  24. namespace ge {
  25. namespace {
  26. // When nc1hwc0 dim size = 5, calc element count directly.
  27. const uint32_t kNc1hwc0CalcByDimsSize = 5;
  28. // Unknown shape element num
  29. const int64_t kElementCntUnknownShape = -1;
  30. // Unknown shape mem size
  31. const int64_t kMemSizeUnknownShape = -1;
  32. // Nchw and nhwc dim size must be 4
  33. const uint32_t kDimSize4d = 4;
  34. // C1HWNCoC0 dim size must be 6
  35. const uint32_t kDimSizeC1hwncoc0 = 6;
  36. // Cube size is 16
  37. const uint32_t kTheCubeSize = 16;
  38. // Default c0 size equals cube size.
  39. const uint32_t kC0SizeDefault = kTheCubeSize;
  40. // Size equals int8 cube size is 32
  41. const uint32_t kC0SizeInt8 = 32;
  42. // NCHW dim N index
  43. const int32_t kNchwDimIdxN = 0;
  44. // NCHW dim C index
  45. const int32_t kNchwDimIdxC = 1;
  46. // NCHW dim H index
  47. const int32_t kNchwDimIdxH = 2;
  48. // NCHW dim W index
  49. const int32_t kNchwDimIdxW = 3;
  50. const int kDataMemAlignSize = 32;
  51. const int kNum2 = 2;
  52. } // namespace
  53. ///
  54. /// Check if a * b overflow.
  55. /// @param a multiplier
  56. /// @param b Multiplicand
  57. /// @return true: overflow
  58. /// false: not overflow
  59. ///
  60. static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) {
  61. if (a > 0) {
  62. if (b > 0) {
  63. if (a > (INT64_MAX / b)) {
  64. return true;
  65. }
  66. } else {
  67. if (b < (INT64_MIN / a)) {
  68. return true;
  69. }
  70. }
  71. } else {
  72. if (b > 0) {
  73. if (a < (INT64_MIN / b)) {
  74. return true;
  75. }
  76. } else {
  77. if ((a != 0) && (b < (INT64_MAX / a))) {
  78. return true;
  79. }
  80. }
  81. }
  82. return false;
  83. }
  84. ///
  85. /// Calculate element num by dims directly.
  86. /// @param dims dim info
  87. /// @param element_cnt element count
  88. /// @return GRAPH_SUCCESS:success
  89. /// other:failed
  90. ///
  91. static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_t &element_cnt) {
  92. element_cnt = 1;
  93. for (int64_t dim : dims) {
  94. if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
  95. ErrorManager::GetInstance().ATCReportErrMessage(
  96. "E19013", {"function", "var1", "var2"},
  97. {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
  98. GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
  99. return GRAPH_FAILED;
  100. }
  101. element_cnt *= dim;
  102. }
  103. return GRAPH_SUCCESS;
  104. }
  105. ///
  106. /// Calculate fixed dims element num.
  107. /// @param dims dim info
  108. /// @param fixed_dim_size fixed dim size
  109. /// @param element_cnt element count
  110. /// @return GRAPH_SUCCESS:success
  111. /// other:failed
  112. ///
  113. static graphStatus CalcElementCntOfFixedDims(const std::vector<int64_t> &dims, Format format, uint32_t fixed_dim_size,
  114. int64_t &element_cnt) {
  115. if (dims.size() != fixed_dim_size) {
  116. GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format,
  117. TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size());
  118. }
  119. return CalcElementCntByDims(dims, element_cnt);
  120. }
  121. ///
  122. /// Get dim c0 size by type
  123. /// @param data_type data type
  124. /// @return c0 size
  125. ///
  126. static uint32_t GetDimC0(DataType &data_type) {
  127. bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) ||
  128. (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8);
  129. return is_int8_size ? kC0SizeInt8 : kC0SizeDefault;
  130. }
  131. ///
  132. /// Calculate nc1hwc0 element num.
  133. /// @param dims dim info
  134. /// @param data_type data type
  135. /// @param element_cnt element count
  136. /// @return GRAPH_SUCCESS:success
  137. /// other:failed
  138. ///
  139. static graphStatus CalcElementCntOfNc1hwc0(const std::vector<int64_t> &dims, DataType data_type, int64_t &element_cnt) {
  140. // When nc1hwc0 dims size = 5, no need split dim c
  141. if (dims.size() == kNc1hwc0CalcByDimsSize) {
  142. return CalcElementCntByDims(dims, element_cnt);
  143. } else if (dims.size() != kDimSize4d) {
  144. GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d,
  145. kNc1hwc0CalcByDimsSize);
  146. return GRAPH_FAILED;
  147. }
  148. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  149. // Nc1hwc0 dims is according to nchw, dim c index is 1.
  150. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  151. // Store dims is split c to c1 and c0.
  152. std::vector<int64_t> store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  153. return CalcElementCntByDims(store_dims, element_cnt);
  154. }
  155. ///
  156. /// Calculate FractalZ element num.
  157. /// @param dims dim info
  158. /// @param data_type data type
  159. /// @param element_cnt element count
  160. /// @return GRAPH_SUCCESS:success
  161. /// other:failed
  162. ///
  163. static graphStatus CalcElementCntOfFractalZ(const std::vector<int64_t> &dims, DataType data_type,
  164. int64_t &element_cnt) {
  165. static char *parser_priority = std::getenv("PARSER_PRIORITY");
  166. if (parser_priority != nullptr && string(parser_priority) == "cce") {
  167. if (dims.size() != kDimSize4d) {
  168. GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d);
  169. return GRAPH_FAILED;
  170. }
  171. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  172. // FractalZ dims is according to nchw, dim c index is 1.
  173. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  174. // Spread NC1HWC0 as a two dimension array, n as column dimension,
  175. // C1HWC0 as row dimension
  176. std::vector<int64_t> r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  177. int64_t r_count = 1;
  178. graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count);
  179. if (graph_status != GRAPH_SUCCESS) {
  180. GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH],
  181. dims[kNchwDimIdxW], c0);
  182. return graph_status;
  183. }
  184. // Cube count in n
  185. auto nc_cnt = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize));
  186. // Cube count in vertical direction(C1HWC0)
  187. int64_t vc_cnt = r_count / c0;
  188. // Element count in each cube
  189. int64_t cube_elem_cnt = c0 * kTheCubeSize;
  190. if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) {
  191. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt);
  192. return GRAPH_FAILED;
  193. }
  194. // Read data times needed by cube
  195. int64_t c_cnt = nc_cnt * vc_cnt;
  196. if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) {
  197. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt);
  198. return GRAPH_FAILED;
  199. }
  200. // Element count after fractal arrangement
  201. element_cnt = c_cnt * cube_elem_cnt;
  202. return GRAPH_SUCCESS;
  203. } else {
  204. return CalcElementCntByDims(dims, element_cnt);
  205. }
  206. }
  207. ///
  208. /// Calculate tensor element num.
  209. /// @param dims dim info
  210. /// @param format tensor format
  211. /// @param data_type data type
  212. /// @param element_cnt element count
  213. /// @return GRAPH_SUCCESS:success
  214. /// other:failed
  215. ///
  216. static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format format, DataType data_type,
  217. int64_t &element_cnt) {
  218. const string format_str = TypeUtils::FormatToSerialString(format);
  219. // Check dims
  220. for (size_t i = 0; i < dims.size(); ++i) {
  221. int64_t dim = dims[i];
  222. if (dim < 0) {
  223. GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str());
  224. element_cnt = kElementCntUnknownShape;
  225. return GRAPH_SUCCESS;
  226. } else if (dim == 0) {
  227. GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str());
  228. element_cnt = 0;
  229. return GRAPH_SUCCESS;
  230. }
  231. }
  232. graphStatus graph_status;
  233. switch (format) {
  234. case FORMAT_ND:
  235. case FORMAT_MD:
  236. graph_status = CalcElementCntByDims(dims, element_cnt);
  237. break;
  238. case FORMAT_NCHW:
  239. case FORMAT_HWCN:
  240. case FORMAT_NHWC:
  241. case FORMAT_CHWN:
  242. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt);
  243. break;
  244. case FORMAT_C1HWNCoC0:
  245. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt);
  246. break;
  247. case FORMAT_NC1HWC0:
  248. graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt);
  249. break;
  250. case FORMAT_FRACTAL_Z:
  251. graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
  252. break;
  253. case FORMAT_FRACTAL_NZ:
  254. case FORMAT_FRACTAL_ZZ:
  255. case FORMAT_NDHWC:
  256. case FORMAT_NCDHW:
  257. case FORMAT_DHWCN:
  258. case FORMAT_DHWNC:
  259. case FORMAT_FRACTAL_Z_3D:
  260. case FORMAT_FRACTAL_Z_3D_TRANSPOSE:
  261. case FORMAT_NDC1HWC0:
  262. case FORMAT_FRACTAL_Z_C04:
  263. case FORMAT_FRACTAL_ZN_LSTM:
  264. case FORMAT_NC1HWC0_C04:
  265. graph_status = CalcElementCntByDims(dims, element_cnt);
  266. break;
  267. default:
  268. GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str());
  269. graph_status = GRAPH_FAILED;
  270. break;
  271. }
  272. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  273. if (graph_status == GRAPH_SUCCESS) {
  274. GELOGD(
  275. "CalcTensorElementCnt end, format=%d(%s),"
  276. " data_type=%d(%s), element_cnt=%ld.",
  277. format, format_str.c_str(), data_type, type_str.c_str(), element_cnt);
  278. } else {
  279. GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(),
  280. data_type, type_str.c_str());
  281. }
  282. return graph_status;
  283. }
  284. ///
  285. /// Calculate tensor mem size.
  286. /// @param shape tensor shape
  287. /// @param format tensor format
  288. /// @param data_type tensor data type
  289. /// @param mem_size -1 means unknown shape,other means mem size
  290. /// @return GRAPH_SUCCESS:success, other:failed
  291. ///
  292. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape,
  293. Format format,
  294. DataType data_type,
  295. int64_t &mem_size) {
  296. const string format_str = TypeUtils::FormatToSerialString(format);
  297. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  298. uint32_t type_size = 0;
  299. bool result = TypeUtils::GetDataTypeLength(data_type, type_size);
  300. if (!result) {
  301. GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str());
  302. return GRAPH_FAILED;
  303. }
  304. std::vector<int64_t> dims = shape.GetDims();
  305. int64_t element_cnt = 0;
  306. graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt);
  307. if (status != GRAPH_SUCCESS) {
  308. GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format,
  309. format_str.c_str(), data_type, type_str.c_str());
  310. return status;
  311. }
  312. // Support unknown shape
  313. if (element_cnt < 0) {
  314. mem_size = kMemSizeUnknownShape;
  315. GELOGD(
  316. "element_cnt is unknown. "
  317. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  318. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  319. return GRAPH_SUCCESS;
  320. }
  321. auto type_size_int64 = static_cast<int64_t>(type_size);
  322. if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) {
  323. GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).",
  324. element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str());
  325. return GRAPH_FAILED;
  326. }
  327. mem_size = element_cnt * type_size_int64;
  328. GELOGD(
  329. "CalcTensorMemSize end, "
  330. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  331. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  332. return GRAPH_SUCCESS;
  333. }
  334. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  335. TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  336. graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp);
  337. if (graph_status != GRAPH_SUCCESS) {
  338. return GRAPH_FAILED;
  339. }
  340. // 64-byte alignment, if size is 0, align to 32 bytes
  341. if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) {
  342. GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp);
  343. } else {
  344. size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
  345. }
  346. return GRAPH_SUCCESS;
  347. }
  348. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  349. TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  350. GeShape output_shape = desc_temp.GetShape();
  351. Format format = desc_temp.GetFormat();
  352. DataType data_type = desc_temp.GetDataType();
  353. int64_t output_mem_size = 0;
  354. graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size);
  355. if (graph_status != GRAPH_SUCCESS) {
  356. GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!");
  357. return GRAPH_FAILED;
  358. }
  359. if (output_mem_size < 0) {
  360. GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]",
  361. output_mem_size, INT64_MAX);
  362. return GRAPH_FAILED;
  363. }
  364. size_temp = output_mem_size;
  365. return GRAPH_SUCCESS;
  366. }
  367. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示