You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hcom_util.h 5.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_
  17. #define GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_
  18. #include <map>
  19. #include <string>
  20. #include <vector>
  21. #include "common/debug/log.h"
  22. #include "common/opskernel/ge_task_info.h"
  23. #include "common/string_util.h"
  24. #include "common/types.h"
  25. #include "common/util.h"
  26. #include "graph/op_desc.h"
  27. #include "hccl/hcom.h"
  28. #include "proto/task.pb.h"
  29. namespace ge {
  30. using std::string;
  31. using std::vector;
  32. static std::map<int64_t, HcclDataType> kConstOpHcclDataType = {
  33. {ge::DT_FLOAT, HCCL_DATA_TYPE_FP32},
  34. {ge::DT_FLOAT16, HCCL_DATA_TYPE_FP16},
  35. {ge::DT_INT8, HCCL_DATA_TYPE_INT8},
  36. {ge::DT_INT32, HCCL_DATA_TYPE_INT32},
  37. {ge::DT_INT64, HCCL_DATA_TYPE_INT64},
  38. {ge::DT_UINT64, HCCL_DATA_TYPE_UINT64},
  39. };
  40. static std::map<HcclDataType, int32_t> kConstOpHcclDataTypeSize = {
  41. {HCCL_DATA_TYPE_FP32, sizeof(float)},
  42. {HCCL_DATA_TYPE_FP16, sizeof(float) / 2},
  43. {HCCL_DATA_TYPE_INT8, sizeof(int8_t)},
  44. {HCCL_DATA_TYPE_INT32, sizeof(int32_t)},
  45. {HCCL_DATA_TYPE_INT64, sizeof(int64_t)},
  46. {HCCL_DATA_TYPE_UINT64, sizeof(uint64_t)},
  47. };
  48. static std::map<HorovodReduceOp, HcclReduceOp> kHorovodRedOpToHcclRedOp = {
  49. {HOROVOD_REDUCE_SUM, HCCL_REDUCE_SUM}, {HOROVOD_REDUCE_MIN, HCCL_REDUCE_MIN},
  50. {HOROVOD_REDUCE_MAX, HCCL_REDUCE_MAX}, {HOROVOD_REDUCE_PROD, HCCL_REDUCE_PROD},
  51. {HOROVOD_REDUCE_RESERVED, HCCL_REDUCE_RESERVED},
  52. };
  53. class HcomOmeUtil {
  54. public:
  55. ///
  56. /// @ingroup domi_ome
  57. /// @brief GetHcclDataType
  58. /// @return SUCCESS
  59. /// @return FAIL
  60. ///
  61. static Status GetHcclDataType(const ge::ConstOpDescPtr &op_desc,
  62. std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  63. ///
  64. /// @ingroup domi_ome
  65. /// @brief GetHcclTypeSize
  66. /// @return SUCCESS
  67. /// @return FAIL
  68. ///
  69. static Status GetHcclTypeSize(HcclDataType data_type, int32_t &size);
  70. ///
  71. /// @ingroup domi_ome
  72. /// @brief GetHcclCount
  73. /// @return SUCCESS
  74. /// @return FAIL
  75. ///
  76. static Status GetHcclCount(const ge::ConstOpDescPtr &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  77. ///
  78. /// @ingroup domi_ome
  79. /// @brief GetHcclOperationType
  80. /// @return SUCCESS
  81. /// @return FAIL
  82. ///
  83. static Status GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, HcclReduceOp &op_type);
  84. ///
  85. /// @ingroup domi_ome
  86. /// @brief GetHcclRootId
  87. /// @return SUCCESS
  88. /// @return FAIL
  89. ///
  90. static Status GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id);
  91. ///
  92. /// @ingroup domi_ome
  93. /// @brief GetAllRootId
  94. /// @return SUCCESS
  95. /// @return FAIL
  96. ///
  97. static Status GetAllRootId(const ge::ConstOpDescPtr &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  98. ///
  99. /// @ingroup domi_ome
  100. /// @brief check the op_type whether is hcom operator or not
  101. /// @return true
  102. /// @return false
  103. ///
  104. static bool IsHCOMOp(const string &op_type);
  105. ///
  106. /// @ingroup domi_ome
  107. /// @brief check the op_type whether is horovod operator or not
  108. /// @return true
  109. /// @return false
  110. ///
  111. static bool IsHorovodOp(const string &op_type);
  112. ///
  113. /// @ingroup domi_ome
  114. /// @brief GetHcclType
  115. /// @return void
  116. ///
  117. static void GetHcclType(const domi::TaskDef &task_def, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  118. ///
  119. /// @ingroup domi_ome
  120. /// @brief CheckKernelHcclInfo
  121. /// @return SUCCESS
  122. /// @return FAIL
  123. ///
  124. static Status CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc,
  125. std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  126. ///
  127. /// @ingroup domi_ome
  128. /// @brief GetHorovodInputs
  129. /// @return SUCCESS
  130. /// @return FAIL
  131. ///
  132. static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc,
  133. std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  134. ///
  135. /// @ingroup domi_ome
  136. /// @brief GetHcomCount
  137. /// @return SUCCESS
  138. /// @return FAIL
  139. ///
  140. static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType data_type, bool is_allgather,
  141. int &count);
  142. private:
  143. ///
  144. /// @ingroup domi_ome
  145. /// @brief GetHorovodCount
  146. /// @return SUCCESS
  147. /// @return FAIL
  148. ///
  149. static Status GetHorovodCount(const ge::ConstOpDescPtr &op_desc,
  150. std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos);
  151. };
  152. } // namespace ge
  153. #endif // GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示