You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc_utils.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/op_desc_utils.h"
  17. #include <algorithm>
  18. #include "debug/ge_attr_define.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/compute_graph.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/node_utils.h"
  27. using std::vector;
  28. /*lint -e512 -e737 -e752*/
  29. namespace ge {
  30. const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
  31. static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  32. bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
  33. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  34. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  35. vector<int> index_list;
  36. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  37. if (in_anchor->GetPeerOutAnchor() == nullptr) {
  38. index_list.push_back(in_anchor->GetIdx());
  39. }
  40. }
  41. std::sort(index_list.begin(), index_list.end());
  42. // Node's in anchor index need shrink
  43. for (size_t i = 0; i < index_list.size(); ++i) {
  44. auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
  45. if (iter < node->GetOpDesc()->inputs_desc_.end()) {
  46. (void)node->GetOpDesc()->inputs_desc_.erase(iter);
  47. } else {
  48. GELOGW("inputs_desc_ iterator out of range.");
  49. }
  50. }
  51. return true;
  52. }
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
  54. const uint32_t index) {
  55. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  56. GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);
  57. auto iter = op_desc->inputs_desc_.begin() + index;
  58. if (iter < op_desc->inputs_desc_.end()) {
  59. (void)op_desc->inputs_desc_.erase(iter);
  60. } else {
  61. GELOGW("inputs_desc_ iterator out of range.");
  62. }
  63. return true;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
  66. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
  67. return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
  68. }
  69. bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
  70. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  71. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  72. vector<int> index_list;
  73. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  74. if (out_anchor->GetPeerInDataAnchors().empty()) {
  75. index_list.push_back(out_anchor->GetIdx());
  76. }
  77. }
  78. std::sort(index_list.begin(), index_list.end());
  79. // Node's out anchor index need shrink
  80. for (size_t i = 0; i < index_list.size(); ++i) {
  81. auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
  82. if (iter < node->GetOpDesc()->outputs_desc_.end()) {
  83. (void)node->GetOpDesc()->outputs_desc_.erase(iter);
  84. } else {
  85. GELOGW("outputs_desc_ iterator out of range.");
  86. }
  87. }
  88. return true;
  89. }
  90. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
  91. uint32_t index) {
  92. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  93. GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);
  94. auto iter = op_desc->outputs_desc_.begin() + index;
  95. if (iter < op_desc->outputs_desc_.end()) {
  96. (void)op_desc->outputs_desc_.erase(iter);
  97. } else {
  98. GELOGW("outputs_desc_ iterator out of range.");
  99. }
  100. return true;
  101. }
  102. bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }
  103. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  104. OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
  105. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  106. GeAttrValue attr_value;
  107. GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  108. "GetQuantizeFactorParams failed");
  109. return attr_value.GetValue<QuantizeFactorParams>(quant);
  110. }
  111. graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
  112. GeAttrValue attr_value;
  113. GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  114. "GetQuantizeFactorParams failed");
  115. return attr_value.GetValue<QuantizeFactorParams>(quant);
  116. }
  117. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  118. OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
  119. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  120. return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  121. }
  122. graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
  123. return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  124. }
  125. GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
  126. GeTensorPtr weight = nullptr;
  127. if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
  128. GELOGW("MutableTensor error");
  129. }
  130. return weight;
  131. }
  132. GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
  133. if (op_desc == nullptr) {
  134. GELOGE(GRAPH_FAILED, "op_desc is null");
  135. return nullptr;
  136. }
  137. return MutableWeights(*op_desc);
  138. }
  139. graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
  140. if (weight == nullptr) {
  141. GELOGE(GRAPH_FAILED, "weight is null");
  142. return GRAPH_FAILED;
  143. }
  144. return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
  145. }
  146. graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
  147. GE_CHECK_NOTNULL(op_desc);
  148. GE_CHECK_NOTNULL(weight);
  149. return SetWeights(*op_desc, weight);
  150. }
  151. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
  152. auto weights = MutableWeights(node);
  153. vector<ConstGeTensorPtr> ret(weights.size());
  154. std::copy(weights.begin(), weights.end(), ret.begin());
  155. return ret;
  156. }
  157. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
  158. const ge::ConstNodePtr &node) {
  159. if (node == nullptr) {
  160. return vector<ge::ConstGeTensorPtr>();
  161. }
  162. return GetWeights(*node);
  163. }
  164. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
  165. const ge::Node &node) {
  166. vector<ge::NodePtr> ret;
  167. auto in_anchors = node.GetAllInDataAnchors();
  168. for (const auto &in_anchor : in_anchors) {
  169. auto out_anchor = in_anchor->GetPeerOutAnchor();
  170. if (out_anchor == nullptr) {
  171. // normally out_anchor could be null, this is ok
  172. GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
  173. continue;
  174. }
  175. auto in_node = out_anchor->GetOwnerNode();
  176. while (true) {
  177. if (in_node == nullptr) {
  178. break;
  179. }
  180. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  181. ret.push_back(in_node);
  182. break;
  183. } else if (in_node->GetType() == DATA) {
  184. if (NodeUtils::IsWhileVaryingInput(in_node)) {
  185. break;
  186. }
  187. in_node = NodeUtils::GetParentInput(in_node);
  188. } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
  189. bool is_constant = false;
  190. (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
  191. if (!is_constant) {
  192. break;
  193. }
  194. // Enter node has and only has one input
  195. if (in_node->GetInDataNodes().size() != 1) {
  196. GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
  197. in_node->GetInDataNodes().size());
  198. break;
  199. }
  200. in_node = in_node->GetInDataNodes().at(0);
  201. } else {
  202. break;
  203. }
  204. }
  205. }
  206. return ret;
  207. }
  208. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
  209. const vector<ge::NodePtr> &input_nodes) {
  210. vector<ConstGeTensorPtr> ret;
  211. for (const auto &input_node : input_nodes) {
  212. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  213. if (temp_weight == nullptr) {
  214. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  215. return vector<ConstGeTensorPtr>();
  216. }
  217. ret.push_back(temp_weight);
  218. }
  219. return ret;
  220. }
  221. size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
  222. if (NodeUtils::IsAnchorStatusSet(node)) {
  223. size_t input_num = 0;
  224. for (const auto &anchor : node.GetAllInDataAnchors()) {
  225. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  226. input_num++;
  227. continue;
  228. }
  229. }
  230. return input_num; // lint !e712
  231. } else {
  232. GE_IF_BOOL_EXEC(
  233. node.GetInDataNodes().size() < GetConstInputs(node).size(),
  234. GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
  235. return 0);
  236. return node.GetInDataNodes().size() - GetConstInputs(node).size();
  237. }
  238. }
  239. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
  240. if (node == nullptr) {
  241. GELOGE(GRAPH_FAILED, "Node is nullptr");
  242. return 0;
  243. }
  244. return GetNonConstInputsSize(*node);
  245. }
  246. GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
  247. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
  248. size_t i = 0;
  249. if (NodeUtils::IsAnchorStatusSet(node)) {
  250. for (const auto &anchor : node.GetAllInDataAnchors()) {
  251. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  252. if (index_non_const == i) {
  253. return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  254. }
  255. ++i;
  256. }
  257. }
  258. } else {
  259. for (const auto &anchor : node.GetAllInDataAnchors()) {
  260. auto peer_anchor = anchor->GetPeerOutAnchor();
  261. if (peer_anchor == nullptr) {
  262. continue;
  263. }
  264. auto owner_node = peer_anchor->GetOwnerNode();
  265. if (owner_node == nullptr) {
  266. continue;
  267. }
  268. if (owner_node->GetType() == CONSTANT) {
  269. continue;
  270. }
  271. if (index_non_const == i) {
  272. return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
  273. }
  274. ++i;
  275. }
  276. }
  277. return GeTensorDesc();
  278. }
  279. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
  280. OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
  281. CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
  282. return GetNonConstInputTensorDesc(*node, index_non_const);
  283. }
  284. bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
  285. bool ret = false;
  286. size_t i = 0;
  287. if (NodeUtils::IsAnchorStatusSet(node)) {
  288. for (const auto &anchor : node.GetAllInDataAnchors()) {
  289. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  290. if (index_non_const == i) {
  291. index = static_cast<size_t>(anchor->GetIdx());
  292. ret = true;
  293. }
  294. ++i;
  295. }
  296. }
  297. } else {
  298. for (const auto &anchor : node.GetAllInDataAnchors()) {
  299. auto peer_anchor = anchor->GetPeerOutAnchor();
  300. if (peer_anchor == nullptr) {
  301. continue;
  302. }
  303. auto owner_node = peer_anchor->GetOwnerNode();
  304. if (owner_node == nullptr) {
  305. continue;
  306. }
  307. if (owner_node->GetType() == CONSTANT) {
  308. continue;
  309. }
  310. if (index_non_const == i) {
  311. index = static_cast<size_t>(anchor->GetIdx());
  312. ret = true;
  313. }
  314. ++i;
  315. }
  316. }
  317. return ret;
  318. }
  319. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
  320. size_t index_non_const,
  321. size_t &index) {
  322. CHECK_FALSE_EXEC(node != nullptr, return false);
  323. return GetNonConstInputIndex(*node, index_non_const, index);
  324. }
  325. bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
  326. bool ret = false;
  327. if (index < node.GetAllInDataAnchors().size()) {
  328. if (NodeUtils::IsAnchorStatusSet(node)) {
  329. ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
  330. } else {
  331. for (const auto &anchor : node.GetAllInDataAnchors()) {
  332. if (anchor->GetIdx() != static_cast<int>(index)) {
  333. continue;
  334. }
  335. auto peer_anchor = anchor->GetPeerOutAnchor();
  336. if (peer_anchor == nullptr) {
  337. break;
  338. }
  339. auto owner_node = peer_anchor->GetOwnerNode();
  340. if (owner_node == nullptr) {
  341. break;
  342. }
  343. ret = (owner_node->GetType() != CONSTANT);
  344. }
  345. }
  346. }
  347. return ret;
  348. }
  349. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
  350. size_t index) {
  351. CHECK_FALSE_EXEC(node != nullptr, return false);
  352. return IsNonConstInput(*node, index);
  353. }
  354. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
  355. const ge::ConstNodePtr &node) {
  356. if (node == nullptr) {
  357. return vector<ge::NodePtr>();
  358. }
  359. return GetConstInputs(*node);
  360. }
  361. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
  362. const ge::ConstNodePtr &node) {
  363. if (node == nullptr || node->GetOpDesc() == nullptr) {
  364. return vector<ge::GeTensorDesc>();
  365. }
  366. vector<ge::GeTensorDesc> ret;
  367. if (NodeUtils::IsAnchorStatusSet(*node)) {
  368. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  369. if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
  370. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  371. }
  372. }
  373. } else {
  374. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  375. auto out_anchor = in_anchor->GetPeerOutAnchor();
  376. if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  377. continue;
  378. }
  379. if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
  380. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  381. }
  382. }
  383. }
  384. return ret;
  385. }
  386. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
  387. vector<ge::NodePtr> ret;
  388. auto in_anchors = node.GetAllInDataAnchors();
  389. for (const auto &in_anchor : in_anchors) {
  390. auto out_anchor = in_anchor->GetPeerOutAnchor();
  391. if (out_anchor == nullptr) continue;
  392. auto in_node = out_anchor->GetOwnerNode();
  393. if (in_node->GetType() == CONSTANT) {
  394. ret.push_back(in_node);
  395. } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
  396. // const --> switch --> matmul
  397. auto switch_input = GetConstInputs(*in_node);
  398. if (switch_input.size() > 0) {
  399. ret.insert(ret.end(), switch_input.begin(), switch_input.end());
  400. }
  401. } else if (in_node->GetType() == DATA) {
  402. auto parent = NodeUtils::GetParentInput(in_node);
  403. if ((parent != nullptr) && (parent->GetType() == CONSTANT)) {
  404. ret.push_back(parent);
  405. }
  406. }
  407. }
  408. return ret;
  409. }
  410. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
  411. vector<GeTensorPtr> ret;
  412. auto op_desc = node.GetOpDesc();
  413. GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
  414. // Place holder operator, try to get the weight from parent node
  415. // when parent node is const operator
  416. if (node.GetType() == PLACEHOLDER) {
  417. std::string parent_op;
  418. (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
  419. // This if judgment is necessary because the current subgraph optimization is multithreaded
  420. // and the parent node of the PLD operation should be a stable type, such as const
  421. if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
  422. NodePtr parent_node = nullptr;
  423. parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
  424. if (parent_node != nullptr) {
  425. op_desc = parent_node->GetOpDesc();
  426. GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
  427. }
  428. }
  429. }
  430. // Const operator, take the weight directly
  431. if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
  432. auto weight = MutableWeights(op_desc);
  433. if (weight == nullptr) {
  434. GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
  435. return ret;
  436. }
  437. ret.push_back(weight);
  438. return ret;
  439. }
  440. if (node.GetType() == DATA) {
  441. auto parent = NodeUtils::GetParentInput(node);
  442. if ((parent != nullptr) && NodeUtils::IsConst(*parent)) {
  443. auto weight = MutableWeights(parent->GetOpDesc());
  444. if (weight == nullptr) {
  445. GELOGI("const op has no weight, op name:%s", parent->GetName().c_str());
  446. return ret;
  447. }
  448. ret.push_back(weight);
  449. }
  450. return ret;
  451. }
  452. // Other operators, get weights from connected constop
  453. auto input_nodes = GetConstInputs(node);
  454. for (const auto &input_node : input_nodes) {
  455. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  456. if (temp_weight == nullptr) {
  457. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  458. return vector<GeTensorPtr>();
  459. }
  460. ret.push_back(temp_weight);
  461. }
  462. return ret;
  463. }
  464. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
  465. if (node == nullptr) {
  466. GELOGE(GRAPH_FAILED, "Node is nullptr");
  467. return vector<ge::GeTensorPtr>();
  468. }
  469. return MutableWeights(*node);
  470. }
  471. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  472. OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
  473. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
  474. if (node.GetOpDesc()->GetType() == CONSTANT) {
  475. if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  476. return SetWeights(node.GetOpDesc(), weights[0]);
  477. }
  478. GELOGI("const op weight size %zu should be 1", weights.size());
  479. return GRAPH_PARAM_INVALID;
  480. }
  481. auto input_nodes = GetConstInputs(node);
  482. if (weights.size() < input_nodes.size()) {
  483. GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
  484. return GRAPH_PARAM_INVALID;
  485. }
  486. ge::GeAttrValue::NAMED_ATTRS named_attrs;
  487. (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
  488. vector<ge::GeTensorPtr> copy_weights;
  489. (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);
  490. for (size_t i = 0; i < input_nodes.size(); ++i) {
  491. if (input_nodes[i]->GetOpDesc() != nullptr) {
  492. SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
  493. }
  494. }
  495. // If set more weights than constop, need to add constop
  496. for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
  497. // Use org weight before SetWeights Overwrite
  498. auto const_opdesc = CreateConstOp(copy_weights[i]);
  499. GE_CHECK_NOTNULL(const_opdesc);
  500. auto owner_graph = node.GetOwnerComputeGraph();
  501. if (owner_graph == nullptr) {
  502. GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
  503. return GRAPH_PARAM_INVALID;
  504. }
  505. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  506. GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!");
  507. std::vector<ge::NodePtr> original_nodes;
  508. ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
  509. }
  510. return GRAPH_SUCCESS;
  511. }
  512. OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
  513. GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
  514. shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
  515. if (const_opdesc == nullptr) {
  516. GELOGE(GRAPH_FAILED, "failed to make_shared ");
  517. return nullptr;
  518. }
  519. CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);
  520. const_opdesc->SetType(CONSTANT);
  521. thread_local int64_t const_count = 0;
  522. const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count));
  523. GELOGI("add const op: %s", const_opdesc->GetName().c_str());
  524. ++const_count;
  525. (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());
  526. GELOGI("after add const op: %s", const_opdesc->GetName().c_str());
  527. return const_opdesc;
  528. }
  529. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  530. OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
  531. GE_CHECK_NOTNULL(in_anchor);
  532. GE_CHECK_NOTNULL(tensor_ptr);
  533. auto const_opdesc = CreateConstOp(tensor_ptr);
  534. GE_CHECK_NOTNULL(const_opdesc);
  535. auto in_node = in_anchor->GetOwnerNode();
  536. GE_CHECK_NOTNULL(in_node);
  537. auto owner_graph = in_node->GetOwnerComputeGraph();
  538. if (owner_graph == nullptr) {
  539. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
  540. return GRAPH_PARAM_INVALID;
  541. }
  542. auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
  543. GE_CHECK_NOTNULL(const_node);
  544. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  545. GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
  546. return GRAPH_PARAM_INVALID;
  547. }
  548. return GRAPH_SUCCESS;
  549. }
  550. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  551. OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
  552. GE_CHECK_NOTNULL(node);
  553. return SetWeights(*node, weights);
  554. }
  555. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
  556. GE_CHECK_NOTNULL(node);
  557. auto const_ops = GetConstInputs(node);
  558. auto graph = node->GetOwnerComputeGraph();
  559. if (graph == nullptr) {
  560. GELOGE(GRAPH_FAILED, "Graph is nullptr");
  561. return GRAPH_PARAM_INVALID;
  562. }
  563. for (const auto &const_op : const_ops) {
  564. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
  565. const_op->GetName().c_str(), const_op->GetType().c_str());
  566. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
  567. "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
  568. const_op->GetType().c_str());
  569. }
  570. return GRAPH_SUCCESS;
  571. }
  572. ///
  573. /// @brief Add input
  574. /// @param [in] name
  575. /// @return OpDescBuilder
  576. ///
  577. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
  578. inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  579. return *this;
  580. }
  581. ///
  582. /// @brief Add input
  583. /// @param [in] name
  584. /// @param [in] tensor
  585. /// @return OpDescBuilder
  586. ///
  587. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name,
  588. const GeTensorDesc &tensor) {
  589. inputs_.emplace_back(std::make_pair(name, tensor));
  590. return *this;
  591. }
  592. ///
  593. /// @brief Add dynamic input
  594. /// @param [in] name
  595. /// @param [in] num
  596. /// @return OpDescBuilder
  597. ///
  598. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
  599. uint32_t num) {
  600. for (uint32_t i = 0; i < num; i++) {
  601. inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  602. }
  603. return *this;
  604. }
  605. ///
  606. /// @brief Add dynamic input
  607. /// @param [in] name
  608. /// @param [in] num
  609. /// @param [in] tensor
  610. /// @return OpDescBuilder
  611. ///
  612. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(
  613. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  614. for (uint32_t i = 0; i < num; i++) {
  615. inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  616. }
  617. return *this;
  618. }
  619. ///
  620. /// @brief Add output
  621. /// @param [in] name
  622. /// @return OpDescBuilder
  623. ///
  624. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
  625. outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  626. return *this;
  627. }
  628. ///
  629. /// @brief Add output
  630. /// @param [in] name
  631. /// @param [in] tensor
  632. /// @return OpDescBuilder
  633. ///
  634. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name,
  635. const GeTensorDesc &tensor) {
  636. outputs_.emplace_back(std::make_pair(name, tensor));
  637. return *this;
  638. }
  639. ///
  640. /// @brief Add dynamic output
  641. /// @param [in] name
  642. /// @param [in] num
  643. /// @return OpDescBuilder
  644. ///
  645. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
  646. uint32_t num) {
  647. for (uint32_t i = 0; i < num; i++) {
  648. outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  649. }
  650. return *this;
  651. }
  652. ///
  653. /// @brief Add dynamic output
  654. /// @param [in] name
  655. /// @param [in] num
  656. /// @param [in] tensor
  657. /// @return OpDescBuilder
  658. ///
  659. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(
  660. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  661. for (uint32_t i = 0; i < num; i++) {
  662. outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  663. }
  664. return *this;
  665. }
  666. ///
  667. /// @brief Build op_desc
  668. /// @return OpDescPtr
  669. ///
  670. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
  671. OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
  672. if (op_desc == nullptr) {
  673. GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
  674. return nullptr;
  675. }
  676. for (auto &input : inputs_) {
  677. if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
  678. GELOGE(GRAPH_FAILED, "Add input_desc failed.");
  679. return nullptr;
  680. }
  681. }
  682. for (auto &output : outputs_) {
  683. if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
  684. GELOGE(GRAPH_FAILED, "Add output_desc failed.");
  685. return nullptr;
  686. }
  687. }
  688. return op_desc;
  689. }
  690. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(
  691. const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) {
  692. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  693. auto iter = subgraph_names_to_index.find(subgraph_name);
  694. if (iter == subgraph_names_to_index.end()) {
  695. GELOGE(GRAPH_PARAM_INVALID,
  696. "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
  697. subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  698. subgraph_name.c_str());
  699. return GRAPH_PARAM_INVALID;
  700. }
  701. return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
  702. }
  703. } // namespace ge
  704. /*lint +e512 +e737 +e752*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示