|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
- #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
-
- #include "graph/ge_context.h"
- #include "common/op/ge_op_utils.h"
- #include "graph/utils/type_utils.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/debug/ge_attr_define.h"
- #include "ge/ge_api_types.h"
- #include "common/ge/ge_util.h"
- #include "graph/compute_graph.h"
- #include "graph/shape_refiner.h"
- #include "graph/debug/ge_op_types.h"
- #include "framework/common/types.h"
- #include "graph/utils/op_desc_utils.h"
- #include "../pass_utils.h"
-
- #define REQUIRE(cond, ...) \
- do { \
- if (!(cond)) { \
- REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
- GELOGE(FAILED, "[MDS]" __VA_ARGS__); \
- return FAILED; \
- } \
- } while (0)
-
- #define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__)
- #define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__)
- #define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)
- namespace ge {
- namespace {
- // Invalid location index
- const int64_t kInvalidIndex = -1;
- enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 };
- enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 };
-
- // NCHW dim N index
- const int32_t kNchwDimIdxN = 0;
- // NCHW dim C index
- const int32_t kNchwDimIdxC = 1;
- // NCHW dim H index
- const int32_t kNchwDimIdxH = 2;
- // NCHW dim W index
- const int32_t kNchwDimIdxW = 3;
- // default die number
- const uint32_t kDeployNumber = 2;
- enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll };
- enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 };
-
- const int64_t kDefaultFissionFactor = 1;
- const int64_t kDefaultRankSize = 1;
- const std::string kDefaultGroup = "hccl_world_group";
- const std::string kDefaultReduction = "sum";
- const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE";
- const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE";
-
- // deploy info
- const char *const kAttrNeedReturnResult = "_need_return_result";
- const char *const kAttrDeviceType = "_device_type";
- const char *const kDieDeviceTypeValue = "MultiMode";
- const char *const kAttrDeviceId = "_device_id";
- const char *const kAttrGraphName = "_graph_name";
- const char *const kAttrGraphInputs = "_graph_inputs";
- using GraphInputs = vector<GeTensorPtr>;
- using DeviceId = int64_t;
- using GraphInputNodes = vector<NodePtr>;
- } // namespace
- class MdsUtils {
- public:
- // Parse the configuration file and determine whether to enable MDS based on the value of device_type.
- static bool IsMDSNeeded();
- static int64_t GetNLocation(Format fmt);
- static int64_t GetHLocation(Format fmt);
-
- static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type);
- static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type);
- static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor,
- const std::string &group_name = "");
- /// @param [in] index 切分的轴
- /// @param [in] deploy_number 切分的份数
- static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type,
- int64_t deploy_number = kDeployNumber);
- // Sets the information, notifies the number of threads to be started during the
- // loading phase, the device on which each thread should run, and constructs different input data on each device.
- static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node);
- static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys,
- const std::string &device_type = kDieDeviceTypeValue);
- // Get cut policy for whole graph
- static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph);
- static GraphInputNodes GetInputNodes() {
- return input_nodes_;
- }
- static void AddInputNode(const NodePtr &input_node) {
- input_nodes_.push_back(input_node);
- }
-
- static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
- static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
- static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node);
-
- private:
- static GraphInputNodes input_nodes_;
- static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name,
- size_t input_num, size_t output_num);
- static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
- const GeTensorDesc &tensor = GeTensorDesc());
- static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
- const InDataAnchorPtr &dst, NodePtr &reduce_node);
- static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node,
- NodePtr &slice_node);
- static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node);
- static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph);
- };
- } // namespace ge
- #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
|