You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_ir_build_unittest.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <stdio.h>
  17. #include <gtest/gtest.h>
  18. #include "ir_build/option_utils.h"
  19. #include "graph/testcase/ge_graph/graph_builder_utils.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "ge/ge_ir_build.h"
  23. #include "graph/ops_stub.h"
  24. #define protected public
  25. #define private public
  26. #undef private
  27. #undef protected
  28. const string DATA = "Data";
  29. const string AddNYes = "AddNYes";
  30. const string NETOUTPUT = "NetOutput";
  31. using namespace ge;
  32. class UtestIrCommon : public testing::Test {
  33. protected:
  34. void SetUp() {}
  35. void TearDown() {}
  36. };
  37. class UtestIrBuild : public testing::Test {
  38. protected:
  39. void SetUp() {}
  40. void TearDown() {}
  41. };
  42. static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) {
  43. OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type);
  44. ge::GeTensorDesc ge_tensor_desc;
  45. op_desc->AddInputDesc("input", ge_tensor_desc);
  46. op_desc->AddOutputDesc("output", ge_tensor_desc);
  47. return op_desc;
  48. }
  49. static ComputeGraphPtr BuildComputeGraph() {
  50. auto builder = ut::GraphBuilder("test");
  51. auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
  52. auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
  53. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  54. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  55. builder.AddDataEdge(data1, 0, addn1, 0);
  56. builder.AddDataEdge(data2, 0, addn1, 1);
  57. builder.AddDataEdge(addn1, 0,netoutput, 0);
  58. return builder.GetGraph();
  59. }
  60. // data not set attr index;
  61. // but becasue of op proto, register attr index. so all data index is zero;
  62. static Graph BuildIrGraph() {
  63. auto data1 = op::Data("data1");
  64. auto data2 = op::Data("data2");
  65. auto data3 = op::Data("data3");
  66. std::vector<Operator> inputs {data1, data2, data3};
  67. std::vector<Operator> outputs;
  68. Graph graph("test_graph");
  69. graph.SetInputs(inputs).SetOutputs(outputs);
  70. return graph;
  71. }
  72. // data set attr index, but is not valid
  73. static Graph BuildIrGraph1() {
  74. auto data1 = op::Data("data1").set_attr_index(0);
  75. auto data2 = op::Data("data2").set_attr_index(1);
  76. auto data3 = op::Data("data3");
  77. std::vector<Operator> inputs {data1, data2, data3};
  78. std::vector<Operator> outputs;
  79. Graph graph("test_graph");
  80. graph.SetInputs(inputs).SetOutputs(outputs);
  81. return graph;
  82. }
  83. // data set attr index, but is not valid
  84. static Graph BuildIrGraph2() {
  85. auto data1 = op::Data("data1").set_attr_index(0);
  86. auto data2 = op::Data("data2");
  87. auto data3 = op::Data("data3").set_attr_index(2);
  88. std::vector<Operator> inputs {data1, data2, data3};
  89. std::vector<Operator> outputs;
  90. Graph graph("test_graph");
  91. graph.SetInputs(inputs).SetOutputs(outputs);
  92. return graph;
  93. }
  94. // data set attr index
  95. static Graph BuildIrGraph3() {
  96. auto data1 = op::Data("data1").set_attr_index(0);
  97. auto data2 = op::Data("data2").set_attr_index(1);
  98. auto data3 = op::Data("data3").set_attr_index(2);
  99. std::vector<Operator> inputs {data1, data2, data3};
  100. std::vector<Operator> outputs;
  101. Graph graph("test_graph");
  102. graph.SetInputs(inputs).SetOutputs(outputs);
  103. return graph;
  104. }
  105. TEST(UtestIrCommon, update_data_op_shape) {
  106. ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
  107. map<string, vector<int64_t>> shape_map;
  108. shape_map["Data"] = {{1,2}};
  109. Status ret = UpdateDataOpShape(op_desc, shape_map);
  110. EXPECT_EQ(ret, ge::SUCCESS);
  111. }
  112. TEST(UtestIrCommon, update_data_op_shape_range) {
  113. ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
  114. std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map;
  115. std::pair<int64_t, int64_t> range_pair(1, 2);
  116. vector<pair<int64_t, int64_t>> range_pair_tmp = { range_pair };
  117. index_shape_range_map.push_back(range_pair_tmp);
  118. AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0);
  119. Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map);
  120. EXPECT_EQ(ret, ge::SUCCESS);
  121. }
  122. TEST(UtestIrCommon, update_dynamic_shape_range_success) {
  123. ComputeGraphPtr graph = BuildComputeGraph();
  124. std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]";
  125. Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  126. EXPECT_EQ(ret, ge::SUCCESS);
  127. }
  128. TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
  129. ComputeGraphPtr graph = BuildComputeGraph();
  130. // 1
  131. std::string input_shape_range = "input1;[1, 2~3, -1]";
  132. Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  133. EXPECT_EQ(ret, ge::PARAM_INVALID);
  134. // 2
  135. input_shape_range = "input1:[1, 2~3, -1)";
  136. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  137. EXPECT_EQ(ret, ge::PARAM_INVALID);
  138. //3
  139. input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]";
  140. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  141. EXPECT_EQ(ret, ge::FAILED);
  142. //4
  143. input_shape_range = "input1:[1, 2~-3, -1]";
  144. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  145. EXPECT_EQ(ret, ge::PARAM_INVALID);
  146. //5
  147. input_shape_range = "input:[1, 2~3, -1]";
  148. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  149. EXPECT_EQ(ret, ge::PARAM_INVALID);
  150. //6
  151. input_shape_range = "addn1:[1, 2~3, -1]";
  152. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  153. EXPECT_EQ(ret, ge::PARAM_INVALID);
  154. }
  155. TEST(UtestIrCommon, check_dynamic_image_size_fail) {
  156. map<string, vector<int64_t>> shape_map;
  157. shape_map["input1"] = {8, 3, -1, -1};
  158. string input_format = "NCHW";
  159. string dynamic_image_size = "@64,64;128,128;";
  160. bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
  161. EXPECT_EQ(ret, false);
  162. }
  163. TEST(UtestIrCommon, check_input_format_failed) {
  164. std::string format = "invalid";
  165. Status ret = CheckInputFormat(format);
  166. EXPECT_EQ(ret, ge::PARAM_INVALID);
  167. }
  168. TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) {
  169. map<string, vector<int64_t>> shape_map;
  170. shape_map.insert(std::pair<string, vector<int64_t>>("data", {-1, 2, 3}));
  171. std::string dynamic_batch_size = "11";
  172. bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size);
  173. EXPECT_EQ(ret, true);
  174. }
  175. TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) {
  176. map<string, vector<int64_t>> shape_map;
  177. shape_map.insert(std::pair<string, vector<int64_t>>("data", {4, -1, -1, 5}));
  178. std::string input_format = "NCHW";
  179. std::string dynamic_image_size = "4,5";
  180. Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
  181. EXPECT_EQ(ret, ge::SUCCESS);
  182. }
  183. TEST(UtestIrCommon, check_dynamic_input_param_succ) {
  184. string dynamic_batch_size = "1";
  185. string dynamic_image_size;
  186. string dynamic_dims;
  187. string input_shape = "data:-1,3,244,244";
  188. string input_shape_range;
  189. string input_format = "NCHW";
  190. bool is_dynamic_input = false;
  191. Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims,
  192. input_shape, input_shape_range, input_format,is_dynamic_input);
  193. EXPECT_EQ(ret, ge::SUCCESS);
  194. }
  195. TEST(UtestIrCommon, check_dynamic_input_param_failed) {
  196. string dynamic_batch_size = "1";
  197. string dynamic_image_size;
  198. string dynamic_dims;
  199. string input_shape = "data:1,3,244,244";
  200. string input_shape_range;
  201. string input_format = "NCHW";
  202. bool is_dynamic_input = false;
  203. Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims,
  204. input_shape, input_shape_range, input_format,is_dynamic_input);
  205. EXPECT_EQ(ret, ge::PARAM_INVALID);
  206. }
  207. TEST(UtestIrCommon, check_modify_mixlist_param) {
  208. std::string precision_mode = "allow_mix_precision";
  209. std::string modify_mixlist = "/mixlist.json";
  210. Status ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist);
  211. EXPECT_EQ(ret, ge::SUCCESS);
  212. precision_mode = "";
  213. ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist);
  214. EXPECT_EQ(ret, ge::PARAM_INVALID);
  215. }
  216. TEST(UtestIrCommon, check_compress_weight) {
  217. std::string enable_compress_weight = "true";
  218. std::string compress_weight_conf="./";
  219. Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
  220. EXPECT_EQ(ret, PARAM_INVALID);
  221. enable_compress_weight = "yes";
  222. compress_weight_conf = "./";
  223. ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
  224. EXPECT_EQ(ret, PARAM_INVALID);
  225. }
  226. TEST(UtestIrCommon, check_param_failed) {
  227. std::string param_invalid = "invalid";
  228. Status ret = CheckOutputTypeParamValid(param_invalid);
  229. EXPECT_EQ(ret, PARAM_INVALID);
  230. ret = CheckBufferOptimizeParamValid(param_invalid);
  231. EXPECT_EQ(ret, PARAM_INVALID);
  232. ret = CheckKeepTypeParamValid(param_invalid);
  233. EXPECT_EQ(ret, PARAM_INVALID);
  234. ret = CheckInsertOpConfParamValid(param_invalid);
  235. EXPECT_EQ(ret, PARAM_INVALID);
  236. ret = CheckDisableReuseMemoryParamValid(param_invalid);
  237. EXPECT_EQ(ret, PARAM_INVALID);
  238. ret = CheckEnableSingleStreamParamValid(param_invalid);
  239. EXPECT_EQ(ret, PARAM_INVALID);
  240. std::string optypelist_for_implmode;
  241. std::string op_select_implmode = "1";
  242. ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode);
  243. EXPECT_EQ(ret, PARAM_INVALID);
  244. ret = CheckLogParamValidAndSetLogLevel(param_invalid);
  245. }
  246. // Get attr index failed, when set input shape range
  247. TEST(UtestIrBuild, check_data_op_attr_index_invalid_0) {
  248. ComputeGraphPtr compute_graph = BuildComputeGraph();
  249. Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  250. const map<string, string> build_options = {
  251. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  252. };
  253. ModelBufferData model;
  254. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  255. EXPECT_EQ(ret, GRAPH_FAILED);
  256. }
  257. // not set attr index, when set input shape range
  258. TEST(UtestIrBuild, check_data_op_attr_index_invalid_1) {
  259. Graph graph = BuildIrGraph();
  260. const map<string, string> build_options = {
  261. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  262. };
  263. ModelBufferData model;
  264. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  265. EXPECT_EQ(ret, GRAPH_FAILED);
  266. }
  267. // set attr index, but not valid, when set input shape range
  268. TEST(UtestIrBuild, check_data_op_attr_index_invalid_2) {
  269. Graph graph = BuildIrGraph1();
  270. const map<string, string> build_options = {
  271. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  272. };
  273. ModelBufferData model;
  274. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  275. EXPECT_EQ(ret, GRAPH_FAILED);
  276. Graph graph2 = BuildIrGraph2();
  277. ret = aclgrphBuildModel(graph2, build_options, model);
  278. EXPECT_EQ(ret, GRAPH_FAILED);
  279. }
  280. // set attr index valid, when set input shape range
  281. // only check data op attr index valid func.
  282. TEST(UtestIrBuild, check_data_op_attr_index_valid) {
  283. Graph graph = BuildIrGraph3();
  284. const map<string, string> build_options = {
  285. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  286. };
  287. ModelBufferData model;
  288. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  289. EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
  290. }
  291. // set attr index invalid, when not set input shape range
  292. // only check data op attr index valid func.
  293. TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) {
  294. Graph graph = BuildIrGraph1();
  295. const map<string, string> build_options;
  296. ModelBufferData model;
  297. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  298. EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
  299. }
  300. TEST(UtestIrBuild, check_modify_mixlist_param) {
  301. Graph graph = BuildIrGraph1();
  302. const std::map<std::string, std::string> build_options = {
  303. {"ge.exec.modify_mixlist", "/modify.json"}
  304. };
  305. ModelBufferData model;
  306. auto ret = aclgrphBuildModel(graph, build_options, model);
  307. EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
  308. }
  309. TEST(UtestIrBuild, check_cfg_optype_param) {
  310. Graph graph = BuildIrGraph1();
  311. FILE *fp = fopen("./keep.txt", "w+");
  312. if (fp) {
  313. fprintf(fp, "Test\n");
  314. fprintf(fp, "OpType::Mul\n");
  315. fprintf(fp, "Optype::Sub\n");
  316. fclose(fp);
  317. }
  318. auto ret = aclgrphSetOpAttr(graph, ATTR_TYPE_KEEP_DTYPE, "./keep.txt");
  319. (void)remove("./keep.txt");
  320. EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
  321. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示