You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infershape_pass.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/infershape_pass.h"
  17. #include "common/util/error_manager/error_manager.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "analyzer/analyzer.h"
  20. #include "framework/common/util.h"
  21. #include "graph/common/omg_util.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/debug/ge_util.h"
  24. #include "graph/operator_factory_impl.h"
  25. #include "graph/utils/graph_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/tensor_utils.h"
  28. #include "graph/utils/type_utils.h"
  29. namespace ge {
  30. namespace {
  31. const char *const kPreOpInputShapeRange = "_pre_op_in_range";
  32. thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map;
  33. }
  34. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void InferShapePass::ClearContextMap() { context_map.clear(); }
  35. InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
  36. const NodePtr &node) {
  37. if (node == nullptr) {
  38. GELOGE(GRAPH_FAILED, "node is null");
  39. return nullptr;
  40. }
  41. InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
  42. if (inference_context == nullptr) {
  43. REPORT_CALL_ERROR("E19999", "Failed to alloc InferenceContext, node:%s", node->GetName().c_str());
  44. GELOGE(GRAPH_FAILED, "[Alloc][InferenceContext] failed.");
  45. return nullptr;
  46. }
  47. auto all_in_data_anchors = node->GetAllInDataAnchors();
  48. std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
  49. std::vector<std::string> marks;
  50. bool has_input_shapes_and_types = false;
  51. for (const auto &in_anchor : all_in_data_anchors) {
  52. const auto &out_anchor = in_anchor->GetPeerOutAnchor();
  53. if (out_anchor == nullptr) {
  54. continue;
  55. }
  56. auto input_node = out_anchor->GetOwnerNode();
  57. if (input_node == nullptr) {
  58. continue;
  59. }
  60. auto iter = context_map.find(input_node);
  61. if (iter != context_map.end()) {
  62. const auto &src_context = iter->second;
  63. GE_IF_BOOL_EXEC(src_context == nullptr, REPORT_INNER_ERROR("E19999", "src_context is null.");
  64. GELOGE(GRAPH_FAILED, "[Check][Param] src_context is null."); return nullptr);
  65. GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
  66. input_node->GetName().c_str());
  67. for (auto mark : src_context->GetMarks()) {
  68. marks.push_back(mark);
  69. }
  70. auto output_idx = out_anchor->GetIdx();
  71. auto input_idx = in_anchor->GetIdx();
  72. auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
  73. if (output_idx < static_cast<int>(output_shape_and_type.size())) {
  74. GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
  75. node->GetName().c_str(), input_idx);
  76. input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
  77. has_input_shapes_and_types = true;
  78. } else {
  79. GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
  80. output_shape_and_type.size());
  81. }
  82. }
  83. }
  84. if (has_input_shapes_and_types) {
  85. inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
  86. }
  87. inference_context->SetMarks(marks);
  88. return inference_context;
  89. }
  90. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  91. desc_str += "[";
  92. std::vector<std::pair<int64_t, int64_t>> shape_range;
  93. (void)desc->GetShapeRange(shape_range);
  94. for (const auto &pair : shape_range) {
  95. desc_str += "{";
  96. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  97. desc_str += "},";
  98. }
  99. desc_str += "]";
  100. shape_range.clear();
  101. (void)desc->GetOriginShapeRange(shape_range);
  102. for (const auto &pair : shape_range) {
  103. desc_str += ",{";
  104. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  105. desc_str += "},";
  106. }
  107. }
  108. std::string GetInTensorInfoWithString(const ge::NodePtr &node) {
  109. ge::OpDescPtr op_desc = node->GetOpDesc();
  110. std::stringstream ss;
  111. ss << "{";
  112. int32_t in_idx = 0;
  113. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  114. if (input_desc == nullptr) {
  115. in_idx++;
  116. continue;
  117. }
  118. if (in_idx > 0) {
  119. ss << " ";
  120. }
  121. ss << "input_" << in_idx << " "
  122. << "tensor: [";
  123. ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
  124. ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
  125. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
  126. ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
  127. ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
  128. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
  129. string range_str;
  130. SerialShapeRange(input_desc, range_str);
  131. ss << "(shape_range:" << range_str << ")]";
  132. in_idx++;
  133. }
  134. return ss.str();
  135. }
  136. void InferShapePass::AnalyzeFailedInfo(const NodePtr &node) {
  137. auto graph = node->GetOwnerComputeGraph();
  138. if (graph == nullptr) {
  139. GELOGW("Owner compute graph of node %s is nullptr", node->GetName().c_str());
  140. }
  141. auto root_graph = ge::GraphUtils::FindRootGraph(graph);
  142. if (root_graph == nullptr) {
  143. GELOGW("Root compute graph of node %s is nullptr", node->GetName().c_str());
  144. }
  145. analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node,
  146. "InferShapeFailed!"};
  147. (void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
  148. (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID());
  149. REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(),
  150. node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
  151. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str());
  152. }
  153. graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) {
  154. changed = false;
  155. const auto &dst_dims = dst->GetShape().GetDims();
  156. const auto &src_dims = src->GetShape().GetDims();
  157. if (dst_dims == src_dims) {
  158. changed = true;
  159. }
  160. dst = src;
  161. return GRAPH_SUCCESS;
  162. }
  163. graphStatus InferShapePass::UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) {
  164. changed = false;
  165. if (dst->GetShape().GetDims() == src->GetShape().GetDims()) {
  166. changed = true;
  167. }
  168. dst->SetOriginShape(src->GetOriginShape());
  169. dst->SetShape(src->GetShape());
  170. dst->SetDataType(src->GetDataType());
  171. dst->SetOriginDataType(src->GetOriginDataType());
  172. std::vector<std::pair<int64_t, int64_t>> shape_range;
  173. (void)src->GetShapeRange(shape_range);
  174. dst->SetShapeRange(shape_range);
  175. ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size()));
  176. return GRAPH_SUCCESS;
  177. }
  178. graphStatus InferShapePass::Infer(NodePtr &node) {
  179. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  180. auto opdesc = node->GetOpDesc();
  181. // some op can not infershape twice such as aipp
  182. bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
  183. if (need_update_input) {
  184. auto status = UpdateOpInputDesc(node);
  185. if (status != GRAPH_SUCCESS) {
  186. REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", status, node->GetName().c_str());
  187. GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", status);
  188. return status;
  189. }
  190. }
  191. if (node->Verify() != GRAPH_SUCCESS) {
  192. REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str());
  193. GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str());
  194. return GRAPH_FAILED;
  195. }
  196. PrintInOutTensorShape(node, "before_infershape");
  197. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  198. if (!is_unknown_graph) {
  199. auto inference_context = CreateInferenceContext(context_map, node);
  200. GE_CHECK_NOTNULL(inference_context);
  201. GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
  202. op.SetInferenceContext(inference_context);
  203. }
  204. graphStatus status = CallInferShapeFunc(node, op);
  205. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  206. if (is_unknown_graph) {
  207. PrintInOutTensorShape(node, "after_infershape when running");
  208. return GRAPH_SUCCESS;
  209. }
  210. UpdateInputOutputOriginAttr(node);
  211. } else {
  212. REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
  213. GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
  214. return GRAPH_FAILED;
  215. }
  216. if (!is_unknown_graph) {
  217. auto ctx_after_infer = op.GetInferenceContext();
  218. if (ctx_after_infer != nullptr) {
  219. GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  220. if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
  221. GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
  222. ctx_after_infer->GetMarks().size());
  223. (void)context_map.emplace(node, ctx_after_infer);
  224. }
  225. }
  226. }
  227. return GRAPH_SUCCESS;
  228. }
  229. graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
  230. auto op_desc = node->GetOpDesc();
  231. const auto &op_type = op_desc->GetType();
  232. auto ret = op_desc->CallInferFunc(op);
  233. if (ret == GRAPH_PARAM_INVALID) {
  234. // Op ir no infer func, try to get infer func from operator factory
  235. auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
  236. if (node_op.IsEmpty()) {
  237. GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
  238. return ret;
  239. }
  240. GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
  241. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  242. node_op.BreakConnect();
  243. if (temp_op_desc == nullptr) {
  244. REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr.");
  245. GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null");
  246. return GRAPH_FAILED;
  247. }
  248. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  249. GELOGW("InferShapeAndType UpdateInputName failed");
  250. for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
  251. if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
  252. break;
  253. }
  254. return GRAPH_SUCCESS;
  255. }
  256. }
  257. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  258. GELOGW("InferShapeAndType UpdateOutputName failed");
  259. }
  260. op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
  261. ret = op_desc->CallInferFunc(op);
  262. GELOGI("op CallInferFunc second. ret: %u", ret);
  263. }
  264. return ret;
  265. }
  266. void InferShapePass::UpdateInputOutputOriginAttr(NodePtr &node) {
  267. auto op_desc = node->GetOpDesc();
  268. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  269. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  270. if (output_tensor == nullptr) {
  271. continue;
  272. }
  273. if (output_tensor->MutableShape().GetDims().empty()) {
  274. output_tensor->SetOriginShape(output_tensor->GetShape());
  275. }
  276. ge::TensorUtils::SetRealDimCnt(*output_tensor,
  277. static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size()));
  278. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  279. // set output origin shape range
  280. std::vector<std::pair<int64_t, int64_t>> range;
  281. (void)output_tensor->GetShapeRange(range);
  282. output_tensor->SetOriginShapeRange(range);
  283. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(),
  284. output_tensor->GetOriginShape().GetShapeSize(),
  285. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  286. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  287. }
  288. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  289. auto input_tensor = op_desc->MutableInputDesc(in_anchor->GetIdx());
  290. if (input_tensor == nullptr) {
  291. continue;
  292. }
  293. // set input origin shape range
  294. std::vector<std::pair<int64_t, int64_t>> range;
  295. (void)input_tensor->GetShapeRange(range);
  296. input_tensor->SetOriginShapeRange(range);
  297. }
  298. }
  299. graphStatus InferShapePass::UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
  300. for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
  301. auto in_idx = in_anchor->GetIdx();
  302. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  303. if (peer_out_data_anchor == nullptr) {
  304. continue;
  305. }
  306. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  307. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  308. continue;
  309. }
  310. int peer_out_idx = peer_out_data_anchor->GetIdx();
  311. auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  312. // check shape and dtype continuity. do not stop process
  313. auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
  314. if (in_desc == nullptr) {
  315. continue;
  316. }
  317. auto in_shape = in_desc->MutableShape().GetDims();
  318. auto in_dtype = in_desc->GetDataType();
  319. auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
  320. auto peer_out_dtype = peer_out_desc->GetDataType();
  321. if (peer_out_dtype != in_dtype) {
  322. GELOGW(
  323. "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th "
  324. "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  325. node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
  326. peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
  327. } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
  328. string in_shape_str = " "; //Serial(in_shape);
  329. string peer_out_shape_str = " "; //Serial(peer_out_shape);
  330. GELOGW(
  331. "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th "
  332. "output_shape is [%s].The two shape should be same! Please check graph and fix it",
  333. node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
  334. peer_out_shape_str.c_str());
  335. }
  336. // refresh current node input desc
  337. in_desc->SetOriginShape(peer_out_desc->GetOriginShape());
  338. in_desc->SetShape(peer_out_desc->MutableShape());
  339. in_desc->SetDataType(peer_out_desc->GetDataType());
  340. in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType());
  341. if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) {
  342. std::vector<std::pair<int64_t, int64_t>> shape_range;
  343. (void)peer_out_desc->GetShapeRange(shape_range);
  344. in_desc->SetShapeRange(shape_range);
  345. }
  346. std::vector<int64_t> pre_op_in_range;
  347. if (ge::AttrUtils::GetListInt(*peer_out_desc, kPreOpInputShapeRange, pre_op_in_range)) {
  348. (void)ge::AttrUtils::SetListInt(*in_desc, kPreOpInputShapeRange, pre_op_in_range);
  349. }
  350. ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->MutableShape().GetDims().size()));
  351. }
  352. return GRAPH_SUCCESS;
  353. }
  354. Status InferShapePass::DoRepassForLoopNode(NodePtr &node) {
  355. GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
  356. bool need_repass = false;
  357. auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
  358. if (has_attr) {
  359. if (!OptionExists(kOptimizeAfterSubGraph)) {
  360. return SUCCESS;
  361. }
  362. if (need_repass) {
  363. AddImmediateRePassNode(node);
  364. GELOGD("Node %s need repass immediately.", node->GetName().c_str());
  365. } else {
  366. // clear attr on while
  367. node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
  368. }
  369. }
  370. return SUCCESS;
  371. }
  372. Status InferShapePass::RePassLoopNode(const NodePtr &node) {
  373. const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
  374. for (auto &n : node->GetOutDataNodes()) {
  375. GE_CHECK_NOTNULL(n);
  376. std::string node_type;
  377. GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed.");
  378. if (re_pass_types.count(node_type) > 0) {
  379. AddImmediateRePassNode(n);
  380. (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
  381. GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
  382. }
  383. }
  384. return SUCCESS;
  385. };
  386. const auto ExProcNode = [&](const std::set<std::string> &proc_types,
  387. const std::function<void(InferShapePass *, NodePtr)> &proc_func,
  388. const std::string &info) {
  389. for (auto &n : node->GetOutDataNodes()) {
  390. GE_CHECK_NOTNULL(n);
  391. std::string node_type;
  392. GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed.");
  393. if (proc_types.count(node_type) > 0) {
  394. proc_func(this, n);
  395. GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
  396. }
  397. }
  398. return SUCCESS;
  399. };
  400. std::string node_type;
  401. GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed.");
  402. if (kNextIterationOpTypes.count(node_type) > 0) {
  403. return RePassNode(kMergeOpTypes); // Re-Pass Merge
  404. }
  405. if (kMergeOpTypes.count(node_type) > 0) {
  406. if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
  407. node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
  408. return RePassNode(kSwitchOpTypes); // Re-Pass Switch
  409. }
  410. return SUCCESS;
  411. }
  412. if (kSwitchOpTypes.count(node_type) > 0) {
  413. if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
  414. node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
  415. return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
  416. } else {
  417. return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
  418. }
  419. }
  420. return SUCCESS;
  421. }
  422. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  423. graphStatus InferShapePass::InferShapeAndType(NodePtr &node) {
  424. GE_CHECK_NOTNULL(node);
  425. GE_CHECK_NOTNULL(node->GetOpDesc());
  426. InferShapePass pass;
  427. std::set<NodePtr> unused_changed_nodes;
  428. return pass.InferAndUpdate(node, true, unused_changed_nodes);
  429. }
  430. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  431. graphStatus InferShapePass::InferShapeAndType(NodePtr &node, bool before_subgraph) {
  432. GE_CHECK_NOTNULL(node);
  433. GE_CHECK_NOTNULL(node->GetOpDesc());
  434. InferShapePass pass;
  435. std::set<NodePtr> unused_changed_nodes;
  436. return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes);
  437. }
  438. graphStatus InferShapeForRunning::Infer(NodePtr &node) {
  439. auto opdesc = node->GetOpDesc();
  440. vector<ge::DataType> temp_dtype;
  441. for (auto &tensor_desc : opdesc->GetAllOutputsDescPtr()) {
  442. temp_dtype.emplace_back(tensor_desc->GetDataType());
  443. }
  444. PrintInOutTensorShape(node, "before_infershape when running");
  445. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  446. graphStatus status = CallInferShapeFuncForRunning(node, op);
  447. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  448. // ensure the dtype is not changed after infershape in running
  449. auto after_opdesc = node->GetOpDesc();
  450. GE_IF_BOOL_EXEC(after_opdesc == nullptr, REPORT_INNER_ERROR("E19999", "param node has no opdesc, check invalid.");
  451. GELOGE(GRAPH_FAILED, "[Get][OpDesc] after_opdesc is null."); return GRAPH_FAILED);
  452. auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr();
  453. for (size_t i = 0; i < all_output_tensor.size(); ++i) {
  454. if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) {
  455. GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", node->GetName().c_str(), i,
  456. TypeUtils::DataTypeToSerialString(all_output_tensor.at(i)->GetDataType()).c_str(),
  457. TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str());
  458. all_output_tensor.at(i)->SetDataType(temp_dtype[i]);
  459. }
  460. }
  461. PrintInOutTensorShape(node, "after_infershape when running");
  462. return GRAPH_SUCCESS;
  463. } else {
  464. REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
  465. GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
  466. return GRAPH_FAILED;
  467. }
  468. }
  469. graphStatus InferShapeForRunning::CallInferShapeFuncForRunning(NodePtr &node, Operator &op) {
  470. auto op_desc = node->GetOpDesc();
  471. const auto &op_type = op_desc->GetType();
  472. // Create InferenceContext to avoid null pointer access.
  473. const static std::set<std::string> force_context_op_types{"Enter", "Switch", "RefSwitch"};
  474. if (force_context_op_types.count(op_type) > 0) {
  475. GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str());
  476. op.SetInferenceContext(std::shared_ptr<InferenceContext>(InferenceContext::Create()));
  477. }
  478. // Get infer func and execute
  479. auto ret = op_desc->CallInferFunc(op);
  480. if (ret == GRAPH_PARAM_INVALID) {
  481. GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str());
  482. auto origin_type = NodeUtils::GetNodeType(*node);
  483. auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type);
  484. if (infer_func == nullptr) {
  485. REPORT_INNER_ERROR("E19999", "Failed to Get InferFunc. type is %s", origin_type.c_str());
  486. GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str());
  487. return GRAPH_FAILED;
  488. }
  489. op_desc->AddInferFunc(infer_func);
  490. ret = op_desc->CallInferFunc(op);
  491. GELOGI("op CallInferFunc second. ret: %u", ret);
  492. }
  493. return ret;
  494. }
  495. graphStatus InferShapeForRunning::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) {
  496. changed = false;
  497. const auto &dst_dims = dst->GetShape().GetDims();
  498. const auto &src_dims = src->GetShape().GetDims();
  499. if (dst_dims == src_dims) {
  500. changed = true;
  501. }
  502. dst = src;
  503. return GRAPH_SUCCESS;
  504. }
  505. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  506. graphStatus InferShapeForRunning::InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph) {
  507. GE_CHECK_NOTNULL(node);
  508. GE_CHECK_NOTNULL(node->GetOpDesc());
  509. InferShapeForRunning pass;
  510. std::set<NodePtr> unused_changed_nodes;
  511. return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes);
  512. }
  513. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示