You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

axis_util.h 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file axis_util.h
  18. * \brief get the axis value
  19. */
  20. #ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
  21. #define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
  22. #include <memory.h>
  23. #include <functional>
  24. #include <vector>
  25. #include "external/graph/ge_error_codes.h"
  26. #include "external/graph/types.h"
  27. #include "framework/common/debug/ge_log.h"
  28. namespace common {
  29. namespace transformer {
  30. const int32_t DIM_DEFAULT_SIZE = 4;
  31. const uint32_t NCHW_DIMENSION_NUM = 4;
  32. const int32_t AXIS_NCHW_DIM_N = 0;
  33. const int32_t AXIS_NCHW_DIM_C = 1;
  34. const int32_t AXIS_NCHW_DIM_H = 2;
  35. const int32_t AXIS_NCHW_DIM_W = 3;
  36. const int32_t AXIS_NHWC_DIM_N = 0;
  37. const int32_t AXIS_NHWC_DIM_H = 1;
  38. const int32_t AXIS_NHWC_DIM_W = 2;
  39. const int32_t AXIS_NHWC_DIM_C = 3;
  40. const int32_t AXIS_NC1HWC0_DIM_N = 0;
  41. const int32_t AXIS_NC1HWC0_DIM_C1 = 1;
  42. const int32_t AXIS_NC1HWC0_DIM_C0 = 4;
  43. const int32_t AXIS_NC1HWC0_DIM_H = 2;
  44. const int32_t AXIS_NC1HWC0_DIM_W = 3;
  45. const int32_t AXIS_HWCN_DIM_H = 0;
  46. const int32_t AXIS_HWCN_DIM_W = 1;
  47. const int32_t AXIS_HWCN_DIM_C = 2;
  48. const int32_t AXIS_HWCN_DIM_N = 3;
  49. const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0;
  50. const int32_t AXIS_C1HWNCoC0_DIM_H = 1;
  51. const int32_t AXIS_C1HWNCoC0_DIM_W = 2;
  52. const int32_t AXIS_C1HWNCoC0_DIM_N = 3;
  53. const int32_t AXIS_C1HWNCoC0_DIM_Co = 4;
  54. const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5;
  55. const int32_t NDHWC_DIM_N = 0;
  56. const int32_t NDHWC_DIM_D = 1;
  57. const int32_t NDHWC_DIM_H = 2;
  58. const int32_t NDHWC_DIM_W = 3;
  59. const int32_t NDHWC_DIM_C = 4;
  60. const int32_t NCDHW_DIM_N = 0;
  61. const int32_t NCDHW_DIM_C = 1;
  62. const int32_t NCDHW_DIM_D = 2;
  63. const int32_t NCDHW_DIM_H = 3;
  64. const int32_t NCDHW_DIM_W = 4;
  65. const int32_t DHWCN_DIM_D = 0;
  66. const int32_t DHWCN_DIM_H = 1;
  67. const int32_t DHWCN_DIM_W = 2;
  68. const int32_t DHWCN_DIM_C = 3;
  69. const int32_t DHWCN_DIM_N = 4;
  70. const int32_t DHWNC_DIM_D = 0;
  71. const int32_t DHWNC_DIM_H = 1;
  72. const int32_t DHWNC_DIM_W = 2;
  73. const int32_t DHWNC_DIM_N = 3;
  74. const int32_t DHWNC_DIM_C = 4;
  75. #define CHECK_NOTNULL(val) \
  76. do { \
  77. if ((val) == nullptr) { \
  78. GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \
  79. return false; \
  80. } \
  81. } while (0)
  82. #define CHECK(cond, log_func, return_expr) \
  83. do { \
  84. if (cond) { \
  85. log_func; \
  86. return_expr; \
  87. } \
  88. } while (0)
  89. enum AxisValueType {
  90. AXIS_N = 0,
  91. AXIS_C = 1,
  92. AXIS_H = 2,
  93. AXIS_W = 3,
  94. AXIS_C1 = 4,
  95. AXIS_C0 = 5,
  96. AXIS_Co = 6,
  97. AXIS_D = 7,
  98. AXIS_BOTTOM = 8
  99. };
  100. int64_t DivisionCeiling(int64_t dividend, int64_t divisor);
  101. /* Axis value is arranged as {N,C,H,W,C1,C0,...} */
  102. /* The first parameter is old shape's dimension,
  103. * second is c0 and third is axis value. */
  104. using GetAxisValueInfoByFormat =
  105. std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>;
  106. using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>;
  107. class AxisUtil {
  108. public:
  109. AxisUtil();
  110. ~AxisUtil(){};
  111. bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0,
  112. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  113. bool HasAxisValueFunc(const ge::Format& format);
  114. private:
  115. static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  116. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  117. static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  118. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  119. static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  120. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  121. static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  122. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  123. static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  124. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  125. static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  126. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  127. static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  128. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  129. static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
  130. std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
  131. static bool GetAxisValueByNDHWC(const std::vector<int64_t>& original_dim_vec, const uint32_t& c0,
  132. std::vector<int64_t>& axis_value, std::vector<int64_t>& nd_value);
  133. static bool GetAxisValueByNCDHW(const std::vector<int64_t>& original_dim_vec, const uint32_t& c0,
  134. std::vector<int64_t>& axis_value, std::vector<int64_t>& nd_value);
  135. static bool GetAxisValueByDHWCN(const std::vector<int64_t>& original_dim_vec, const uint32_t& c0,
  136. std::vector<int64_t>& axis_value, std::vector<int64_t>& nd_value);
  137. static bool GetAxisValueByDHWNC(const std::vector<int64_t>& original_dim_vec, const uint32_t& c0,
  138. std::vector<int64_t>& axis_value, std::vector<int64_t>& nd_value);
  139. /* map of GetAxisValueInfoByFormat, get axis value by different original
  140. * formats. */
  141. std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap;
  142. };
  143. } // namespace transformer
  144. } // namespace common
  145. #endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示