You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc_utils.cc 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/op_desc_utils.h"
  17. #include <algorithm>
  18. #include "debug/ge_attr_define.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "graph/anchor.h"
  24. #include "graph/compute_graph.h"
  25. #include "graph/ge_attr_value.h"
  26. #include "utils/graph_utils.h"
  27. #include "utils/node_utils.h"
  28. using std::vector;
  29. /*lint -e512 -e737 -e752*/
  30. namespace ge {
  31. const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
  32. namespace {
  33. const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  34. }
  35. bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
  36. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  37. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  38. vector<int> index_list;
  39. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  40. if (in_anchor->GetPeerOutAnchor() == nullptr) {
  41. index_list.push_back(in_anchor->GetIdx());
  42. }
  43. }
  44. std::sort(index_list.begin(), index_list.end());
  45. // Node's in anchor index need shrink
  46. for (size_t i = 0; i < index_list.size(); ++i) {
  47. auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
  48. if (iter < node->GetOpDesc()->inputs_desc_.end()) {
  49. (void)node->GetOpDesc()->inputs_desc_.erase(iter);
  50. } else {
  51. GELOGW("inputs_desc_ iterator out of range.");
  52. }
  53. }
  54. return true;
  55. }
  56. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
  57. const uint32_t index) {
  58. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  59. GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);
  60. auto iter = op_desc->inputs_desc_.begin() + index;
  61. if (iter < op_desc->inputs_desc_.end()) {
  62. (void)op_desc->inputs_desc_.erase(iter);
  63. } else {
  64. GELOGW("inputs_desc_ iterator out of range.");
  65. }
  66. return true;
  67. }
  68. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
  69. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
  70. return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
  71. }
  72. bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
  73. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  74. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  75. vector<int> index_list;
  76. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  77. if (out_anchor->GetPeerInDataAnchors().empty()) {
  78. index_list.push_back(out_anchor->GetIdx());
  79. }
  80. }
  81. std::sort(index_list.begin(), index_list.end());
  82. // Node's out anchor index need shrink
  83. for (size_t i = 0; i < index_list.size(); ++i) {
  84. auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
  85. if (iter < node->GetOpDesc()->outputs_desc_.end()) {
  86. (void)node->GetOpDesc()->outputs_desc_.erase(iter);
  87. } else {
  88. GELOGW("outputs_desc_ iterator out of range.");
  89. }
  90. }
  91. return true;
  92. }
  93. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
  94. uint32_t index) {
  95. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  96. GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);
  97. auto iter = op_desc->outputs_desc_.begin() + index;
  98. if (iter < op_desc->outputs_desc_.end()) {
  99. (void)op_desc->outputs_desc_.erase(iter);
  100. } else {
  101. GELOGW("outputs_desc_ iterator out of range.");
  102. }
  103. return true;
  104. }
  105. bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }
  106. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  107. OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
  108. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  109. GeAttrValue attr_value;
  110. GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  111. "GetQuantizeFactorParams failed");
  112. return attr_value.GetValue<QuantizeFactorParams>(quant);
  113. }
  114. graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
  115. GeAttrValue attr_value;
  116. GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  117. "GetQuantizeFactorParams failed");
  118. return attr_value.GetValue<QuantizeFactorParams>(quant);
  119. }
  120. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  121. OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
  122. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  123. return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  124. }
  125. graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
  126. return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  127. }
  128. GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
  129. GeTensorPtr weight = nullptr;
  130. if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
  131. GELOGW("MutableTensor error");
  132. }
  133. return weight;
  134. }
  135. GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
  136. if (op_desc == nullptr) {
  137. GELOGE(GRAPH_FAILED, "op_desc is null");
  138. return nullptr;
  139. }
  140. return MutableWeights(*op_desc);
  141. }
  142. graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
  143. if (weight == nullptr) {
  144. GELOGE(GRAPH_FAILED, "weight is null");
  145. return GRAPH_FAILED;
  146. }
  147. return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
  148. }
  149. graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
  150. GE_CHECK_NOTNULL(op_desc);
  151. GE_CHECK_NOTNULL(weight);
  152. return SetWeights(*op_desc, weight);
  153. }
  154. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
  155. auto weights = MutableWeights(node);
  156. vector<ConstGeTensorPtr> ret(weights.size());
  157. std::copy(weights.begin(), weights.end(), ret.begin());
  158. return ret;
  159. }
  160. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
  161. const ge::ConstNodePtr &node) {
  162. if (node == nullptr) {
  163. return vector<ge::ConstGeTensorPtr>();
  164. }
  165. return GetWeights(*node);
  166. }
  167. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
  168. const ge::Node &node) {
  169. vector<ge::NodePtr> ret;
  170. auto in_anchors = node.GetAllInDataAnchors();
  171. for (const auto &in_anchor : in_anchors) {
  172. auto out_anchor = in_anchor->GetPeerOutAnchor();
  173. if (out_anchor == nullptr) {
  174. // normally out_anchor could be null, this is ok
  175. GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
  176. continue;
  177. }
  178. auto in_node = out_anchor->GetOwnerNode();
  179. while (true) {
  180. if (in_node == nullptr) {
  181. break;
  182. }
  183. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  184. ret.push_back(in_node);
  185. break;
  186. } else if (in_node->GetType() == DATA) {
  187. if (NodeUtils::IsWhileVaryingInput(in_node)) {
  188. break;
  189. }
  190. in_node = NodeUtils::GetParentInput(in_node);
  191. } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
  192. bool is_constant = false;
  193. (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
  194. if (!is_constant) {
  195. break;
  196. }
  197. // Enter node has and only has one input
  198. if (in_node->GetInDataNodes().size() != 1) {
  199. GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
  200. in_node->GetInDataNodes().size());
  201. break;
  202. }
  203. in_node = in_node->GetInDataNodes().at(0);
  204. } else {
  205. break;
  206. }
  207. }
  208. }
  209. return ret;
  210. }
  211. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
  212. const vector<ge::NodePtr> &input_nodes) {
  213. vector<ConstGeTensorPtr> ret;
  214. for (const auto &input_node : input_nodes) {
  215. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  216. if (temp_weight == nullptr) {
  217. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  218. return vector<ConstGeTensorPtr>();
  219. }
  220. ret.push_back(temp_weight);
  221. }
  222. return ret;
  223. }
  224. size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
  225. if (NodeUtils::IsAnchorStatusSet(node)) {
  226. size_t input_num = 0;
  227. for (const auto &anchor : node.GetAllInDataAnchors()) {
  228. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  229. input_num++;
  230. continue;
  231. }
  232. }
  233. return input_num; // lint !e712
  234. } else {
  235. GE_IF_BOOL_EXEC(
  236. node.GetInDataNodes().size() < GetConstInputs(node).size(),
  237. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  238. {"GetNonConstInputsSize", "InDataNodes size[" + std::to_string(node.GetInDataNodes().size()) +
  239. "] is smaller than ConstInputs[" + std::to_string(GetConstInputs(node).size()) + "]"});
  240. GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
  241. return 0);
  242. return node.GetInDataNodes().size() - GetConstInputs(node).size();
  243. }
  244. }
  245. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
  246. if (node == nullptr) {
  247. GELOGE(GRAPH_FAILED, "Node is nullptr");
  248. return 0;
  249. }
  250. return GetNonConstInputsSize(*node);
  251. }
  252. GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
  253. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
  254. size_t i = 0;
  255. if (NodeUtils::IsAnchorStatusSet(node)) {
  256. for (const auto &anchor : node.GetAllInDataAnchors()) {
  257. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  258. if (index_non_const == i) {
  259. return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  260. }
  261. ++i;
  262. }
  263. }
  264. } else {
  265. for (const auto &anchor : node.GetAllInDataAnchors()) {
  266. auto peer_anchor = anchor->GetPeerOutAnchor();
  267. if (peer_anchor == nullptr) {
  268. continue;
  269. }
  270. auto owner_node = peer_anchor->GetOwnerNode();
  271. if (owner_node == nullptr) {
  272. continue;
  273. }
  274. if (owner_node->GetType() == CONSTANT) {
  275. continue;
  276. }
  277. if (index_non_const == i) {
  278. return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
  279. }
  280. ++i;
  281. }
  282. }
  283. return GeTensorDesc();
  284. }
  285. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
  286. OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
  287. CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
  288. return GetNonConstInputTensorDesc(*node, index_non_const);
  289. }
  290. bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
  291. bool ret = false;
  292. size_t i = 0;
  293. if (NodeUtils::IsAnchorStatusSet(node)) {
  294. for (const auto &anchor : node.GetAllInDataAnchors()) {
  295. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  296. if (index_non_const == i) {
  297. index = static_cast<size_t>(anchor->GetIdx());
  298. ret = true;
  299. }
  300. ++i;
  301. }
  302. }
  303. } else {
  304. for (const auto &anchor : node.GetAllInDataAnchors()) {
  305. auto peer_anchor = anchor->GetPeerOutAnchor();
  306. if (peer_anchor == nullptr) {
  307. continue;
  308. }
  309. auto owner_node = peer_anchor->GetOwnerNode();
  310. if (owner_node == nullptr) {
  311. continue;
  312. }
  313. if (owner_node->GetType() == CONSTANT) {
  314. continue;
  315. }
  316. if (index_non_const == i) {
  317. index = static_cast<size_t>(anchor->GetIdx());
  318. ret = true;
  319. }
  320. ++i;
  321. }
  322. }
  323. return ret;
  324. }
  325. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
  326. size_t index_non_const,
  327. size_t &index) {
  328. CHECK_FALSE_EXEC(node != nullptr, return false);
  329. return GetNonConstInputIndex(*node, index_non_const, index);
  330. }
  331. bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
  332. bool ret = false;
  333. if (index < node.GetAllInDataAnchors().size()) {
  334. if (NodeUtils::IsAnchorStatusSet(node)) {
  335. ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
  336. } else {
  337. for (const auto &anchor : node.GetAllInDataAnchors()) {
  338. if (anchor->GetIdx() != static_cast<int>(index)) {
  339. continue;
  340. }
  341. auto peer_anchor = anchor->GetPeerOutAnchor();
  342. if (peer_anchor == nullptr) {
  343. break;
  344. }
  345. auto owner_node = peer_anchor->GetOwnerNode();
  346. if (owner_node == nullptr) {
  347. break;
  348. }
  349. ret = (owner_node->GetType() != CONSTANT);
  350. }
  351. }
  352. }
  353. return ret;
  354. }
  355. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
  356. size_t index) {
  357. CHECK_FALSE_EXEC(node != nullptr, return false);
  358. return IsNonConstInput(*node, index);
  359. }
  360. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
  361. const ge::ConstNodePtr &node) {
  362. if (node == nullptr) {
  363. return vector<ge::NodePtr>();
  364. }
  365. return GetConstInputs(*node);
  366. }
  367. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
  368. const ge::ConstNodePtr &node) {
  369. if (node == nullptr || node->GetOpDesc() == nullptr) {
  370. return vector<ge::GeTensorDesc>();
  371. }
  372. vector<ge::GeTensorDesc> ret;
  373. if (NodeUtils::IsAnchorStatusSet(*node)) {
  374. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  375. if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
  376. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  377. }
  378. }
  379. } else {
  380. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  381. auto out_anchor = in_anchor->GetPeerOutAnchor();
  382. if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  383. continue;
  384. }
  385. if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
  386. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  387. }
  388. }
  389. }
  390. return ret;
  391. }
  392. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
  393. vector<ge::NodePtr> ret;
  394. auto in_anchors = node.GetAllInDataAnchors();
  395. for (const auto &in_anchor : in_anchors) {
  396. auto out_anchor = in_anchor->GetPeerOutAnchor();
  397. if (out_anchor == nullptr) continue;
  398. auto in_node = out_anchor->GetOwnerNode();
  399. if (in_node->GetType() == CONSTANT) {
  400. ret.push_back(in_node);
  401. } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
  402. // const --> switch --> matmul
  403. auto switch_input = GetConstInputs(*in_node);
  404. if (switch_input.size() > 0) {
  405. ret.insert(ret.end(), switch_input.begin(), switch_input.end());
  406. }
  407. } else if (in_node->GetType() == DATA) {
  408. auto parent = NodeUtils::GetParentInput(in_node);
  409. if ((parent != nullptr) && (parent->GetType() == CONSTANT)) {
  410. ret.push_back(parent);
  411. }
  412. }
  413. }
  414. return ret;
  415. }
  416. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
  417. vector<GeTensorPtr> ret;
  418. auto op_desc = node.GetOpDesc();
  419. GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
  420. // Place holder operator, try to get the weight from parent node
  421. // when parent node is const operator
  422. if (node.GetType() == PLACEHOLDER) {
  423. std::string parent_op;
  424. (void) AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
  425. // This if judgment is necessary because the current subgraph optimization is multithreaded
  426. // and the parent node of the PLD operation should be a stable type, such as const
  427. if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
  428. NodePtr parent_node = nullptr;
  429. parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
  430. if (parent_node != nullptr) {
  431. op_desc = parent_node->GetOpDesc();
  432. GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
  433. }
  434. }
  435. }
  436. // Const operator, take the weight directly
  437. if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
  438. auto weight = MutableWeights(op_desc);
  439. if (weight == nullptr) {
  440. GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
  441. return ret;
  442. }
  443. ret.push_back(weight);
  444. return ret;
  445. }
  446. if (node.GetType() == DATA) {
  447. auto parent = NodeUtils::GetParentInput(node);
  448. if ((parent != nullptr) && NodeUtils::IsConst(*parent)) {
  449. auto weight = MutableWeights(parent->GetOpDesc());
  450. if (weight == nullptr) {
  451. GELOGI("const op has no weight, op name:%s", parent->GetName().c_str());
  452. return ret;
  453. }
  454. ret.push_back(weight);
  455. }
  456. return ret;
  457. }
  458. // Other operators, get weights from connected constop
  459. auto input_nodes = GetConstInputs(node);
  460. for (const auto &input_node : input_nodes) {
  461. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  462. if (temp_weight == nullptr) {
  463. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  464. {"MutableWeights", "const op[" + input_node->GetName() + "]'s weight is null"});
  465. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  466. return vector<GeTensorPtr>();
  467. }
  468. ret.push_back(temp_weight);
  469. }
  470. return ret;
  471. }
  472. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
  473. if (node == nullptr) {
  474. GELOGE(GRAPH_FAILED, "Node is nullptr");
  475. return vector<ge::GeTensorPtr>();
  476. }
  477. return MutableWeights(*node);
  478. }
  479. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  480. OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
  481. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
  482. if (node.GetOpDesc()->GetType() == CONSTANT) {
  483. if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  484. return SetWeights(node.GetOpDesc(), weights[0]);
  485. }
  486. GELOGI("const op weight size %zu should be 1", weights.size());
  487. return GRAPH_PARAM_INVALID;
  488. }
  489. auto input_nodes = GetConstInputs(node);
  490. if (weights.size() < input_nodes.size()) {
  491. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  492. {"SetWeights", "weights count can't be less than const input count"});
  493. GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
  494. return GRAPH_PARAM_INVALID;
  495. }
  496. ge::GeAttrValue::NAMED_ATTRS named_attrs;
  497. (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
  498. vector<ge::GeTensorPtr> copy_weights;
  499. (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);
  500. for (size_t i = 0; i < input_nodes.size(); ++i) {
  501. if (input_nodes[i]->GetOpDesc() != nullptr) {
  502. SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
  503. }
  504. }
  505. // If set more weights than constop, need to add constop
  506. for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
  507. // Use org weight before SetWeights Overwrite
  508. auto const_opdesc = CreateConstOp(copy_weights[i]);
  509. GE_CHECK_NOTNULL(const_opdesc);
  510. auto owner_graph = node.GetOwnerComputeGraph();
  511. if (owner_graph == nullptr) {
  512. GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
  513. return GRAPH_PARAM_INVALID;
  514. }
  515. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  516. GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS,
  517. GELOGE(GRAPH_FAILED, "graph add link failed!");
  518. return GRAPH_FAILED);
  519. std::vector<ge::NodePtr> original_nodes;
  520. ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
  521. }
  522. return GRAPH_SUCCESS;
  523. }
  524. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  525. OpDescUtils::SetWeights(ge::Node &node, const map<int, ge::GeTensorPtr> &weights_map) {
  526. GE_CHECK_NOTNULL(node.GetOpDesc());
  527. // 1. node is const
  528. if (node.GetOpDesc()->GetType() == CONSTANT) {
  529. if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  530. return SetWeights(node.GetOpDesc(), weights_map.begin()->second);
  531. }
  532. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  533. {"SetWeights", "const op[" + node.GetName() + "] weight size[" +
  534. std::to_string(weights_map.size()) + "] should be 1"});
  535. GELOGE(GRAPH_PARAM_INVALID, "const op %s weight size %zu should be 1", node.GetName().c_str(), weights_map.size());
  536. return GRAPH_PARAM_INVALID;
  537. }
  538. // 2. node is not const
  539. for (const auto &pair:weights_map) {
  540. auto in_data_anchor = node.GetInDataAnchor(pair.first);
  541. if (in_data_anchor != nullptr && in_data_anchor->GetPeerOutAnchor() != nullptr) {
  542. // a. update const input node
  543. auto out_anchor = in_data_anchor->GetPeerOutAnchor();
  544. auto peer_node = out_anchor->GetOwnerNode();
  545. if (peer_node == nullptr) {
  546. GELOGE(GRAPH_PARAM_INVALID, "op %s [%d]'s input node is null", node.GetName().c_str(), pair.first);
  547. return GRAPH_PARAM_INVALID;
  548. }
  549. if (peer_node->GetType() != CONSTANT) {
  550. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  551. {"SetWeights", "op[" + node.GetName() + "] [" + std::to_string(pair.first) +
  552. "]'s input node should be const, but real op is " +
  553. peer_node->GetName() + ", type is " + peer_node->GetType()});
  554. GELOGE(GRAPH_PARAM_INVALID,
  555. " op %s [%d]'s input node should be const, but is %s type:%s ", node.GetName().c_str(),
  556. pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str());
  557. }
  558. SetWeights(peer_node->GetOpDesc(), pair.second);
  559. } else {
  560. // b. create new const input node
  561. auto const_opdesc = CreateConstOp(pair.second);
  562. GE_CHECK_NOTNULL(const_opdesc);
  563. auto owner_graph = node.GetOwnerComputeGraph();
  564. if (owner_graph == nullptr) {
  565. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", node.GetName().c_str());
  566. return GRAPH_PARAM_INVALID;
  567. }
  568. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  569. if (node.AddLinkFrom(static_cast<uint32_t>(pair.first), const_node) != GRAPH_SUCCESS) {
  570. GELOGE(GRAPH_FAILED, "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first);
  571. return GRAPH_FAILED;
  572. }
  573. }
  574. }
  575. NodeUtils::UpdateIsInputConst(node);
  576. return GRAPH_SUCCESS;
  577. }
  578. OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
  579. GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
  580. shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
  581. if (const_opdesc == nullptr) {
  582. GELOGE(GRAPH_FAILED, "failed to make_shared ");
  583. return nullptr;
  584. }
  585. CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);
  586. const_opdesc->SetType(CONSTANT);
  587. thread_local int64_t const_count = 0;
  588. const_opdesc->SetName("dynamic_const_" + std::to_string(GeLog::GetTid()) + "_" + std::to_string(const_count));
  589. GELOGI("add const op: %s", const_opdesc->GetName().c_str());
  590. ++const_count;
  591. (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());
  592. GELOGI("after add const op: %s", const_opdesc->GetName().c_str());
  593. return const_opdesc;
  594. }
  595. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  596. OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
  597. GE_CHECK_NOTNULL(in_anchor);
  598. GE_CHECK_NOTNULL(tensor_ptr);
  599. auto const_opdesc = CreateConstOp(tensor_ptr);
  600. GE_CHECK_NOTNULL(const_opdesc);
  601. auto in_node = in_anchor->GetOwnerNode();
  602. GE_CHECK_NOTNULL(in_node);
  603. auto owner_graph = in_node->GetOwnerComputeGraph();
  604. if (owner_graph == nullptr) {
  605. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
  606. return GRAPH_PARAM_INVALID;
  607. }
  608. auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
  609. GE_CHECK_NOTNULL(const_node);
  610. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  611. GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
  612. return GRAPH_PARAM_INVALID;
  613. }
  614. return GRAPH_SUCCESS;
  615. }
  616. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  617. OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
  618. GE_CHECK_NOTNULL(node);
  619. return SetWeights(*node, weights);
  620. }
  621. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
  622. GE_CHECK_NOTNULL(node);
  623. auto const_ops = GetConstInputs(node);
  624. auto graph = node->GetOwnerComputeGraph();
  625. if (graph == nullptr) {
  626. GELOGE(GRAPH_FAILED, "Graph is nullptr");
  627. return GRAPH_PARAM_INVALID;
  628. }
  629. for (const auto &const_op : const_ops) {
  630. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
  631. const_op->GetName().c_str(), const_op->GetType().c_str());
  632. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
  633. "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
  634. const_op->GetType().c_str());
  635. }
  636. return GRAPH_SUCCESS;
  637. }
  638. ///
  639. /// @brief Add input
  640. /// @param [in] name
  641. /// @return OpDescBuilder
  642. ///
  643. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddInput(const std::string &name) {
  644. inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  645. return *this;
  646. }
  647. ///
  648. /// @brief Add input
  649. /// @param [in] name
  650. /// @param [in] tensor
  651. /// @return OpDescBuilder
  652. ///
  653. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  654. OpDescBuilder& OpDescBuilder::AddInput(const std::string &name, const GeTensorDesc &tensor) {
  655. inputs_.emplace_back(std::make_pair(name, tensor));
  656. return *this;
  657. }
  658. ///
  659. /// @brief Add dynamic input
  660. /// @param [in] name
  661. /// @param [in] num
  662. /// @return OpDescBuilder
  663. ///
  664. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name,
  665. uint32_t num) {
  666. for (uint32_t i = 0; i < num; i++) {
  667. inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  668. }
  669. return *this;
  670. }
  671. ///
  672. /// @brief Add dynamic input
  673. /// @param [in] name
  674. /// @param [in] num
  675. /// @param [in] tensor
  676. /// @return OpDescBuilder
  677. ///
  678. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  679. OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  680. for (uint32_t i = 0; i < num; i++) {
  681. inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  682. }
  683. return *this;
  684. }
  685. ///
  686. /// @brief Add output
  687. /// @param [in] name
  688. /// @return OpDescBuilder
  689. ///
  690. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name) {
  691. outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  692. return *this;
  693. }
  694. ///
  695. /// @brief Add output
  696. /// @param [in] name
  697. /// @param [in] tensor
  698. /// @return OpDescBuilder
  699. ///
  700. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  701. OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name, const GeTensorDesc &tensor) {
  702. outputs_.emplace_back(std::make_pair(name, tensor));
  703. return *this;
  704. }
  705. ///
  706. /// @brief Add dynamic output
  707. /// @param [in] name
  708. /// @param [in] num
  709. /// @return OpDescBuilder
  710. ///
  711. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name,
  712. uint32_t num) {
  713. for (uint32_t i = 0; i < num; i++) {
  714. outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  715. }
  716. return *this;
  717. }
  718. ///
  719. /// @brief Add dynamic output
  720. /// @param [in] name
  721. /// @param [in] num
  722. /// @param [in] tensor
  723. /// @return OpDescBuilder
  724. ///
  725. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  726. OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  727. for (uint32_t i = 0; i < num; i++) {
  728. outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  729. }
  730. return *this;
  731. }
  732. ///
  733. /// @brief Build op_desc
  734. /// @return OpDescPtr
  735. ///
  736. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
  737. OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
  738. if (op_desc == nullptr) {
  739. GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
  740. return nullptr;
  741. }
  742. for (auto &input : inputs_) {
  743. if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
  744. GELOGE(GRAPH_FAILED, "Add input_desc failed.");
  745. return nullptr;
  746. }
  747. }
  748. for (auto &output : outputs_) {
  749. if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
  750. GELOGE(GRAPH_FAILED, "Add output_desc failed.");
  751. return nullptr;
  752. }
  753. }
  754. return op_desc;
  755. }
  756. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  757. graphStatus OpDescUtils::SetSubgraphInstanceName(const std::string &subgraph_name,
  758. const std::string &subgraph_instance_name,
  759. OpDescPtr &op_desc) {
  760. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  761. auto iter = subgraph_names_to_index.find(subgraph_name);
  762. if (iter == subgraph_names_to_index.end()) {
  763. ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
  764. {"SetSubgraphInstanceName", "subgraph name[" + subgraph_name + "] is not exists."
  765. "The op is " + op_desc->GetName() + ", type is " + op_desc->GetType() +
  766. ", subgraph is " + subgraph_instance_name});
  767. GELOGE(GRAPH_PARAM_INVALID,
  768. "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
  769. subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), subgraph_name.c_str());
  770. return GRAPH_PARAM_INVALID;
  771. }
  772. return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
  773. }
  774. } // namespace ge
  775. /*lint +e512 +e737 +e752*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示