You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mds_utils.h 6.4 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
  17. #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
  18. #include "graph/ge_context.h"
  19. #include "common/op/ge_op_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "ge/ge_api_types.h"
  24. #include "common/ge/ge_util.h"
  25. #include "graph/compute_graph.h"
  26. #include "graph/shape_refiner.h"
  27. #include "graph/debug/ge_op_types.h"
  28. #include "framework/common/types.h"
  29. #include "graph/utils/op_desc_utils.h"
  30. #include "../pass_utils.h"
  31. #define REQUIRE(cond, ...) \
  32. do { \
  33. if (!(cond)) { \
  34. REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
  35. GELOGE(FAILED, "[MDS]" __VA_ARGS__); \
  36. return FAILED; \
  37. } \
  38. } while (0)
  39. #define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__)
  40. #define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__)
  41. #define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)
  42. namespace ge {
  43. namespace {
  44. // Invalid location index
  45. const int64_t kInvalidIndex = -1;
  46. enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 };
  47. enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 };
  48. // NCHW dim N index
  49. const int32_t kNchwDimIdxN = 0;
  50. // NCHW dim C index
  51. const int32_t kNchwDimIdxC = 1;
  52. // NCHW dim H index
  53. const int32_t kNchwDimIdxH = 2;
  54. // NCHW dim W index
  55. const int32_t kNchwDimIdxW = 3;
  56. // default die number
  57. const uint32_t kDeployNumber = 2;
  58. enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll };
  59. enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 };
  60. const int64_t kDefaultFissionFactor = 1;
  61. const int64_t kDefaultRankSize = 1;
  62. const std::string kDefaultGroup = "hccl_world_group";
  63. const std::string kDefaultReduction = "sum";
  64. const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE";
  65. const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE";
  66. // deploy info
  67. const char *const kAttrNeedReturnResult = "_need_return_result";
  68. const char *const kAttrDeviceType = "_device_type";
  69. const char *const kDieDeviceTypeValue = "MultiMode";
  70. const char *const kAttrDeviceId = "_device_id";
  71. const char *const kAttrGraphName = "_graph_name";
  72. const char *const kAttrGraphInputs = "_graph_inputs";
  73. using GraphInputs = vector<GeTensorPtr>;
  74. using DeviceId = int64_t;
  75. using GraphInputNodes = vector<NodePtr>;
  76. } // namespace
  77. class MdsUtils {
  78. public:
  79. // Parse the configuration file and determine whether to enable MDS based on the value of device_type.
  80. static bool IsMDSNeeded();
  81. static int64_t GetNLocation(Format fmt);
  82. static int64_t GetHLocation(Format fmt);
  83. static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type);
  84. static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type);
  85. static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor,
  86. const std::string &group_name = "");
  87. /// @param [in] index 切分的轴
  88. /// @param [in] deploy_number 切分的份数
  89. static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type,
  90. int64_t deploy_number = kDeployNumber);
  91. // Sets the information, notifies the number of threads to be started during the
  92. // loading phase, the device on which each thread should run, and constructs different input data on each device.
  93. static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node);
  94. static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys,
  95. const std::string &device_type = kDieDeviceTypeValue);
  96. // Get cut policy for whole graph
  97. static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph);
  98. static GraphInputNodes GetInputNodes() {
  99. return input_nodes_;
  100. }
  101. static void AddInputNode(const NodePtr &input_node) {
  102. input_nodes_.push_back(input_node);
  103. }
  104. static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  105. static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
  106. static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node);
  107. private:
  108. static GraphInputNodes input_nodes_;
  109. static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name,
  110. size_t input_num, size_t output_num);
  111. static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
  112. const GeTensorDesc &tensor = GeTensorDesc());
  113. static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
  114. const InDataAnchorPtr &dst, NodePtr &reduce_node);
  115. static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node,
  116. NodePtr &slice_node);
  117. static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node);
  118. static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph);
  119. };
  120. } // namespace ge
  121. #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示