You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gather_v2_kernel.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/gather_v2_kernel.h"
  17. #include <memory>
  18. #include <set>
  19. #include "common/fp16_t.h"
  20. #include "framework/common/ge_inner_error_codes.h"
  21. #include "framework/common/op/ge_op_utils.h"
  22. #include "framework/common/types.h"
  23. #include "framework/common/util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "host_kernels/kernel_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. #include "inc/kernel_factory.h"
  28. namespace ge {
  29. namespace {
  30. const size_t kGatherV2InputIndexZero = 0;
  31. const size_t kGatherV2InputIndexOne = 1;
  32. const size_t kGatherV2InputIndexTwo = 2;
  33. const size_t kGatherV2InputIndexThree = 3;
  34. const size_t kGatherV2DimOne = 1;
  35. const size_t kGatherV2InpotNum = 3;
  36. const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_
  37. const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32,
  38. DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
  39. const int64_t DIM_AXIS_0 = 0;
  40. const int64_t DIM_AXIS_1 = 1;
  41. const int64_t DIM_AXIS_2 = 2;
  42. const int64_t DIM_AXIS_3 = 3;
  43. } // namespace
  44. template <typename T>
  45. Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  46. Status ret = SUCCESS;
  47. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  48. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  49. // index is valid, and no bigger than kGatherV2InputIndexZero
  50. size_t output_size = output->GetData().size();
  51. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  52. T *data_ptr_x_tmp = data_ptr_x + indicates_[i] * xstride_[kGatherV2InputIndexZero];
  53. T *data_ptr_y_tmp = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  54. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexZero];
  55. if (data_ptr_y_tmp - data_ptr_y < 0) {
  56. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  57. return PARAM_INVALID;
  58. }
  59. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  60. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  61. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  62. if (ret_mem != 0) {
  63. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  64. return MEMALLOC_FAILED;
  65. }
  66. }
  67. return ret;
  68. }
  69. template <typename T>
  70. Status GatherV2Kernel::ProcessAxis1(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  71. Status ret = SUCCESS;
  72. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  73. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  74. // index is valid, and no bigger than kGatherV2InputIndexOne
  75. size_t output_size = output->GetData().size();
  76. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  77. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  78. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  79. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  80. T *data_ptr_x_tmp = data_ptr_x_i + indicates_[j] * xstride_[kGatherV2InputIndexOne];
  81. T *data_ptr_y_tmp = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  82. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexOne];
  83. if (data_ptr_y_tmp - data_ptr_y < 0) {
  84. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  85. return PARAM_INVALID;
  86. }
  87. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  88. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  89. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  90. if (ret_mem != 0) {
  91. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  92. return MEMALLOC_FAILED;
  93. }
  94. }
  95. }
  96. return ret;
  97. }
  98. template <typename T>
  99. Status GatherV2Kernel::ProcessAxis2(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  100. Status ret = SUCCESS;
  101. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  102. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  103. // index is valid, and no bigger than kGatherV2InputIndexTwo
  104. size_t output_size = output->GetData().size();
  105. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  106. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  107. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  108. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  109. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  110. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  111. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  112. T *data_ptr_x_tmp = data_ptr_x_j + indicates_[m] * xstride_[kGatherV2InputIndexTwo];
  113. T *data_ptr_y_tmp = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  114. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexTwo];
  115. if (data_ptr_y_tmp - data_ptr_y < 0) {
  116. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  117. return PARAM_INVALID;
  118. }
  119. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  120. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  121. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  122. if (ret_mem != 0) {
  123. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  124. return MEMALLOC_FAILED;
  125. }
  126. }
  127. }
  128. }
  129. return ret;
  130. }
  131. template <typename T>
  132. Status GatherV2Kernel::ProcessAxis3(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  133. Status ret = SUCCESS;
  134. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  135. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  136. // index is valid, and no bigger than kGatherV2InputIndexThree
  137. size_t output_size = output->GetData().size();
  138. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  139. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  140. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  141. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  142. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  143. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  144. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  145. T *data_ptr_x_m = data_ptr_x_j + m * xstride_[kGatherV2InputIndexTwo];
  146. T *data_ptr_y_m = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  147. for (int64_t n = 0; n < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexThree); n++) {
  148. T *data_ptr_x_tmp = data_ptr_x_m + indicates_[n] * xstride_[kGatherV2InputIndexThree];
  149. T *data_ptr_y_tmp = data_ptr_y_m + n * ystride_[kGatherV2InputIndexThree];
  150. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexThree];
  151. if (data_ptr_y_tmp - data_ptr_y < 0) {
  152. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  153. return PARAM_INVALID;
  154. }
  155. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  156. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  157. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  158. if (ret_mem != 0) {
  159. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  160. return MEMALLOC_FAILED;
  161. }
  162. }
  163. }
  164. }
  165. }
  166. return ret;
  167. }
  168. template <typename T>
  169. Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x, int64_t axis, GeTensorPtr output) {
  170. if (data_num <= 0) {
  171. return PARAM_INVALID;
  172. }
  173. if (!CheckInt64MulOverflow(data_num, sizeof(T))) {
  174. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T));
  175. return PARAM_INVALID;
  176. }
  177. std::unique_ptr<T[]> buf(new (std::nothrow) T[data_num]());
  178. if (buf == nullptr) {
  179. GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(T) * data_num));
  180. return MEMALLOC_FAILED;
  181. }
  182. GE_IF_BOOL_EXEC(
  183. output->SetData(reinterpret_cast<uint8_t *>(buf.get()), static_cast<size_t>(data_num * sizeof(T))) != GRAPH_SUCCESS,
  184. GELOGE(INTERNAL_ERROR, "set data failed");
  185. return INTERNAL_ERROR);
  186. Status ret = SUCCESS;
  187. switch (axis) {
  188. case DIM_AXIS_0:
  189. ret = ProcessAxis0<T>(tensor_x, output);
  190. break;
  191. case DIM_AXIS_1:
  192. ret = ProcessAxis1<T>(tensor_x, output);
  193. break;
  194. case DIM_AXIS_2:
  195. ret = ProcessAxis2<T>(tensor_x, output);
  196. break;
  197. case DIM_AXIS_3:
  198. ret = ProcessAxis3<T>(tensor_x, output);
  199. break;
  200. default:
  201. GELOGI("Only support 4 dims and below but input axis is %ld", axis);
  202. return NOT_CHANGED;
  203. break;
  204. }
  205. return ret;
  206. }
  207. Status GatherV2Kernel::CalcStride(std::vector<int64_t> &stride, std::vector<int64_t> dims) {
  208. if (stride.size() != dims.size() || dims.size() == 0) {
  209. return PARAM_INVALID;
  210. }
  211. int i = static_cast<int>(dims.size() - kGatherV2DimOne);
  212. stride[static_cast<size_t>(i)] = static_cast<int64_t>(kGatherV2DimOne);
  213. i--;
  214. while (i >= 0) {
  215. size_t index = static_cast<size_t>(i) + kGatherV2DimOne;
  216. if (!CheckInt64MulOverflow(stride[index], dims[index])) {
  217. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]);
  218. return PARAM_INVALID;
  219. }
  220. stride[static_cast<size_t>(i)] = stride[index] * dims[index];
  221. i--;
  222. }
  223. return SUCCESS;
  224. }
  225. Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPtr input_tensor_ptr,
  226. GeTensorPtr output_ptr) {
  227. Status ret = SUCCESS;
  228. int64_t data_num = output_ptr->GetTensorDesc().GetShape().GetShapeSize();
  229. switch (data_type) {
  230. case DT_FLOAT16:
  231. ret = GenData<fp16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  232. break;
  233. case DT_DOUBLE:
  234. ret = GenData<double>(data_num, input_tensor_ptr, axis, output_ptr);
  235. break;
  236. case DT_INT8:
  237. ret = GenData<int8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  238. break;
  239. case DT_INT16:
  240. ret = GenData<int16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  241. break;
  242. case DT_INT32:
  243. ret = GenData<int32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  244. break;
  245. case DT_INT64:
  246. ret = GenData<int64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  247. break;
  248. case DT_UINT8:
  249. ret = GenData<uint8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  250. break;
  251. case DT_UINT16:
  252. ret = GenData<uint16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  253. break;
  254. case DT_UINT32:
  255. ret = GenData<uint32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  256. break;
  257. case DT_UINT64:
  258. ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  259. break;
  260. default:
  261. GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
  262. return NOT_CHANGED;
  263. break;
  264. }
  265. return ret;
  266. }
  267. Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr, GeShape &x_shape,
  268. GeShape &indices_shape, DataType indices_data_type, size_t axis) {
  269. if (indices_data_type == DT_INT32) {
  270. auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
  271. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  272. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  273. GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
  274. return NOT_CHANGED;
  275. }
  276. indicates_.push_back(*(indices_ptr + i));
  277. }
  278. } else {
  279. // int64
  280. auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
  281. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  282. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  283. GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
  284. return NOT_CHANGED;
  285. }
  286. indicates_.push_back(*(indices_ptr + i));
  287. }
  288. }
  289. return SUCCESS;
  290. }
  291. Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  292. vector<GeTensorPtr> &v_output) const {
  293. if (op_desc_ptr == nullptr) {
  294. GELOGW("input opdesc is nullptr.");
  295. return NOT_CHANGED;
  296. }
  297. if (input.size() != kGatherV2InpotNum) {
  298. GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
  299. return NOT_CHANGED;
  300. }
  301. bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr ||
  302. input[kGatherV2InputIndexTwo] == nullptr);
  303. if (is_null) {
  304. GELOGW("some input is nullptr.");
  305. return NOT_CHANGED;
  306. }
  307. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  308. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  309. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  310. bool size_is_zero =
  311. ((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0));
  312. if (size_is_zero) {
  313. GELOGW("some input size is zero.");
  314. return NOT_CHANGED;
  315. }
  316. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  317. auto axis_shape = tensor2->GetTensorDesc().GetShape();
  318. // axis must be scalar
  319. if (axis_shape.GetDimNum() != 0) {
  320. GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
  321. return NOT_CHANGED;
  322. }
  323. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  324. bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
  325. if (!is_valid_axis_data_type) {
  326. GELOGW("axis datatype must be DT_INT32 or DT_INT64");
  327. return NOT_CHANGED;
  328. }
  329. // check indices data_type && dims && every element
  330. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  331. bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
  332. if (!is_valid_indices_data_type) {
  333. GELOGW("indices datatype must be DT_INT32 or DT_INT64.");
  334. return NOT_CHANGED;
  335. }
  336. if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
  337. GELOGW("indices input only support 0 or 1 dims.");
  338. return NOT_CHANGED;
  339. }
  340. return SUCCESS;
  341. }
  342. void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
  343. const std::vector<int64_t> &y_shape) {
  344. GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(),
  345. indices_shape.GetDimNum(), y_shape.size());
  346. for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
  347. GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i));
  348. }
  349. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  350. GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i));
  351. }
  352. for (size_t i = 0; i < y_shape.size(); i++) {
  353. GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]);
  354. }
  355. for (auto ele : indicates_) {
  356. GELOGD("GatherV2Kernel indices:%ld.", ele);
  357. }
  358. }
  359. Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  360. vector<GeTensorPtr> &v_output) {
  361. GELOGI("Enter GatherV2Kernel Process.");
  362. Status ret = Check(op_desc_ptr, input, v_output);
  363. if (ret != SUCCESS) {
  364. GELOGW("param check failed");
  365. return NOT_CHANGED;
  366. }
  367. GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str());
  368. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  369. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  370. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  371. auto x_shape = tensor0->GetTensorDesc().GetShape();
  372. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  373. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  374. int64_t axis = axis_data_type == DT_INT32
  375. ? *(const_cast<int32_t *>(reinterpret_cast<const int32_t *>(tensor2->GetData().data())))
  376. : *(const_cast<int64_t *>(reinterpret_cast<const int64_t *>(tensor2->GetData().data())));
  377. axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
  378. // check axis value
  379. if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
  380. GELOGW("axis is invalid!");
  381. return NOT_CHANGED;
  382. }
  383. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  384. ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast<size_t>(axis));
  385. if (ret != SUCCESS) {
  386. GELOGW("Save indeices by data type failed!");
  387. return ret;
  388. }
  389. // check input data type
  390. auto x_data_type = tensor0->GetTensorDesc().GetDataType();
  391. if (supported_type.find(x_data_type) == supported_type.end()) {
  392. GELOGI("GatherV2Kernel does not support this Data type:%s.",
  393. TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  394. return NOT_CHANGED;
  395. }
  396. // calc output shape
  397. std::vector<int64_t> y_shape;
  398. for (size_t i = 0; i < static_cast<size_t>(axis); i++) {
  399. y_shape.push_back(x_shape.GetDim(i));
  400. }
  401. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  402. y_shape.push_back(indices_shape.GetDim(i));
  403. }
  404. for (size_t i = static_cast<size_t>(axis) + 1; i < x_shape.GetDimNum(); i++) {
  405. y_shape.push_back(x_shape.GetDim(i));
  406. }
  407. GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
  408. if (output_ptr == nullptr) {
  409. GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
  410. return NOT_CHANGED;
  411. }
  412. output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape));
  413. output_ptr->MutableTensorDesc().SetDataType(x_data_type);
  414. // added for debug
  415. DebugPrint(axis, x_shape, indices_shape, y_shape);
  416. // calc stride
  417. std::vector<int64_t> xstride(x_shape.GetDimNum());
  418. std::vector<int64_t> ystride(y_shape.size());
  419. xstride_ = xstride;
  420. ystride_ = ystride;
  421. auto ret_x = CalcStride(xstride_, x_shape.GetDims());
  422. auto ret_y = CalcStride(ystride_, y_shape);
  423. ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED;
  424. if (ret != SUCCESS) {
  425. GELOGE(ret, "CalcStride Failed");
  426. return ret;
  427. }
  428. ret = Process(axis, x_data_type, tensor0, output_ptr);
  429. if (ret != SUCCESS) {
  430. GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  431. return ret;
  432. }
  433. GELOGI("GatherV2Kernel Process Success.");
  434. v_output.push_back(output_ptr);
  435. return SUCCESS;
  436. }
  437. REGISTER_KERNEL(GATHERV2, GatherV2Kernel);
  438. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示