You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_pass_unittest.cc 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdint>
  18. #include <string>
  19. #define private public
  20. #include "graph/passes/switch_pass.h"
  21. #include "common/ge_inner_error_codes.h"
  22. #include "inc/pass_manager.h"
  23. #include "utils/graph_utils.h"
  24. #undef private
  25. using namespace domi;
  26. namespace ge {
  27. namespace {
  28. class UTEST_graph_passes_switch_pass : public testing::Test {
  29. protected:
  30. UTEST_graph_passes_switch_pass() {
  31. graph_ = std::make_shared<ComputeGraph>("test");
  32. vector<int64_t> shape_vec{1, 1, 1, 1};
  33. GeShape shape = GeShape(shape_vec);
  34. default_tensor_desc_ = std::make_shared<GeTensorDesc>();
  35. default_tensor_desc_->SetShape(shape);
  36. default_tensor_desc_->SetFormat(FORMAT_NCHW);
  37. default_tensor_desc_->SetDataType(DT_FLOAT);
  38. }
  39. NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt) {
  40. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  41. for (int i = 0; i < input_cnt; ++i) {
  42. op_desc->AddInputDesc(default_tensor_desc_->Clone());
  43. }
  44. for (int i = 0; i < output_cnt; ++i) {
  45. op_desc->AddOutputDesc(default_tensor_desc_->Clone());
  46. }
  47. NodePtr node = graph_->AddNode(op_desc);
  48. (void)node->SetOwnerComputeGraph(graph_);
  49. return node;
  50. }
  51. void BuildDefaultGraph(bool is_input_const, const bool *pred_value = nullptr) {
  52. /// input pred
  53. /// \ /
  54. /// Switch
  55. /// | |
  56. /// F T
  57. /// | |
  58. /// Merge
  59. ///
  60. bool is_pred_const = pred_value != nullptr;
  61. if (is_pred_const) {
  62. pred_node_ = NewNode("pred", CONSTANT, 0, 1);
  63. int32_t weight[] = {static_cast<int32_t>(*pred_value)};
  64. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  65. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  66. OpDescUtils::SetWeights(pred_node_, {tensor});
  67. } else {
  68. pred_node_ = NewNode("pred", GREATER, 2, 1);
  69. }
  70. if (is_input_const) {
  71. int32_t weight[] = {1};
  72. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  73. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  74. input_node_ = NewNode("input", CONSTANT, 0, 1);
  75. OpDescUtils::SetWeights(input_node_, {tensor});
  76. } else {
  77. input_node_ = NewNode("input", RELU, 0, 1);
  78. }
  79. switch_node_ = NewNode("switch", SWITCH, 2, 2);
  80. output_false_node_ = NewNode("false_output", RELU, 1, 1);
  81. output_true_node_ = NewNode("true_output", RELU, 1, 1);
  82. merge_node_ = NewNode("merge", MERGE, 2, 1);
  83. switch_node_->GetOpDesc()->SetIsInputConst({false, is_pred_const});
  84. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
  85. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
  86. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  87. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
  88. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  89. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  90. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  91. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  92. }
  93. void TestPickOutput(bool expect_output) {
  94. auto ret = pass_.Run(switch_node_);
  95. EXPECT_EQ(ret, SUCCESS);
  96. EXPECT_EQ(graph_->GetDirectNodesSize(), 5); // has two isolate nodes
  97. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  98. if (expect_output) {
  99. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  100. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  101. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  102. } else {
  103. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  104. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  105. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  106. }
  107. }
  108. ComputeGraphPtr graph_;
  109. GeTensorDescPtr default_tensor_desc_;
  110. SwitchPass pass_;
  111. NodePtr pred_node_;
  112. NodePtr input_node_;
  113. NodePtr switch_node_;
  114. NodePtr output_false_node_;
  115. NodePtr output_true_node_;
  116. NodePtr merge_node_;
  117. };
  118. } // namespace
  119. TEST_F(UTEST_graph_passes_switch_pass, NullInput) {
  120. NodePtr node = nullptr;
  121. auto ret = pass_.Run(node);
  122. EXPECT_EQ(ret, PARAM_INVALID);
  123. }
  124. TEST_F(UTEST_graph_passes_switch_pass, NullPred) {
  125. BuildDefaultGraph(false);
  126. switch_node_->GetInDataAnchor(1)->UnlinkAll();
  127. auto ret = pass_.Run(switch_node_);
  128. EXPECT_EQ(ret, SUCCESS);
  129. }
  130. TEST_F(UTEST_graph_passes_switch_pass, NullData) {
  131. BuildDefaultGraph(false);
  132. switch_node_->GetInDataAnchor(0)->UnlinkAll();
  133. auto ret = pass_.Run(switch_node_);
  134. EXPECT_EQ(ret, SUCCESS);
  135. }
  136. TEST_F(UTEST_graph_passes_switch_pass, UnsupportedNodeType) {
  137. auto node = NewNode("Op1", CONSTANT, 0, 1);
  138. auto ret = pass_.Run(node);
  139. EXPECT_EQ(ret, SUCCESS);
  140. }
  141. TEST_F(UTEST_graph_passes_switch_pass, EmptyOutput) {
  142. BuildDefaultGraph(false);
  143. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  144. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  145. auto ret = pass_.Run(switch_node_);
  146. EXPECT_EQ(ret, SUCCESS);
  147. }
  148. TEST_F(UTEST_graph_passes_switch_pass, NonConstPred) {
  149. BuildDefaultGraph(false);
  150. auto ret = pass_.Run(switch_node_);
  151. EXPECT_EQ(ret, SUCCESS);
  152. }
  153. TEST_F(UTEST_graph_passes_switch_pass, PickOutputFalse) {
  154. bool pred_value = false;
  155. BuildDefaultGraph(false, &pred_value);
  156. TestPickOutput(false);
  157. }
  158. TEST_F(UTEST_graph_passes_switch_pass, PickOutputFalseForFloat) {
  159. bool pred_value = false;
  160. BuildDefaultGraph(false, &pred_value);
  161. float weight[] = {0.0f};
  162. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
  163. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  164. OpDescUtils::SetWeights(pred_node_, {tensor});
  165. TestPickOutput(false);
  166. }
  167. TEST_F(UTEST_graph_passes_switch_pass, PickOutputFalseForBool) {
  168. bool pred_value = false;
  169. BuildDefaultGraph(false, &pred_value);
  170. bool weight[] = {false};
  171. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_BOOL);
  172. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  173. OpDescUtils::SetWeights(pred_node_, {tensor});
  174. TestPickOutput(false);
  175. }
  176. TEST_F(UTEST_graph_passes_switch_pass, PickOutputFalseForU16) {
  177. bool pred_value = false;
  178. BuildDefaultGraph(false, &pred_value);
  179. uint16_t weight[] = {0};
  180. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_UINT16);
  181. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  182. OpDescUtils::SetWeights(pred_node_, {tensor});
  183. TestPickOutput(false);
  184. }
  185. TEST_F(UTEST_graph_passes_switch_pass, PickOutputTrue) {
  186. bool pred_value = true;
  187. BuildDefaultGraph(false, &pred_value);
  188. TestPickOutput(true);
  189. }
  190. TEST_F(UTEST_graph_passes_switch_pass, PickOutputTrueForDouble) {
  191. bool pred_value = true;
  192. BuildDefaultGraph(false, &pred_value);
  193. double weight[] = {1.0};
  194. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_DOUBLE);
  195. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  196. OpDescUtils::SetWeights(pred_node_, {tensor});
  197. TestPickOutput(true);
  198. }
  199. TEST_F(UTEST_graph_passes_switch_pass, PickOutputTrueForInt64) {
  200. bool pred_value = true;
  201. BuildDefaultGraph(false, &pred_value);
  202. int64_t weight[] = {1L};
  203. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT64);
  204. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  205. OpDescUtils::SetWeights(pred_node_, {tensor});
  206. TestPickOutput(true);
  207. }
  208. TEST_F(UTEST_graph_passes_switch_pass, InactiveOutputNotExists) {
  209. /// input pred(false)
  210. /// \ /
  211. /// Switch
  212. /// |
  213. /// F
  214. /// |
  215. /// Merge
  216. bool pred_value = false;
  217. BuildDefaultGraph(false, &pred_value);
  218. output_true_node_->GetOutDataAnchor(0)->UnlinkAll();
  219. GraphUtils::RemoveNodeWithoutRelink(graph_, output_true_node_);
  220. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  221. // switch_node_->outDataAnchors_.pop_back();
  222. /// input
  223. /// |
  224. /// F
  225. /// |
  226. /// Merge
  227. auto ret = pass_.Run(switch_node_);
  228. EXPECT_EQ(ret, SUCCESS);
  229. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  230. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  231. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  232. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  233. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  234. }
  235. TEST_F(UTEST_graph_passes_switch_pass, ConstInputPickOutputTrue) {
  236. /// const pred(true)
  237. /// \ /
  238. /// Switch
  239. /// | | \
  240. /// F T1 T2
  241. /// | | |
  242. /// | | /
  243. /// | T3
  244. /// | |
  245. /// Merge
  246. bool pred_value = true;
  247. BuildDefaultGraph(true, &pred_value);
  248. auto output_true_node2 = NewNode("true_output2", RELU, 1, 1);
  249. auto output_true_node3 = NewNode("true_output3", ADD, 2, 1);
  250. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node2->GetInDataAnchor(0));
  251. GraphUtils::RemoveEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  252. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(0));
  253. GraphUtils::AddEdge(output_true_node2->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(1));
  254. GraphUtils::AddEdge(output_true_node3->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  255. /// pred C
  256. /// | | |
  257. /// F T1 T2
  258. /// | /
  259. /// T3
  260. /// |
  261. /// Merge
  262. auto ret = pass_.Run(switch_node_);
  263. EXPECT_EQ(ret, SUCCESS);
  264. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  265. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  266. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  267. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node3->GetOutDataAnchor(0));
  268. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  269. EXPECT_NE(output_true_node2->GetInDataAnchor(0)->GetPeerOutAnchor(),
  270. output_true_node3->GetInDataAnchor(0)->GetPeerOutAnchor());
  271. }
  272. TEST_F(UTEST_graph_passes_switch_pass, AfterSwitchConstTakeFalseBranch) {
  273. /// C pred(false)
  274. /// \ /
  275. /// Switch
  276. /// . .
  277. /// . .
  278. /// C_1 -> F T <- C_2
  279. /// | |
  280. /// Merge
  281. bool pred_value = false;
  282. BuildDefaultGraph(true, &pred_value);
  283. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  284. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  285. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  286. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  287. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  288. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  289. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  290. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  291. /// C pred(false)
  292. ///
  293. /// C_1 C_2
  294. /// | |
  295. /// F T
  296. /// |
  297. /// Merge
  298. auto ret = pass_.Run(switch_node_);
  299. EXPECT_EQ(ret, SUCCESS);
  300. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  301. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  302. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  303. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  304. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_1->GetOutDataAnchor(0));
  305. }
  306. TEST_F(UTEST_graph_passes_switch_pass, AfterSwitchConstTakeTrueBranch) {
  307. /// C pred(true)
  308. /// \ /
  309. /// Switch
  310. /// . .
  311. /// . .
  312. /// C_1 -> F T <- C_2
  313. /// | |
  314. /// Merge
  315. bool pred_value = true;
  316. BuildDefaultGraph(true, &pred_value);
  317. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  318. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  319. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  320. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  321. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  322. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  323. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  324. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  325. /// C_1 C_2
  326. /// | |
  327. /// F T
  328. /// |
  329. /// Merge
  330. auto ret = pass_.Run(switch_node_);
  331. EXPECT_EQ(ret, SUCCESS);
  332. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  333. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  334. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  335. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  336. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_2->GetOutDataAnchor(0));
  337. }
  338. TEST_F(UTEST_graph_passes_switch_pass, DeadOutputConnectedToMerge) {
  339. /// input pred(true)
  340. /// \ /
  341. /// Switch
  342. /// | |
  343. /// | T
  344. /// | |
  345. /// Merge
  346. bool pred_value = true;
  347. BuildDefaultGraph(false, &pred_value);
  348. // graph_->RemoveNode(output_false_node_);
  349. output_false_node_->GetOutDataAnchor(0)->UnlinkAll();
  350. GraphUtils::RemoveNodeWithoutRelink(graph_, output_false_node_);
  351. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  352. /// input pred(true)
  353. /// \ /
  354. /// Switch
  355. /// |
  356. /// T
  357. /// |
  358. /// Merge
  359. auto ret = pass_.Run(switch_node_);
  360. EXPECT_EQ(ret, SUCCESS);
  361. /// input
  362. /// |
  363. /// T
  364. /// |
  365. /// Merge
  366. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  367. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  368. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  369. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  370. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  371. }
  372. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)