diff --git a/tests/ut/ge/graph/load/model_manager_unittest.cc b/tests/ut/ge/graph/load/model_manager_unittest.cc index 142b5566..81d88ecd 100644 --- a/tests/ut/ge/graph/load/model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/model_manager_unittest.cc @@ -85,14 +85,64 @@ class UtestModelManagerModelManager : public testing::Test { } void LoadStandardModelData(ModelData &data) { - static const std::string STANDARD_MODEL_DATA_PATH = - "llt/framework/domi/ut/ome/test/data/standard_partition_model.txt"; - proto::ModelDef model_def; - ReadProtoFromText(STANDARD_MODEL_DATA_PATH.c_str(), &model_def); - - data.model_len = model_def.ByteSizeLong(); + data.model_len = 512; data.model_data = new uint8_t[data.model_len]; - model_def.SerializePartialToArray(data.model_data, data.model_len); + uint8_t *model_data = reinterpret_cast(data.model_data); + + uint32_t mem_offset = sizeof(ModelFileHeader); + ModelPartitionTable *partition_table = reinterpret_cast(model_data + mem_offset); + partition_table->num = PARTITION_SIZE; + + mem_offset += sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * 5; + { + Model model; + ComputeGraphPtr graph = make_shared("default"); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + model.SetVersion(123); + + Buffer buffer; + model.Save(buffer); + EXPECT_TRUE(mem_offset + buffer.GetSize() < 512); + memcpy(model_data + mem_offset, buffer.GetData(), buffer.GetSize()); + + ModelPartitionMemInfo &partition_info = partition_table->partition[0]; + partition_info.type = ModelPartitionType::MODEL_DEF; + partition_info.mem_size = buffer.GetSize(); + mem_offset += buffer.GetSize(); + } + + { + ModelPartitionMemInfo &partition_info = partition_table->partition[1]; + partition_info.type = ModelPartitionType::WEIGHTS_DATA; + partition_info.mem_offset = mem_offset; + partition_info.mem_size = 0; + } + + { + ModelPartitionMemInfo &partition_info = partition_table->partition[2]; + partition_info.type = ModelPartitionType::TASK_INFO; + partition_info.mem_offset = mem_offset; + partition_info.mem_size = 0; + } + + { + ModelPartitionMemInfo &partition_info = partition_table->partition[3]; + partition_info.type = ModelPartitionType::TBE_KERNELS; + partition_info.mem_offset = mem_offset; + partition_info.mem_size = 0; + } + + { + ModelPartitionMemInfo &partition_info = partition_table->partition[4]; + partition_info.type = ModelPartitionType::CUST_AICPU_KERNELS; + partition_info.mem_offset = mem_offset; + partition_info.mem_size = 0; + } + + EXPECT_TRUE(mem_offset < 512); + ModelFileHeader *header = new (data.model_data) ModelFileHeader; + header->length = mem_offset - sizeof(ModelFileHeader); + data.model_len = mem_offset; } }; @@ -151,6 +201,17 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID); delete[](uint8_t *) data.model_data; } + +TEST_F(UtestModelManagerModelManager, case_load_model_data_success) { + ModelData data; + LoadStandardModelData(data); + + uint32_t model_id = 1; + ModelManager mm; + EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), SUCCESS); + delete[](uint8_t *) data.model_data; +} + /* shared_ptr LabelCallBack(new DModelListener());