You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass.cc 30 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_op_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/graph.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kTransOpOutIndex = 0;
  30. std::string GetKey(Format format, DataType type, const std::vector<int64_t> &dims) {
  31. std::stringstream key;
  32. key << static_cast<int>(format) << '-';
  33. key << static_cast<int>(type) << '-';
  34. for (auto dim : dims) {
  35. key << dim << '-';
  36. }
  37. return key.str();
  38. }
  39. Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) {
  40. GE_CHECK_NOTNULL(trans_node);
  41. GE_CHECK_NOTNULL(ref_node);
  42. GELOGD("Begin to bypass trans node %s", trans_node->GetName().c_str());
  43. auto ret = GraphUtils::CopyInCtrlEdges(trans_node, ref_node);
  44. if (ret != GRAPH_SUCCESS) {
  45. REPORT_CALL_ERROR("E19999", "Copy in control edge from node:%s(%s) to node:%s(%s) failed",
  46. trans_node->GetName().c_str(), trans_node->GetType().c_str(),
  47. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  48. GELOGE(INTERNAL_ERROR, "[Copy][InCtrlEdges] from node:%s(%s) to node:%s(%s) failed",
  49. trans_node->GetName().c_str(), trans_node->GetType().c_str(),
  50. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  51. return INTERNAL_ERROR;
  52. }
  53. auto ref_in_anchor = ref_node->GetInDataAnchor(0);
  54. if (ref_in_anchor == nullptr) {
  55. REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no input anchor, check invalid",
  56. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  57. GELOGE(INTERNAL_ERROR, "[Get][InDataAnchor] failed, The variable ref node %s does not have an input anchor",
  58. ref_node->GetName().c_str());
  59. return INTERNAL_ERROR;
  60. }
  61. ref_in_anchor->UnlinkAll();
  62. auto trans_in_anchor = trans_node->GetInDataAnchor(0);
  63. if (trans_in_anchor == nullptr) {
  64. REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no input anchor, check invalid",
  65. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  66. GELOGE(INTERNAL_ERROR, "[Get][InDataAnchor] failed, Node:%s(%s) has no input anchor",
  67. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  68. return INTERNAL_ERROR;
  69. }
  70. auto prev_trans_node_out_anchor = trans_in_anchor->GetPeerOutAnchor();
  71. if (prev_trans_node_out_anchor == nullptr) {
  72. GELOGW(
  73. "The trans node %s does not have an input, so the ref node %s does"
  74. " not have any inputs after bypass",
  75. trans_node->GetName().c_str(), trans_node->GetName().c_str());
  76. } else {
  77. ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor);
  78. if (ret != GRAPH_SUCCESS) {
  79. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  80. prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(),
  81. prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(),
  82. prev_trans_node_out_anchor->GetIdx(),
  83. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  84. GELOGE(INTERNAL_ERROR, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  85. prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(),
  86. prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(),
  87. prev_trans_node_out_anchor->GetIdx(), ref_node->GetName().c_str(), ref_node->GetType().c_str());
  88. return INTERNAL_ERROR;
  89. }
  90. }
  91. return SUCCESS;
  92. }
  93. bool IsTransSupport(const TransNodeInfo &trans_info) {
  94. if (trans_info.output.GetShape().IsUnknownShape()) {
  95. return false;
  96. }
  97. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  98. return true;
  99. } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  100. formats::TransArgs args{nullptr,
  101. trans_info.input.GetFormat(),
  102. trans_info.output.GetFormat(),
  103. trans_info.input.GetShape().GetDims(),
  104. trans_info.output.GetShape().GetDims(),
  105. trans_info.input.GetDataType()};
  106. return formats::IsTransFormatSupport(args);
  107. } else if (trans_info.node_type == CAST) {
  108. formats::CastArgs datatype_args{nullptr, static_cast<size_t>(trans_info.input.GetShape().GetShapeSize()),
  109. trans_info.input.GetDataType(), trans_info.output.GetDataType()};
  110. return formats::IsTransDataTypeSupport(datatype_args);
  111. } else {
  112. return false;
  113. }
  114. }
  115. } // namespace
  116. Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
  117. if (graph == nullptr) {
  118. REPORT_INNER_ERROR("E19999", "Param graph is nullptr, check invalid");
  119. GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to run variable op pass, null graph");
  120. return INTERNAL_ERROR;
  121. }
  122. auto graph_id = GraphUtils::FindRootGraph(graph)->GetGraphID();
  123. GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
  124. GetContext().SessionId(), graph_id);
  125. if (var_accelerate_ctrl_ == nullptr) {
  126. REPORT_INNER_ERROR("E19999", "The variable accelerate control is nullptr, check invalid");
  127. GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to run var op pass, the variable accelerate control is null");
  128. return INTERNAL_ERROR;
  129. }
  130. GELOGD("Begin to generate ref map for variable and refs, graph name:%s.", graph->GetName().c_str());
  131. if (RenewVarDesc(graph) != SUCCESS) {
  132. GELOGE(INTERNAL_ERROR, "[Renew][VarDesc] on graph:%s failed", graph->GetName().c_str());
  133. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  134. }
  135. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  136. GELOGE(INTERNAL_ERROR, "[Generate][VariableMap] for graph:%s failed", graph->GetName().c_str());
  137. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  138. }
  139. GELOGD("Begin to fusion variables and trans nodes");
  140. for (auto &var_to_refs : var_and_var_ref_map_) {
  141. auto &node = var_to_refs.first;
  142. GE_CHECK_NOTNULL(node);
  143. GE_CHECK_NOTNULL(var_accelerate_ctrl_);
  144. if (!var_accelerate_ctrl_->IsVarPermitToChangeFormats(node->GetName())) {
  145. GELOGD("The var %s does not permit to change formats, skip it", node->GetName().c_str());
  146. continue;
  147. }
  148. VarTransRoad fusion_road;
  149. auto ret = FusionIfNeed(node, fusion_road);
  150. if (ret != SUCCESS) {
  151. GELOGE(FAILED, "[Call][FusionIfNeed] for node:%s failed", node->GetName().c_str());
  152. return ret;
  153. }
  154. if (fusion_road.empty()) {
  155. GELOGD("No need to fusion variable and trans op for var %s", node->GetName().c_str());
  156. continue;
  157. }
  158. auto start_iter = fusion_road.begin();
  159. auto end_iter = fusion_road.rbegin();
  160. GELOGD(
  161. "Trans variable data for %s from format %s to %s, shape %s to %s "
  162. "data-type %s to %s, path len %zu success",
  163. node->GetName().c_str(), TypeUtils::FormatToSerialString(start_iter->input.GetFormat()).c_str(),
  164. TypeUtils::FormatToSerialString(end_iter->output.GetFormat()).c_str(),
  165. formats::ShapeToString(start_iter->input.GetShape().GetDims()).c_str(),
  166. formats::ShapeToString(end_iter->output.GetShape().GetDims()).c_str(),
  167. TypeUtils::DataTypeToSerialString(start_iter->input.GetDataType()).c_str(),
  168. TypeUtils::DataTypeToSerialString(end_iter->output.GetDataType()).c_str(), fusion_road.size());
  169. ret = VarManager::Instance(graph->GetSessionID())->SetTransRoad(node->GetName(), fusion_road);
  170. if (ret != SUCCESS) {
  171. REPORT_CALL_ERROR("E19999", "Set Trans road for node:%s(%s) failed, session_id:%lu",
  172. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  173. GELOGE(INTERNAL_ERROR, "[Set][TransRoad] for node:%s(%s) failed, session_id:%lu",
  174. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  175. return INTERNAL_ERROR;
  176. }
  177. ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph_id);
  178. if (ret != SUCCESS) {
  179. REPORT_CALL_ERROR("E19999", "Update graph_id:%u for node:%s(%s) failed, session_id:%lu",
  180. graph_id, node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  181. GELOGE(INTERNAL_ERROR, "[Update][GraphId] %u for node:%s(%s) failed, session_id:%lu",
  182. graph_id, node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  183. return INTERNAL_ERROR;
  184. }
  185. var_accelerate_ctrl_->SetVarChanged(node->GetName());
  186. GELOGD("Begin to update format info for var %s.", node->GetName().c_str());
  187. std::set<ge::NodePtr> node_set({node});
  188. if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) {
  189. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  190. }
  191. // renew var desc if the trans_road is all reshape or reformat
  192. ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road);
  193. if (ret != SUCCESS) {
  194. GELOGE(FAILED, "[Renew][VarDesc] for var[%s] failed!", node->GetName().c_str());
  195. return FAILED;
  196. }
  197. }
  198. return SUCCESS;
  199. }
  200. Status VariableOpPass::DealFusion(const ge::NodePtr &var_node) {
  201. GE_CHECK_NOTNULL(var_node);
  202. GELOGD("Begin to fusion var %s with trans", var_node->GetName().c_str());
  203. auto graph = var_node->GetOwnerComputeGraph();
  204. for (auto &trans_node : var_node->GetOutDataNodes()) {
  205. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  206. trans_node->GetType().c_str(), var_node->GetName().c_str());
  207. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  208. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) failed",
  209. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  210. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "[Isolate][Node] %s(%s) failed",
  211. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  212. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  213. }
  214. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  215. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  216. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  217. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "[Remove][Node] %s(%s) without relink in graph:%s failed",
  218. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  219. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  220. }
  221. }
  222. auto iterator = var_and_var_ref_map_.find(var_node);
  223. if (iterator == var_and_var_ref_map_.end()) {
  224. GELOGD("there is no var_ref of node %s", var_node->GetName().c_str());
  225. return SUCCESS;
  226. }
  227. for (auto ref_node : iterator->second) {
  228. GE_CHECK_NOTNULL(ref_node);
  229. for (auto &trans_node : ref_node->GetInDataNodes()) {
  230. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  231. trans_node->GetType().c_str(), var_node->GetName().c_str());
  232. if (trans_node->GetOutDataNodes().size() > 1) {
  233. GELOGD(
  234. "The trans node %s type %s connecting with var-ref %s has more"
  235. " than one output data nodes, unlink the edge between them",
  236. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  237. if (ByPassTransNode(trans_node, ref_node) != SUCCESS) {
  238. GELOGE(INTERNAL_ERROR, "[ByPass][TransNode] %s to ref %s failed", trans_node->GetName().c_str(),
  239. ref_node->GetName().c_str());
  240. return INTERNAL_ERROR;
  241. }
  242. } else {
  243. GELOGD(
  244. "The trans node %s type %s connecting with var-ref %s has only"
  245. " one output data nodes, isolate and remove it.",
  246. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  247. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  248. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) failed",
  249. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  250. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "[Isolate][Node] %s(%s) failed",
  251. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  252. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  253. }
  254. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  255. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  256. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  257. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "[Remove][Node] %s(%s) without relink in graph:%s failed",
  258. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  259. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  260. }
  261. }
  262. }
  263. }
  264. return SUCCESS;
  265. }
  266. Status VariableOpPass::CheckSameAndTransOp(const ge::NodePtr &var_node, bool &is_matched, VarTransRoad &fusion_road) {
  267. std::set<std::string> data_type_and_formats;
  268. std::string trans_op_type;
  269. ge::NodePtr out_node;
  270. ge::GeTensorDesc output_desc;
  271. GE_CHECK_NOTNULL(var_node);
  272. for (auto &out_node_and_anchor : var_node->GetOutDataNodesAndAnchors()) {
  273. auto in_anchor = out_node_and_anchor.second;
  274. GE_CHECK_NOTNULL(in_anchor);
  275. out_node = out_node_and_anchor.first;
  276. GE_CHECK_NOTNULL(out_node);
  277. auto trans_op_desc = out_node->GetOpDesc();
  278. GE_CHECK_NOTNULL(trans_op_desc);
  279. trans_op_type = trans_op_desc->GetType();
  280. GELOGD("current node type is %s.", trans_op_type.c_str());
  281. int data_index = TransOpUtil::GetTransOpDataIndex(trans_op_type);
  282. if (data_index < 0) {
  283. GELOGD("Variables only can be fusion with trans_op, the next op is %s type %s", out_node->GetName().c_str(),
  284. out_node->GetType().c_str());
  285. return SUCCESS;
  286. }
  287. if (data_index != in_anchor->GetIdx()) {
  288. GELOGD(
  289. "Variables only can be fusion with trans nodes, the next node %s"
  290. " type %s index %d does not trans anything(correct index %d)",
  291. out_node->GetName().c_str(), out_node->GetType().c_str(), in_anchor->GetIdx(), data_index);
  292. return SUCCESS;
  293. }
  294. output_desc = trans_op_desc->GetOutputDesc(kTransOpOutIndex);
  295. auto trans_op_format = output_desc.GetFormat();
  296. auto trans_op_data_type = output_desc.GetDataType();
  297. auto shape = output_desc.GetShape().GetDims();
  298. auto datatype_and_format = GetKey(trans_op_format, trans_op_data_type, shape);
  299. data_type_and_formats.insert(datatype_and_format);
  300. }
  301. if (data_type_and_formats.empty()) {
  302. return SUCCESS;
  303. }
  304. if (data_type_and_formats.size() > 1) {
  305. std::stringstream type_and_formats_stream;
  306. bool first_time = true;
  307. for (const auto &data_type_and_format : data_type_and_formats) {
  308. if (first_time) {
  309. first_time = false;
  310. } else {
  311. type_and_formats_stream << "|";
  312. }
  313. type_and_formats_stream << data_type_and_format;
  314. }
  315. GELOGW(
  316. "trans_op type size for var Node(%s) is over 1, Currently not"
  317. " supported, dataTypeAndFormats is %s.",
  318. var_node->GetName().c_str(), type_and_formats_stream.str().c_str());
  319. return SUCCESS;
  320. }
  321. int tran_in_index = TransOpUtil::GetTransOpDataIndex(out_node->GetType());
  322. auto out_op_desc = out_node->GetOpDesc();
  323. GE_CHECK_NOTNULL(out_op_desc);
  324. TransNodeInfo trans_node_info;
  325. trans_node_info.node_type = out_node->GetType();
  326. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  327. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  328. if (!IsTransSupport(trans_node_info)) {
  329. GELOGD("The trans node %s does not support, skip the variable accelerating", trans_node_info.node_type.c_str());
  330. return SUCCESS;
  331. }
  332. is_matched = true;
  333. fusion_road.emplace_back(trans_node_info);
  334. return SUCCESS;
  335. }
  336. Status VariableOpPass::CheckVariableRefLegally(const ge::NodePtr &var_node, bool &is_var_ref_legally) {
  337. is_var_ref_legally = true;
  338. GE_CHECK_NOTNULL(var_node);
  339. auto iterator = var_and_var_ref_map_.find(var_node);
  340. if (iterator == var_and_var_ref_map_.end()) {
  341. GELOGD("var name %s are not in var var_ref map", var_node->GetName().c_str());
  342. return SUCCESS;
  343. }
  344. GELOGD("var name %s, ref var count %zu.", var_node->GetName().c_str(), iterator->second.size());
  345. for (const auto &var_ref_node : iterator->second) {
  346. if (CheckVarAndVarRefAreAlike(var_node, var_ref_node, is_var_ref_legally) != SUCCESS) {
  347. GELOGE(FAILED, "[Call][CheckVarAndVarRefAreAlike] for node:%s failed", var_node->GetName().c_str());
  348. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  349. }
  350. GELOGD("is_var_ref_legally is %d", is_var_ref_legally);
  351. if (!is_var_ref_legally) {
  352. return SUCCESS;
  353. }
  354. }
  355. return SUCCESS;
  356. }
  357. Status VariableOpPass::UpdateVarAndRefOutputFormatInfo(const GeTensorDesc &final_output, const ge::NodePtr &node) {
  358. if (node == nullptr || node->GetOpDesc() == nullptr) {
  359. REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid");
  360. GELOGE(FAILED, "[Check][Param] node or its opdesc is nullptr");
  361. return FAILED;
  362. }
  363. const Format &format = final_output.GetFormat();
  364. const DataType &data_type = final_output.GetDataType();
  365. const GeShape &shape = final_output.GetShape();
  366. GELOGD("last ref is (%s, %s, %lu), var_ref_name is %s.", TypeUtils::DataTypeToSerialString(data_type).c_str(),
  367. TypeUtils::FormatToSerialString(format).c_str(), shape.GetDims().size(), node->GetName().c_str());
  368. auto node_desc = node->GetOpDesc()->GetOutputDesc(0);
  369. CopyVariableFormatDataTypeAndShape(final_output, node_desc);
  370. if (node->GetOpDesc()->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  371. REPORT_CALL_ERROR("E19999", "Update ouput:0 desc in op:%s(%s) failed",
  372. node->GetName().c_str(), node->GetType().c_str());
  373. GELOGE(FAILED, "[Update][OutputDesc] in op:%s(%s) failed, index:0",
  374. node->GetName().c_str(), node->GetType().c_str());
  375. return FAILED;
  376. }
  377. GELOGD("node ref is (%s, %s, %lu), var_ref_name is %s.",
  378. TypeUtils::DataTypeToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetDataType()).c_str(),
  379. TypeUtils::FormatToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetFormat()).c_str(),
  380. node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims().size(), node->GetName().c_str());
  381. auto iterator = var_and_var_ref_map_.find(node);
  382. if (iterator == var_and_var_ref_map_.end()) {
  383. auto graph = node->GetOwnerComputeGraph();
  384. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  385. GELOGE(INTERNAL_ERROR, "[Generate][VariableMap] for graph:%s failed", graph->GetName().c_str());
  386. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  387. }
  388. }
  389. iterator = var_and_var_ref_map_.find(node);
  390. if (iterator == var_and_var_ref_map_.end()) {
  391. GELOGW("The var node %s which belongs to graph %s can not be found on the graph", node->GetName().c_str(),
  392. node->GetOwnerComputeGraph()->GetName().c_str());
  393. return SUCCESS;
  394. }
  395. for (const auto &var_ref_node : iterator->second) {
  396. auto var_ref_node_description = var_ref_node->GetOpDesc();
  397. GE_CHECK_NOTNULL(var_ref_node_description);
  398. GELOGD("var_ref_node before is (%s, %s, %zu), var_ref_name is %s.",
  399. TypeUtils::DataTypeToSerialString(data_type).c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  400. shape.GetDims().size(), var_ref_node->GetName().c_str());
  401. if (var_ref_node_description->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  402. GELOGW("UpdateOutputDesc fail.");
  403. }
  404. if (var_ref_node_description->UpdateInputDesc(0, node_desc) != GRAPH_SUCCESS) {
  405. GELOGW("UpdateInputDesc fail.");
  406. }
  407. const auto &input_desc = var_ref_node_description->MutableInputDesc(0);
  408. const auto &output_desc = var_ref_node_description->MutableOutputDesc(0);
  409. GE_CHECK_NOTNULL(input_desc);
  410. GE_CHECK_NOTNULL(output_desc);
  411. GELOGD("var_ref_node ref is (%s, %s, %zu), var_ref_name is %s.",
  412. TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(),
  413. TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), output_desc->GetShape().GetDims().size(),
  414. var_ref_node->GetName().c_str());
  415. }
  416. return SUCCESS;
  417. }
  418. Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &compute_graph) {
  419. std::map<std::string, NodePtr> names_to_var;
  420. std::map<std::string, std::set<NodePtr>> names_to_refs;
  421. GE_CHECK_NOTNULL(compute_graph);
  422. for (auto &node : compute_graph->GetDirectNode()) {
  423. if (node->GetType() != VARIABLE) {
  424. continue;
  425. }
  426. std::string ref_var_name;
  427. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) {
  428. names_to_var[node->GetName()] = node;
  429. } else {
  430. names_to_refs[ref_var_name].insert(node);
  431. }
  432. }
  433. for (auto &name_to_var : names_to_var) {
  434. var_and_var_ref_map_[name_to_var.second] = names_to_refs[name_to_var.first];
  435. }
  436. return SUCCESS;
  437. }
  438. Status VariableOpPass::CheckVarAndVarRefAreAlike(const NodePtr &var_node, const NodePtr &var_ref_node,
  439. bool &is_var_and_variable_ref_are_alike) {
  440. GE_CHECK_NOTNULL(var_node);
  441. GE_CHECK_NOTNULL(var_ref_node);
  442. GELOGD("var_node GetOutDataNodes. name is %s.", var_node->GetName().c_str());
  443. const auto &var_node_trans_nodes = var_node->GetOutDataNodes();
  444. GELOGD("var_node_trans_nodes size is %zu.", var_node_trans_nodes.size());
  445. GELOGD("var_ref_node GetOutDataNodes. name is %s.", var_ref_node->GetName().c_str());
  446. const auto &var_ref_node_trans_nodes = var_ref_node->GetInDataNodes();
  447. GELOGD("var_ref_node_trans_nodes size is %zu.", var_ref_node_trans_nodes.size());
  448. if (var_ref_node_trans_nodes.size() > 1) {
  449. REPORT_INNER_ERROR("E19999", "In data node num:%zu of node:%s(%s) bigger than 1, check invalid",
  450. var_ref_node_trans_nodes.size(),
  451. var_ref_node->GetName().c_str(), var_ref_node->GetType().c_str());
  452. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "[Check][Param] In data node num:%zu of node:%s(%s) bigger than 1.",
  453. var_ref_node_trans_nodes.size(), var_ref_node->GetName().c_str(), var_ref_node->GetType().c_str());
  454. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  455. }
  456. const auto &var_node_trans_node = var_node_trans_nodes.at(0);
  457. const auto &var_ref_node_trans_node = var_ref_node_trans_nodes.at(0);
  458. if (CheckTransNodeAreInverse(var_node_trans_node, var_ref_node_trans_node, is_var_and_variable_ref_are_alike) !=
  459. SUCCESS) {
  460. GELOGE(FAILED, "[Call][CheckTransNodeAreInverse] failed");
  461. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  462. }
  463. return SUCCESS;
  464. }
  465. Status VariableOpPass::CheckTransNodeAreInverse(const NodePtr &node_a, const NodePtr &node_b, bool &is_same) {
  466. GELOGD("In CheckTransNodeAreInverse.");
  467. GE_CHECK_NOTNULL(node_a);
  468. GE_CHECK_NOTNULL(node_b);
  469. const auto &node_a_op_desc = node_a->GetOpDesc();
  470. const auto &node_b_op_desc = node_b->GetOpDesc();
  471. GE_CHECK_NOTNULL(node_a_op_desc);
  472. GE_CHECK_NOTNULL(node_b_op_desc);
  473. const auto &node_a_out_op_desc = node_a_op_desc->MutableOutputDesc(0);
  474. const auto &node_a_in_op_desc = node_a_op_desc->MutableInputDesc(0);
  475. GE_CHECK_NOTNULL(node_a_out_op_desc);
  476. GE_CHECK_NOTNULL(node_a_in_op_desc);
  477. const auto &node_b_out_op_desc = node_b_op_desc->MutableOutputDesc(0);
  478. const auto &node_b_in_op_desc = node_b_op_desc->MutableInputDesc(0);
  479. GE_CHECK_NOTNULL(node_b_out_op_desc);
  480. GE_CHECK_NOTNULL(node_b_in_op_desc);
  481. is_same = IsOpDescSame(node_a_out_op_desc, node_b_in_op_desc) && IsOpDescSame(node_b_out_op_desc, node_a_in_op_desc);
  482. return SUCCESS;
  483. }
  484. bool VariableOpPass::IsOpDescSame(const GeTensorDescPtr &op_desc_a, const GeTensorDescPtr &op_desc_b) {
  485. const auto &format_a = op_desc_a->GetFormat();
  486. const auto &type_a = op_desc_a->GetDataType();
  487. const auto &shape_a = op_desc_a->GetShape();
  488. const auto &format_b = op_desc_b->GetFormat();
  489. const auto &type_b = op_desc_b->GetDataType();
  490. const auto &shape_b = op_desc_b->GetShape();
  491. const auto &dims_a = shape_a.GetDims();
  492. const auto &dims_b = shape_b.GetDims();
  493. GELOGD("(format, data type, shape) = (%s, %s, %zu) (%s, %s, %zu)", TypeUtils::FormatToSerialString(format_a).c_str(),
  494. TypeUtils::DataTypeToSerialString(type_a).c_str(), dims_a.size(),
  495. TypeUtils::FormatToSerialString(format_b).c_str(), TypeUtils::DataTypeToSerialString(type_b).c_str(),
  496. dims_b.size());
  497. return (format_a == format_b) && (type_a == type_b) && (dims_a == dims_b);
  498. }
  499. void VariableOpPass::CopyVariableFormatDataTypeAndShape(const GeTensorDesc &src_tensor_desc,
  500. GeTensorDesc &dst_tensor_desc) {
  501. dst_tensor_desc.SetShape(src_tensor_desc.GetShape());
  502. dst_tensor_desc.SetFormat(src_tensor_desc.GetFormat());
  503. dst_tensor_desc.SetDataType(src_tensor_desc.GetDataType());
  504. }
  505. Status VariableOpPass::CheckIfCouldBeOptimized(const ge::NodePtr &node, bool &flag, VarTransRoad &fusion_road) {
  506. if (node == nullptr) {
  507. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
  508. GELOGE(FAILED, "[Check][Param] param node is nullptr.");
  509. return FAILED;
  510. }
  511. bool is_matched = false;
  512. auto ret = CheckSameAndTransOp(node, is_matched, fusion_road);
  513. if (ret != SUCCESS) {
  514. GELOGE(FAILED, "[Call][CheckSameAndTransOp] failed, node:%s", node->GetName().c_str());
  515. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  516. }
  517. if (!is_matched) {
  518. flag = false;
  519. return SUCCESS;
  520. }
  521. bool is_var_ref_legally = false;
  522. ret = CheckVariableRefLegally(node, is_var_ref_legally);
  523. if (ret != SUCCESS) {
  524. GELOGE(FAILED, "[Call][CheckVariableRefLegally] failed, node:%s", node->GetName().c_str());
  525. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  526. }
  527. GELOGD("is_var_ref_legally is %d.", is_var_ref_legally);
  528. if (!is_var_ref_legally) {
  529. GELOGI("variable ref connection are illegally");
  530. flag = false;
  531. fusion_road.clear();
  532. return SUCCESS;
  533. }
  534. flag = true;
  535. GELOGD("node %s, is_matched = %d is_var_ref_legally = %d, flag = %d", node->GetName().c_str(), is_matched,
  536. is_var_ref_legally, flag);
  537. return SUCCESS;
  538. }
  539. Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_road) {
  540. bool can_fusion = false;
  541. while (true) {
  542. auto ret = CheckIfCouldBeOptimized(var, can_fusion, fusion_road);
  543. if (ret != SUCCESS) {
  544. GELOGE(FAILED, "[Call][CheckIfCouldBeOptimized] failed");
  545. return ret;
  546. }
  547. if (!can_fusion) {
  548. break;
  549. }
  550. ret = DealFusion(var);
  551. if (ret != SUCCESS) {
  552. GELOGE(FAILED, "[Call][DealFusion] failed");
  553. return ret;
  554. }
  555. }
  556. return SUCCESS;
  557. }
  558. Status VariableOpPass::UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set<NodePtr> &nodes) {
  559. for (auto &need_set_node : nodes) {
  560. auto ret = UpdateVarAndRefOutputFormatInfo(final_output, need_set_node);
  561. if (ret != SUCCESS) {
  562. GELOGE(FAILED, "[Call][UpdateVarAndRefOutputFormatInfo] failed");
  563. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  564. }
  565. }
  566. return SUCCESS;
  567. }
  568. Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) {
  569. GE_CHECK_NOTNULL(graph);
  570. // renew var manager desc
  571. Status ret = SUCCESS;
  572. for (auto &node : graph->GetDirectNode()) {
  573. bool is_var_node =
  574. (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP);
  575. if (is_var_node) {
  576. if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) {
  577. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  578. continue;
  579. }
  580. GELOGD("var manager exist var node[%s], graph name[%s]", node->GetName().c_str(), graph->GetName().c_str());
  581. GE_CHECK_NOTNULL(node->GetOpDesc());
  582. ret = ge::VarManager::Instance(graph->GetSessionID())->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  583. if (ret != SUCCESS) {
  584. REPORT_CALL_ERROR("E19999", "Renew descriptor for node:%s(%s) failed, session_id:%lu",
  585. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  586. GELOGE(FAILED, "[Renew][Descriptor] for node:%s(%s) failed, session_id:%lu",
  587. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  588. return FAILED;
  589. }
  590. }
  591. }
  592. return SUCCESS;
  593. }
  594. Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) {
  595. // renew var desc if the trans_road is all reshape or reformat
  596. for (auto &road : fusion_road) {
  597. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  598. return SUCCESS;
  599. }
  600. }
  601. if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) {
  602. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  603. return SUCCESS;
  604. }
  605. GELOGD("var manager exist var node[%s]", node->GetName().c_str());
  606. GE_CHECK_NOTNULL(node->GetOpDesc());
  607. Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  608. if (ret != SUCCESS) {
  609. REPORT_CALL_ERROR("E19999", "Renew descriptor for node:%s(%s) failed, session_id:%lu",
  610. node->GetName().c_str(), node->GetType().c_str(), session_id);
  611. GELOGE(FAILED, "[Renew][Descriptor] for node:%s(%s) failed, session_id:%lu",
  612. node->GetName().c_str(), node->GetType().c_str(), session_id);
  613. return FAILED;
  614. }
  615. return SUCCESS;
  616. }
  617. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示