@@ -154,47 +154,75 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi | |||||
Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | ||||
ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
const std::vector<ModelPartition> &partitionDatas, | |||||
const std::vector<ModelPartition> &partition_datas, | |||||
ge::ModelBufferData &model) { | ge::ModelBufferData &model) { | ||||
GE_CHK_BOOL_RET_STATUS( | |||||
!partitionDatas.empty() && model_partition_table.num != 0 && model_partition_table.num == partitionDatas.size(), | |||||
FAILED, "Invalid param:partition data size is (%u), model_partition_table.num is (%zu).", | |||||
model_partition_table.num, partitionDatas.size()); | |||||
uint32_t model_header_size = sizeof(ModelFileHeader); | |||||
uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(model_partition_table)); | |||||
uint32_t total_size = model_header_size + table_size; | |||||
for (const auto &partitionData : partitionDatas) { | |||||
auto ret = ge::CheckUint32AddOverflow(total_size, partitionData.size); | |||||
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "add uint32 overflow!"); | |||||
total_size = total_size + partitionData.size; | |||||
const vector<ModelPartitionTable *> model_partition_tables = { &model_partition_table }; | |||||
const std::vector<std::vector<ModelPartition>> all_partition_datas = { partition_datas }; | |||||
return SaveToBuffWithFileHeader(file_header, model_partition_tables, all_partition_datas, model); | |||||
} | |||||
Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
const vector<ModelPartitionTable *> &model_partition_tables, | |||||
const std::vector<std::vector<ModelPartition>> &all_partition_datas, | |||||
ge::ModelBufferData &model) { | |||||
GE_CHK_BOOL_RET_STATUS(model_partition_tables.size() == all_partition_datas.size(), PARAM_INVALID, | |||||
"Model table size %zu does not match partition size %zu.", | |||||
model_partition_tables.size(), all_partition_datas.size()); | |||||
for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
auto &cur_partiton_data = all_partition_datas[index]; | |||||
auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0 | |||||
&& cur_model_partition_table.num == cur_partiton_data.size(), FAILED, | |||||
"Invalid param: partition data size is (%zu), model_partition_table.num is (%u).", | |||||
cur_partiton_data.size(), cur_model_partition_table.num); | |||||
} | } | ||||
uint64_t model_header_size = sizeof(ModelFileHeader); | |||||
uint64_t total_size = model_header_size; | |||||
for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
total_size += static_cast<uint64_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_model_partition_table)); | |||||
auto &cur_partition_data = all_partition_datas[index]; | |||||
for (const auto &partition_data : cur_partition_data) { | |||||
auto ret = ge::CheckUint64AddOverflow(total_size, partition_data.size); | |||||
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "Add uint64 overflow!"); | |||||
total_size += partition_data.size; | |||||
} | |||||
} | |||||
// save to buff | |||||
auto buff = reinterpret_cast<uint8_t *>(malloc(total_size)); | auto buff = reinterpret_cast<uint8_t *>(malloc(total_size)); | ||||
GE_CHK_BOOL_RET_STATUS(buff != nullptr, FAILED, "malloc failed!"); | |||||
GE_PRINT_DYNAMIC_MEMORY(malloc, "file buffer.", total_size) | |||||
GE_CHK_BOOL_RET_STATUS(buff != nullptr, FAILED, "Malloc failed!"); | |||||
GE_PRINT_DYNAMIC_MEMORY(malloc, "File buffer.", total_size) | |||||
model.data.reset(buff, [](uint8_t *buff) { | model.data.reset(buff, [](uint8_t *buff) { | ||||
GELOGD("Free online model memory."); | GELOGD("Free online model memory."); | ||||
free(buff); | free(buff); | ||||
buff = nullptr; | buff = nullptr; | ||||
}); | }); | ||||
model.length = total_size; | model.length = total_size; | ||||
uint32_t left_space = total_size; | |||||
auto ret_mem1 = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<ModelFileHeader *>(&file_header)), | |||||
model_header_size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem1 == 0, FAILED, "memcpy_s failed!"); | |||||
uint64_t left_space = total_size; | |||||
auto ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<ModelFileHeader *>(&file_header)), | |||||
model_header_size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
buff += model_header_size; | buff += model_header_size; | ||||
left_space -= model_header_size; | left_space -= model_header_size; | ||||
auto ret_mem2 = memcpy_s(buff, left_space, reinterpret_cast<void *>(&model_partition_table), table_size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem2 == 0, FAILED, "memcpy_s failed!"); | |||||
buff += table_size; | |||||
left_space -= table_size; | |||||
for (const auto &partitionData : partitionDatas) { | |||||
auto ret_mem3 = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<uint8_t *>(partitionData.data)), | |||||
partitionData.size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem3 == 0, FAILED, "memcpy failed!"); | |||||
buff += partitionData.size; | |||||
left_space -= partitionData.size; | |||||
for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
auto &cur_tabel = *model_partition_tables[index]; | |||||
uint64_t table_size = static_cast<uint64_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel)); | |||||
ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(&cur_tabel), table_size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
buff += table_size; | |||||
left_space -= table_size; | |||||
auto &cur_partition_data = all_partition_datas[index]; | |||||
for (const auto &partition_data : cur_partition_data) { | |||||
ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<uint8_t *>(partition_data.data)), | |||||
partition_data.size); | |||||
GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
buff += partition_data.size; | |||||
left_space -= partition_data.size; | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -80,9 +80,14 @@ class FileSaver { | |||||
static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | ||||
ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
const std::vector<ModelPartition> &partitionDatas, | |||||
const std::vector<ModelPartition> &partition_datas, | |||||
ge::ModelBufferData& model); | ge::ModelBufferData& model); | ||||
static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
const std::vector<ModelPartitionTable *> &model_partition_tables, | |||||
const std::vector<std::vector<ModelPartition>> &all_partition_datas, | |||||
ge::ModelBufferData &model); | |||||
static Status SaveToFile(const string &file_path, const void *data, int len); | static Status SaveToFile(const string &file_path, const void *data, int len); | ||||
protected: | protected: | ||||
@@ -113,8 +118,8 @@ class FileSaver { | |||||
ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
const std::vector<ModelPartition> &partition_datas); | const std::vector<ModelPartition> &partition_datas); | ||||
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | ||||
vector<ModelPartitionTable *> &model_partition_tables, | |||||
const vector<vector<ModelPartition>> &all_partition_datas); | |||||
std::vector<ModelPartitionTable *> &model_partition_tables, | |||||
const std::vector<std::vector<ModelPartition>> &all_partition_datas); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_COMMON_AUTH_FILE_SAVER_H_ | #endif // GE_COMMON_AUTH_FILE_SAVER_H_ |
@@ -416,8 +416,7 @@ Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char * | |||||
if (is_offline) { | if (is_offline) { | ||||
ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions); | ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions); | ||||
} else { | } else { | ||||
GELOGW("do not support save ge root model to buff now"); | |||||
return FAILED; | |||||
ret = FileSaver::SaveToBuffWithFileHeader(model_header_, model_partition_tabels, all_model_partitions, model); | |||||
} | } | ||||
if (ret == SUCCESS) { | if (ret == SUCCESS) { | ||||
GELOGD("Save model success without encrypt."); | GELOGD("Save model success without encrypt."); | ||||
@@ -17,6 +17,7 @@ | |||||
#include "hybrid_model_executor.h" | #include "hybrid_model_executor.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/runtime_inference_context.h" | #include "graph/runtime_inference_context.h" | ||||
#include "graph/utils/tensor_utils.h" | |||||
#include "common/dump/dump_manager.h" | #include "common/dump/dump_manager.h" | ||||
#include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
@@ -48,6 +49,11 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
auto root_graph_item = model_->GetRootGraphItem(); | auto root_graph_item = model_->GetRootGraphItem(); | ||||
GE_CHECK_NOTNULL(root_graph_item); | GE_CHECK_NOTNULL(root_graph_item); | ||||
if (root_graph_item->IsDynamic()) { | |||||
GE_CHK_STATUS_RET(CheckInputShapeByShapeRange(root_graph_item, args), | |||||
"[%s] check input node shape by shape range failed.", | |||||
root_graph_item->GetName().c_str()); | |||||
} | |||||
if (context_.global_step != nullptr) { | if (context_.global_step != nullptr) { | ||||
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | ||||
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | ||||
@@ -151,5 +157,55 @@ Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context | |||||
GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_item, | |||||
HybridModelExecutor::ExecuteArgs &args) { | |||||
GE_CHECK_NOTNULL(graph_item); | |||||
auto input_nodes = graph_item->GetInputNodes(); | |||||
if (args.input_desc.size() < input_nodes.size()) { | |||||
REPORT_INNER_ERROR("E19999", "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.", | |||||
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size()); | |||||
GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.", | |||||
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
for (size_t i = 0; i < input_nodes.size(); ++i) { | |||||
auto &input_node = input_nodes[i]; | |||||
if (input_node == nullptr) { | |||||
GELOGD("[%s] Input[%zu] is not needed by graph, skip it.", graph_item->GetName().c_str(), i); | |||||
continue; | |||||
} | |||||
GeTensorDescPtr model_input_desc = input_node->MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(model_input_desc); | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
if (model_input_desc->GetShapeRange(shape_range) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i); | |||||
GELOGE(INTERNAL_ERROR, "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (shape_range.empty()) { | |||||
GELOGD("[%s] Input[%zu] shape is not needed to check by shape range, skip it.", graph_item->GetName().c_str(), i); | |||||
continue; | |||||
} | |||||
ConstGeTensorDescPtr args_tensor_desc = args.input_desc[i]; | |||||
GE_CHECK_NOTNULL(args_tensor_desc); | |||||
GeShape shape = args_tensor_desc->GetShape(); | |||||
if (shape.IsUnknownShape()) { | |||||
REPORT_INNER_ERROR("E19999", "[%s] Input desc shape [%zu] designed by user must be static.", | |||||
graph_item->GetName().c_str(), i); | |||||
GELOGE(INTERNAL_ERROR, "[%s] Input desc shape [%zu] designed by user must be static.", | |||||
graph_item->GetName().c_str(), i); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Check][InputShape] [%s] check input [%zu] shape failed by shape range.", | |||||
graph_item->GetName().c_str(), i); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -52,6 +52,7 @@ class HybridModelExecutor { | |||||
Status Cleanup(); | Status Cleanup(); | ||||
Status InitExecutionContext(); | Status InitExecutionContext(); | ||||
static Status ResetExecutionContext(GraphExecutionContext &context); | static Status ResetExecutionContext(GraphExecutionContext &context); | ||||
static Status CheckInputShapeByShapeRange(const GraphItem *graph_item, HybridModelExecutor::ExecuteArgs &args); | |||||
HybridModel *model_; | HybridModel *model_; | ||||
uint32_t device_id_; | uint32_t device_id_; | ||||
@@ -46,27 +46,6 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
} | } | ||||
} | } | ||||
Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, | |||||
const GeTensorDesc &target_tensor_desc) const { | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Get shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (shape_range.empty()) { | |||||
GELOGD("Shape range is empty, no need to check input shape."); | |||||
return SUCCESS; | |||||
} | |||||
GeShape target_shape = target_tensor_desc.GetShape(); | |||||
if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "Check shape by shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | ||||
if (node_item.IsInputShapeStatic(idx)) { | if (node_item.IsInputShapeStatic(idx)) { | ||||
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | ||||
@@ -58,8 +58,6 @@ struct ShapeInferenceState { | |||||
const vector<GeTensorDesc> &GetOutputTensorDesc() const; | const vector<GeTensorDesc> &GetOutputTensorDesc() const; | ||||
Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const; | |||||
const NodeItem &node_item; | const NodeItem &node_item; | ||||
private: | private: | ||||
@@ -59,7 +59,7 @@ const char *const kKeepDtypeError = "file not found"; | |||||
const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | ||||
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | ||||
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | ||||
const char *const kInputShapeRangeSample2 = "\"[]\""; | |||||
const char *const kInputShapeRangeSample2 = "\"[1~20]\""; | |||||
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | ||||
vector<string> SplitInputShape(const std::string &input_shape) { | vector<string> SplitInputShape(const std::string &input_shape) { | ||||
@@ -301,8 +301,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_ | |||||
} | } | ||||
} | } | ||||
bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') && | |||||
(square_brackets.size() == kSquareBracketsSize); | |||||
bool is_square_brackets = (square_brackets.size() == kSquareBracketsSize) && | |||||
(square_brackets[0] == '[') && (square_brackets[1] == ']'); | |||||
if (!is_square_brackets) { | if (!is_square_brackets) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | ||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2}); | {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2}); | ||||
@@ -503,8 +503,17 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTe | |||||
string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); | string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); | ||||
GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str()); | GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str()); | ||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
if (tensor.GetShapeRange(shape_range) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "[Creat][Input] Data op [%s] get shape range failed.", data_op_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
ge::GeTensor inputTensor; | ge::GeTensor inputTensor; | ||||
ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type); | ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type); | ||||
if (desc.SetShapeRange(shape_range) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "[Creat][Input] Data op [%s] set shape range failed.", data_op_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
inputTensor.SetTensorDesc(desc); | inputTensor.SetTensorDesc(desc); | ||||
inputs.push_back(inputTensor); | inputs.push_back(inputTensor); | ||||
} | } | ||||
@@ -770,6 +770,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"common/format_transfer_fracz_nhwc_unittest.cc" | "common/format_transfer_fracz_nhwc_unittest.cc" | ||||
"common/format_transfer_fracz_hwcn_unittest.cc" | "common/format_transfer_fracz_hwcn_unittest.cc" | ||||
"common/ge_format_util_unittest.cc" | "common/ge_format_util_unittest.cc" | ||||
"common/ge_auth_file_saver_unittest.cc" | |||||
"graph/variable_accelerate_ctrl_unittest.cc" | "graph/variable_accelerate_ctrl_unittest.cc" | ||||
"graph/build/logical_stream_allocator_unittest.cc" | "graph/build/logical_stream_allocator_unittest.cc" | ||||
"graph/build/model_builder_unittest.cc" | "graph/build/model_builder_unittest.cc" | ||||
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/auth/file_saver.h" | |||||
namespace ge { | |||||
class UTEST_file_saver : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UTEST_file_saver, save_model_data_to_buff_success) { | |||||
ModelFileHeader file_header; | |||||
std::vector<char> data; | |||||
data.resize(sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo), 0); | |||||
ModelPartitionTable *partition_table = reinterpret_cast<ModelPartitionTable*>(data.data()); | |||||
partition_table->num = 1; | |||||
partition_table->partition[0] = { MODEL_DEF, 0, 12 }; | |||||
std::vector<ModelPartitionTable *> partition_tables; | |||||
partition_tables.push_back(partition_table); | |||||
auto buff = reinterpret_cast<uint8_t *>(malloc(12)); | |||||
struct ge::ModelPartition model_partition; | |||||
model_partition.type = MODEL_DEF; | |||||
model_partition.data = buff; | |||||
model_partition.size = 12; | |||||
std::vector<ModelPartition> model_partitions = { model_partition }; | |||||
std::vector<std::vector<ModelPartition>> all_partition_datas = { model_partitions }; | |||||
ge::ModelBufferData model; | |||||
Status ret = FileSaver::SaveToBuffWithFileHeader(file_header, partition_tables, all_partition_datas, model); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
free(buff); | |||||
buff = nullptr; | |||||
model_partition.data = nullptr; | |||||
} | |||||
} // namespace ge |
@@ -425,3 +425,44 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||||
ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS); | ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS); | ||||
ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims()); | ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims()); | ||||
} | } | ||||
TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) { | |||||
HybridModelExecutor::ExecuteArgs args; | |||||
GeTensorDescPtr ge_tensor = make_shared<GeTensorDesc>(GeTensorDesc()); | |||||
vector<int64_t> dim = {2 , 3}; | |||||
ge_tensor->SetShape(GeShape(dim)); | |||||
args.input_desc.push_back(ge_tensor); | |||||
// create node | |||||
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("God"); | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>("data", DATA); | |||||
GeTensorDesc tensor_desc(GeShape({2, 3})); | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range({std::pair<int64_t, int64_t>(1, 3), | |||||
std::pair<int64_t, int64_t>(2, 4)}); | |||||
tensor_desc.SetShapeRange(shape_range); | |||||
op_desc->AddInputDesc(tensor_desc); | |||||
op_desc->AddOutputDesc(tensor_desc); | |||||
NodePtr node = graph->AddNode(op_desc); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
NodeItem::Create(node, new_node); | |||||
GraphItem graph_item; | |||||
graph_item.input_nodes_.emplace_back(new_node.get()); | |||||
Status ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args); | |||||
ASSERT_EQ(ret, ge::SUCCESS); | |||||
HybridModelExecutor::ExecuteArgs args1; | |||||
ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1); | |||||
ASSERT_EQ(ret, ge::INTERNAL_ERROR); | |||||
HybridModelExecutor::ExecuteArgs args2; | |||||
GeTensorDescPtr ge_tensor2 = make_shared<GeTensorDesc>(GeTensorDesc()); | |||||
vector<int64_t> dim2 = {-1 , 3}; | |||||
ge_tensor2->SetShape(GeShape(dim2)); | |||||
args2.input_desc.push_back(ge_tensor2); | |||||
ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1); | |||||
ASSERT_EQ(ret, ge::INTERNAL_ERROR); | |||||
} |