You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mds_utils.cc 25 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "./mds_utils.h"
  17. namespace ge {
  18. namespace {
  19. // for count
  20. thread_local int64_t data_slice_count = 0;
  21. thread_local int64_t data_gather_count = 0;
  22. thread_local int64_t data_reduce_count = 0;
  23. const std::string kPrefix = "mds";
  24. } // namespace
  25. int64_t MdsUtils::GetNLocation(Format fmt) {
  26. int64_t loc = kNInvalidLocation;
  27. switch (fmt) {
  28. case FORMAT_NCHW:
  29. case FORMAT_NHWC:
  30. loc = kNLocation0;
  31. break;
  32. case FORMAT_CHWN:
  33. case FORMAT_HWCN:
  34. loc = kNLocation3;
  35. break;
  36. default:
  37. GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
  38. }
  39. return loc;
  40. }
  41. int64_t MdsUtils::GetHLocation(Format fmt) {
  42. int64_t loc = kHInvalidLocation;
  43. switch (fmt) {
  44. case FORMAT_HWCN:
  45. loc = kHLocation0;
  46. break;
  47. case FORMAT_NHWC:
  48. case FORMAT_CHWN:
  49. loc = kHLocation1;
  50. break;
  51. case FORMAT_NCHW:
  52. loc = kHLocation2;
  53. default:
  54. GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
  55. }
  56. return loc;
  57. }
  58. int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) {
  59. Format fmt = ge_tensor_desc->GetFormat();
  60. switch (type) {
  61. case kCutN:
  62. return GetNLocation(fmt);
  63. case kCutH:
  64. return GetHLocation(fmt);
  65. default:;
  66. }
  67. GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
  68. return kInvalidIndex;
  69. }
  70. bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type) {
  71. if (ge_tensor_desc == nullptr) {
  72. REPORT_INNER_ERROR("E19999", "invalid input param: tensor is null!");
  73. GELOGE(FAILED, "[MDS]invalid input param: tensor is null!");
  74. return false;
  75. }
  76. if (type != kCutN && type != kCutH) {
  77. REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type);
  78. GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
  79. return false;
  80. }
  81. int64_t cut_index = GetIndexByFormat(ge_tensor_desc, type);
  82. if (cut_index == kInvalidIndex) {
  83. REPORT_INNER_ERROR("E19999", "invalid index param:%ld", cut_index);
  84. GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index);
  85. return false;
  86. }
  87. auto dims = ge_tensor_desc->GetShape().GetDims();
  88. if (cut_index < 0 || cut_index >= dims.size()) {
  89. REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
  90. dims.size());
  91. GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
  92. dims.size());
  93. return false;
  94. }
  95. if (dims[cut_index] % kDeployNumber != 0) {
  96. GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]);
  97. return false;
  98. }
  99. vector<int64_t> cut_support_info;
  100. if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) {
  101. REPORT_INNER_ERROR("E19999", "call GetlistInt failed");
  102. GELOGE(FAILED, "[MDS]", "call GetlistInt failed");
  103. return false;
  104. }
  105. if (cut_index < 0 || cut_index >= cut_support_info.size()) {
  106. REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
  107. type, cut_support_info.size());
  108. GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
  109. type, cut_support_info.size());
  110. return false;
  111. }
  112. if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) {
  113. REPORT_INNER_ERROR("E19999", "invalid cut info value:%ld", cut_support_info[cut_index]);
  114. GELOGE(FAILED, "[MDS]", "invalid cut info value:%ld", cut_support_info[cut_index]);
  115. return false;
  116. }
  117. return cut_support_info[cut_index] & kSplitCutSupported;
  118. }
  119. Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) {
  120. GE_CHECK_NOTNULL(ge_tensor_desc);
  121. auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type);
  122. auto dims = ge_tensor_desc->GetShape().GetDims();
  123. REQUIRE(index < dims.size(), "[DistributedDeploy] failed, index %ld should less than %zu", index, dims.size());
  124. auto dim_after_deploy = dims[index] / deploy_number;
  125. MDS_REQUIRE_SUCCESS(ge_tensor_desc->MutableShape().SetDim(index, dim_after_deploy),
  126. "[DistributedDeploy] update shape failed");
  127. return SUCCESS;
  128. }
  129. Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) {
  130. GE_CHECK_NOTNULL(hcom_op);
  131. REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor,
  132. kDefaultFissionFactor);
  133. REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor),
  134. "Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(),
  135. hcom_op->GetType().c_str());
  136. if (!group_name.empty()) {
  137. REQUIRE(ge::AttrUtils::SetStr(hcom_op, HCOM_ATTR_GROUP, group_name), "Failed to set attr group %s for op:%s(%s)",
  138. group_name.c_str(), hcom_op->GetName().c_str(), hcom_op->GetType().c_str());
  139. }
  140. return SUCCESS;
  141. }
  142. bool MdsUtils::IsMDSNeeded() {
  143. std::string device_type;
  144. if (ge::GetContext().GetOption(ge::OPTION_DEVICE_TYPE, device_type) && device_type == kDefaultDeviceType) {
  145. GELOGI("[MDS]device type is %s, skip mds", device_type.c_str());
  146. return false;
  147. }
  148. // TODO: Parse the configuration file of the system to get the sys_config_exe_unit
  149. std::string sys_config_exe_unit = "DIE";
  150. return device_type != sys_config_exe_unit;
  151. }
  152. Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node) {
  153. GE_CHECK_NOTNULL(compute_graph);
  154. GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str());
  155. // build deploy info
  156. vector<GeAttrValue::NAMED_ATTRS> deploy_info;
  157. GE_CHECK_NOTNULL(input_node);
  158. for (int64_t j = 0; j < kDeployNumber; j++) {
  159. int64_t device_id = j;
  160. GeAttrValue::LIST_TENSOR graph_inputs;
  161. GeTensorPtr graph_input = MakeShared<GeTensor>(input_node->GetOpDesc()->GetOutputDesc(0));
  162. vector<uint8_t> data{static_cast<uint8_t>(device_id)};
  163. graph_input->SetData(data);
  164. // For now, only one graph_input
  165. graph_inputs.push_back(graph_input);
  166. GeAttrValue::NAMED_ATTRS thread_instance;
  167. thread_instance.SetName(std::to_string(device_id));
  168. (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
  169. // TODO:Change to enumeration from RTS header file
  170. (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode"));
  171. (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
  172. (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs));
  173. deploy_info.emplace_back(thread_instance);
  174. GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
  175. }
  176. // set deploy info
  177. REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
  178. "Set attr failed for graph %s", compute_graph->GetName().c_str());
  179. return SUCCESS;
  180. }
  181. CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) {
  182. bool is_unknown_graph = false;
  183. if (GraphUtils::IsUnknownShapeGraph(compute_graph)) {
  184. GELOGI("Graph %s is unknown shape graph", compute_graph->GetName().c_str());
  185. is_unknown_graph = true;
  186. }
  187. CutType selected_cut_type = kNoCut;
  188. for (const auto &data : compute_graph->GetInputNodes()) {
  189. GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str());
  190. auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN);
  191. auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index);
  192. auto data_h_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutH);
  193. auto data_h_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_h_index);
  194. if (data_n_dim == -1 && data_h_dim == -1) {
  195. selected_cut_type = kDynamicCutAll;
  196. break;
  197. }
  198. if (data_n_dim % kDeployNumber == 0) {
  199. is_unknown_graph ? selected_cut_type = kDynamicCutN : selected_cut_type = kCutN;
  200. break;
  201. }
  202. if (data_h_dim % kDeployNumber == 0) {
  203. is_unknown_graph ? selected_cut_type = kDynamicCutH : selected_cut_type = kCutH;
  204. }
  205. }
  206. return selected_cut_type;
  207. }
  208. Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph,
  209. const std::multimap<DeviceId, GraphInputs> &deploys, const std::string &device_type) {
  210. GE_CHECK_NOTNULL(compute_graph);
  211. GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str());
  212. // build deploy info
  213. vector<GeAttrValue::NAMED_ATTRS> deploy_info;
  214. for (const auto &pair : deploys) {
  215. int64_t device_id = pair.first;
  216. GeAttrValue::NAMED_ATTRS thread_instance;
  217. thread_instance.SetName(std::to_string(device_id));
  218. (void)thread_instance.SetAttr(kAttrNeedReturnResult,
  219. GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false));
  220. (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
  221. (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type));
  222. (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
  223. (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second));
  224. deploy_info.emplace_back(thread_instance);
  225. GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
  226. }
  227. // set deploy info
  228. REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
  229. "Set attr failed for graph %s", compute_graph->GetName().c_str());
  230. return SUCCESS;
  231. }
  232. Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) {
  233. auto src_node = src->GetOwnerNode();
  234. GE_CHECK_NOTNULL(src_node);
  235. auto dst_node = dst->GetOwnerNode();
  236. GE_CHECK_NOTNULL(dst_node);
  237. auto src_graph = src_node->GetOwnerComputeGraph();
  238. GE_CHECK_NOTNULL(src_graph);
  239. std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count));
  240. auto hcom_allgather_node =
  241. AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1);
  242. GE_CHECK_NOTNULL(hcom_allgather_node);
  243. MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node),
  244. "[DataGather] failed between %s and %s", src_node->GetName().c_str(),
  245. dst_node->GetName().c_str());
  246. MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
  247. "[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(),
  248. hcom_allgather_node->GetType().c_str());
  249. REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize),
  250. "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
  251. hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str());
  252. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false),
  253. "[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str());
  254. data_gather_count++;
  255. return SUCCESS;
  256. }
  257. // gradients->ApplyMomentum
  258. // we want to reduce gradients on different device(die), so graph topo changed to
  259. // gradients->hcomallreducemean->ApplyMomentum; Because 'mean' is not currently supported by hcomallreduce,
  260. // topo will end up like gradients->hcomallreducesum->div->ApplyMomentum
  261. Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) {
  262. auto src_node = src->GetOwnerNode();
  263. GE_CHECK_NOTNULL(src_node);
  264. auto dst_node = dst->GetOwnerNode();
  265. GE_CHECK_NOTNULL(dst_node);
  266. auto src_graph = src_node->GetOwnerComputeGraph();
  267. GE_CHECK_NOTNULL(src_graph);
  268. NodePtr all_reduce_node = nullptr;
  269. if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) {
  270. MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node),
  271. "[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str());
  272. GE_CHECK_NOTNULL(all_reduce_node);
  273. } else {
  274. GE_CHECK_NOTNULL(all_reduce_node);
  275. MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(all_reduce_node->GetOpDesc(), kDeployNumber),
  276. "[DataReduce][Modify] set attr for allreduce node for %s failed",
  277. all_reduce_node->GetName().c_str());
  278. }
  279. return SUCCESS;
  280. }
  281. // tensor t with shape like [n,c,h,w], we want get [0:2/n, c, h, w] and [2/n : n, c, h, w] on different
  282. // device; To achieve this goal, we use slice nodes.
  283. // slice(t, [i * n/2, 0, 0, 0], [n/2, c, h, w]) i=0,1
  284. // slice three input like : t->slice; data(0,1)->mul(n/2)->pack[i*n/2,0,0,0]->slice; const(n,c,h,w)->slice
  285. Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node) {
  286. auto src_node = src->GetOwnerNode();
  287. GE_CHECK_NOTNULL(src_node);
  288. auto dst_node = dst->GetOwnerNode();
  289. GE_CHECK_NOTNULL(dst_node);
  290. auto src_graph = src_node->GetOwnerComputeGraph();
  291. GE_CHECK_NOTNULL(src_graph);
  292. if (input_node == nullptr) {
  293. std::string input_node_name = std::string(DATA) + "_" + kPrefix + "_" + std::to_string(0);
  294. input_node = AddSingleInputOutputNode(src_graph, input_node_name, DATA);
  295. AddInputNode(input_node);
  296. }
  297. GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx());
  298. NodePtr slice_node = nullptr;
  299. MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node),
  300. "[DataSlice] construct slice node for %s failed", src_node->GetName().c_str());
  301. GE_CHECK_NOTNULL(slice_node);
  302. MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s",
  303. src_node->GetName().c_str(), dst_node->GetName().c_str());
  304. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed",
  305. slice_node->GetName().c_str());
  306. return SUCCESS;
  307. }
  308. Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node,
  309. NodePtr &slice_node) {
  310. vector<int64_t> slice_sizes = tensor.GetShape().GetDims();
  311. // TODO: Express with graph structure
  312. slice_sizes[0] /= kDeployNumber;
  313. vector<GeTensorPtr> ge_tensors;
  314. GeTensorDesc ge_tensor_desc;
  315. ge_tensor_desc.SetDataType(DT_INT64);
  316. MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors),
  317. "[ConstructTensorDescWithData] failed");
  318. GeTensorPtr slice_size_tensor = ge_tensors[0];
  319. auto const_node_slice_size = AddConstNodeToGraph(slice_size_tensor, src_graph);
  320. vector<int64_t> slice_offset_other_dim{0};
  321. ge_tensors.clear();
  322. MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_other_dim, ge_tensors, true),
  323. "[ConstructTensorDescWithData] failed");
  324. GeTensorPtr slice_offset_tensor = ge_tensors[0];
  325. auto const_node_slice_offset = AddConstNodeToGraph(slice_offset_tensor, src_graph);
  326. vector<int64_t> slice_offset_first_dim{slice_sizes[0]};
  327. ge_tensors.clear();
  328. MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_first_dim, ge_tensors, true),
  329. "[ConstructTensorDescWithData] failed");
  330. GeTensorPtr slice_offset_first_dim_tensor = ge_tensors[0];
  331. auto const_node_slice_offset_first_dim = AddConstNodeToGraph(slice_offset_first_dim_tensor, src_graph);
  332. std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_slice_count));
  333. NodePtr mul_node = AddDynamicInputOutputNode(src_graph, MUL, MUL + node_name_suffix, 2, 1);
  334. GE_CHECK_NOTNULL(input_node);
  335. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)),
  336. "[ConstructSliceNode] add edge failed");
  337. MDS_REQUIRE_SUCCESS(
  338. GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)),
  339. "[ConstructSliceNode] add edge failed");
  340. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed",
  341. mul_node->GetName().c_str());
  342. NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1);
  343. bool is_first_input = true;
  344. for (const auto &in_anchor : pack_node->GetAllInDataAnchors()) {
  345. if (is_first_input) {
  346. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), in_anchor),
  347. "[ConstructSliceNode] add edge failed");
  348. is_first_input = false;
  349. } else {
  350. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset->GetOutDataAnchor(0), in_anchor),
  351. "[ConstructSliceNode] add edge failed");
  352. }
  353. }
  354. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed",
  355. pack_node->GetName().c_str());
  356. slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1);
  357. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)),
  358. "[ConstructSliceNode] add edge failed");
  359. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_size->GetOutDataAnchor(0), slice_node->GetInDataAnchor(2)),
  360. "[ConstructSliceNode] add edge failed");
  361. ++data_slice_count;
  362. return SUCCESS;
  363. }
  364. NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
  365. const GeTensorDesc &tensor) {
  366. GELOGI("Begin to create op: %s", name.c_str());
  367. OpDescBuilder op_desc_builder(name, type);
  368. OpDescPtr op_desc = op_desc_builder.AddInput("x", tensor).AddOutput("y", tensor).Build();
  369. if (op_desc == nullptr) {
  370. REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", name.c_str(), type.c_str());
  371. GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", name.c_str(), type.c_str());
  372. return nullptr;
  373. }
  374. NodePtr node = graph->AddNode(op_desc);
  375. if (node == nullptr) {
  376. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
  377. op_desc->GetType().c_str(), graph->GetName().c_str());
  378. GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  379. graph->GetName().c_str());
  380. return nullptr;
  381. }
  382. return node;
  383. }
  384. NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type,
  385. const std::string &node_name, size_t input_num, size_t output_num) {
  386. GELOGI("Begin to create op: %s", node_name.c_str());
  387. OpDescBuilder op_desc_builder(node_name, type);
  388. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build();
  389. if (op_desc == nullptr) {
  390. REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", node_name.c_str(), type.c_str());
  391. GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", node_name.c_str(), type.c_str());
  392. return nullptr;
  393. }
  394. NodePtr node = graph->AddNode(op_desc);
  395. if (node == nullptr) {
  396. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
  397. op_desc->GetType().c_str(), graph->GetName().c_str());
  398. GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  399. graph->GetName().c_str());
  400. return nullptr;
  401. }
  402. return node;
  403. }
  404. NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph) {
  405. auto const_desc = OpDescUtils::CreateConstOp(tensor);
  406. if (const_desc == nullptr) {
  407. REPORT_CALL_ERROR("E19999", "Create Const op failed");
  408. GELOGE(OUT_OF_MEMORY, "[Create][ConstOp] failed");
  409. return nullptr;
  410. }
  411. if (graph == nullptr) {
  412. GELOGW("input param graph is null");
  413. return nullptr;
  414. }
  415. return graph->AddNodeFront(const_desc);
  416. }
  417. Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
  418. const InDataAnchorPtr &dst, NodePtr &reduce_node) {
  419. std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count));
  420. reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1);
  421. MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node),
  422. "[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(),
  423. src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
  424. MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
  425. "[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str());
  426. REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction),
  427. "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
  428. reduce_node->GetName().c_str(), reduce_node->GetType().c_str());
  429. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed",
  430. reduce_node->GetName().c_str());
  431. auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1);
  432. vector<int64_t> slice_sizes{kDeployNumber};
  433. vector<GeTensorPtr> ge_tensors;
  434. GeTensorDesc ge_tensor_desc;
  435. ge_tensor_desc.SetDataType(DT_INT64);
  436. MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors),
  437. "[ConstructReduceNode] failed");
  438. REQUIRE(!ge_tensors.empty(), "[ConstructReduceNode] failed");
  439. auto const_node_div_input = AddConstNodeToGraph(ge_tensors[0], src_graph);
  440. MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)),
  441. "[ConstructSliceNode] add edge failed");
  442. MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node),
  443. "[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(),
  444. reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
  445. MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed",
  446. div_node->GetName().c_str());
  447. return SUCCESS;
  448. }
  449. bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) {
  450. // TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node,
  451. // so there is no need to insert i
  452. return true;
  453. }
  454. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示