You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_manager_unittest.cc 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #define private public
  18. #define protected public
  19. #include "graph/load/model_manager/model_manager.h"
  20. #include "common/helper/om_file_helper.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "common/op/ge_op_utils.h"
  24. #include "graph/load/graph_loader.h"
  25. #include "graph/load/model_manager/davinci_model.h"
  26. #include "graph/ops_stub.h"
  27. #include "common/profiling/profiling_manager.h"
  28. using namespace std;
  29. using namespace testing;
  30. namespace ge {
  31. const static std::string ENC_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
  32. class UtestModelManagerModelManager : public testing::Test {
  33. protected:
  34. static Status LoadStub(const uint8_t *data, size_t len, Model &model) {
  35. InitModelDefault(model);
  36. return SUCCESS;
  37. }
  38. static void InitModelDefault(Model &model) {
  39. AttrUtils::SetInt(&model, ATTR_MODEL_MEMORY_SIZE, 0);
  40. AttrUtils::SetInt(&model, ATTR_MODEL_WEIGHT_SIZE, 0);
  41. AttrUtils::SetInt(&model, ATTR_MODEL_STREAM_NUM, 0);
  42. AttrUtils::SetInt(&model, ATTR_MODEL_EVENT_NUM, 0);
  43. AttrUtils::SetStr(&model, ATTR_MODEL_TARGET_TYPE, "MINI"); // domi::MINI
  44. auto computeGraph = std::make_shared<ComputeGraph>("graph");
  45. auto graph = GraphUtils::CreateGraphFromComputeGraph(computeGraph);
  46. model.SetGraph(graph);
  47. }
  48. void SetUp() {}
  49. void TearDown() {}
  50. void CreateGraph(Graph &graph) {
  51. TensorDesc desc(ge::Shape({1, 3, 224, 224}));
  52. uint32_t size = desc.GetShape().GetShapeSize();
  53. desc.SetSize(size);
  54. auto data = op::Data("Data").set_attr_index(0);
  55. data.update_input_desc_data(desc);
  56. data.update_output_desc_out(desc);
  57. auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out());
  58. std::vector<Operator> inputs{data};
  59. std::vector<Operator> outputs{flatten};
  60. std::vector<Operator> targets{flatten};
  61. // Graph graph("test_graph");
  62. graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets);
  63. }
  64. void GenUnencryptModelData(ModelData &data) {
  65. const int model_len = 10;
  66. data.model_len = sizeof(ModelFileHeader) + model_len;
  67. data.model_data = new uint8_t[data.model_len];
  68. memset(data.model_data, 0, data.model_len);
  69. ModelFileHeader *header = (ModelFileHeader *)data.model_data;
  70. header->magic = MODEL_FILE_MAGIC_NUM;
  71. header->version = MODEL_VERSION;
  72. header->is_encrypt = ModelEncryptType::UNENCRYPTED;
  73. header->length = model_len;
  74. header->is_checksum = ModelCheckType::CHECK;
  75. }
  76. void LoadStandardModelData(ModelData &data) {
  77. data.model_len = 512;
  78. data.model_data = new uint8_t[data.model_len];
  79. uint8_t *model_data = reinterpret_cast<uint8_t *>(data.model_data);
  80. uint32_t mem_offset = sizeof(ModelFileHeader);
  81. ModelPartitionTable *partition_table = reinterpret_cast<ModelPartitionTable *>(model_data + mem_offset);
  82. partition_table->num = PARTITION_SIZE;
  83. mem_offset += sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * 5;
  84. {
  85. Model model;
  86. ComputeGraphPtr graph = make_shared<ComputeGraph>("default");
  87. model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
  88. model.SetVersion(123);
  89. Buffer buffer;
  90. model.Save(buffer);
  91. EXPECT_TRUE(mem_offset + buffer.GetSize() < 512);
  92. memcpy(model_data + mem_offset, buffer.GetData(), buffer.GetSize());
  93. ModelPartitionMemInfo &partition_info = partition_table->partition[0];
  94. partition_info.type = ModelPartitionType::MODEL_DEF;
  95. partition_info.mem_size = buffer.GetSize();
  96. mem_offset += buffer.GetSize();
  97. }
  98. {
  99. ModelPartitionMemInfo &partition_info = partition_table->partition[1];
  100. partition_info.type = ModelPartitionType::WEIGHTS_DATA;
  101. partition_info.mem_offset = mem_offset;
  102. partition_info.mem_size = 0;
  103. }
  104. {
  105. ModelPartitionMemInfo &partition_info = partition_table->partition[2];
  106. partition_info.type = ModelPartitionType::TASK_INFO;
  107. partition_info.mem_offset = mem_offset;
  108. partition_info.mem_size = 0;
  109. }
  110. {
  111. ModelPartitionMemInfo &partition_info = partition_table->partition[3];
  112. partition_info.type = ModelPartitionType::TBE_KERNELS;
  113. partition_info.mem_offset = mem_offset;
  114. partition_info.mem_size = 0;
  115. }
  116. {
  117. ModelPartitionMemInfo &partition_info = partition_table->partition[4];
  118. partition_info.type = ModelPartitionType::CUST_AICPU_KERNELS;
  119. partition_info.mem_offset = mem_offset;
  120. partition_info.mem_size = 0;
  121. }
  122. EXPECT_TRUE(mem_offset < 512);
  123. ModelFileHeader *header = new (data.model_data) ModelFileHeader;
  124. header->length = mem_offset - sizeof(ModelFileHeader);
  125. data.model_len = mem_offset;
  126. }
  127. };
  128. class DModelListener : public ModelListener {
  129. public:
  130. DModelListener(){};
  131. uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index,
  132. uint32_t resultCode, std::vector<ge::Tensor> &outputs) { return 0; }
  133. };
  134. TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) {
  135. ModelManager mm;
  136. uint32_t model_id = 0;
  137. ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("graph");
  138. ge::GeRootModel model;
  139. EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
  140. model.SetRootGraph(root_graph);
  141. EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
  142. }
  143. TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) {
  144. ModelManager mm;
  145. uint32_t model_id = 0;
  146. ModelData data;
  147. // Load allow listener is null
  148. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID);
  149. }
  150. TEST_F(UtestModelManagerModelManager, case_load_model_len_too_short) {
  151. ModelManager mm;
  152. ModelData data;
  153. data.model_len = 10;
  154. data.model_data = (void *)&data;
  155. uint32_t model_id = 1;
  156. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  157. data.model_data = nullptr;
  158. }
  159. TEST_F(UtestModelManagerModelManager, case_load_model_len_not_match) {
  160. ModelManager mm;
  161. ModelData data;
  162. GenUnencryptModelData(data);
  163. data.model_len = sizeof(ModelFileHeader) + 1;
  164. uint32_t model_id = 1;
  165. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  166. delete[](uint8_t *) data.model_data;
  167. }
  168. TEST_F(UtestModelManagerModelManager, case_load_model_encypt_not_match) {
  169. ModelManager mm;
  170. ModelData data;
  171. GenUnencryptModelData(data);
  172. data.key = ENC_KEY;
  173. uint32_t model_id = 1;
  174. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  175. delete[](uint8_t *) data.model_data;
  176. }
  177. TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) {
  178. ModelManager mm;
  179. ModelData data;
  180. GenUnencryptModelData(data);
  181. ModelFileHeader *header = (ModelFileHeader *)data.model_data;
  182. header->is_encrypt = 255;
  183. uint32_t model_id = 1;
  184. // Error for: LoadModelPartitionTable: Invalid partition_table->num:0
  185. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID);
  186. delete[](uint8_t *) data.model_data;
  187. }
  188. TEST_F(UtestModelManagerModelManager, case_load_model_data_success) {
  189. ModelData data;
  190. LoadStandardModelData(data);
  191. uint32_t model_id = 1;
  192. ModelManager mm;
  193. EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), SUCCESS);
  194. delete[](uint8_t *) data.model_data;
  195. }
  196. /*
  197. shared_ptr<ModelListener> LabelCallBack(new DModelListener());
  198. // test HandleCommand
  199. TEST_F(UtestModelManagerModelManager, command_success1) {
  200. ModelManager manager;
  201. Command cmd;
  202. cmd.cmd_type = "INFERENCE";
  203. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  204. cmd.cmd_type = "NOT SUPPORT";
  205. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  206. }
  207. TEST_F(UtestModelManagerModelManager, command_success2) {
  208. ModelManager manager;
  209. Command cmd;
  210. cmd.cmd_type = "dump";
  211. cmd.cmd_params.push_back("status");
  212. cmd.cmd_params.push_back("on");
  213. cmd.cmd_params.push_back("model_name");
  214. cmd.cmd_params.push_back("test_model");
  215. cmd.cmd_params.push_back("path");
  216. cmd.cmd_params.push_back("/test");
  217. cmd.cmd_params.push_back("layer");
  218. cmd.cmd_params.push_back("layer1");
  219. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  220. }
  221. // test profile
  222. TEST_F(UtestModelManagerModelManager, command_profile_success) {
  223. ModelManager manager;
  224. Command cmd;
  225. cmd.cmd_type = "profile";
  226. cmd.cmd_params.push_back("ome");
  227. cmd.cmd_params.push_back("on");
  228. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  229. bool ome_profile_on = PropertiesManager::Instance().GetPropertyValue(OME_PROFILE) == "1";
  230. EXPECT_EQ(true, ome_profile_on);
  231. cmd.cmd_params.clear();
  232. cmd.cmd_params.push_back("ome");
  233. cmd.cmd_params.push_back("off");
  234. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  235. ome_profile_on = PropertiesManager::Instance().GetPropertyValue(OME_PROFILE) == "1";
  236. EXPECT_FALSE(ome_profile_on);
  237. cmd.cmd_params.clear();
  238. cmd.cmd_params.push_back("cce");
  239. cmd.cmd_params.push_back("on");
  240. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  241. bool cce_profile_on = PropertiesManager::Instance().GetPropertyValue(CCE_PROFILE) == "1";
  242. EXPECT_EQ(true, cce_profile_on);
  243. cmd.cmd_params.clear();
  244. cmd.cmd_params.push_back("cce");
  245. cmd.cmd_params.push_back("off");
  246. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  247. cce_profile_on = PropertiesManager::Instance().GetPropertyValue(CCE_PROFILE) == "1";
  248. EXPECT_FALSE(cce_profile_on);
  249. cmd.cmd_params.clear();
  250. cmd.cmd_params.push_back("runtime");
  251. cmd.cmd_params.push_back("on");
  252. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  253. bool rts_profile_on = PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE) == "1";
  254. EXPECT_EQ(true, rts_profile_on);
  255. cmd.cmd_params.clear();
  256. cmd.cmd_params.push_back("runtime");
  257. cmd.cmd_params.push_back("off");
  258. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  259. rts_profile_on = PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE) == "1";
  260. EXPECT_FALSE(rts_profile_on);
  261. cmd.cmd_params.clear();
  262. cmd.cmd_params.push_back("profiler_jobctx");
  263. cmd.cmd_params.push_back("jobctx");
  264. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  265. EXPECT_EQ("jobctx", PropertiesManager::Instance().GetPropertyValue(PROFILER_JOBCTX));
  266. cmd.cmd_params.clear();
  267. cmd.cmd_params.push_back("profiler_target_path");
  268. cmd.cmd_params.push_back("/test/target");
  269. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  270. EXPECT_EQ("/test/target", PropertiesManager::Instance().GetPropertyValue(PROFILER_TARGET_PATH));
  271. cmd.cmd_params.clear();
  272. cmd.cmd_params.push_back("RTS_PATH");
  273. cmd.cmd_params.push_back("/test/rts_path");
  274. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  275. EXPECT_EQ("/test/rts_path", PropertiesManager::Instance().GetPropertyValue(RTS_PROFILE_PATH));
  276. }
  277. // test acl profiling
  278. TEST_F(UtestModelManagerModelManager, command_profiling) {
  279. ModelManager manager;
  280. Command cmd;
  281. cmd.cmd_type = "profiling";
  282. cmd.cmd_params.push_back("config");
  283. cmd.cmd_params.push_back("on");
  284. EXPECT_EQ(SUCCESS, manager.HandleCommand(cmd));
  285. }
  286. TEST_F(UtestModelManagerModelManager, command_profile_failed) {
  287. ModelManager manager;
  288. Command cmd;
  289. cmd.cmd_type = "profile";
  290. cmd.cmd_params.push_back("ome");
  291. EXPECT_EQ(PARAM_INVALID, manager.HandleCommand(cmd));
  292. }
  293. // test Start
  294. TEST_F(UtestModelManagerModelManager, start_fail) {
  295. ModelManager manager;
  296. manager.model_map_[2] = nullptr;
  297. EXPECT_EQ(PARAM_INVALID, manager.Start(2));
  298. }
  299. // test GetMaxUsedMemory
  300. TEST_F(UtestModelManagerModelManager, get_max_used_memory_fail) {
  301. ModelManager manager;
  302. uint64_t max_size = 0;
  303. manager.model_map_[2] = nullptr;
  304. EXPECT_EQ(PARAM_INVALID, manager.GetMaxUsedMemory(2, max_size));
  305. }
  306. // test GetInputOutputDescInfo
  307. TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_fail) {
  308. ModelManager manager;
  309. manager.model_map_[2] = nullptr;
  310. vector<InputOutputDescInfo> input_shape;
  311. vector<InputOutputDescInfo> output_shape;
  312. EXPECT_EQ(PARAM_INVALID, manager.GetInputOutputDescInfo(2, input_shape, output_shape));
  313. }
  314. *//*
  315. // test GetInputOutputDescInfo fail
  316. TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) {
  317. ModelManager manager;
  318. manager.model_map_[2] = nullptr;
  319. vector<InputOutputDescInfo> input_shape;
  320. vector<InputOutputDescInfo> output_shape;
  321. EXPECT_EQ(PARAM_INVALID, manager.GetInputOutputDescInfoForZeroCopy(2, input_shape, output_shape));
  322. }
  323. *//*
  324. // test Stop
  325. TEST_F(UtestModelManagerModelManager, stop_fail) {
  326. ModelManager manager;
  327. manager.model_map_[2] = nullptr;
  328. EXPECT_EQ(PARAM_INVALID, manager.Stop(2));
  329. }
  330. // build input_data
  331. TEST_F(UtestModelManagerModelManager, check_data_len_success) {
  332. shared_ptr<ModelListener> g_label_call_back(new DModelListener());
  333. DavinciModel model(0, g_label_call_back);
  334. ModelManager model_manager;
  335. InputData input_data;
  336. DataBuffer data_buffer;
  337. data_buffer.data = new char[51200];
  338. data_buffer.length = 51200;
  339. input_data.index = 0;
  340. input_data.model_id = 1;
  341. input_data.blobs.push_back(data_buffer);
  342. delete[](char *) data_buffer.data;
  343. }
  344. // test LoadModeldef
  345. TEST_F(UtestModelManagerModelManager, destroy_aicpu_session) {
  346. ModelManager manager;
  347. manager.DestroyAicpuSession(0);
  348. manager.sess_ids_.insert(0);
  349. manager.DestroyAicpuSession(0);
  350. }*/
  351. // test DataInputTensor
  352. TEST_F(UtestModelManagerModelManager, test_data_input_tensor) {
  353. shared_ptr<ModelListener> g_label_call_back(nullptr);
  354. auto model = std::make_shared<DavinciModel>(0, g_label_call_back);
  355. ModelManager mm;
  356. uint32_t model_id = 1;
  357. mm.model_map_[1] = model;
  358. mm.hybrid_model_map_[1] = std::make_shared<hybrid::HybridDavinciModel>();
  359. ge::Tensor input_tensor;
  360. vector<ge::Tensor> inputs;
  361. inputs.emplace_back(input_tensor);
  362. auto ret = mm.DataInputTensor(model_id,inputs);
  363. EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null.
  364. }
  365. TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) {
  366. ModelManager mm;
  367. // cust_aicpu_so_ is empty.
  368. EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS);
  369. // deleteCustOp after Launch will deleted.
  370. uintptr_t resource_id = 1; // for rtCtxGetCurrent stub
  371. std::vector<char> kernel_bin(256);
  372. auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id];
  373. auto tbe_kernel = std::shared_ptr<OpKernelBin>(new OpKernelBin("deleteCustOp", std::move(kernel_bin)));
  374. auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel;
  375. EXPECT_FALSE(mm.cust_aicpu_so_.empty());
  376. EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS);
  377. EXPECT_TRUE(mm.cust_aicpu_so_.empty());
  378. }
  379. shared_ptr<ModelListener> listerner(new DModelListener());
  380. TEST_F(UtestModelManagerModelManager, test_load_model_online) {
  381. ModelManager mm;
  382. uint32_t model_id = 1;
  383. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  384. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
  385. auto &profiling_manager = ge::ProfilingManager::Instance();
  386. profiling_manager.SetSubscribeInfo(0, model_id, true);
  387. Status ret = mm.LoadModelOnline(model_id, ge_root_model, listerner);
  388. profiling_manager.CleanSubscribeInfo();
  389. }
  390. TEST_F(UtestModelManagerModelManager, command_profiling) {
  391. ModelManager manager;
  392. uint32_t model_id = 1;
  393. Command cmd;
  394. auto model = std::make_shared<DavinciModel>(1, listerner);
  395. model->SetId(model_id);
  396. cmd.cmd_params.push_back("modelId");
  397. cmd.cmd_params.push_back(to_string(model_id));
  398. auto &profiling_manager = ge::ProfilingManager::Instance();
  399. profiling_manager.SetSubscribeInfo(0, model_id, true);
  400. Status ret = manager.HandleProfModelUnsubscribeCommand(cmd);
  401. profiling_manager.CleanSubscribeInfo();
  402. }
  403. TEST_F(UtestModelManagerModelManager, Cal_follow_stream_sum) {
  404. std::multimap<int64_t, int64_t> hccl_stream_map = {{1,10}, {1,20}, {2,10}, {2,5}};
  405. int64_t result = ModelUtils::CalFollowStreamSum(hccl_stream_map);
  406. EXPECT_EQ(result, 30);
  407. }
  408. TEST_F(UtestModelManagerModelManager, get_max_stream_and_event) {
  409. ModelManager mm;
  410. auto model1 = std::make_shared<DavinciModel> (1, nullptr);
  411. auto model2 = std::make_shared<DavinciModel> (2, nullptr);
  412. rtStream_t stream = nullptr;
  413. rtStream_t stream2 = nullptr;
  414. rtStream_t stream3 = nullptr;
  415. rtStream_t stream4 = nullptr;
  416. rtEvent_t event = nullptr;
  417. rtEvent_t event2 = nullptr;
  418. rtEvent_t event3 = nullptr;
  419. model1->stream_list_ = {stream, stream2, stream3, stream4};
  420. model1->event_list_ = {event, event2};
  421. model2->stream_list_ = {stream, stream2};
  422. model2->event_list_ = {event, event2, event3};
  423. mm.InsertModel(1, model1);
  424. mm.InsertModel(2, model2);
  425. uint32_t max_stream_model;
  426. uint32_t max_event_model;
  427. mm.GetMaxStreamAndEventModel(max_stream_model, max_event_model);
  428. EXPECT_EQ(max_stream_model, 1);
  429. EXPECT_EQ(max_event_model, 2);
  430. int64_t free_stream;
  431. int64_t free_event;
  432. Status ret = mm.GetFreeStream(free_stream);
  433. EXPECT_EQ(ret, SUCCESS);
  434. }
  435. TEST_F(UtestModelManagerModelManager, release_resource_stream) {
  436. ModelManager mm;
  437. auto model1 = std::make_shared<DavinciModel> (1, nullptr);
  438. auto model2 = std::make_shared<DavinciModel> (2, nullptr);
  439. rtStream_t stream = nullptr;
  440. rtStream_t stream2 = nullptr;
  441. rtStream_t stream3 = nullptr;
  442. rtStream_t stream4 = nullptr;
  443. rtEvent_t event = nullptr;
  444. rtEvent_t event2 = nullptr;
  445. rtEvent_t event3 = nullptr;
  446. model1->stream_list_ = {stream, stream2, stream3, stream4};
  447. model1->event_list_ = {event, event2};
  448. model2->stream_list_ = {stream, stream2};
  449. model2->event_list_ = {event, event2, event3};
  450. mm.InsertModel(1, model1);
  451. mm.InsertModel(2, model2);
  452. string kind = "stream";
  453. Status ret = mm.ReleaseResource(110, 109, kind);
  454. EXPECT_EQ(ret, SUCCESS);
  455. string kind2 = "event";
  456. Status ret2 = mm.ReleaseResource(110, 109, kind2);
  457. EXPECT_EQ(ret2, SUCCESS);
  458. }
  459. TEST_F(UtestModelManagerModelManager, check_stream_and_event_resource) {
  460. ModelManager mm;
  461. auto ge_model = make_shared<GeModel>();
  462. Status ret = mm.CheckAndReleaseStreamEventResource(ge_model, 1);
  463. EXPECT_EQ(ret, FAILED);
  464. }
  465. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示