/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ #include "graph/ge_context.h" #include "common/op/ge_op_utils.h" #include "graph/utils/type_utils.h" #include "graph/utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" #include "ge/ge_api_types.h" #include "common/ge/ge_util.h" #include "graph/compute_graph.h" #include "graph/shape_refiner.h" #include "graph/debug/ge_op_types.h" #include "framework/common/types.h" #include "graph/utils/op_desc_utils.h" #include "../pass_utils.h" #define REQUIRE(cond, ...) \ do { \ if (!(cond)) { \ REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ return FAILED; \ } \ } while (0) #define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) #define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) #define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) namespace ge { namespace { // Invalid location index const int64_t kInvalidIndex = -1; enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 }; enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 }; // NCHW dim N index const int32_t kNchwDimIdxN = 0; // NCHW dim C index const int32_t kNchwDimIdxC = 1; // NCHW dim H index const int32_t kNchwDimIdxH = 2; // NCHW dim W index const int32_t kNchwDimIdxW = 3; // default die number const uint32_t kDeployNumber = 2; enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll }; enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 }; const int64_t kDefaultFissionFactor = 1; const int64_t kDefaultRankSize = 1; const std::string kDefaultGroup = "hccl_world_group"; const std::string kDefaultReduction = "sum"; const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE"; const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE"; // deploy info const char *const kAttrNeedReturnResult = "_need_return_result"; const char *const kAttrDeviceType = "_device_type"; const char *const kDieDeviceTypeValue = "MultiMode"; const char *const kAttrDeviceId = "_device_id"; const char *const kAttrGraphName = "_graph_name"; const char *const kAttrGraphInputs = "_graph_inputs"; using GraphInputs = vector; using DeviceId = int64_t; using GraphInputNodes = vector; } // namespace class MdsUtils { public: // Parse the configuration file and determine whether to enable MDS based on the value of device_type. static bool IsMDSNeeded(); static int64_t GetNLocation(Format fmt); static int64_t GetHLocation(Format fmt); static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type); static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type); static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name = ""); /// @param [in] index 切分的轴 /// @param [in] deploy_number 切分的份数 static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number = kDeployNumber); // Sets the information, notifies the number of threads to be started during the // loading phase, the device on which each thread should run, and constructs different input data on each device. static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node); static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap &deploys, const std::string &device_type = kDieDeviceTypeValue); // Get cut policy for whole graph static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph); static GraphInputNodes GetInputNodes() { return input_nodes_; } static void AddInputNode(const NodePtr &input_node) { input_nodes_.push_back(input_node); } static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node); private: static GraphInputNodes input_nodes_; static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name, size_t input_num, size_t output_num); static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, const GeTensorDesc &tensor = GeTensorDesc()); static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &reduce_node); static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node, NodePtr &slice_node); static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node); static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph); }; } // namespace ge #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_