You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass.cc 28 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_op_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/graph.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kTransOpOutIndex = 0;
  30. std::string GetKey(Format format, DataType type, const std::vector<int64_t> &dims) {
  31. std::stringstream key;
  32. key << static_cast<int>(format) << '-';
  33. key << static_cast<int>(type) << '-';
  34. for (auto dim : dims) {
  35. key << dim << '-';
  36. }
  37. return key.str();
  38. }
  39. Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) {
  40. GE_CHECK_NOTNULL(trans_node);
  41. GE_CHECK_NOTNULL(ref_node);
  42. GELOGD("Begin to bypass trans node %s", trans_node->GetName().c_str());
  43. auto ret = GraphUtils::CopyInCtrlEdges(trans_node, ref_node);
  44. if (ret != GRAPH_SUCCESS) {
  45. REPORT_CALL_ERROR("E19999", "Copy in control edge from node:%s(%s) to node:%s(%s) failed",
  46. trans_node->GetName().c_str(), trans_node->GetType().c_str(),
  47. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  48. GELOGE(INTERNAL_ERROR,
  49. "Failed to move control edges from trans "
  50. "node %s to var-ref %s",
  51. trans_node->GetName().c_str(), ref_node->GetName().c_str());
  52. return INTERNAL_ERROR;
  53. }
  54. auto ref_in_anchor = ref_node->GetInDataAnchor(0);
  55. if (ref_in_anchor == nullptr) {
  56. REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no input anchor, check invalid",
  57. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  58. GELOGE(INTERNAL_ERROR,
  59. "The variable ref node %s does not have an "
  60. "input anchor",
  61. ref_node->GetName().c_str());
  62. return INTERNAL_ERROR;
  63. }
  64. ref_in_anchor->UnlinkAll();
  65. auto trans_in_anchor = trans_node->GetInDataAnchor(0);
  66. if (trans_in_anchor == nullptr) {
  67. REPORT_INNER_ERROR("E19999", "Node:%s(%s) has no input anchor, check invalid",
  68. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  69. GELOGE(INTERNAL_ERROR,
  70. "Failed to get the in data anchor from trans"
  71. " node %s type %s",
  72. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  73. return INTERNAL_ERROR;
  74. }
  75. auto prev_trans_node_out_anchor = trans_in_anchor->GetPeerOutAnchor();
  76. if (prev_trans_node_out_anchor == nullptr) {
  77. GELOGW(
  78. "The trans node %s does not have an input, so the ref node %s does"
  79. " not have any inputs after bypass",
  80. trans_node->GetName().c_str(), trans_node->GetName().c_str());
  81. } else {
  82. ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor);
  83. if (ret != GRAPH_SUCCESS) {
  84. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  85. prev_trans_node_out_anchor->GetOwnerNode()->GetName().c_str(),
  86. prev_trans_node_out_anchor->GetOwnerNode()->GetType().c_str(),
  87. prev_trans_node_out_anchor->GetIdx(),
  88. ref_node->GetName().c_str(), ref_node->GetType().c_str());
  89. GELOGE(INTERNAL_ERROR,
  90. "Failed to add edge between ref node %s "
  91. "and the prev node of trans node %s",
  92. ref_node->GetName().c_str(), trans_node->GetName().c_str());
  93. return INTERNAL_ERROR;
  94. }
  95. }
  96. return SUCCESS;
  97. }
  98. bool IsTransSupport(const TransNodeInfo &trans_info) {
  99. if (trans_info.output.GetShape().IsUnknownShape()) {
  100. return false;
  101. }
  102. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  103. return true;
  104. } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  105. formats::TransArgs args{nullptr,
  106. trans_info.input.GetFormat(),
  107. trans_info.output.GetFormat(),
  108. trans_info.input.GetShape().GetDims(),
  109. trans_info.output.GetShape().GetDims(),
  110. trans_info.input.GetDataType()};
  111. return formats::IsTransFormatSupport(args);
  112. } else if (trans_info.node_type == CAST) {
  113. formats::CastArgs datatype_args{nullptr, static_cast<size_t>(trans_info.input.GetShape().GetShapeSize()),
  114. trans_info.input.GetDataType(), trans_info.output.GetDataType()};
  115. return formats::IsTransDataTypeSupport(datatype_args);
  116. } else {
  117. return false;
  118. }
  119. }
  120. } // namespace
  121. Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
  122. if (graph == nullptr) {
  123. REPORT_INNER_ERROR("E19999", "Param graph is nullptr, check invalid");
  124. GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph");
  125. return INTERNAL_ERROR;
  126. }
  127. auto graph_id = GraphUtils::FindRootGraph(graph)->GetGraphID();
  128. GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
  129. GetContext().SessionId(), graph_id);
  130. if (var_accelerate_ctrl_ == nullptr) {
  131. REPORT_INNER_ERROR("E19999", "The variable accelerate control is nullptr, check invalid");
  132. GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null");
  133. return INTERNAL_ERROR;
  134. }
  135. GELOGD("Begin to generate ref map for variable and refs, graph name:%s.", graph->GetName().c_str());
  136. if (RenewVarDesc(graph) != SUCCESS) {
  137. GELOGE(INTERNAL_ERROR, "Failed to renew var desc on graph");
  138. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  139. }
  140. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  141. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  142. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  143. }
  144. GELOGD("Begin to fusion variables and trans nodes");
  145. for (auto &var_to_refs : var_and_var_ref_map_) {
  146. auto &node = var_to_refs.first;
  147. GE_CHECK_NOTNULL(node);
  148. GE_CHECK_NOTNULL(var_accelerate_ctrl_);
  149. if (!var_accelerate_ctrl_->IsVarPermitToChangeFormats(node->GetName())) {
  150. GELOGD("The var %s does not permit to change formats, skip it", node->GetName().c_str());
  151. continue;
  152. }
  153. VarTransRoad fusion_road;
  154. auto ret = FusionIfNeed(node, fusion_road);
  155. if (ret != SUCCESS) {
  156. return ret;
  157. }
  158. if (fusion_road.empty()) {
  159. GELOGD("No need to fusion variable and trans op for var %s", node->GetName().c_str());
  160. continue;
  161. }
  162. auto start_iter = fusion_road.begin();
  163. auto end_iter = fusion_road.rbegin();
  164. GELOGD(
  165. "Trans variable data for %s from format %s to %s, shape %s to %s "
  166. "data-type %s to %s, path len %zu success",
  167. node->GetName().c_str(), TypeUtils::FormatToSerialString(start_iter->input.GetFormat()).c_str(),
  168. TypeUtils::FormatToSerialString(end_iter->output.GetFormat()).c_str(),
  169. formats::ShapeToString(start_iter->input.GetShape().GetDims()).c_str(),
  170. formats::ShapeToString(end_iter->output.GetShape().GetDims()).c_str(),
  171. TypeUtils::DataTypeToSerialString(start_iter->input.GetDataType()).c_str(),
  172. TypeUtils::DataTypeToSerialString(end_iter->output.GetDataType()).c_str(), fusion_road.size());
  173. ret = VarManager::Instance(graph->GetSessionID())->SetTransRoad(node->GetName(), fusion_road);
  174. if (ret != SUCCESS) {
  175. REPORT_CALL_ERROR("E19999", "Set Trans road for node:%s(%s) failed, session_id:%lu",
  176. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  177. GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str());
  178. return INTERNAL_ERROR;
  179. }
  180. ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph_id);
  181. if (ret != SUCCESS) {
  182. REPORT_CALL_ERROR("E19999", "Update graph_id:%u for node:%s(%s) failed, session_id:%lu",
  183. graph_id, node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  184. GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str());
  185. return INTERNAL_ERROR;
  186. }
  187. var_accelerate_ctrl_->SetVarChanged(node->GetName());
  188. GELOGD("Begin to update format info for var %s.", node->GetName().c_str());
  189. std::set<ge::NodePtr> node_set({node});
  190. if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) {
  191. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  192. }
  193. // renew var desc if the trans_road is all reshape or reformat
  194. ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road);
  195. if (ret != SUCCESS) {
  196. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  197. return FAILED;
  198. }
  199. }
  200. return SUCCESS;
  201. }
  202. Status VariableOpPass::DealFusion(const ge::NodePtr &var_node) {
  203. GE_CHECK_NOTNULL(var_node);
  204. GELOGD("Begin to fusion var %s with trans", var_node->GetName().c_str());
  205. auto graph = var_node->GetOwnerComputeGraph();
  206. for (auto &trans_node : var_node->GetOutDataNodes()) {
  207. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  208. trans_node->GetType().c_str(), var_node->GetName().c_str());
  209. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  210. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) failed",
  211. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  212. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  213. }
  214. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  215. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  216. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  217. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  218. }
  219. }
  220. auto iterator = var_and_var_ref_map_.find(var_node);
  221. if (iterator == var_and_var_ref_map_.end()) {
  222. GELOGD("there is no var_ref of node %s", var_node->GetName().c_str());
  223. return SUCCESS;
  224. }
  225. for (auto ref_node : iterator->second) {
  226. GE_CHECK_NOTNULL(ref_node);
  227. for (auto &trans_node : ref_node->GetInDataNodes()) {
  228. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  229. trans_node->GetType().c_str(), var_node->GetName().c_str());
  230. if (trans_node->GetOutDataNodes().size() > 1) {
  231. GELOGD(
  232. "The trans node %s type %s connecting with var-ref %s has more"
  233. " than one output data nodes, unlink the edge between them",
  234. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  235. if (ByPassTransNode(trans_node, ref_node) != SUCCESS) {
  236. GELOGE(INTERNAL_ERROR, "Failed to bypass trans node %s to ref %s", trans_node->GetName().c_str(),
  237. ref_node->GetName().c_str());
  238. return INTERNAL_ERROR;
  239. }
  240. } else {
  241. GELOGD(
  242. "The trans node %s type %s connecting with var-ref %s has only"
  243. " one output data nodes, isolate and remove it.",
  244. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  245. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  246. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) failed",
  247. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  248. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  249. }
  250. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  251. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  252. trans_node->GetName().c_str(), trans_node->GetType().c_str(), graph->GetName().c_str());
  253. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  254. }
  255. }
  256. }
  257. }
  258. return SUCCESS;
  259. }
  260. Status VariableOpPass::CheckSameAndTransOp(const ge::NodePtr &var_node, bool &is_matched, VarTransRoad &fusion_road) {
  261. std::set<std::string> data_type_and_formats;
  262. std::string trans_op_type;
  263. ge::NodePtr out_node;
  264. ge::GeTensorDesc output_desc;
  265. GE_CHECK_NOTNULL(var_node);
  266. for (auto &out_node_and_anchor : var_node->GetOutDataNodesAndAnchors()) {
  267. auto in_anchor = out_node_and_anchor.second;
  268. GE_CHECK_NOTNULL(in_anchor);
  269. out_node = out_node_and_anchor.first;
  270. GE_CHECK_NOTNULL(out_node);
  271. auto trans_op_desc = out_node->GetOpDesc();
  272. GE_CHECK_NOTNULL(trans_op_desc);
  273. trans_op_type = trans_op_desc->GetType();
  274. GELOGD("current node type is %s.", trans_op_type.c_str());
  275. int data_index = TransOpUtil::GetTransOpDataIndex(trans_op_type);
  276. if (data_index < 0) {
  277. GELOGD("Variables only can be fusion with trans_op, the next op is %s type %s", out_node->GetName().c_str(),
  278. out_node->GetType().c_str());
  279. return SUCCESS;
  280. }
  281. if (data_index != in_anchor->GetIdx()) {
  282. GELOGD(
  283. "Variables only can be fusion with trans nodes, the next node %s"
  284. " type %s index %d does not trans anything(correct index %d)",
  285. out_node->GetName().c_str(), out_node->GetType().c_str(), in_anchor->GetIdx(), data_index);
  286. return SUCCESS;
  287. }
  288. output_desc = trans_op_desc->GetOutputDesc(kTransOpOutIndex);
  289. auto trans_op_format = output_desc.GetFormat();
  290. auto trans_op_data_type = output_desc.GetDataType();
  291. auto shape = output_desc.GetShape().GetDims();
  292. auto datatype_and_format = GetKey(trans_op_format, trans_op_data_type, shape);
  293. data_type_and_formats.insert(datatype_and_format);
  294. }
  295. if (data_type_and_formats.empty()) {
  296. return SUCCESS;
  297. }
  298. if (data_type_and_formats.size() > 1) {
  299. std::stringstream type_and_formats_stream;
  300. bool first_time = true;
  301. for (const auto &data_type_and_format : data_type_and_formats) {
  302. if (first_time) {
  303. first_time = false;
  304. } else {
  305. type_and_formats_stream << "|";
  306. }
  307. type_and_formats_stream << data_type_and_format;
  308. }
  309. GELOGW(
  310. "trans_op type size for var Node(%s) is over 1, Currently not"
  311. " supported, dataTypeAndFormats is %s.",
  312. var_node->GetName().c_str(), type_and_formats_stream.str().c_str());
  313. return SUCCESS;
  314. }
  315. int tran_in_index = TransOpUtil::GetTransOpDataIndex(out_node->GetType());
  316. auto out_op_desc = out_node->GetOpDesc();
  317. GE_CHECK_NOTNULL(out_op_desc);
  318. TransNodeInfo trans_node_info;
  319. trans_node_info.node_type = out_node->GetType();
  320. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  321. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  322. if (!IsTransSupport(trans_node_info)) {
  323. GELOGD("The trans node %s does not support, skip the variable accelerating", trans_node_info.node_type.c_str());
  324. return SUCCESS;
  325. }
  326. is_matched = true;
  327. fusion_road.emplace_back(trans_node_info);
  328. return SUCCESS;
  329. }
  330. Status VariableOpPass::CheckVariableRefLegally(const ge::NodePtr &var_node, bool &is_var_ref_legally) {
  331. is_var_ref_legally = true;
  332. GE_CHECK_NOTNULL(var_node);
  333. auto iterator = var_and_var_ref_map_.find(var_node);
  334. if (iterator == var_and_var_ref_map_.end()) {
  335. GELOGD("var name %s are not in var var_ref map", var_node->GetName().c_str());
  336. return SUCCESS;
  337. }
  338. GELOGD("var name %s, ref var count %zu.", var_node->GetName().c_str(), iterator->second.size());
  339. for (const auto &var_ref_node : iterator->second) {
  340. if (CheckVarAndVarRefAreAlike(var_node, var_ref_node, is_var_ref_legally) != SUCCESS) {
  341. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  342. }
  343. GELOGD("is_var_ref_legally is %d", is_var_ref_legally);
  344. if (!is_var_ref_legally) {
  345. return SUCCESS;
  346. }
  347. }
  348. return SUCCESS;
  349. }
  350. Status VariableOpPass::UpdateVarAndRefOutputFormatInfo(const GeTensorDesc &final_output, const ge::NodePtr &node) {
  351. if (node == nullptr || node->GetOpDesc() == nullptr) {
  352. REPORT_INNER_ERROR("E19999", "Param node or its op_desc is nullptr, check invalid");
  353. GELOGE(FAILED, "node or opdesc is nullptr");
  354. return FAILED;
  355. }
  356. const Format &format = final_output.GetFormat();
  357. const DataType &data_type = final_output.GetDataType();
  358. const GeShape &shape = final_output.GetShape();
  359. GELOGD("last ref is (%s, %s, %lu), var_ref_name is %s.", TypeUtils::DataTypeToSerialString(data_type).c_str(),
  360. TypeUtils::FormatToSerialString(format).c_str(), shape.GetDims().size(), node->GetName().c_str());
  361. auto node_desc = node->GetOpDesc()->GetOutputDesc(0);
  362. CopyVariableFormatDataTypeAndShape(final_output, node_desc);
  363. if (node->GetOpDesc()->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  364. REPORT_CALL_ERROR("E19999", "Update ouput:0 desc in op:%s(%s) failed",
  365. node->GetName().c_str(), node->GetType().c_str());
  366. GELOGE(FAILED, "update output desc fail.");
  367. return FAILED;
  368. }
  369. GELOGD("node ref is (%s, %s, %lu), var_ref_name is %s.",
  370. TypeUtils::DataTypeToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetDataType()).c_str(),
  371. TypeUtils::FormatToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetFormat()).c_str(),
  372. node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims().size(), node->GetName().c_str());
  373. auto iterator = var_and_var_ref_map_.find(node);
  374. if (iterator == var_and_var_ref_map_.end()) {
  375. auto graph = node->GetOwnerComputeGraph();
  376. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  377. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  378. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  379. }
  380. }
  381. iterator = var_and_var_ref_map_.find(node);
  382. if (iterator == var_and_var_ref_map_.end()) {
  383. GELOGW("The var node %s which belongs to graph %s can not be found on the graph", node->GetName().c_str(),
  384. node->GetOwnerComputeGraph()->GetName().c_str());
  385. return SUCCESS;
  386. }
  387. for (const auto &var_ref_node : iterator->second) {
  388. auto var_ref_node_description = var_ref_node->GetOpDesc();
  389. GE_CHECK_NOTNULL(var_ref_node_description);
  390. GELOGD("var_ref_node before is (%s, %s, %zu), var_ref_name is %s.",
  391. TypeUtils::DataTypeToSerialString(data_type).c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  392. shape.GetDims().size(), var_ref_node->GetName().c_str());
  393. if (var_ref_node_description->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  394. GELOGW("UpdateOutputDesc fail.");
  395. }
  396. if (var_ref_node_description->UpdateInputDesc(0, node_desc) != GRAPH_SUCCESS) {
  397. GELOGW("UpdateInputDesc fail.");
  398. }
  399. const auto &input_desc = var_ref_node_description->MutableInputDesc(0);
  400. const auto &output_desc = var_ref_node_description->MutableOutputDesc(0);
  401. GE_CHECK_NOTNULL(input_desc);
  402. GE_CHECK_NOTNULL(output_desc);
  403. GELOGD("var_ref_node ref is (%s, %s, %zu), var_ref_name is %s.",
  404. TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(),
  405. TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), output_desc->GetShape().GetDims().size(),
  406. var_ref_node->GetName().c_str());
  407. }
  408. return SUCCESS;
  409. }
  410. Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &compute_graph) {
  411. std::map<std::string, NodePtr> names_to_var;
  412. std::map<std::string, std::set<NodePtr>> names_to_refs;
  413. GE_CHECK_NOTNULL(compute_graph);
  414. for (auto &node : compute_graph->GetDirectNode()) {
  415. if (node->GetType() != VARIABLE) {
  416. continue;
  417. }
  418. std::string ref_var_name;
  419. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) {
  420. names_to_var[node->GetName()] = node;
  421. } else {
  422. names_to_refs[ref_var_name].insert(node);
  423. }
  424. }
  425. for (auto &name_to_var : names_to_var) {
  426. var_and_var_ref_map_[name_to_var.second] = names_to_refs[name_to_var.first];
  427. }
  428. return SUCCESS;
  429. }
  430. Status VariableOpPass::CheckVarAndVarRefAreAlike(const NodePtr &var_node, const NodePtr &var_ref_node,
  431. bool &is_var_and_variable_ref_are_alike) {
  432. GE_CHECK_NOTNULL(var_node);
  433. GE_CHECK_NOTNULL(var_ref_node);
  434. GELOGD("var_node GetOutDataNodes. name is %s.", var_node->GetName().c_str());
  435. const auto &var_node_trans_nodes = var_node->GetOutDataNodes();
  436. GELOGD("var_node_trans_nodes size is %zu.", var_node_trans_nodes.size());
  437. GELOGD("var_ref_node GetOutDataNodes. name is %s.", var_ref_node->GetName().c_str());
  438. const auto &var_ref_node_trans_nodes = var_ref_node->GetInDataNodes();
  439. GELOGD("var_ref_node_trans_nodes size is %zu.", var_ref_node_trans_nodes.size());
  440. if (var_ref_node_trans_nodes.size() > 1) {
  441. REPORT_INNER_ERROR("E19999", "In data node num:%zu of node:%s(%s) bigger than 1, check invalid",
  442. var_ref_node_trans_nodes.size(),
  443. var_ref_node->GetName().c_str(), var_ref_node->GetType().c_str());
  444. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "var_ref_node_trans_nodes.size() > 1.");
  445. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  446. }
  447. const auto &var_node_trans_node = var_node_trans_nodes.at(0);
  448. const auto &var_ref_node_trans_node = var_ref_node_trans_nodes.at(0);
  449. if (CheckTransNodeAreInverse(var_node_trans_node, var_ref_node_trans_node, is_var_and_variable_ref_are_alike) !=
  450. SUCCESS) {
  451. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  452. }
  453. return SUCCESS;
  454. }
  455. Status VariableOpPass::CheckTransNodeAreInverse(const NodePtr &node_a, const NodePtr &node_b, bool &is_same) {
  456. GELOGD("In CheckTransNodeAreInverse.");
  457. GE_CHECK_NOTNULL(node_a);
  458. GE_CHECK_NOTNULL(node_b);
  459. const auto &node_a_op_desc = node_a->GetOpDesc();
  460. const auto &node_b_op_desc = node_b->GetOpDesc();
  461. GE_CHECK_NOTNULL(node_a_op_desc);
  462. GE_CHECK_NOTNULL(node_b_op_desc);
  463. const auto &node_a_out_op_desc = node_a_op_desc->MutableOutputDesc(0);
  464. const auto &node_a_in_op_desc = node_a_op_desc->MutableInputDesc(0);
  465. GE_CHECK_NOTNULL(node_a_out_op_desc);
  466. GE_CHECK_NOTNULL(node_a_in_op_desc);
  467. const auto &node_b_out_op_desc = node_b_op_desc->MutableOutputDesc(0);
  468. const auto &node_b_in_op_desc = node_b_op_desc->MutableInputDesc(0);
  469. GE_CHECK_NOTNULL(node_b_out_op_desc);
  470. GE_CHECK_NOTNULL(node_b_in_op_desc);
  471. is_same = IsOpDescSame(node_a_out_op_desc, node_b_in_op_desc) && IsOpDescSame(node_b_out_op_desc, node_a_in_op_desc);
  472. return SUCCESS;
  473. }
  474. bool VariableOpPass::IsOpDescSame(const GeTensorDescPtr &op_desc_a, const GeTensorDescPtr &op_desc_b) {
  475. const auto &format_a = op_desc_a->GetFormat();
  476. const auto &type_a = op_desc_a->GetDataType();
  477. const auto &shape_a = op_desc_a->GetShape();
  478. const auto &format_b = op_desc_b->GetFormat();
  479. const auto &type_b = op_desc_b->GetDataType();
  480. const auto &shape_b = op_desc_b->GetShape();
  481. const auto &dims_a = shape_a.GetDims();
  482. const auto &dims_b = shape_b.GetDims();
  483. GELOGD("(format, data type, shape) = (%s, %s, %zu) (%s, %s, %zu)", TypeUtils::FormatToSerialString(format_a).c_str(),
  484. TypeUtils::DataTypeToSerialString(type_a).c_str(), dims_a.size(),
  485. TypeUtils::FormatToSerialString(format_b).c_str(), TypeUtils::DataTypeToSerialString(type_b).c_str(),
  486. dims_b.size());
  487. return (format_a == format_b) && (type_a == type_b) && (dims_a == dims_b);
  488. }
  489. void VariableOpPass::CopyVariableFormatDataTypeAndShape(const GeTensorDesc &src_tensor_desc,
  490. GeTensorDesc &dst_tensor_desc) {
  491. dst_tensor_desc.SetShape(src_tensor_desc.GetShape());
  492. dst_tensor_desc.SetFormat(src_tensor_desc.GetFormat());
  493. dst_tensor_desc.SetDataType(src_tensor_desc.GetDataType());
  494. }
  495. Status VariableOpPass::CheckIfCouldBeOptimized(const ge::NodePtr &node, bool &flag, VarTransRoad &fusion_road) {
  496. if (node == nullptr) {
  497. REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
  498. return FAILED;
  499. }
  500. bool is_matched = false;
  501. auto ret = CheckSameAndTransOp(node, is_matched, fusion_road);
  502. if (ret != SUCCESS) {
  503. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  504. }
  505. if (!is_matched) {
  506. flag = false;
  507. return SUCCESS;
  508. }
  509. bool is_var_ref_legally = false;
  510. ret = CheckVariableRefLegally(node, is_var_ref_legally);
  511. if (ret != SUCCESS) {
  512. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  513. }
  514. GELOGD("is_var_ref_legally is %d.", is_var_ref_legally);
  515. if (!is_var_ref_legally) {
  516. GELOGI("variable ref connection are illegally");
  517. flag = false;
  518. fusion_road.clear();
  519. return SUCCESS;
  520. }
  521. flag = true;
  522. GELOGD("node %s, is_matched = %d is_var_ref_legally = %d, flag = %d", node->GetName().c_str(), is_matched,
  523. is_var_ref_legally, flag);
  524. return SUCCESS;
  525. }
  526. Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_road) {
  527. bool can_fusion = false;
  528. while (true) {
  529. auto ret = CheckIfCouldBeOptimized(var, can_fusion, fusion_road);
  530. if (ret != SUCCESS) {
  531. return ret;
  532. }
  533. if (!can_fusion) {
  534. break;
  535. }
  536. ret = DealFusion(var);
  537. if (ret != SUCCESS) {
  538. return ret;
  539. }
  540. }
  541. return SUCCESS;
  542. }
  543. Status VariableOpPass::UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set<NodePtr> &nodes) {
  544. for (auto &need_set_node : nodes) {
  545. auto ret = UpdateVarAndRefOutputFormatInfo(final_output, need_set_node);
  546. if (ret != SUCCESS) {
  547. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  548. }
  549. }
  550. return SUCCESS;
  551. }
  552. Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) {
  553. GE_CHECK_NOTNULL(graph);
  554. // renew var manager desc
  555. Status ret = SUCCESS;
  556. for (auto &node : graph->GetDirectNode()) {
  557. bool is_var_node =
  558. (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP);
  559. if (is_var_node) {
  560. if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) {
  561. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  562. continue;
  563. }
  564. GELOGD("var manager exist var node[%s], graph name[%s]", node->GetName().c_str(), graph->GetName().c_str());
  565. GE_CHECK_NOTNULL(node->GetOpDesc());
  566. ret = ge::VarManager::Instance(graph->GetSessionID())->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  567. if (ret != SUCCESS) {
  568. REPORT_CALL_ERROR("E19999", "Renew descriptor for node:%s(%s) failed, session_id:%lu",
  569. node->GetName().c_str(), node->GetType().c_str(), graph->GetSessionID());
  570. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  571. return FAILED;
  572. }
  573. }
  574. }
  575. return SUCCESS;
  576. }
  577. Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) {
  578. // renew var desc if the trans_road is all reshape or reformat
  579. for (auto &road : fusion_road) {
  580. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  581. return SUCCESS;
  582. }
  583. }
  584. if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) {
  585. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  586. return SUCCESS;
  587. }
  588. GELOGD("var manager exist var node[%s]", node->GetName().c_str());
  589. GE_CHECK_NOTNULL(node->GetOpDesc());
  590. Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  591. if (ret != SUCCESS) {
  592. REPORT_CALL_ERROR("E19999", "Renew descriptor for node:%s(%s) failed, session_id:%lu",
  593. node->GetName().c_str(), node->GetType().c_str(), session_id);
  594. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  595. return FAILED;
  596. }
  597. return SUCCESS;
  598. }
  599. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示