You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_op_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/graph.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kTransOpOutIndex = 0;
  30. std::string GetKey(Format format, DataType type, const std::vector<int64_t> &dims) {
  31. std::stringstream key;
  32. key << static_cast<int>(format) << '-';
  33. key << static_cast<int>(type) << '-';
  34. for (auto dim : dims) {
  35. key << dim << '-';
  36. }
  37. return key.str();
  38. }
  39. Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) {
  40. GE_CHECK_NOTNULL(trans_node);
  41. GE_CHECK_NOTNULL(ref_node);
  42. GELOGD("Begin to bypass trans node %s", trans_node->GetName().c_str());
  43. auto ret = GraphUtils::CopyInCtrlEdges(trans_node, ref_node);
  44. if (ret != GRAPH_SUCCESS) {
  45. GELOGE(INTERNAL_ERROR,
  46. "Failed to move control edges from trans "
  47. "node %s to var-ref %s",
  48. trans_node->GetName().c_str(), ref_node->GetName().c_str());
  49. return INTERNAL_ERROR;
  50. }
  51. auto ref_in_anchor = ref_node->GetInDataAnchor(0);
  52. if (ref_in_anchor == nullptr) {
  53. GELOGE(INTERNAL_ERROR,
  54. "The variable ref node %s does not have an "
  55. "input anchor",
  56. ref_node->GetName().c_str());
  57. return INTERNAL_ERROR;
  58. }
  59. ref_in_anchor->UnlinkAll();
  60. auto trans_in_anchor = trans_node->GetInDataAnchor(0);
  61. if (trans_in_anchor == nullptr) {
  62. GELOGE(INTERNAL_ERROR,
  63. "Failed to get the in data anchor from trans"
  64. " node %s type %s",
  65. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  66. return INTERNAL_ERROR;
  67. }
  68. auto prev_trans_node_out_anchor = trans_in_anchor->GetPeerOutAnchor();
  69. if (prev_trans_node_out_anchor == nullptr) {
  70. GELOGW(
  71. "The trans node %s does not have an input, so the ref node %s does"
  72. " not have any inputs after bypass",
  73. trans_node->GetName().c_str(), trans_node->GetName().c_str());
  74. } else {
  75. ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor);
  76. if (ret != GRAPH_SUCCESS) {
  77. GELOGE(INTERNAL_ERROR,
  78. "Failed to add edge between ref node %s "
  79. "and the prev node of trans node %s",
  80. ref_node->GetName().c_str(), trans_node->GetName().c_str());
  81. return INTERNAL_ERROR;
  82. }
  83. }
  84. return SUCCESS;
  85. }
  86. bool IsTransSupport(const TransNodeInfo &trans_info) {
  87. if (trans_info.output.GetShape().IsUnknownShape()) {
  88. return false;
  89. }
  90. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  91. return true;
  92. } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  93. formats::TransArgs args{nullptr,
  94. trans_info.input.GetFormat(),
  95. trans_info.output.GetFormat(),
  96. trans_info.input.GetShape().GetDims(),
  97. trans_info.output.GetShape().GetDims(),
  98. trans_info.input.GetDataType()};
  99. return formats::IsTransFormatSupport(args);
  100. } else if (trans_info.node_type == CAST) {
  101. formats::CastArgs datatype_args{nullptr, static_cast<size_t>(trans_info.input.GetShape().GetShapeSize()),
  102. trans_info.input.GetDataType(), trans_info.output.GetDataType()};
  103. return formats::IsTransDataTypeSupport(datatype_args);
  104. } else {
  105. return false;
  106. }
  107. }
  108. } // namespace
  109. Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
  110. if (graph == nullptr) {
  111. GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph");
  112. return INTERNAL_ERROR;
  113. }
  114. auto graph_id = GraphUtils::FindRootGraph(graph)->GetGraphID();
  115. GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
  116. GetContext().SessionId(), graph_id);
  117. if (var_accelerate_ctrl_ == nullptr) {
  118. GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null");
  119. return INTERNAL_ERROR;
  120. }
  121. GELOGD("Begin to generate ref map for variable and refs, graph name:%s.", graph->GetName().c_str());
  122. if (RenewVarDesc(graph) != SUCCESS) {
  123. GELOGE(INTERNAL_ERROR, "Failed to renew var desc on graph");
  124. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  125. }
  126. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  127. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  128. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  129. }
  130. GELOGD("Begin to fusion variables and trans nodes");
  131. for (auto &var_to_refs : var_and_var_ref_map_) {
  132. auto &node = var_to_refs.first;
  133. GE_CHECK_NOTNULL(node);
  134. GE_CHECK_NOTNULL(var_accelerate_ctrl_);
  135. if (!var_accelerate_ctrl_->IsVarPermitToChangeFormats(node->GetName())) {
  136. GELOGD("The var %s does not permit to change formats, skip it", node->GetName().c_str());
  137. continue;
  138. }
  139. VarTransRoad fusion_road;
  140. auto ret = FusionIfNeed(node, fusion_road);
  141. if (ret != SUCCESS) {
  142. return ret;
  143. }
  144. if (fusion_road.empty()) {
  145. GELOGD("No need to fusion variable and trans op for var %s", node->GetName().c_str());
  146. continue;
  147. }
  148. auto start_iter = fusion_road.begin();
  149. auto end_iter = fusion_road.rbegin();
  150. GELOGD(
  151. "Trans variable data for %s from format %s to %s, shape %s to %s "
  152. "data-type %s to %s, path len %zu success",
  153. node->GetName().c_str(), TypeUtils::FormatToSerialString(start_iter->input.GetFormat()).c_str(),
  154. TypeUtils::FormatToSerialString(end_iter->output.GetFormat()).c_str(),
  155. formats::ShapeToString(start_iter->input.GetShape().GetDims()).c_str(),
  156. formats::ShapeToString(end_iter->output.GetShape().GetDims()).c_str(),
  157. TypeUtils::DataTypeToSerialString(start_iter->input.GetDataType()).c_str(),
  158. TypeUtils::DataTypeToSerialString(end_iter->output.GetDataType()).c_str(), fusion_road.size());
  159. ret = VarManager::Instance(graph->GetSessionID())->SetTransRoad(node->GetName(), fusion_road);
  160. if (ret != SUCCESS) {
  161. GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str());
  162. return INTERNAL_ERROR;
  163. }
  164. ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph_id);
  165. if (ret != SUCCESS) {
  166. GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str());
  167. return INTERNAL_ERROR;
  168. }
  169. var_accelerate_ctrl_->SetVarChanged(node->GetName());
  170. GELOGD("Begin to update format info for var %s.", node->GetName().c_str());
  171. std::set<ge::NodePtr> node_set({node});
  172. if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) {
  173. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  174. }
  175. // renew var desc if the trans_road is all reshape or reformat
  176. ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road);
  177. if (ret != SUCCESS) {
  178. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  179. return FAILED;
  180. }
  181. }
  182. return SUCCESS;
  183. }
  184. Status VariableOpPass::DealFusion(const ge::NodePtr &var_node) {
  185. GE_CHECK_NOTNULL(var_node);
  186. GELOGD("Begin to fusion var %s with trans", var_node->GetName().c_str());
  187. auto graph = var_node->GetOwnerComputeGraph();
  188. for (auto &trans_node : var_node->GetOutDataNodes()) {
  189. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  190. trans_node->GetType().c_str(), var_node->GetName().c_str());
  191. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  192. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  193. }
  194. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  195. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  196. }
  197. }
  198. auto iterator = var_and_var_ref_map_.find(var_node);
  199. if (iterator == var_and_var_ref_map_.end()) {
  200. GELOGD("there is no var_ref of node %s", var_node->GetName().c_str());
  201. return SUCCESS;
  202. }
  203. for (auto ref_node : iterator->second) {
  204. GE_CHECK_NOTNULL(ref_node);
  205. for (auto &trans_node : ref_node->GetInDataNodes()) {
  206. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  207. trans_node->GetType().c_str(), var_node->GetName().c_str());
  208. if (trans_node->GetOutDataNodes().size() > 1) {
  209. GELOGD(
  210. "The trans node %s type %s connecting with var-ref %s has more"
  211. " than one output data nodes, unlink the edge between them",
  212. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  213. if (ByPassTransNode(trans_node, ref_node) != SUCCESS) {
  214. GELOGE(INTERNAL_ERROR, "Failed to bypass trans node %s to ref %s", trans_node->GetName().c_str(),
  215. ref_node->GetName().c_str());
  216. return INTERNAL_ERROR;
  217. }
  218. } else {
  219. GELOGD(
  220. "The trans node %s type %s connecting with var-ref %s has only"
  221. " one output data nodes, isolate and remove it.",
  222. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  223. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  224. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  225. }
  226. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  227. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  228. }
  229. }
  230. }
  231. }
  232. return SUCCESS;
  233. }
  234. Status VariableOpPass::CheckSameAndTransOp(const ge::NodePtr &var_node, bool &is_matched, VarTransRoad &fusion_road) {
  235. std::set<std::string> data_type_and_formats;
  236. std::string trans_op_type;
  237. ge::NodePtr out_node;
  238. ge::GeTensorDesc output_desc;
  239. GE_CHECK_NOTNULL(var_node);
  240. for (auto &out_node_and_anchor : var_node->GetOutDataNodesAndAnchors()) {
  241. auto in_anchor = out_node_and_anchor.second;
  242. GE_CHECK_NOTNULL(in_anchor);
  243. out_node = out_node_and_anchor.first;
  244. GE_CHECK_NOTNULL(out_node);
  245. auto trans_op_desc = out_node->GetOpDesc();
  246. GE_CHECK_NOTNULL(trans_op_desc);
  247. trans_op_type = trans_op_desc->GetType();
  248. GELOGD("current node type is %s.", trans_op_type.c_str());
  249. int data_index = TransOpUtil::GetTransOpDataIndex(trans_op_type);
  250. if (data_index < 0) {
  251. GELOGD("Variables only can be fusion with trans_op, the next op is %s type %s", out_node->GetName().c_str(),
  252. out_node->GetType().c_str());
  253. return SUCCESS;
  254. }
  255. if (data_index != in_anchor->GetIdx()) {
  256. GELOGD(
  257. "Variables only can be fusion with trans nodes, the next node %s"
  258. " type %s index %d does not trans anything(correct index %d)",
  259. out_node->GetName().c_str(), out_node->GetType().c_str(), in_anchor->GetIdx(), data_index);
  260. return SUCCESS;
  261. }
  262. output_desc = trans_op_desc->GetOutputDesc(kTransOpOutIndex);
  263. auto trans_op_format = output_desc.GetFormat();
  264. auto trans_op_data_type = output_desc.GetDataType();
  265. auto shape = output_desc.GetShape().GetDims();
  266. auto datatype_and_format = GetKey(trans_op_format, trans_op_data_type, shape);
  267. data_type_and_formats.insert(datatype_and_format);
  268. }
  269. if (data_type_and_formats.empty()) {
  270. return SUCCESS;
  271. }
  272. if (data_type_and_formats.size() > 1) {
  273. std::stringstream type_and_formats_stream;
  274. bool first_time = true;
  275. for (const auto &data_type_and_format : data_type_and_formats) {
  276. if (first_time) {
  277. first_time = false;
  278. } else {
  279. type_and_formats_stream << "|";
  280. }
  281. type_and_formats_stream << data_type_and_format;
  282. }
  283. GELOGW(
  284. "trans_op type size for var Node(%s) is over 1, Currently not"
  285. " supported, dataTypeAndFormats is %s.",
  286. var_node->GetName().c_str(), type_and_formats_stream.str().c_str());
  287. return SUCCESS;
  288. }
  289. int tran_in_index = TransOpUtil::GetTransOpDataIndex(out_node->GetType());
  290. auto out_op_desc = out_node->GetOpDesc();
  291. GE_CHECK_NOTNULL(out_op_desc);
  292. TransNodeInfo trans_node_info;
  293. trans_node_info.node_type = out_node->GetType();
  294. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  295. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  296. if (!IsTransSupport(trans_node_info)) {
  297. GELOGD("The trans node %s does not support, skip the variable accelerating", trans_node_info.node_type.c_str());
  298. return SUCCESS;
  299. }
  300. is_matched = true;
  301. fusion_road.emplace_back(trans_node_info);
  302. return SUCCESS;
  303. }
  304. Status VariableOpPass::CheckVariableRefLegally(const ge::NodePtr &var_node, bool &is_var_ref_legally) {
  305. is_var_ref_legally = true;
  306. GE_CHECK_NOTNULL(var_node);
  307. auto iterator = var_and_var_ref_map_.find(var_node);
  308. if (iterator == var_and_var_ref_map_.end()) {
  309. GELOGD("var name %s are not in var var_ref map", var_node->GetName().c_str());
  310. return SUCCESS;
  311. }
  312. GELOGD("var name %s, ref var count %zu.", var_node->GetName().c_str(), iterator->second.size());
  313. for (const auto &var_ref_node : iterator->second) {
  314. if (CheckVarAndVarRefAreAlike(var_node, var_ref_node, is_var_ref_legally) != SUCCESS) {
  315. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  316. }
  317. GELOGD("is_var_ref_legally is %d", is_var_ref_legally);
  318. if (!is_var_ref_legally) {
  319. return SUCCESS;
  320. }
  321. }
  322. return SUCCESS;
  323. }
  324. Status VariableOpPass::UpdateVarAndRefOutputFormatInfo(const GeTensorDesc &final_output, const ge::NodePtr &node) {
  325. if (node == nullptr || node->GetOpDesc() == nullptr) {
  326. GELOGE(FAILED, "node or opdesc is nullptr");
  327. return FAILED;
  328. }
  329. const Format &format = final_output.GetFormat();
  330. const DataType &data_type = final_output.GetDataType();
  331. const GeShape &shape = final_output.GetShape();
  332. GELOGD("last ref is (%s, %s, %lu), var_ref_name is %s.", TypeUtils::DataTypeToSerialString(data_type).c_str(),
  333. TypeUtils::FormatToSerialString(format).c_str(), shape.GetDims().size(), node->GetName().c_str());
  334. auto node_desc = node->GetOpDesc()->GetOutputDesc(0);
  335. CopyVariableFormatDataTypeAndShape(final_output, node_desc);
  336. if (node->GetOpDesc()->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  337. GELOGE(FAILED, "update output desc fail.");
  338. return FAILED;
  339. }
  340. GELOGD("node ref is (%s, %s, %lu), var_ref_name is %s.",
  341. TypeUtils::DataTypeToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetDataType()).c_str(),
  342. TypeUtils::FormatToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetFormat()).c_str(),
  343. node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims().size(), node->GetName().c_str());
  344. auto iterator = var_and_var_ref_map_.find(node);
  345. if (iterator == var_and_var_ref_map_.end()) {
  346. auto graph = node->GetOwnerComputeGraph();
  347. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  348. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  349. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  350. }
  351. }
  352. iterator = var_and_var_ref_map_.find(node);
  353. if (iterator == var_and_var_ref_map_.end()) {
  354. GELOGW("The var node %s which belongs to graph %s can not be found on the graph", node->GetName().c_str(),
  355. node->GetOwnerComputeGraph()->GetName().c_str());
  356. return SUCCESS;
  357. }
  358. for (const auto &var_ref_node : iterator->second) {
  359. auto var_ref_node_description = var_ref_node->GetOpDesc();
  360. GE_CHECK_NOTNULL(var_ref_node_description);
  361. GELOGD("var_ref_node before is (%s, %s, %zu), var_ref_name is %s.",
  362. TypeUtils::DataTypeToSerialString(data_type).c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  363. shape.GetDims().size(), var_ref_node->GetName().c_str());
  364. if (var_ref_node_description->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  365. GELOGW("UpdateOutputDesc fail.");
  366. }
  367. if (var_ref_node_description->UpdateInputDesc(0, node_desc) != GRAPH_SUCCESS) {
  368. GELOGW("UpdateInputDesc fail.");
  369. }
  370. const auto &input_desc = var_ref_node_description->MutableInputDesc(0);
  371. const auto &output_desc = var_ref_node_description->MutableOutputDesc(0);
  372. GE_CHECK_NOTNULL(input_desc);
  373. GE_CHECK_NOTNULL(output_desc);
  374. GELOGD("var_ref_node ref is (%s, %s, %zu), var_ref_name is %s.",
  375. TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(),
  376. TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), output_desc->GetShape().GetDims().size(),
  377. var_ref_node->GetName().c_str());
  378. }
  379. return SUCCESS;
  380. }
  381. Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &compute_graph) {
  382. std::map<std::string, NodePtr> names_to_var;
  383. std::map<std::string, std::set<NodePtr>> names_to_refs;
  384. GE_CHECK_NOTNULL(compute_graph);
  385. for (auto &node : compute_graph->GetDirectNode()) {
  386. if (node->GetType() != VARIABLE) {
  387. continue;
  388. }
  389. std::string ref_var_name;
  390. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) {
  391. names_to_var[node->GetName()] = node;
  392. } else {
  393. names_to_refs[ref_var_name].insert(node);
  394. }
  395. }
  396. for (auto &name_to_var : names_to_var) {
  397. var_and_var_ref_map_[name_to_var.second] = names_to_refs[name_to_var.first];
  398. }
  399. return SUCCESS;
  400. }
  401. Status VariableOpPass::CheckVarAndVarRefAreAlike(const NodePtr &var_node, const NodePtr &var_ref_node,
  402. bool &is_var_and_variable_ref_are_alike) {
  403. GE_CHECK_NOTNULL(var_node);
  404. GE_CHECK_NOTNULL(var_ref_node);
  405. GELOGD("var_node GetOutDataNodes. name is %s.", var_node->GetName().c_str());
  406. const auto &var_node_trans_nodes = var_node->GetOutDataNodes();
  407. GELOGD("var_node_trans_nodes size is %zu.", var_node_trans_nodes.size());
  408. GELOGD("var_ref_node GetOutDataNodes. name is %s.", var_ref_node->GetName().c_str());
  409. const auto &var_ref_node_trans_nodes = var_ref_node->GetInDataNodes();
  410. GELOGD("var_ref_node_trans_nodes size is %zu.", var_ref_node_trans_nodes.size());
  411. if (var_ref_node_trans_nodes.size() > 1) {
  412. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "var_ref_node_trans_nodes.size() > 1.");
  413. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  414. }
  415. const auto &var_node_trans_node = var_node_trans_nodes.at(0);
  416. const auto &var_ref_node_trans_node = var_ref_node_trans_nodes.at(0);
  417. if (CheckTransNodeAreInverse(var_node_trans_node, var_ref_node_trans_node, is_var_and_variable_ref_are_alike) !=
  418. SUCCESS) {
  419. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  420. }
  421. return SUCCESS;
  422. }
  423. Status VariableOpPass::CheckTransNodeAreInverse(const NodePtr &node_a, const NodePtr &node_b, bool &is_same) {
  424. GELOGD("In CheckTransNodeAreInverse.");
  425. GE_CHECK_NOTNULL(node_a);
  426. GE_CHECK_NOTNULL(node_b);
  427. const auto &node_a_op_desc = node_a->GetOpDesc();
  428. const auto &node_b_op_desc = node_b->GetOpDesc();
  429. GE_CHECK_NOTNULL(node_a_op_desc);
  430. GE_CHECK_NOTNULL(node_b_op_desc);
  431. const auto &node_a_out_op_desc = node_a_op_desc->MutableOutputDesc(0);
  432. const auto &node_a_in_op_desc = node_a_op_desc->MutableInputDesc(0);
  433. GE_CHECK_NOTNULL(node_a_out_op_desc);
  434. GE_CHECK_NOTNULL(node_a_in_op_desc);
  435. const auto &node_b_out_op_desc = node_b_op_desc->MutableOutputDesc(0);
  436. const auto &node_b_in_op_desc = node_b_op_desc->MutableInputDesc(0);
  437. GE_CHECK_NOTNULL(node_b_out_op_desc);
  438. GE_CHECK_NOTNULL(node_b_in_op_desc);
  439. is_same = IsOpDescSame(node_a_out_op_desc, node_b_in_op_desc) && IsOpDescSame(node_b_out_op_desc, node_a_in_op_desc);
  440. return SUCCESS;
  441. }
  442. bool VariableOpPass::IsOpDescSame(const GeTensorDescPtr &op_desc_a, const GeTensorDescPtr &op_desc_b) {
  443. const auto &format_a = op_desc_a->GetFormat();
  444. const auto &type_a = op_desc_a->GetDataType();
  445. const auto &shape_a = op_desc_a->GetShape();
  446. const auto &format_b = op_desc_b->GetFormat();
  447. const auto &type_b = op_desc_b->GetDataType();
  448. const auto &shape_b = op_desc_b->GetShape();
  449. const auto &dims_a = shape_a.GetDims();
  450. const auto &dims_b = shape_b.GetDims();
  451. GELOGD("(format, data type, shape) = (%s, %s, %zu) (%s, %s, %zu)", TypeUtils::FormatToSerialString(format_a).c_str(),
  452. TypeUtils::DataTypeToSerialString(type_a).c_str(), dims_a.size(),
  453. TypeUtils::FormatToSerialString(format_b).c_str(), TypeUtils::DataTypeToSerialString(type_b).c_str(),
  454. dims_b.size());
  455. return (format_a == format_b) && (type_a == type_b) && (dims_a == dims_b);
  456. }
  457. void VariableOpPass::CopyVariableFormatDataTypeAndShape(const GeTensorDesc &src_tensor_desc,
  458. GeTensorDesc &dst_tensor_desc) {
  459. dst_tensor_desc.SetShape(src_tensor_desc.GetShape());
  460. dst_tensor_desc.SetFormat(src_tensor_desc.GetFormat());
  461. dst_tensor_desc.SetDataType(src_tensor_desc.GetDataType());
  462. }
  463. Status VariableOpPass::CheckIfCouldBeOptimized(const ge::NodePtr &node, bool &flag, VarTransRoad &fusion_road) {
  464. if (node == nullptr) {
  465. return FAILED;
  466. }
  467. bool is_matched = false;
  468. auto ret = CheckSameAndTransOp(node, is_matched, fusion_road);
  469. if (ret != SUCCESS) {
  470. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  471. }
  472. if (!is_matched) {
  473. flag = false;
  474. return SUCCESS;
  475. }
  476. bool is_var_ref_legally = false;
  477. ret = CheckVariableRefLegally(node, is_var_ref_legally);
  478. if (ret != SUCCESS) {
  479. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  480. }
  481. GELOGD("is_var_ref_legally is %d.", is_var_ref_legally);
  482. if (!is_var_ref_legally) {
  483. GELOGI("variable ref connection are illegally");
  484. flag = false;
  485. fusion_road.clear();
  486. return SUCCESS;
  487. }
  488. flag = true;
  489. GELOGD("node %s, is_matched = %d is_var_ref_legally = %d, flag = %d", node->GetName().c_str(), is_matched,
  490. is_var_ref_legally, flag);
  491. return SUCCESS;
  492. }
  493. Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_road) {
  494. bool can_fusion = false;
  495. while (true) {
  496. auto ret = CheckIfCouldBeOptimized(var, can_fusion, fusion_road);
  497. if (ret != SUCCESS) {
  498. return ret;
  499. }
  500. if (!can_fusion) {
  501. break;
  502. }
  503. ret = DealFusion(var);
  504. if (ret != SUCCESS) {
  505. return ret;
  506. }
  507. }
  508. return SUCCESS;
  509. }
  510. Status VariableOpPass::UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set<NodePtr> &nodes) {
  511. for (auto &need_set_node : nodes) {
  512. auto ret = UpdateVarAndRefOutputFormatInfo(final_output, need_set_node);
  513. if (ret != SUCCESS) {
  514. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  515. }
  516. }
  517. return SUCCESS;
  518. }
  519. Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) {
  520. GE_CHECK_NOTNULL(graph);
  521. // renew var manager desc
  522. Status ret = SUCCESS;
  523. for (auto &node : graph->GetDirectNode()) {
  524. bool is_var_node =
  525. (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP);
  526. if (is_var_node) {
  527. if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) {
  528. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  529. continue;
  530. }
  531. GELOGD("var manager exist var node[%s], graph name[%s]", node->GetName().c_str(), graph->GetName().c_str());
  532. GE_CHECK_NOTNULL(node->GetOpDesc());
  533. ret = ge::VarManager::Instance(graph->GetSessionID())->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  534. if (ret != SUCCESS) {
  535. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  536. return FAILED;
  537. }
  538. }
  539. }
  540. return SUCCESS;
  541. }
  542. Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) {
  543. // renew var desc if the trans_road is all reshape or reformat
  544. for (auto &road : fusion_road) {
  545. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  546. return SUCCESS;
  547. }
  548. }
  549. if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) {
  550. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  551. return SUCCESS;
  552. }
  553. GELOGD("var manager exist var node[%s]", node->GetName().c_str());
  554. GE_CHECK_NOTNULL(node->GetOpDesc());
  555. Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  556. if (ret != SUCCESS) {
  557. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  558. return FAILED;
  559. }
  560. return SUCCESS;
  561. }
  562. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示