You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

var_is_initialized_op_pass.cc 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/var_is_initialized_op_pass.h"
  17. #include <memory>
  18. #include <utility>
  19. #include "framework/common/debug/ge_log.h"
  20. #include "common/ge/ge_util.h"
  21. #include "graph/anchor.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/node.h"
  25. #include "graph/utils/graph_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kAssignVarRefIndex = 0;
  30. const int kVarIsInitializedIOCnt = 1;
  31. const int kVarIsInitVarInputIndex = 0;
  32. } // namespace
  33. Status VarIsInitializedOpPass::Run(NodePtr &node) {
  34. GE_CHECK_NOTNULL(node);
  35. auto ret = UpdateInitedVars(node);
  36. if (ret != SUCCESS) {
  37. GELOGE(ret, "[Call][UpdateInitedVars] for node:%s failed", node->GetName().c_str());
  38. return ret;
  39. }
  40. if (node->GetType() != VARISINITIALIZEDOP) {
  41. return SUCCESS;
  42. }
  43. bool inited = false;
  44. if (CheckSrcNode(node, inited) != SUCCESS) {
  45. GELOGE(ret, "[Call][CheckSrcNode] for node:%s failed", node->GetName().c_str());
  46. return FAILED;
  47. }
  48. GELOGI("The variable inited status %s on node %s",
  49. inited ? "true" : "false", node->GetName().c_str());
  50. ret = ChangeNodeToConstant(node, inited);
  51. GELOGI("Change VarIsInitializedOp %s to be Constant %s end.",
  52. node->GetName().c_str(), inited ? "true" : "false");
  53. return ret;
  54. }
  55. Status VarIsInitializedOpPass::CheckSrcNode(const NodePtr &node, bool &inited) const {
  56. GE_CHECK_NOTNULL(node);
  57. auto input_nodes = node->GetInDataNodes();
  58. if (input_nodes.size() != kVarIsInitializedIOCnt) {
  59. REPORT_INNER_ERROR("E19999", "In data node num:%zu of node:%s(%s) not equal to %d, check invalid",
  60. input_nodes.size(), node->GetName().c_str(), node->GetType().c_str(), kVarIsInitializedIOCnt);
  61. GELOGE(FAILED, "[Check][Param] In data node num:%zu of node:%s(%s) not equal to %d.",
  62. input_nodes.size(), node->GetName().c_str(), node->GetType().c_str(), kVarIsInitializedIOCnt);
  63. return FAILED;
  64. }
  65. auto &input_node = input_nodes.at(kVarIsInitVarInputIndex);
  66. GE_CHECK_NOTNULL(input_node);
  67. auto input_node_name = input_node->GetName();
  68. auto input_node_type = input_node->GetType();
  69. if (input_node_type != VARIABLE) {
  70. REPORT_INNER_ERROR("E19999", "Index:%d In data node of node:%s(%s), type:%s not %s, check invalid",
  71. kVarIsInitVarInputIndex, node->GetName().c_str(), node->GetType().c_str(),
  72. input_node_type.c_str(), VARIABLE);
  73. GELOGE(FAILED, "[Check][Param] Index:%d In data node of node:%s(%s), type:%s not equal to %s.",
  74. kVarIsInitVarInputIndex, node->GetName().c_str(), node->GetType().c_str(),
  75. input_node_type.c_str(), VARIABLE);
  76. return FAILED;
  77. }
  78. // initialized and initialized check graph must not be in the same graph
  79. ComputeGraphPtr compute_graph = node->GetOwnerComputeGraph();
  80. auto session_id = compute_graph->GetSessionID();
  81. if (VarManager::Instance(session_id)->IsVarExist(input_node_name)) {
  82. inited = true;
  83. return SUCCESS;
  84. }
  85. GE_CHECK_NOTNULL(input_node->GetOpDesc());
  86. inited = IsVarInitedOnTheGraphAndNode(node, input_node->GetOpDesc()->GetId());
  87. return SUCCESS;
  88. }
  89. Status VarIsInitializedOpPass::CreateConstant(NodePtr &node, OpDescPtr &op_desc, bool inited) {
  90. GE_CHECK_NOTNULL(node);
  91. // 1. get OpDesc of VarIsInitializedOp
  92. OpDescPtr original_op_desc = node->GetOpDesc();
  93. if (original_op_desc == nullptr) {
  94. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  95. GELOGE(FAILED, "[Get][OpDesc] failed, Op desc of node must not be null.");
  96. return FAILED;
  97. }
  98. GeTensorDesc original_desc = original_op_desc->GetOutputDesc(0);
  99. // 2. create Constant OpDesc
  100. op_desc = MakeShared<OpDesc>(node->GetName().c_str(), CONSTANT);
  101. if (op_desc == nullptr) {
  102. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  103. GELOGE(FAILED, "[New][OpDesc] failed.");
  104. return FAILED;
  105. }
  106. // 3. create attr value of Constant, is a tensor
  107. bool val = inited;
  108. GeTensorPtr const_tensor_ptr = MakeShared<GeTensor>(original_desc, reinterpret_cast<uint8_t *>(&val), sizeof(bool));
  109. if (const_tensor_ptr == nullptr) {
  110. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  111. GELOGE(FAILED, "[New][GeTensor] failed.");
  112. return FAILED;
  113. }
  114. if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_tensor_ptr)) {
  115. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  116. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  117. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  118. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  119. return FAILED;
  120. }
  121. // 4. set Constant output desc
  122. GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_desc),
  123. "[Add][OutputDesc] to op:%s(%s) failed",
  124. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  125. return SUCCESS;
  126. }
  127. Status VarIsInitializedOpPass::ProcessInAnchor(NodePtr &node, NodePtr &new_node) {
  128. GE_CHECK_NOTNULL(node);
  129. GE_CHECK_NOTNULL(new_node);
  130. auto in_anchors = node->GetAllInDataAnchors();
  131. auto out_anchors = node->GetAllOutDataAnchors();
  132. if ((in_anchors.size() != kVarIsInitializedIOCnt) ||
  133. (out_anchors.size() != kVarIsInitializedIOCnt)) {
  134. REPORT_INNER_ERROR("E19999", "In data anchor num:%zu and out data anchor num:%zu of node:%s(%s), "
  135. "must be equal to %d, check invalid", in_anchors.size(), out_anchors.size(),
  136. node->GetName().c_str(), node->GetType().c_str(), kVarIsInitializedIOCnt);
  137. GELOGE(FAILED, "[Check][Param] In data anchor num:%zu and out data anchor num:%zu of node:%s(%s), "
  138. "must be equal to %d.", in_anchors.size(), out_anchors.size(),
  139. node->GetName().c_str(), node->GetType().c_str(), kVarIsInitializedIOCnt);
  140. return FAILED;
  141. }
  142. // 1. delete in data anchor of VarIsInitializedOp node
  143. auto &in_anchor = in_anchors.at(kVarIsInitVarInputIndex);
  144. GE_CHECK_NOTNULL(in_anchor);
  145. auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  146. GE_CHECK_NOTNULL(peer_out_anchor);
  147. if (GraphUtils::RemoveEdge(in_anchor, peer_out_anchor) != GRAPH_SUCCESS) {
  148. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  149. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetType().c_str(),
  150. in_anchor->GetIdx(), peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  151. peer_out_anchor->GetOwnerNode()->GetType().c_str(), peer_out_anchor->GetIdx());
  152. GELOGE(FAILED, "[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  153. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetType().c_str(),
  154. in_anchor->GetIdx(), peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  155. peer_out_anchor->GetOwnerNode()->GetType().c_str(), peer_out_anchor->GetIdx());
  156. return FAILED;
  157. }
  158. auto src_node = peer_out_anchor->GetOwnerNode();
  159. if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), new_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  160. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  161. src_node->GetName().c_str(), src_node->GetType().c_str(),
  162. new_node->GetName().c_str(), new_node->GetType().c_str());
  163. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  164. src_node->GetName().c_str(), src_node->GetType().c_str(),
  165. new_node->GetName().c_str(), new_node->GetType().c_str());
  166. return FAILED;
  167. }
  168. if (GraphUtils::MoveInCtrlEdges(node, new_node) != GRAPH_SUCCESS) {
  169. REPORT_CALL_ERROR("E19999", "Move in control edge from node:%s(%s) to node:%s(%s) failed",
  170. node->GetName().c_str(), node->GetType().c_str(),
  171. new_node->GetName().c_str(), new_node->GetType().c_str());
  172. GELOGE(FAILED, "[Move][InCtrlEdges] from node:%s(%s) to node:%s(%s) failed",
  173. node->GetName().c_str(), node->GetType().c_str(),
  174. new_node->GetName().c_str(), new_node->GetType().c_str());
  175. return FAILED;
  176. }
  177. if (GraphUtils::MoveOutCtrlEdges(node, new_node) != GRAPH_SUCCESS) {
  178. REPORT_CALL_ERROR("E19999", "Move out control edge from node:%s(%s) to node:%s(%s) failed",
  179. node->GetName().c_str(), node->GetType().c_str(),
  180. new_node->GetName().c_str(), new_node->GetType().c_str());
  181. GELOGE(FAILED, "[Move][OutCtrlEdges] from node:%s(%s) to node:%s(%s) failed",
  182. node->GetName().c_str(), node->GetType().c_str(),
  183. new_node->GetName().c_str(), new_node->GetType().c_str());
  184. return FAILED;
  185. }
  186. return SUCCESS;
  187. }
  188. Status VarIsInitializedOpPass::ChangeNodeToConstant(NodePtr &node, bool inited) {
  189. GE_CHECK_NOTNULL(node);
  190. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  191. OpDescPtr constant_op_desc = nullptr;
  192. if (CreateConstant(node, constant_op_desc, inited) != SUCCESS) {
  193. GELOGE(FAILED, "[Create][Constant] failed, node:%s", node->GetName().c_str());
  194. return FAILED;
  195. }
  196. NodePtr const_node = graph->AddNodeFront(constant_op_desc);
  197. if (const_node == nullptr) {
  198. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s front failed",
  199. constant_op_desc->GetName().c_str(), constant_op_desc->GetType().c_str(),
  200. graph->GetName().c_str());
  201. GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s front failed",
  202. constant_op_desc->GetName().c_str(), constant_op_desc->GetType().c_str(), graph->GetName().c_str());
  203. return FAILED;
  204. }
  205. if (ProcessInAnchor(node, const_node) != SUCCESS) {
  206. GELOGE(FAILED, "[Process][InAnchor] failed, node:%s", node->GetName().c_str());
  207. return FAILED;
  208. }
  209. if (NodeUtils::MoveOutputEdges(node, const_node) != GRAPH_SUCCESS) {
  210. REPORT_CALL_ERROR("E19999", "Move out edge from node:%s(%s) to node:%s(%s) failed",
  211. node->GetName().c_str(), node->GetType().c_str(),
  212. const_node->GetName().c_str(), const_node->GetType().c_str());
  213. GELOGE(FAILED, "[Move][OutputEdges] from node:%s(%s) to node:%s(%s) failed",
  214. node->GetName().c_str(), node->GetType().c_str(),
  215. const_node->GetName().c_str(), const_node->GetType().c_str());
  216. return FAILED;
  217. }
  218. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) {
  219. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  220. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  221. GELOGE(FAILED, "[Remove][Node] %s(%s) without relink in graph:%s failed",
  222. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  223. return FAILED;
  224. }
  225. AddRePassNodesWithInOut(const_node);
  226. // delete VarIsInitializedOp node from the graph
  227. AddNodeDeleted(node);
  228. return SUCCESS;
  229. }
  230. Status VarIsInitializedOpPass::UpdateInitedVars(const NodePtr &node) {
  231. GE_CHECK_NOTNULL(node);
  232. std::set<int64_t> *inited_vars = nullptr;
  233. bool inited_vars_merged = false;
  234. bool init_var = false;
  235. int64_t inited_var_id;
  236. auto ret = CheckAndSetVarInited(node, init_var, inited_var_id);
  237. if (ret != SUCCESS) {
  238. return ret;
  239. }
  240. if (init_var) {
  241. inited_vars = CreateInitedVars();
  242. if (inited_vars == nullptr) {
  243. return OUT_OF_MEMORY;
  244. }
  245. inited_vars_merged = true;
  246. inited_vars->insert(inited_var_id);
  247. }
  248. for (auto &in_node : node->GetInNodes()) {
  249. GE_CHECK_NOTNULL(in_node->GetOpDesc());
  250. auto iter = nodes_to_inited_vars_.find(in_node->GetOpDesc()->GetId());
  251. if (iter == nodes_to_inited_vars_.end()) {
  252. continue;
  253. }
  254. if (inited_vars == nullptr) {
  255. inited_vars = iter->second;
  256. continue;
  257. }
  258. if (inited_vars == iter->second) {
  259. continue;
  260. }
  261. // if there are multiple different inited_vars set, we should merge them to a new one
  262. if (inited_vars_merged) {
  263. inited_vars->insert(iter->second->begin(), iter->second->end());
  264. } else {
  265. auto origin_inited_vars = inited_vars;
  266. inited_vars = CreateInitedVars();
  267. if (inited_vars == nullptr) {
  268. return OUT_OF_MEMORY;
  269. }
  270. inited_vars_merged = true;
  271. inited_vars->insert(origin_inited_vars->begin(), origin_inited_vars->end());
  272. inited_vars->insert(iter->second->begin(), iter->second->end());
  273. }
  274. }
  275. if (inited_vars != nullptr) {
  276. GE_CHECK_NOTNULL(node->GetOpDesc());
  277. nodes_to_inited_vars_[node->GetOpDesc()->GetId()] = inited_vars;
  278. GELOGD("Inited vars on this graph when node %s, inited vars count %zu",
  279. node->GetName().c_str(), inited_vars->size());
  280. }
  281. return SUCCESS;
  282. }
  283. std::set<int64_t> *VarIsInitializedOpPass::CreateInitedVars() {
  284. std::unique_ptr<std::set<int64_t>> inited_vars_keeper(new(std::nothrow) std::set<int64_t>());
  285. if (inited_vars_keeper == nullptr) {
  286. REPORT_CALL_ERROR("E19999", "New set failed");
  287. GELOGE(OUT_OF_MEMORY, "[New][Set] failed");
  288. return nullptr;
  289. }
  290. auto inited_vars = inited_vars_keeper.get();
  291. var_inited_keeper_.emplace_back(std::move(inited_vars_keeper));
  292. return inited_vars;
  293. }
  294. bool VarIsInitializedOpPass::IsVarInitedOnTheGraphAndNode(const NodePtr &node, int64_t var_id) const {
  295. if (node == nullptr || node->GetOpDesc() == nullptr) {
  296. return false;
  297. }
  298. auto iter = nodes_to_inited_vars_.find(node->GetOpDesc()->GetId());
  299. if (iter == nodes_to_inited_vars_.end()) {
  300. return false;
  301. }
  302. return iter->second->count(var_id) > 0;
  303. }
  304. Status VarIsInitializedOpPass::CheckAndSetVarInited(const NodePtr &node, bool &inited, int64_t &inited_var) {
  305. GE_CHECK_NOTNULL(node);
  306. inited = false;
  307. if (node->GetType() != ASSIGN) {
  308. return SUCCESS;
  309. }
  310. auto ref_in_anchor = node->GetInDataAnchor(kAssignVarRefIndex);
  311. if (ref_in_anchor == nullptr) {
  312. GELOGW("Invalid assign node on graph, no ref input. name %s", node->GetName().c_str());
  313. return PARAM_INVALID;
  314. }
  315. auto var_out_anchor = ref_in_anchor->GetPeerOutAnchor();
  316. if (var_out_anchor == nullptr) {
  317. GELOGW("Invalid assign node on graph, no variable peer. name %s", node->GetName().c_str());
  318. return PARAM_INVALID;
  319. }
  320. auto var = var_out_anchor->GetOwnerNode();
  321. if (var == nullptr) {
  322. GELOGW("Invalid assign node on graph, no variable peer. name %s", node->GetName().c_str());
  323. return PARAM_INVALID;
  324. }
  325. inited = true;
  326. GE_CHECK_NOTNULL(var->GetOpDesc());
  327. inited_var = var->GetOpDesc()->GetId();
  328. return SUCCESS;
  329. }
  330. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示