You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

label_maker.cc 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/label/label_maker.h"
  17. #include "common/util.h"
  18. #include "common/ge_inner_error_codes.h"
  19. #include "framework/common/types.h"
  20. #include "framework/common/op/ge_op_utils.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. namespace {
  24. const int64_t kInvalidStreamId = -1;
  25. } // namespace
  26. namespace ge {
  27. /**
  28. * @ingroup ge
  29. * @brief Set stream id for head node.
  30. * @param [in] graph: graph for add node.
  31. * @param [in] op_desc: OpDesc for set logical stream id.
  32. * @return: void
  33. */
  34. void LabelMaker::SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
  35. int64_t stream_id = kInvalidStreamId;
  36. const auto &node_list = graph->GetDirectNode();
  37. for (size_t i = 0; i < node_list.size(); ++i) {
  38. const auto &node = node_list.at(i);
  39. GE_CHECK_NOTNULL_EXEC(node, continue);
  40. stream_id = node->GetOpDesc()->GetStreamId();
  41. if (stream_id != kInvalidStreamId) {
  42. break;
  43. }
  44. }
  45. GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id);
  46. op_desc->SetStreamId(stream_id);
  47. }
  48. /**
  49. * @ingroup ge
  50. * @brief Set stream id for tail node.
  51. * @param [in] graph: graph for add node.
  52. * @param [in] op_desc: OpDesc for set logical stream id.
  53. * @return: void
  54. */
  55. void LabelMaker::SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
  56. int64_t stream_id = kInvalidStreamId;
  57. const auto &node_list = graph->GetDirectNode();
  58. for (size_t i = node_list.size(); i > 0; --i) {
  59. const auto &node = node_list.at(i - 1); // i from list size, need shift 1.
  60. GE_CHECK_NOTNULL_EXEC(node, continue);
  61. stream_id = node->GetOpDesc()->GetStreamId();
  62. if (stream_id != kInvalidStreamId) {
  63. break;
  64. }
  65. }
  66. GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id);
  67. op_desc->SetStreamId(stream_id);
  68. }
  69. /**
  70. * @ingroup ge
  71. * @brief Set stream id for parent node.
  72. * @param [in] graph: graph for add node.
  73. * @param [in] op_desc: OpDesc for set logical stream id.
  74. * @return: void
  75. */
  76. void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
  77. int64_t stream_id = kInvalidStreamId;
  78. const auto &node = graph->GetParentNode();
  79. if (node != nullptr) {
  80. stream_id = node->GetOpDesc()->GetStreamId();
  81. }
  82. GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id);
  83. op_desc->SetStreamId(stream_id);
  84. }
  85. /**
  86. * @ingroup ge
  87. * @brief Add StreamActive node at graph front.
  88. * @param [in] graph: graph for add node.
  89. * @param [in] name: stream active node name.
  90. * @return: NodePtr for success / nullptr for fail
  91. */
  92. NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::string &name) {
  93. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  94. const auto &node_list = graph->GetDirectNode();
  95. if (node_list.empty()) {
  96. GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str());
  97. return nullptr;
  98. }
  99. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
  100. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  101. SetStreamIdOwner(graph, op_desc);
  102. GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str());
  103. vector<uint32_t> active_streams;
  104. (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName());
  105. (void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams);
  106. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, true);
  107. NodePtr stream_active = graph->AddNodeFront(op_desc);
  108. GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr);
  109. return stream_active;
  110. }
  111. /**
  112. * @ingroup ge
  113. * @brief Add LabelSet node at graph front.
  114. * @param [in] graph: graph for add node.
  115. * @param [in] name: label set node name.
  116. * @param [in] index: label id for set.
  117. * @return: NodePtr for success / nullptr for fail
  118. */
  119. NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index,
  120. NodePtr &stream_active) {
  121. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  122. GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr);
  123. const auto &node_list = graph->GetDirectNode();
  124. if (node_list.empty()) {
  125. GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str());
  126. return nullptr;
  127. }
  128. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET);
  129. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  130. SetStreamIdOwner(graph, op_desc);
  131. GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str());
  132. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  133. NodePtr label_set = graph->AddNodeFront(op_desc);
  134. GE_CHECK_NOTNULL_EXEC(label_set, return nullptr);
  135. if (GraphUtils::AddEdge(label_set->GetOutControlAnchor(), stream_active->GetInControlAnchor()) != SUCCESS) {
  136. GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", label_set->GetName().c_str(),
  137. stream_active->GetName().c_str());
  138. return nullptr;
  139. }
  140. return label_set;
  141. }
  142. /**
  143. * @ingroup ge
  144. * @brief Add LabelSet node at graph back.
  145. * @param [in] graph: graph for add node.
  146. * @param [in] name: label set node name.
  147. * @param [in] index: label id for set.
  148. * @return: NodePtr for success / nullptr for fail
  149. */
  150. NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  151. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  152. const auto &node_list = graph->GetDirectNode();
  153. auto it = node_list.end();
  154. if (it == node_list.begin()) {
  155. GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str());
  156. return nullptr;
  157. }
  158. --it;
  159. const NodePtr &node = *it;
  160. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  161. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET);
  162. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  163. SetStreamIdOwner(graph, op_desc);
  164. GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str());
  165. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  166. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true);
  167. NodePtr label_set = graph->AddNode(op_desc);
  168. GE_CHECK_NOTNULL_EXEC(label_set, return nullptr);
  169. // Link control edge to graph tail.
  170. if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_set->GetInControlAnchor()) != SUCCESS) {
  171. GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str());
  172. return nullptr;
  173. }
  174. return label_set;
  175. }
  176. /**
  177. * @ingroup ge
  178. * @brief Add LabelGoto node at graph front.
  179. * @param [in] graph: graph for add node.
  180. * @param [in] name: label goto node name.
  181. * @param [in] index: label id for goto.
  182. * @return: NodePtr for success / nullptr for fail
  183. */
  184. NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  185. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  186. const auto &node_list = graph->GetDirectNode();
  187. auto it = node_list.begin();
  188. if (it == node_list.end()) {
  189. GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str());
  190. return nullptr;
  191. }
  192. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX);
  193. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  194. SetStreamIdOwner(graph, op_desc);
  195. GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str());
  196. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  197. NodePtr label_goto = graph->AddNodeFront(op_desc);
  198. if (label_goto == nullptr) {
  199. GELOGE(INTERNAL_ERROR, "LabelGoto: Add to graph %s failed.", graph->GetName().c_str());
  200. return nullptr;
  201. }
  202. return label_goto;
  203. }
  204. /**
  205. * @ingroup ge
  206. * @brief Add LabelGoto node at graph back.
  207. * @param [in] graph: graph for add node.
  208. * @param [in] name: label goto node name.
  209. * @param [in] index: label id for goto.
  210. * @return: NodePtr for success / nullptr for fail
  211. */
  212. NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  213. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  214. const auto &node_list = graph->GetDirectNode();
  215. auto it = node_list.end();
  216. if (it == node_list.begin()) {
  217. GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str());
  218. return nullptr;
  219. }
  220. --it;
  221. const NodePtr &node = *it;
  222. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  223. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX);
  224. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  225. SetStreamIdLeave(graph, op_desc);
  226. GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str());
  227. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  228. NodePtr label_goto = graph->AddNode(op_desc);
  229. GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr);
  230. SetStreamIdOwner(graph, op_desc);
  231. // Link control edge to graph tail.
  232. if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_goto->GetInControlAnchor()) != SUCCESS) {
  233. GELOGE(INTERNAL_ERROR, "LabelGoto: Add ctrl edge to %s failed.", node->GetName().c_str());
  234. return nullptr;
  235. }
  236. return label_goto;
  237. }
  238. /**
  239. * @ingroup ge
  240. * @brief Add LabelSwitch node at graph front.
  241. * @param [in] graph: graph for add node.
  242. * @param [in] name: label switch node name.
  243. * @param [in] desc: label index data desc.
  244. * @param [in] labels: label id for switch.
  245. * @return: NodePtr for success / nullptr for fail
  246. */
  247. NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  248. const std::vector<uint32_t> &labels) {
  249. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  250. const auto &node_list = graph->GetDirectNode();
  251. auto it = node_list.begin();
  252. if (it == node_list.end()) {
  253. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str());
  254. return nullptr;
  255. }
  256. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX);
  257. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  258. SetStreamIdOwner(graph, op_desc);
  259. GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str());
  260. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  261. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add input desc failed.");
  262. return nullptr;
  263. }
  264. if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, labels)) {
  265. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add %s failed.", ATTR_NAME_LABEL_SWITCH_INDEX.c_str());
  266. return nullptr;
  267. }
  268. NodePtr label_switch = graph->AddNodeFront(op_desc);
  269. if (label_switch == nullptr) {
  270. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add to graph %s failed.", graph->GetName().c_str());
  271. return nullptr;
  272. }
  273. return label_switch;
  274. }
  275. /**
  276. * @ingroup ge
  277. * @brief Add LabelSwitch node at graph back.
  278. * @param [in] graph: graph for add node.
  279. * @param [in] name: label switch node name.
  280. * @param [in] desc: label index data desc.
  281. * @param [in] labels: label id for switch.
  282. * @return: NodePtr for success / nullptr for fail
  283. */
  284. NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  285. const std::vector<uint32_t> &labels) {
  286. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  287. const auto &node_list = graph->GetDirectNode();
  288. auto it = node_list.end();
  289. if (it == node_list.begin()) {
  290. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str());
  291. return nullptr;
  292. }
  293. --it;
  294. const NodePtr &node = *it;
  295. GE_CHECK_NOTNULL_EXEC(node, return nullptr);
  296. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX);
  297. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  298. SetStreamIdOwner(graph, op_desc);
  299. GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str());
  300. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  301. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add input desc failed.");
  302. return nullptr;
  303. }
  304. if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, labels)) {
  305. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add %s failed.", ATTR_NAME_LABEL_SWITCH_INDEX.c_str());
  306. return nullptr;
  307. }
  308. NodePtr label_switch = graph->AddNode(op_desc);
  309. GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr);
  310. // Link control edge to graph tail.
  311. if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_switch->GetInControlAnchor()) != SUCCESS) {
  312. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", node->GetName().c_str());
  313. return nullptr;
  314. }
  315. return label_switch;
  316. }
  317. /**
  318. * @ingroup ge
  319. * @brief Add Data node at graph front for switch input.
  320. * @param [in] graph: graph for add node.
  321. * @param [in] name: label switch node name.
  322. * @param [in] desc: label index data desc.
  323. * @param [in] sw_node: switch node for add input.
  324. * @param [in] parent_index: index for parent node.
  325. * @return: NodePtr for success / nullptr for fail
  326. */
  327. NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  328. const NodePtr &sw_node, uint32_t parent_index) {
  329. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  330. OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA);
  331. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  332. op_desc->SetStreamId(kInvalidStreamId);
  333. GELOGI("Data: Create node %s.", op_desc->GetName().c_str());
  334. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  335. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add data input desc failed.");
  336. return nullptr;
  337. }
  338. if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) {
  339. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add data output desc failed.");
  340. return nullptr;
  341. }
  342. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  343. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add %s failed.", ATTR_NAME_PARENT_NODE_INDEX.c_str());
  344. return nullptr;
  345. }
  346. NodePtr op_data = graph->AddNodeFront(op_desc);
  347. GE_CHECK_NOTNULL_EXEC(op_data, return nullptr);
  348. GE_CHECK_NOTNULL_EXEC(graph->AddInputNode(op_data), return nullptr); // take as input node for memory assign.
  349. // Link control edge to graph head.
  350. if (GraphUtils::AddEdge(op_data->GetOutDataAnchor(0), sw_node->GetInDataAnchor(0)) != SUCCESS) {
  351. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add input edge to %s failed.", op_data->GetName().c_str());
  352. return nullptr;
  353. }
  354. return op_data;
  355. }
  356. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示