You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc_utils.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/op_desc_utils.h"
  17. #include <algorithm>
  18. #include "debug/ge_attr_define.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/compute_graph.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/node_utils.h"
  27. using std::vector;
  28. namespace ge {
  29. const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
  30. static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  31. bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
  32. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  33. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  34. vector<int> index_list;
  35. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  36. if (in_anchor->GetPeerOutAnchor() == nullptr) {
  37. index_list.push_back(in_anchor->GetIdx());
  38. }
  39. }
  40. std::sort(index_list.begin(), index_list.end());
  41. // Node's in anchor index need shrink
  42. for (size_t i = 0; i < index_list.size(); ++i) {
  43. auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
  44. if (iter < node->GetOpDesc()->inputs_desc_.end()) {
  45. (void)node->GetOpDesc()->inputs_desc_.erase(iter);
  46. } else {
  47. GELOGW("inputs_desc_ iterator out of range.");
  48. }
  49. }
  50. return true;
  51. }
  52. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
  53. const uint32_t index) {
  54. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  55. GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);
  56. auto iter = op_desc->inputs_desc_.begin() + index;
  57. if (iter < op_desc->inputs_desc_.end()) {
  58. (void)op_desc->inputs_desc_.erase(iter);
  59. } else {
  60. GELOGW("inputs_desc_ iterator out of range.");
  61. }
  62. return true;
  63. }
  64. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
  65. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
  66. return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
  67. }
  68. bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
  69. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  70. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  71. vector<int> index_list;
  72. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  73. if (out_anchor->GetPeerInDataAnchors().empty()) {
  74. index_list.push_back(out_anchor->GetIdx());
  75. }
  76. }
  77. std::sort(index_list.begin(), index_list.end());
  78. // Node's out anchor index need shrink
  79. for (size_t i = 0; i < index_list.size(); ++i) {
  80. auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
  81. if (iter < node->GetOpDesc()->outputs_desc_.end()) {
  82. (void)node->GetOpDesc()->outputs_desc_.erase(iter);
  83. } else {
  84. GELOGW("outputs_desc_ iterator out of range.");
  85. }
  86. }
  87. return true;
  88. }
  89. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
  90. uint32_t index) {
  91. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  92. GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);
  93. auto iter = op_desc->outputs_desc_.begin() + index;
  94. if (iter < op_desc->outputs_desc_.end()) {
  95. (void)op_desc->outputs_desc_.erase(iter);
  96. } else {
  97. GELOGW("outputs_desc_ iterator out of range.");
  98. }
  99. return true;
  100. }
  101. bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }
  102. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  103. OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
  104. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  105. GeAttrValue attr_value;
  106. GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  107. "GetQuantizeFactorParams failed");
  108. return attr_value.GetValue<QuantizeFactorParams>(quant);
  109. }
  110. graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
  111. GeAttrValue attr_value;
  112. GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  113. "GetQuantizeFactorParams failed");
  114. return attr_value.GetValue<QuantizeFactorParams>(quant);
  115. }
  116. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  117. OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
  118. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  119. return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
  120. }
  121. graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
  122. return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
  123. }
  124. GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
  125. GeTensorPtr weight = nullptr;
  126. if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
  127. GELOGW("MutableTensor error");
  128. }
  129. return weight;
  130. }
  131. GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
  132. if (op_desc == nullptr) {
  133. GELOGE(GRAPH_FAILED, "op_desc is null");
  134. return nullptr;
  135. }
  136. return MutableWeights(*op_desc);
  137. }
  138. graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
  139. if (weight == nullptr) {
  140. GELOGE(GRAPH_FAILED, "weight is null");
  141. return GRAPH_FAILED;
  142. }
  143. return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
  144. }
  145. graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
  146. GE_CHECK_NOTNULL(op_desc);
  147. GE_CHECK_NOTNULL(weight);
  148. return SetWeights(*op_desc, weight);
  149. }
  150. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
  151. auto weights = MutableWeights(node);
  152. vector<ConstGeTensorPtr> ret(weights.size());
  153. std::copy(weights.begin(), weights.end(), ret.begin());
  154. return ret;
  155. }
  156. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
  157. const ge::ConstNodePtr &node) {
  158. if (node == nullptr) {
  159. return vector<ge::ConstGeTensorPtr>();
  160. }
  161. return GetWeights(*node);
  162. }
  163. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
  164. const ge::Node &node) {
  165. vector<ge::NodePtr> ret;
  166. auto in_anchors = node.GetAllInDataAnchors();
  167. for (const auto &in_anchor : in_anchors) {
  168. auto out_anchor = in_anchor->GetPeerOutAnchor();
  169. if (out_anchor == nullptr) {
  170. // normally out_anchor could be null, this is ok
  171. GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
  172. continue;
  173. }
  174. auto in_node = out_anchor->GetOwnerNode();
  175. while (true) {
  176. if (in_node == nullptr) {
  177. break;
  178. }
  179. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  180. ret.push_back(in_node);
  181. break;
  182. } else if (in_node->GetType() == DATA) {
  183. if (NodeUtils::IsWhileVaryingInput(in_node)) {
  184. break;
  185. }
  186. in_node = NodeUtils::GetParentInput(in_node);
  187. } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
  188. bool is_constant = false;
  189. (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
  190. if (!is_constant) {
  191. break;
  192. }
  193. // Enter node has and only has one input
  194. if (in_node->GetInDataNodes().size() != 1) {
  195. GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
  196. in_node->GetInDataNodes().size());
  197. break;
  198. }
  199. in_node = in_node->GetInDataNodes().at(0);
  200. } else {
  201. break;
  202. }
  203. }
  204. }
  205. return ret;
  206. }
  207. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
  208. const vector<ge::NodePtr> &input_nodes) {
  209. vector<ConstGeTensorPtr> ret;
  210. for (const auto &input_node : input_nodes) {
  211. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  212. if (temp_weight == nullptr) {
  213. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  214. return vector<ConstGeTensorPtr>();
  215. }
  216. ret.push_back(temp_weight);
  217. }
  218. return ret;
  219. }
  220. size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
  221. if (NodeUtils::IsAnchorStatusSet(node)) {
  222. size_t input_num = 0;
  223. for (const auto &anchor : node.GetAllInDataAnchors()) {
  224. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  225. input_num++;
  226. continue;
  227. }
  228. }
  229. return input_num;
  230. } else {
  231. GE_IF_BOOL_EXEC(
  232. node.GetInDataNodes().size() < GetConstInputs(node).size(),
  233. GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
  234. return 0);
  235. return node.GetInDataNodes().size() - GetConstInputs(node).size();
  236. }
  237. }
  238. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
  239. if (node == nullptr) {
  240. GELOGE(GRAPH_FAILED, "Node is nullptr");
  241. return 0;
  242. }
  243. return GetNonConstInputsSize(*node);
  244. }
  245. GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
  246. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
  247. size_t i = 0;
  248. if (NodeUtils::IsAnchorStatusSet(node)) {
  249. for (const auto &anchor : node.GetAllInDataAnchors()) {
  250. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  251. if (index_non_const == i) {
  252. return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  253. }
  254. ++i;
  255. }
  256. }
  257. } else {
  258. for (const auto &anchor : node.GetAllInDataAnchors()) {
  259. auto peer_anchor = anchor->GetPeerOutAnchor();
  260. if (peer_anchor == nullptr) {
  261. continue;
  262. }
  263. auto owner_node = peer_anchor->GetOwnerNode();
  264. if (owner_node == nullptr) {
  265. continue;
  266. }
  267. if (owner_node->GetType() == CONSTANT) {
  268. continue;
  269. }
  270. if (index_non_const == i) {
  271. return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
  272. }
  273. ++i;
  274. }
  275. }
  276. return GeTensorDesc();
  277. }
  278. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
  279. OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
  280. CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
  281. return GetNonConstInputTensorDesc(*node, index_non_const);
  282. }
  283. bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
  284. bool ret = false;
  285. size_t i = 0;
  286. if (NodeUtils::IsAnchorStatusSet(node)) {
  287. for (const auto &anchor : node.GetAllInDataAnchors()) {
  288. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  289. if (index_non_const == i) {
  290. index = static_cast<size_t>(anchor->GetIdx());
  291. ret = true;
  292. }
  293. ++i;
  294. }
  295. }
  296. } else {
  297. for (const auto &anchor : node.GetAllInDataAnchors()) {
  298. auto peer_anchor = anchor->GetPeerOutAnchor();
  299. if (peer_anchor == nullptr) {
  300. continue;
  301. }
  302. auto owner_node = peer_anchor->GetOwnerNode();
  303. if (owner_node == nullptr) {
  304. continue;
  305. }
  306. if (owner_node->GetType() == CONSTANT) {
  307. continue;
  308. }
  309. if (index_non_const == i) {
  310. index = static_cast<size_t>(anchor->GetIdx());
  311. ret = true;
  312. }
  313. ++i;
  314. }
  315. }
  316. return ret;
  317. }
  318. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
  319. size_t index_non_const,
  320. size_t &index) {
  321. CHECK_FALSE_EXEC(node != nullptr, return false);
  322. return GetNonConstInputIndex(*node, index_non_const, index);
  323. }
  324. bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
  325. bool ret = false;
  326. if (index < node.GetAllInDataAnchors().size()) {
  327. if (NodeUtils::IsAnchorStatusSet(node)) {
  328. ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA);
  329. } else {
  330. for (const auto &anchor : node.GetAllInDataAnchors()) {
  331. if (anchor->GetIdx() != static_cast<int>(index)) {
  332. continue;
  333. }
  334. auto peer_anchor = anchor->GetPeerOutAnchor();
  335. if (peer_anchor == nullptr) {
  336. break;
  337. }
  338. auto owner_node = peer_anchor->GetOwnerNode();
  339. if (owner_node == nullptr) {
  340. break;
  341. }
  342. ret = (owner_node->GetType() != CONSTANT);
  343. }
  344. }
  345. }
  346. return ret;
  347. }
  348. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
  349. size_t index) {
  350. CHECK_FALSE_EXEC(node != nullptr, return false);
  351. return IsNonConstInput(*node, index);
  352. }
  353. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
  354. const ge::ConstNodePtr &node) {
  355. if (node == nullptr) {
  356. return vector<ge::NodePtr>();
  357. }
  358. return GetConstInputs(*node);
  359. }
  360. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
  361. const ge::ConstNodePtr &node) {
  362. if (node == nullptr || node->GetOpDesc() == nullptr) {
  363. return vector<ge::GeTensorDesc>();
  364. }
  365. vector<ge::GeTensorDesc> ret;
  366. if (NodeUtils::IsAnchorStatusSet(*node)) {
  367. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  368. if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
  369. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  370. }
  371. }
  372. } else {
  373. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  374. auto out_anchor = in_anchor->GetPeerOutAnchor();
  375. if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  376. continue;
  377. }
  378. if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
  379. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  380. }
  381. }
  382. }
  383. return ret;
  384. }
  385. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
  386. vector<ge::NodePtr> ret;
  387. auto in_anchors = node.GetAllInDataAnchors();
  388. for (const auto &in_anchor : in_anchors) {
  389. auto out_anchor = in_anchor->GetPeerOutAnchor();
  390. if (out_anchor == nullptr) continue;
  391. auto in_node = out_anchor->GetOwnerNode();
  392. if (in_node->GetType() == CONSTANT) {
  393. ret.push_back(in_node);
  394. } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
  395. // const --> switch --> matmul
  396. auto switch_input = GetConstInputs(*in_node);
  397. if (switch_input.size() > 0) {
  398. ret.insert(ret.end(), switch_input.begin(), switch_input.end());
  399. }
  400. }
  401. }
  402. return ret;
  403. }
  404. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
  405. vector<GeTensorPtr> ret;
  406. auto op_desc = node.GetOpDesc();
  407. GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
  408. // Place holder operator, try to get the weight from parent node
  409. // when parent node is const operator
  410. if (node.GetType() == PLACEHOLDER) {
  411. std::string parent_op;
  412. (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
  413. // This if judgment is necessary because the current subgraph optimization is multithreaded
  414. // and the parent node of the PLD operation should be a stable type, such as const
  415. if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
  416. NodePtr parent_node = nullptr;
  417. parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
  418. if (parent_node != nullptr) {
  419. op_desc = parent_node->GetOpDesc();
  420. GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
  421. }
  422. }
  423. }
  424. // Const operator, take the weight directly
  425. if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
  426. auto weight = MutableWeights(op_desc);
  427. if (weight == nullptr) {
  428. GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
  429. return ret;
  430. }
  431. ret.push_back(weight);
  432. return ret;
  433. }
  434. // Other operators, get weights from connected constop
  435. auto input_nodes = GetConstInputs(node);
  436. for (const auto &input_node : input_nodes) {
  437. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  438. if (temp_weight == nullptr) {
  439. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  440. return vector<GeTensorPtr>();
  441. }
  442. ret.push_back(temp_weight);
  443. }
  444. return ret;
  445. }
  446. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
  447. if (node == nullptr) {
  448. GELOGE(GRAPH_FAILED, "Node is nullptr");
  449. return vector<ge::GeTensorPtr>();
  450. }
  451. return MutableWeights(*node);
  452. }
  453. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  454. OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
  455. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
  456. if (node.GetOpDesc()->GetType() == CONSTANT) {
  457. if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  458. return SetWeights(node.GetOpDesc(), weights[0]);
  459. }
  460. GELOGI("const op weight size %zu should be 1", weights.size());
  461. return GRAPH_PARAM_INVALID;
  462. }
  463. auto input_nodes = GetConstInputs(node);
  464. if (weights.size() < input_nodes.size()) {
  465. GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
  466. return GRAPH_PARAM_INVALID;
  467. }
  468. ge::GeAttrValue::NAMED_ATTRS named_attrs;
  469. (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
  470. vector<ge::GeTensorPtr> copy_weights;
  471. (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);
  472. for (size_t i = 0; i < input_nodes.size(); ++i) {
  473. if (input_nodes[i]->GetOpDesc() != nullptr) {
  474. SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
  475. }
  476. }
  477. // If set more weights than constop, need to add constop
  478. for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
  479. // Use org weight before SetWeights Overwrite
  480. auto const_opdesc = CreateConstOp(copy_weights[i]);
  481. GE_CHECK_NOTNULL(const_opdesc);
  482. auto owner_graph = node.GetOwnerComputeGraph();
  483. if (owner_graph == nullptr) {
  484. GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
  485. return GRAPH_PARAM_INVALID;
  486. }
  487. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  488. GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!");
  489. std::vector<ge::NodePtr> original_nodes;
  490. ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
  491. }
  492. return GRAPH_SUCCESS;
  493. }
  494. OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
  495. GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
  496. shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
  497. if (const_opdesc == nullptr) {
  498. GELOGE(GRAPH_FAILED, "failed to make_shared ");
  499. return nullptr;
  500. }
  501. CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);
  502. const_opdesc->SetType(CONSTANT);
  503. static int const_count = 0;
  504. const_opdesc->SetName("dynamic_const_" + std::to_string(const_count));
  505. GELOGI("add const op: %s", const_opdesc->GetName().c_str());
  506. ++const_count;
  507. (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());
  508. GELOGI("after add const op: %s", const_opdesc->GetName().c_str());
  509. return const_opdesc;
  510. }
  511. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  512. OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
  513. GE_CHECK_NOTNULL(in_anchor);
  514. GE_CHECK_NOTNULL(tensor_ptr);
  515. auto const_opdesc = CreateConstOp(tensor_ptr);
  516. GE_CHECK_NOTNULL(const_opdesc);
  517. auto in_node = in_anchor->GetOwnerNode();
  518. GE_CHECK_NOTNULL(in_node);
  519. auto owner_graph = in_node->GetOwnerComputeGraph();
  520. if (owner_graph == nullptr) {
  521. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
  522. return GRAPH_PARAM_INVALID;
  523. }
  524. auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
  525. GE_CHECK_NOTNULL(const_node);
  526. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  527. GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
  528. return GRAPH_PARAM_INVALID;
  529. }
  530. return GRAPH_SUCCESS;
  531. }
  532. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  533. OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
  534. GE_CHECK_NOTNULL(node);
  535. return SetWeights(*node, weights);
  536. }
  537. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
  538. GE_CHECK_NOTNULL(node);
  539. auto const_ops = GetConstInputs(node);
  540. auto graph = node->GetOwnerComputeGraph();
  541. if (graph == nullptr) {
  542. GELOGE(GRAPH_FAILED, "Graph is nullptr");
  543. return GRAPH_PARAM_INVALID;
  544. }
  545. for (const auto &const_op : const_ops) {
  546. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
  547. const_op->GetName().c_str(), const_op->GetType().c_str());
  548. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
  549. "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
  550. const_op->GetType().c_str());
  551. }
  552. return GRAPH_SUCCESS;
  553. }
  554. ///
  555. /// @brief Add input
  556. /// @param [in] name
  557. /// @return OpDescBuilder
  558. ///
  559. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
  560. inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  561. return *this;
  562. }
  563. ///
  564. /// @brief Add input
  565. /// @param [in] name
  566. /// @param [in] tensor
  567. /// @return OpDescBuilder
  568. ///
  569. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name,
  570. const GeTensorDesc &tensor) {
  571. inputs_.emplace_back(std::make_pair(name, tensor));
  572. return *this;
  573. }
  574. ///
  575. /// @brief Add dynamic input
  576. /// @param [in] name
  577. /// @param [in] num
  578. /// @return OpDescBuilder
  579. ///
  580. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
  581. uint32_t num) {
  582. for (uint32_t i = 0; i < num; i++) {
  583. inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  584. }
  585. return *this;
  586. }
  587. ///
  588. /// @brief Add dynamic input
  589. /// @param [in] name
  590. /// @param [in] num
  591. /// @param [in] tensor
  592. /// @return OpDescBuilder
  593. ///
  594. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(
  595. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  596. for (uint32_t i = 0; i < num; i++) {
  597. inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  598. }
  599. return *this;
  600. }
  601. ///
  602. /// @brief Add output
  603. /// @param [in] name
  604. /// @return OpDescBuilder
  605. ///
  606. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
  607. outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  608. return *this;
  609. }
  610. ///
  611. /// @brief Add output
  612. /// @param [in] name
  613. /// @param [in] tensor
  614. /// @return OpDescBuilder
  615. ///
  616. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name,
  617. const GeTensorDesc &tensor) {
  618. outputs_.emplace_back(std::make_pair(name, tensor));
  619. return *this;
  620. }
  621. ///
  622. /// @brief Add dynamic output
  623. /// @param [in] name
  624. /// @param [in] num
  625. /// @return OpDescBuilder
  626. ///
  627. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
  628. uint32_t num) {
  629. for (uint32_t i = 0; i < num; i++) {
  630. outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  631. }
  632. return *this;
  633. }
  634. ///
  635. /// @brief Add dynamic output
  636. /// @param [in] name
  637. /// @param [in] num
  638. /// @param [in] tensor
  639. /// @return OpDescBuilder
  640. ///
  641. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(
  642. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  643. for (uint32_t i = 0; i < num; i++) {
  644. outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  645. }
  646. return *this;
  647. }
  648. ///
  649. /// @brief Build op_desc
  650. /// @return OpDescPtr
  651. ///
  652. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
  653. OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
  654. if (op_desc == nullptr) {
  655. GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
  656. return nullptr;
  657. }
  658. for (auto &input : inputs_) {
  659. if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
  660. GELOGE(GRAPH_FAILED, "Add input_desc failed.");
  661. return nullptr;
  662. }
  663. }
  664. for (auto &output : outputs_) {
  665. if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
  666. GELOGE(GRAPH_FAILED, "Add output_desc failed.");
  667. return nullptr;
  668. }
  669. }
  670. return op_desc;
  671. }
  672. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(
  673. const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) {
  674. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  675. auto iter = subgraph_names_to_index.find(subgraph_name);
  676. if (iter == subgraph_names_to_index.end()) {
  677. GELOGE(GRAPH_PARAM_INVALID,
  678. "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
  679. subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  680. subgraph_name.c_str());
  681. return GRAPH_PARAM_INVALID;
  682. }
  683. return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
  684. }
  685. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.