You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_manager_unittest.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #define private public
  18. #define protected public
  19. #include "graph/load/model_manager/model_manager.h"
  20. #include "common/helper/om_file_helper.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "common/op/ge_op_utils.h"
  24. #include "graph/load/graph_loader.h"
  25. #include "graph/load/model_manager/davinci_model.h"
  26. using namespace std;
  27. using namespace testing;
  28. namespace ge {
  29. const static std::string ENC_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
  30. class UtestModelManagerModelManager : public testing::Test {
  31. protected:
  32. static Status LoadStub(const uint8_t *data, size_t len, Model &model) {
  33. InitModelDefault(model);
  34. return SUCCESS;
  35. }
  36. static void InitModelDefault(Model &model) {
  37. AttrUtils::SetInt(&model, ATTR_MODEL_MEMORY_SIZE, 0);
  38. AttrUtils::SetInt(&model, ATTR_MODEL_WEIGHT_SIZE, 0);
  39. AttrUtils::SetInt(&model, ATTR_MODEL_STREAM_NUM, 0);
  40. AttrUtils::SetInt(&model, ATTR_MODEL_EVENT_NUM, 0);
  41. AttrUtils::SetStr(&model, ATTR_MODEL_TARGET_TYPE, "MINI"); // domi::MINI
  42. auto computeGraph = std::make_shared<ComputeGraph>("graph");
  43. auto graph = GraphUtils::CreateGraphFromComputeGraph(computeGraph);
  44. model.SetGraph(graph);
  45. }
  46. void SetUp() {}
  47. void TearDown() {}
  48. void GenUnencryptModelData(ModelData &data) {
  49. const int model_len = 10;
  50. data.model_len = sizeof(ModelFileHeader) + model_len;
  51. data.model_data = new uint8_t[data.model_len];
  52. memset((uint8_t *)data.model_data + sizeof(ModelFileHeader), 10, model_len);
  53. ModelFileHeader *header = (ModelFileHeader *)data.model_data;
  54. header->magic = MODEL_FILE_MAGIC_NUM;
  55. header->version = MODEL_VERSION;
  56. header->is_encrypt = ModelEncryptType::UNENCRYPTED;
  57. header->length = model_len;
  58. header->is_checksum = ModelCheckType::CHECK;
  59. }
  60. void GenEncryptModelData(ModelData &data) {
  61. const int model_len = 10;
  62. data.key = ENC_KEY;
  63. data.model_data = new uint8_t[data.model_len];
  64. uint8_t data_ori[model_len];
  65. memset(data_ori, 10, model_len);
  66. ModelFileHeader *header = (ModelFileHeader *)data.model_data;
  67. header->magic = MODEL_FILE_MAGIC_NUM;
  68. header->version = MODEL_VERSION;
  69. header->is_encrypt = ModelEncryptType::ENCRYPTED;
  70. header->length = 10; // encrypt_len;
  71. }
  72. void LoadStandardModelData(ModelData &data) {
  73. data.model_len = 512;
  74. data.model_data = new uint8_t[data.model_len];
  75. uint8_t *model_data = reinterpret_cast<uint8_t *>(data.model_data);
  76. uint32_t mem_offset = sizeof(ModelFileHeader);
  77. ModelPartitionTable *partition_table = reinterpret_cast<ModelPartitionTable *>(model_data + mem_offset);
  78. partition_table->num = PARTITION_SIZE;
  79. mem_offset += sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * 5;
  80. {
  81. Model model;
  82. ComputeGraphPtr graph = make_shared<ComputeGraph>("default");
  83. model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
  84. model.SetVersion(123);
  85. Buffer buffer;
  86. model.Save(buffer);
  87. EXPECT_TRUE(mem_offset + buffer.GetSize() < 512);
  88. memcpy(model_data + mem_offset, buffer.GetData(), buffer.GetSize());
  89. ModelPartitionMemInfo &partition_info = partition_table->partition[0];
  90. partition_info.type = ModelPartitionType::MODEL_DEF;
  91. partition_info.mem_size = buffer.GetSize();
  92. mem_offset += buffer.GetSize();
  93. }
  94. {
  95. ModelPartitionMemInfo &partition_info = partition_table->partition[1];
  96. partition_info.type = ModelPartitionType::WEIGHTS_DATA;
  97. partition_info.mem_offset = mem_offset;
  98. partition_info.mem_size = 0;
  99. }
  100. {
  101. ModelPartitionMemInfo &partition_info = partition_table->partition[2];
  102. partition_info.type = ModelPartitionType::TASK_INFO;
  103. partition_info.mem_offset = mem_offset;
  104. partition_info.mem_size = 0;
  105. }
  106. {
  107. ModelPartitionMemInfo &partition_info = partition_table->partition[3];
  108. partition_info.type = ModelPartitionType::TBE_KERNELS;
  109. partition_info.mem_offset = mem_offset;
  110. partition_info.mem_size = 0;
  111. }
  112. {
  113. ModelPartitionMemInfo &partition_info = partition_table->partition[4];
  114. partition_info.type = ModelPartitionType::CUST_AICPU_KERNELS;
  115. partition_info.mem_offset = mem_offset;
  116. partition_info.mem_size = 0;
  117. }
  118. EXPECT_TRUE(mem_offset < 512);
  119. ModelFileHeader *header = new (data.model_data) ModelFileHeader;
  120. header->length = mem_offset - sizeof(ModelFileHeader);
  121. data.model_len = mem_offset;
  122. }
  123. };
  124. class DModelListener : public ModelListener {
  125. public:
  126. DModelListener(){};
  127. uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; }
  128. };
  129. TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) {
  130. ModelManager mm;
  131. uint32_t model_id = 0;
  132. ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("graph");
  133. ge::GeRootModel model;
  134. EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
  135. model.SetRootGraph(root_graph);
  136. EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
  137. }
  138. TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) {
  139. ModelManager mm;
  140. uint32_t model_id = 0;
  141. ModelData data;
  142. // Load allow listener is null
  143. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID);
  144. }
  145. TEST_F(UtestModelManagerModelManager, case_load_model_len_too_short) {
  146. ModelManager mm;
  147. ModelData data;
  148. data.model_len = 10;
  149. data.model_data = (void *)&data;
  150. uint32_t model_id = 1;
  151. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  152. data.model_data = nullptr;
  153. }
  154. TEST_F(UtestModelManagerModelManager, case_load_model_len_not_match) {
  155. ModelManager mm;
  156. ModelData data;
  157. GenUnencryptModelData(data);
  158. data.model_len = sizeof(ModelFileHeader) + 1;
  159. uint32_t model_id = 1;
  160. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  161. delete[](uint8_t *) data.model_data;
  162. }
  163. TEST_F(UtestModelManagerModelManager, case_load_model_encypt_not_match) {
  164. ModelManager mm;
  165. ModelData data;
  166. GenUnencryptModelData(data);
  167. data.key = ENC_KEY;
  168. uint32_t model_id = 1;
  169. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  170. delete[](uint8_t *) data.model_data;
  171. }
  172. TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) {
  173. ModelManager mm;
  174. ModelData data;
  175. GenUnencryptModelData(data);
  176. ModelFileHeader *header = (ModelFileHeader *)data.model_data;
  177. header->is_encrypt = 255;
  178. uint32_t model_id = 1;
  179. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  180. delete[](uint8_t *) data.model_data;
  181. }
  182. TEST_F(UtestModelManagerModelManager, case_load_model_data_success) {
  183. ModelData data;
  184. LoadStandardModelData(data);
  185. uint32_t model_id = 1;
  186. ModelManager mm;
  187. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), SUCCESS);
  188. delete[](uint8_t *) data.model_data;
  189. }
  190. /*
  191. shared_ptr<ModelListener> LabelCallBack(new DModelListener());
  192. // test HandleCommand
  193. TEST_F(UtestModelManagerModelManager, command_success1) {
  194. ModelManager manager;
  195. Command cmd;
  196. cmd.cmd_type = "INFERENCE";
  197. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  198. cmd.cmd_type = "NOT SUPPORT";
  199. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  200. }
  201. TEST_F(UtestModelManagerModelManager, command_success2) {
  202. ModelManager manager;
  203. Command cmd;
  204. cmd.cmd_type = "dump";
  205. cmd.cmd_params.push_back("status");
  206. cmd.cmd_params.push_back("on");
  207. cmd.cmd_params.push_back("model_name");
  208. cmd.cmd_params.push_back("test_model");
  209. cmd.cmd_params.push_back("path");
  210. cmd.cmd_params.push_back("/test");
  211. cmd.cmd_params.push_back("layer");
  212. cmd.cmd_params.push_back("layer1");
  213. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  214. }
  215. // test profile
  216. TEST_F(UtestModelManagerModelManager, command_profile_success) {
  217. ModelManager manager;
  218. Command cmd;
  219. cmd.cmd_type = "profile";
  220. cmd.cmd_params.push_back("ome");
  221. cmd.cmd_params.push_back("on");
  222. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  223. bool ome_profile_on = PropertiesManager::Instance().GetPropertyValue(OME_PROFILE) == "1";
  224. EXPECT_EQ(true, ome_profile_on);
  225. cmd.cmd_params.clear();
  226. cmd.cmd_params.push_back("ome");
  227. cmd.cmd_params.push_back("off");
  228. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  229. ome_profile_on = PropertiesManager::Instance().GetPropertyValue(OME_PROFILE) == "1";
  230. EXPECT_FALSE(ome_profile_on);
  231. cmd.cmd_params.clear();
  232. cmd.cmd_params.push_back("cce");
  233. cmd.cmd_params.push_back("on");
  234. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  235. bool cce_profile_on = PropertiesManager::Instance().GetPropertyValue(CCE_PROFILE) == "1";
  236. EXPECT_EQ(true, cce_profile_on);
  237. cmd.cmd_params.clear();
  238. cmd.cmd_params.push_back("cce");
  239. cmd.cmd_params.push_back("off");
  240. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  241. cce_profile_on = PropertiesManager::Instance().GetPropertyValue(CCE_PROFILE) == "1";
  242. EXPECT_FALSE(cce_profile_on);
  243. cmd.cmd_params.clear();
  244. cmd.cmd_params.push_back("runtime");
  245. cmd.cmd_params.push_back("on");
  246. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  247. bool rts_profile_on = PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE) == "1";
  248. EXPECT_EQ(true, rts_profile_on);
  249. cmd.cmd_params.clear();
  250. cmd.cmd_params.push_back("runtime");
  251. cmd.cmd_params.push_back("off");
  252. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  253. rts_profile_on = PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE) == "1";
  254. EXPECT_FALSE(rts_profile_on);
  255. cmd.cmd_params.clear();
  256. cmd.cmd_params.push_back("profiler_jobctx");
  257. cmd.cmd_params.push_back("jobctx");
  258. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  259. EXPECT_EQ("jobctx", PropertiesManager::Instance().GetPropertyValue(PROFILER_JOBCTX));
  260. cmd.cmd_params.clear();
  261. cmd.cmd_params.push_back("profiler_target_path");
  262. cmd.cmd_params.push_back("/test/target");
  263. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  264. EXPECT_EQ("/test/target", PropertiesManager::Instance().GetPropertyValue(PROFILER_TARGET_PATH));
  265. cmd.cmd_params.clear();
  266. cmd.cmd_params.push_back("RTS_PATH");
  267. cmd.cmd_params.push_back("/test/rts_path");
  268. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  269. EXPECT_EQ("/test/rts_path", PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE_PATH));
  270. }
  271. // test acl profiling
  272. TEST_F(UtestModelManagerModelManager, command_profiling) {
  273. ModelManager manager;
  274. Command cmd;
  275. cmd.cmd_type = "profiling";
  276. cmd.cmd_params.push_back("config");
  277. cmd.cmd_params.push_back("on");
  278. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  279. }
  280. TEST_F(UtestModelManagerModelManager, command_profile_failed) {
  281. ModelManager manager;
  282. Command cmd;
  283. cmd.cmd_type = "profile";
  284. cmd.cmd_params.push_back("ome");
  285. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  286. }
  287. // test Start
  288. TEST_F(UtestModelManagerModelManager, start_fail) {
  289. ModelManager manager;
  290. manager.model_map_[2] = nullptr;
  291. EXPECT_EQ(PARAM_INVALID, manager.Start(2));
  292. }
  293. // test GetMaxUsedMemory
  294. TEST_F(UtestModelManagerModelManager, get_max_used_memory_fail) {
  295. ModelManager manager;
  296. uint64_t max_size = 0;
  297. manager.model_map_[2] = nullptr;
  298. EXPECT_EQ(PARAM_INVALID, manager.GetMaxUsedMemory(2, max_size));
  299. }
  300. // test GetInputOutputDescInfo
  301. TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_fail) {
  302. ModelManager manager;
  303. manager.model_map_[2] = nullptr;
  304. vector<InputOutputDescInfo> input_shape;
  305. vector<InputOutputDescInfo> output_shape;
  306. EXPECT_EQ(PARAM_INVALID, manager.GetInputOutputDescInfo(2, input_shape, output_shape));
  307. }
  308. *//*
  309. // test GetInputOutputDescInfo fail
  310. TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) {
  311. ModelManager manager;
  312. manager.model_map_[2] = nullptr;
  313. vector<InputOutputDescInfo> input_shape;
  314. vector<InputOutputDescInfo> output_shape;
  315. EXPECT_EQ(PARAM_INVALID, manager.GetInputOutputDescInfoForZeroCopy(2, input_shape, output_shape));
  316. }
  317. *//*
  318. // test Stop
  319. TEST_F(UtestModelManagerModelManager, stop_fail) {
  320. ModelManager manager;
  321. manager.model_map_[2] = nullptr;
  322. EXPECT_EQ(PARAM_INVALID, manager.Stop(2));
  323. }
  324. // build input_data
  325. TEST_F(UtestModelManagerModelManager, check_data_len_success) {
  326. shared_ptr<ModelListener> g_label_call_back(new DModelListener());
  327. DavinciModel model(0, g_label_call_back);
  328. ModelManager model_manager;
  329. InputData input_data;
  330. DataBuffer data_buffer;
  331. data_buffer.data = new char[51200];
  332. data_buffer.length = 51200;
  333. input_data.index = 0;
  334. input_data.model_id = 1;
  335. input_data.blobs.push_back(data_buffer);
  336. delete[](char *) data_buffer.data;
  337. }
  338. // test LoadModeldef
  339. TEST_F(UtestModelManagerModelManager, destroy_aicpu_session) {
  340. ModelManager manager;
  341. manager.DestroyAicpuSession(0);
  342. manager.sess_ids_.insert(0);
  343. manager.DestroyAicpuSession(0);
  344. }*/
  345. // test DataInputTensor
  346. TEST_F(UtestModelManagerModelManager, test_data_input_tensor) {
  347. shared_ptr<ModelListener> g_label_call_back(nullptr);
  348. auto model = std::make_shared<DavinciModel>(0, g_label_call_back);
  349. ModelManager mm;
  350. uint32_t model_id = 1;
  351. mm.model_map_[1] = model;
  352. mm.hybrid_model_map_[1] = std::make_shared<hybrid::HybridDavinciModel>();
  353. auto input_tensor = InputTensorInfo();
  354. vector<InputTensorInfo> inputs;
  355. inputs.emplace_back(input_tensor);
  356. auto ret = mm.DataInputTensor(model_id,inputs);
  357. EXPECT_EQ(UNSUPPORTED, ret);
  358. }
  359. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示