You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memcpy_addr_async_pass.cc 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/memcpy_addr_async_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "framework/common/debug/log.h"
  19. #include "graph/utils/node_utils.h"
  20. #include "graph/utils/op_desc_utils.h"
  21. #include "graph/utils/tensor_utils.h"
  22. namespace ge {
  23. Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
  24. GE_CHECK_NOTNULL(graph);
  25. for (const auto &node : graph->GetAllNodes()) {
  26. if (node->GetType() == STREAMSWITCH) {
  27. auto sub_graph = node->GetOwnerComputeGraph();
  28. if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) {
  29. GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph.");
  30. }
  31. }
  32. }
  33. if (graph->GetGraphUnknownFlag()) {
  34. GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str());
  35. return SUCCESS;
  36. }
  37. int64_t value = 0;
  38. rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value);
  39. if (rt_ret != RT_ERROR_NONE) {
  40. GELOGE(RT_FAILED, "rtGetRtCapability failed, error=0x%x.", rt_ret);
  41. return RT_FAILED;
  42. }
  43. for (auto &node : graph->GetAllNodes()) {
  44. auto op_desc = node->GetOpDesc();
  45. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  46. if (op_desc->GetType() == STREAMSWITCHN || op_desc->GetType() == STREAMMERGE) {
  47. Status ret = AddMemcpyAddrAsyncNode(graph, node);
  48. if (ret != SUCCESS) {
  49. GELOGE(ret, "AddMemcpyAddrAsyncNode failed.");
  50. return ret;
  51. }
  52. }
  53. // handle data->netoutput, const->netoutput in root graph, use mem_addr_async to improve performance
  54. if (op_desc->GetType() == NETOUTPUT) {
  55. // check this netoutput is on root graph
  56. if (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
  57. Status ret = InsertMemAddrAsyncNodeBeforeNetoutput(node->GetOwnerComputeGraph(), node);
  58. if (ret != SUCCESS) {
  59. GELOGE(ret, "AddMemcpyAddrAsyncNode failed.");
  60. return ret;
  61. }
  62. }
  63. }
  64. }
  65. return SUCCESS;
  66. }
  67. Status MemcpyAddrAsyncPass::AddMemcpyAsyncNode(const NodePtr &node) {
  68. GE_CHECK_NOTNULL(node);
  69. GELOGI("Start add memcpyasync node in front of node %s", node->GetName().c_str());
  70. known_sub_graph_ = true;
  71. auto sub_graph = node->GetOwnerComputeGraph();
  72. for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
  73. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  74. GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
  75. auto memcpy_async_node = CreateMemcpyAddrAsyncNode(sub_graph, peer_out_anchor, node);
  76. if (memcpy_async_node == nullptr) {
  77. GELOGE(INTERNAL_ERROR, "Create memcpyasync node failed.");
  78. return INTERNAL_ERROR;
  79. }
  80. Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_async_node);
  81. if (ret != SUCCESS) {
  82. GELOGE(ret, "Insert memcpyasync node failed.");
  83. return ret;
  84. }
  85. }
  86. return SUCCESS;
  87. }
  88. Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) {
  89. GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str());
  90. for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
  91. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  92. GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
  93. NodePtr in_node = peer_out_anchor->GetOwnerNode();
  94. if (in_node->GetType() == DATA) {
  95. ComputeGraphPtr owner_graph = in_node->GetOwnerComputeGraph();
  96. GE_CHECK_NOTNULL(owner_graph);
  97. // Data is in parent_graph
  98. if (owner_graph->GetParentGraph() == nullptr) {
  99. GELOGI("Need to insert MemcpyAddrAsync directly when data in parent graph.");
  100. NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor, node);
  101. GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed.");
  102. return INTERNAL_ERROR);
  103. Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_addr_async_node);
  104. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret);
  105. } else {
  106. uint32_t parent_index = 0;
  107. if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  108. GELOGE(INTERNAL_ERROR, "Failed to get parent index of %s", in_node->GetName().c_str());
  109. return INTERNAL_ERROR;
  110. }
  111. // Data is in sub_graph
  112. GELOGI("Need to find data in parent graph, then insert MemcpyAddrAsync.");
  113. NodePtr parent_node = owner_graph->GetParentNode();
  114. user_data_for_known_ = in_node;
  115. out_of_user_data_for_known_ = node;
  116. peer_out_anchor_for_known_ = peer_out_anchor;
  117. in_anchor_for_known_ = in_data_anchor;
  118. FindUserData(parent_node, parent_index);
  119. if (find_user_data_) {
  120. GELOGI("Insert memcpy_addr_async for non_dynamic.");
  121. GE_CHECK_NOTNULL(peer_out_anchor_);
  122. NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor_, out_of_user_data_);
  123. GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr,
  124. GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed.");
  125. return INTERNAL_ERROR);
  126. Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor_, in_anchor_, memcpy_addr_async_node);
  127. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret);
  128. }
  129. if (find_user_data_for_known_) {
  130. GELOGI("Insert memcpy_addr_async for known graph.");
  131. auto sub_graph = user_data_for_known_->GetOwnerComputeGraph();
  132. NodePtr memcpy_addr_async_node =
  133. CreateMemcpyAddrAsyncNode(sub_graph, peer_out_anchor_for_known_, out_of_user_data_for_known_);
  134. GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr,
  135. GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode for known failed.");
  136. return INTERNAL_ERROR);
  137. Status ret =
  138. InsertMemcpyAddrAsyncNode(peer_out_anchor_for_known_, in_anchor_for_known_, memcpy_addr_async_node);
  139. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode for known failed."); return ret);
  140. }
  141. }
  142. }
  143. }
  144. return SUCCESS;
  145. }
  146. void MemcpyAddrAsyncPass::FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index) {
  147. GELOGI("Start FindUserDataForKnown of %s.", parent_node->GetName().c_str());
  148. if (user_data_for_known_->GetOpDesc() == nullptr) {
  149. GELOGI("Cannot get op_desc of %s.", user_data_for_known_->GetName().c_str());
  150. return;
  151. }
  152. string src_var_name;
  153. if (ge::AttrUtils::GetStr(user_data_for_known_->GetOpDesc(), REF_VAR_SRC_VAR_NAME, src_var_name)) {
  154. GELOGI("The data in known graph is variable, no need to insert memcpy_addr_async.");
  155. find_user_data_for_known_ = false;
  156. return;
  157. } else {
  158. find_user_data_for_known_ = true;
  159. }
  160. }
  161. void MemcpyAddrAsyncPass::FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index) {
  162. GELOGI("Start to FindUserDataForNonDynamic of %s.", parent_node->GetName().c_str());
  163. InDataAnchorPtr in_data_anchor = parent_node->GetInDataAnchor(parent_index);
  164. OutDataAnchorPtr out_anchor = in_data_anchor->GetPeerOutAnchor();
  165. GE_IF_BOOL_EXEC(out_anchor == nullptr,
  166. GELOGE(INTERNAL_ERROR, "Cannot find out_anchor of %s.", parent_node->GetName().c_str());
  167. return);
  168. NodePtr in_node = out_anchor->GetOwnerNode();
  169. GELOGI("in_node of parent_node is %s.", in_node->GetName().c_str());
  170. if (in_node->GetType() == DATA) {
  171. if (in_node->GetOwnerComputeGraph()->GetParentGraph() != nullptr) {
  172. // DATA is in sub graph again, update user_data of known firstly
  173. user_data_for_known_ = in_node;
  174. out_of_user_data_for_known_ = parent_node;
  175. peer_out_anchor_for_known_ = out_anchor;
  176. in_anchor_for_known_ = in_data_anchor;
  177. NodePtr pre_in_node = in_node->GetOwnerComputeGraph()->GetParentNode();
  178. if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  179. GELOGE(INTERNAL_ERROR, "Failed to refresh parent index of %s", in_node->GetName().c_str());
  180. return;
  181. }
  182. FindUserData(pre_in_node, parent_index);
  183. } else {
  184. // DATA is in parent graph and not has input
  185. user_data_ = in_node;
  186. out_of_user_data_ = parent_node;
  187. peer_out_anchor_ = out_anchor;
  188. in_anchor_ = in_data_anchor;
  189. find_user_data_ = true;
  190. GELOGI("%s connect with %s, will insert memcpyaddr.", user_data_->GetName().c_str(),
  191. out_of_user_data_->GetName().c_str());
  192. }
  193. } else if (in_node->GetType() == IF || in_node->GetType() == WHILE || in_node->GetType() == CASE) {
  194. if (!AttrUtils::GetInt(parent_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  195. GELOGE(INTERNAL_ERROR, "Failed to refresh parent index of %s", in_node->GetName().c_str());
  196. return;
  197. }
  198. FindUserData(in_node, parent_index);
  199. } else {
  200. GELOGI("%s connect with %s, which is not user_data.", parent_node->GetName().c_str(), in_node->GetName().c_str());
  201. find_user_data_ = false;
  202. }
  203. }
  204. void MemcpyAddrAsyncPass::FindUserData(const NodePtr &parent_node, uint32_t &parent_index) {
  205. auto parent_op_desc = parent_node->GetOpDesc();
  206. if (parent_op_desc == nullptr) {
  207. GELOGI("Cannot get op_desc of %s.", parent_node->GetName().c_str());
  208. return;
  209. }
  210. bool is_unknown_shape = false;
  211. if (parent_node->GetType() == PARTITIONEDCALL &&
  212. AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape) && !is_unknown_shape) {
  213. FindUserDataForKnown(parent_node, parent_index);
  214. } else {
  215. FindUserDataForNonDynamic(parent_node, parent_index);
  216. }
  217. }
  218. NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &graph,
  219. const OutDataAnchorPtr &out_data_anchor,
  220. const NodePtr &out_of_user_data) {
  221. GELOGD("Start CreateMemcpyAddrAsyncNode.");
  222. static uint32_t new_node_index = 0;
  223. OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
  224. GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid.");
  225. OpDescPtr op_desc = nullptr;
  226. if (known_sub_graph_) { // insert memcpyasync node when known sub graph
  227. string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(new_node_index++);
  228. op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC);
  229. } else {
  230. string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);
  231. op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
  232. }
  233. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  234. if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) {
  235. GELOGE(INTERNAL_ERROR, "Add memcpy_addr_async input desc failed.");
  236. return nullptr;
  237. }
  238. if (op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) {
  239. GELOGE(INTERNAL_ERROR, "Add memcpy_addr_async output desc failed.");
  240. return nullptr;
  241. }
  242. string stream_label;
  243. if (AttrUtils::GetStr(out_of_user_data->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
  244. (void)AttrUtils::SetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label);
  245. GELOGD("Node %s set stream label: %s", op_desc->GetName().c_str(), stream_label.c_str());
  246. }
  247. bool rts_label_node = false;
  248. if (AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_RTS_LABEL_NODE, rts_label_node)) {
  249. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, rts_label_node);
  250. GELOGD("Node %s set rts label node attribute", op_desc->GetName().c_str());
  251. }
  252. bool labeled_input = false;
  253. (void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input);
  254. if (labeled_input) {
  255. if (!ge::AttrUtils::SetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, false)) {
  256. GELOGE(FAILED, "Failed to unset attr %s for node %s.", ATTR_NAME_NODE_CONNECT_INPUT.c_str(),
  257. out_of_user_data->GetName().c_str());
  258. return nullptr;
  259. }
  260. if (!ge::AttrUtils::SetBool(op_desc, ATTR_NAME_NODE_CONNECT_INPUT, true)) {
  261. GELOGE(FAILED, "Failed to set attr %s for node %s.", ATTR_NAME_NODE_CONNECT_INPUT.c_str(),
  262. op_desc->GetName().c_str());
  263. return nullptr;
  264. }
  265. }
  266. NodePtr memcpy_addr_async_node = graph->AddNode(op_desc);
  267. GE_CHECK_NOTNULL_EXEC(memcpy_addr_async_node, return nullptr);
  268. return memcpy_addr_async_node;
  269. }
  270. Status MemcpyAddrAsyncPass::InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor,
  271. const InDataAnchorPtr &in_anchor, const NodePtr &node) {
  272. // insert memcpy_addr of each user_data and out_of_user_data
  273. if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) {
  274. GELOGE(INTERNAL_ERROR, "Remove edge of %s and %s failed.", out_anchor->GetOwnerNode()->GetName().c_str(),
  275. in_anchor->GetOwnerNode()->GetName().c_str());
  276. return INTERNAL_ERROR;
  277. }
  278. if (GraphUtils::AddEdge(out_anchor, node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  279. GELOGE(INTERNAL_ERROR, "Add edge of %s and %s failed.", out_anchor->GetOwnerNode()->GetName().c_str(),
  280. node->GetName().c_str());
  281. return INTERNAL_ERROR;
  282. }
  283. if (GraphUtils::AddEdge(node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  284. GELOGE(INTERNAL_ERROR, "Add edge of %s and %s failed.", node->GetName().c_str(),
  285. in_anchor->GetOwnerNode()->GetName().c_str());
  286. return INTERNAL_ERROR;
  287. }
  288. return SUCCESS;
  289. }
  290. Status MemcpyAddrAsyncPass::InsertMemAddrAsyncNodeBeforeNetoutput(const ComputeGraphPtr &graph, const NodePtr &node) {
  291. GELOGD("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str());
  292. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  293. auto in_node = NodeUtils::GetInDataNodeByIndex(*node, in_data_anchor->GetIdx());
  294. GE_CHECK_NOTNULL(in_node);
  295. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  296. if ((in_node->GetType() != CONSTANT) &&
  297. (in_node->GetType() != CONSTANTOP) &&
  298. (in_node->GetType() != DATA)) {
  299. continue;
  300. }
  301. auto desc = in_node->GetOpDesc();
  302. GE_CHECK_NOTNULL(desc);
  303. if (IsEmptyTenor(desc->GetOutputDesc(peer_out_anchor->GetIdx()).GetShape())) {
  304. continue;
  305. }
  306. GELOGI("Need to insert MemcpyAddrAsync before netoutput on parent graph.");
  307. NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor, in_node);
  308. GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed.");
  309. return INTERNAL_ERROR);
  310. Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_addr_async_node);
  311. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret);
  312. GELOGI("Insert mem_addr_async node %s success between %s and %s.", memcpy_addr_async_node->GetName().c_str(),
  313. in_node->GetName().c_str(), node->GetName().c_str());
  314. // if src node is const, need to update attr and offset here because this pass process is after offset set.
  315. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  316. NodeUtils::UpdateIsInputConst(memcpy_addr_async_node);
  317. auto output_desc = node->GetOpDesc();
  318. GE_CHECK_NOTNULL(output_desc);
  319. auto output_tensor_desc = output_desc->MutableInputDesc(static_cast<uint32_t>(in_data_anchor->GetIdx()));
  320. int64_t data_offset = 0;
  321. (void)TensorUtils::GetDataOffset(*output_tensor_desc, data_offset);
  322. auto input_tensor = memcpy_addr_async_node->GetOpDesc()->MutableInputDesc(0);
  323. GELOGI("Need update const Offset %ld to op [%s]", data_offset, memcpy_addr_async_node->GetName().c_str());
  324. TensorUtils::SetDataOffset(*input_tensor, data_offset);
  325. TensorUtils::SetDataOffset(*output_tensor_desc, 0);
  326. }
  327. }
  328. NodeUtils::UpdateIsInputConst(node);
  329. return SUCCESS;
  330. }
  331. bool MemcpyAddrAsyncPass::IsEmptyTenor(const GeShape &shape) const {
  332. for (const auto dim : shape.GetDims()) {
  333. if (dim == 0) {
  334. return true;
  335. }
  336. }
  337. return false;
  338. }
  339. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示