diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 487b55b9..9f37e7d5 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -444,31 +444,20 @@ Status HybridModelAsyncExecutor::Execute(const std::vector &inputs, TensorValue tensor_value(inputs[i].data, inputs[i].length); args.inputs[i] = tensor_value; } + for (size_t i = 0; i < outputs.size(); ++i) { + args.outputs.emplace_back(TensorValue(outputs[i].data, outputs[i].length)); + } + // usr must designate input tensorDesc when input shape is dynamic in inference + for (size_t i = 0; i < input_desc.size(); ++i) { + ConstGeTensorDescPtr tensor_desc_ptr = MakeShared(input_desc[i]); + args.input_desc.emplace_back(tensor_desc_ptr); + } + GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); for (const auto &output_tensor_desc : args.output_desc) { output_desc.emplace_back(*output_tensor_desc); } - for (size_t i = 0; i < args.outputs.size(); ++i) { - int64_t output_real_size = 0; - ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc[i], output_real_size); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(FAILED, "Get tensor size in bytes failed."); - return FAILED; - } - if (output_real_size > 0) { - if (outputs[i].length < static_cast(output_real_size)) { - GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by " - "user should be greater than or equal to the real size of output[%ld]", - i, outputs[i].length, output_real_size); - return FAILED; - } - GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, - RT_MEMCPY_DEVICE_TO_DEVICE)); - } - outputs[i].length = output_real_size; - } - return SUCCESS; } diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 3ec967d3..14284c0f 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -44,6 +44,27 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( } } +Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, + const GeTensorDesc &target_tensor_desc) const { + std::vector> shape_range; + if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get shape range failed."); + return PARAM_INVALID; + } + if (shape_range.empty()) { + GELOGD("Shape range is empty, no need to check input shape."); + return SUCCESS; + } + + GeShape target_shape = target_tensor_desc.GetShape(); + if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) { + GELOGE(PARAM_INVALID, "Check shape by shape range failed."); + return PARAM_INVALID; + } + + return SUCCESS; +} + Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { if (node_item.IsInputShapeStatic(idx)) { GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", @@ -54,19 +75,31 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target return SUCCESS; } + std::lock_guard lk(mu_); + auto &input_desc = input_tensor_desc[idx]; + if (CheckInputShapeByShapeRange(input_desc, target) != SUCCESS) { + GELOGE(FAILED, "[%s] Check input shape by shape range failed.", node_item.NodeName().c_str()); + return FAILED; + } + GeShape shape = target.GetShape(); + input_desc.SetShape(shape); + input_desc.SetOriginShape(target.GetOriginShape()); int64_t tensor_size = -1; (void) TensorUtils::GetSize(target, tensor_size); + if (tensor_size <= 0) { + Format format = input_desc.GetFormat(); + DataType data_type = input_desc.GetDataType(); + if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { + GELOGE(FAILED, "[%s] Calculate tensor memory size failed.", node_item.NodeName().c_str()); + return FAILED; + } + } GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", node_item.NodeName().c_str(), idx, - target.GetShape().ToString().c_str(), + shape.ToString().c_str(), target.GetOriginShape().ToString().c_str(), tensor_size); - - std::lock_guard lk(mu_); - auto &input_desc = input_tensor_desc[idx]; - input_desc.SetShape(target.GetShape()); - input_desc.SetOriginShape(target.GetOriginShape()); (void) TensorUtils::SetSize(input_desc, tensor_size); if (--num_pending_shapes_ <= 0) { ready_cv_.notify_all(); diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 84a52abd..2da4184d 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -58,6 +58,8 @@ struct ShapeInferenceState { const vector &GetOutputTensorDesc() const; + Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const; + const NodeItem &node_item; private: diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index 77c9be2b..a0217d52 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -225,23 +225,19 @@ Status HybridModel::GetInputDescInfo(vector &input_desc, st GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0)); Format format = op_desc->GetInputDescPtr(0)->GetFormat(); - input.data_type = op_desc->GetInputDescPtr(0)->GetDataType(); + DataType data_type = op_desc->GetInputDescPtr(0)->GetDataType(); + input.data_type = static_cast(data_type); input.name = op_desc->GetName(); - - int64_t input_size = 0; - GE_CHK_STATUS_RET(TensorUtils::GetSize(*op_desc->GetInputDescPtr(0), input_size), "get input size failed."); - - // support dynamic shape - if (input_size < 0) { - GELOGD("dynamic shape scene, input size is unknown. " - "format=%d, data_type=%d, input_size=%ld", - format, input.data_type, input_size); - input_size = kMemSizeUnknownShape; // -1 + GeShape shape = op_desc->GetInputDescPtr(0)->GetShape(); + int64_t tensor_size = 0; + if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Calculate tensor mem size failed."); + return FAILED; } - - // not support dynamic shape input for now, so input_size here will be not less than zero. - input.size = input_size; - + if (tensor_size == kMemSizeUnknownShape) { + tensor_size = 0; + } + input.size = static_cast(tensor_size); CreateInputDimsInfo(op_desc, input); formats.push_back(format); @@ -284,6 +280,9 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, } int64_t tensor_size = 0; (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); + if (tensor_size == kMemSizeUnknownShape) { + tensor_size = 0; + } output_desc_info.size = static_cast(tensor_size); output_desc_info.data_type = output_desc->GetDataType(); } diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc index 42a78dde..ff156c75 100755 --- a/ge/ir_build/atc_ir_common.cc +++ b/ge/ir_build/atc_ir_common.cc @@ -19,7 +19,9 @@ #include "framework/common/string_util.h" #include "framework/common/types.h" #include "framework/common/util.h" +#include "graph/compute_graph.h" #include "graph/utils/type_utils.h" +#include "graph/utils/tensor_utils.h" using std::pair; using std::string; @@ -52,6 +54,11 @@ const char *const kCompressWeightError = "it must be appointed when appoint para const char *const kSelectImplmodeError = "only support high_performance, high_precision"; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; const char *const kKeepDtypeError = "file not found"; +const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; +const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; +const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; +const char *const kInputShapeRangeSample2 = "\"[]\""; +const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; vector SplitInputShape(const std::string &input_shape) { vector shape_pair_vec; @@ -257,8 +264,132 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims return true; } +bool StringToLongNoThrow(const string &str, long &val) { + try { + val = std::stol(str); + return true; + } catch (const std::invalid_argument) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); + GELOGE(PARAM_INVALID, + "Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", + str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); + } catch (const std::out_of_range) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); + GELOGE(PARAM_INVALID, + "Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", + str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); + } + return false; +} + +bool ParseSingleShapeRange(std::string &shape_range, vector> &shape_range_vec) { + vector square_brackets; + for (auto ch : shape_range) { + if (ch == '[' || ch == ']') { + square_brackets.push_back(ch); + } + } + + bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') && (square_brackets.size() == 2); + if (!is_square_brackets) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2}); + GELOGE(PARAM_INVALID, + "Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample2); + return false; + } + // trim start bytes, after that, single input should be "1~20,3,3~6,-1" + if (ge::StringUtils::StartWith(shape_range, "[")) { + shape_range = shape_range.substr(1, shape_range.size() - 1); + } + // parse shape_range of single input. eg. "1~20,3,3~6,-1" + vector dim_range_set = ge::StringUtils::Split(shape_range, ','); + for (const auto &range_pair_str : dim_range_set) { + vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); + pair range_pair; + if (range_pair_set.size() == 1) { + long range_value = 0; + if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { + return false; + } + if (range_value < 0) { + range_pair = std::make_pair(1, range_value); + } else { + range_pair = std::make_pair(range_value, range_value); + } + } else if (range_pair_set.size() == 2) { + // unknown dim, should get range. + long range_left = 0; + if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { + return false; + } + long range_right = 0; + if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { + return false; + } + if (range_left < 0 || (range_right < 0)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); + GELOGE(PARAM_INVALID, + "Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); + return false; + } + range_pair = std::make_pair(range_left, range_right); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); + GELOGE(PARAM_INVALID, + "Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); + return false; + } + shape_range_vec.emplace_back(range_pair); + } + return true; +} + +bool ParseInputShapeRange(const std::string &shape_range, + std::map>> &shape_range_map) { + GELOGD("Input shape range %s", shape_range.c_str()); + + vector shape_range_vec = StringUtils::Split(shape_range, ';'); + const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2; + for (const auto &shape_range_item : shape_range_vec) { + vector shape_range_pair_vec = SplitInputShape(shape_range_item); + if (shape_range_pair_vec.size() != DEFAULT_SHAPE_RANGE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, + {shape_range, kSplitError1, kInputShapeRangeSample1}); + GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed, " + "reason: %s, correct sample is %s.", shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); + return false; + } + if (shape_range_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, + {shape_range, kEmptyError, kInputShapeRangeSample1}); + GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed," + "reason: %s, correct sample is %s.", shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); + return false; + } + + string shape_range_str = shape_range_pair_vec[1]; + vector> shape_range_val; + if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { + GELOGE(PARAM_INVALID, "Parse single shape range %s error.", shape_range_str.c_str()); + return false; + } + shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); + } + + return true; +} + Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, - const string input_shape, const string input_format, bool &is_dynamic_input) { + const string input_shape, const string input_shape_range, const string input_format, + bool &is_dynamic_input) { int32_t param_size = static_cast(!dynamic_batch_size.empty()) + static_cast(!dynamic_image_size.empty()) + static_cast(!dynamic_dims.empty()); if (param_size > 1) { @@ -269,6 +400,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i } if (param_size == 0) { + if (!input_shape_range.empty()) { + std::map>> shape_range_map; + if(!ParseInputShapeRange(input_shape_range, shape_range_map)) { + GELOGE(ge::PARAM_INVALID, "Failed to parse input shape range: %s", input_shape_range.c_str()); + return ge::PARAM_INVALID; + } + } return ge::SUCCESS; } @@ -546,4 +684,91 @@ void EraseEndSemicolon(string ¶m) { param.erase(param.end() - 1); } } + +Status UpdateDataOpShape(const OpDescPtr &op, map> &shape_map) { + GE_CHECK_NOTNULL(op); + if (shape_map.empty()) { + GELOGI("Shape map of data op [%s] is empty, no need to update.", op->GetName().c_str()); + return SUCCESS; + } + + auto tensor_input = op->MutableInputDesc(0); + auto tensor_output = op->MutableOutputDesc(0); + GE_CHECK_NOTNULL(tensor_input); + GE_CHECK_NOTNULL(tensor_output); + string data_op_name = op->GetName(); + auto iter = shape_map.find(data_op_name); + if (iter != shape_map.end()) { + tensor_input->SetShape(ge::GeShape(iter->second)); + tensor_output->SetShape(ge::GeShape(iter->second)); + GELOGI("Update input [%s] shape info", data_op_name.c_str()); + } else { + GELOGI("No need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); + } + + return SUCCESS; +} + +Status UpdateDataOpShapeRange(const OpDescPtr &op, + map>> &shape_range_map) { + GE_CHECK_NOTNULL(op); + if (shape_range_map.empty()) { + GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); + return SUCCESS; + } + + auto tensor_input = op->MutableInputDesc(0); + GE_CHECK_NOTNULL(tensor_input); + string data_op_name = op->GetName(); + auto origin_shape = tensor_input->GetShape(); + auto iter = shape_range_map.find(data_op_name); + if (iter != shape_range_map.end()) { + auto cur_shape_range = iter->second; + if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { + GELOGE(PARAM_INVALID, "[%s] Check shape by shape range failed.", op->GetName().c_str()); + return PARAM_INVALID; + } + for (size_t idx = 0; idx < cur_shape_range.size(); idx++) { + auto left_range = cur_shape_range[idx].first; + auto right_range = cur_shape_range[idx].second; + if (left_range != right_range) { + origin_shape.SetDim(idx, UNKNOWN_DIM); + } + } + tensor_input->SetShape(origin_shape); + tensor_input->SetShapeRange(cur_shape_range); + GELOGI("Update input [%s] shape range info", data_op_name.c_str()); + } else { + GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); + } + + return SUCCESS; +} + +Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { + if (input_shape_range.empty()) { + return SUCCESS; + } + GE_CHECK_NOTNULL(compute_graph); + + map>> shape_range_map; + if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { + GELOGE(PARAM_INVALID, "Parse input shape range failed."); + return PARAM_INVALID; + } + + for (NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { + GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + } // namespace ge diff --git a/ge/ir_build/atc_ir_common.h b/ge/ir_build/atc_ir_common.h index 2ad4efa8..e8637cb9 100644 --- a/ge/ir_build/atc_ir_common.h +++ b/ge/ir_build/atc_ir_common.h @@ -59,10 +59,13 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, std::string &dynamic_dims, const std::string input_shape, - const std::string input_format, bool &is_dynamic_input); + const std::string input_shape_range, const std::string input_format, + bool &is_dynamic_input); bool ParseInputShape(const std::string &input_shape, std::map> &shape_map, std::vector>> &user_shape_map, bool is_dynamic_input = false); +bool ParseInputShapeRange(const std::string &shape_range, + std::map>> &shape_range_map); Status CheckOutputTypeParamValid(const std::string output_type); Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); @@ -76,5 +79,9 @@ Status CheckInputFormat(const string &input_format); Status CheckKeepTypeParamValid(const std::string &keep_dtype); void PrintOptionMap(std::map &options, std::string tips); void EraseEndSemicolon(std::string ¶m); +Status UpdateDataOpShape(const OpDescPtr &op, std::map> &shape_map); +Status UpdateDataOpShapeRange(const OpDescPtr &op, + std::map>> &shape_range_map); +Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); } #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 62684e3a..cb025954 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -55,6 +55,7 @@ const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; const std::string KEEP_DTYPE_OPTION = "keep_dtype"; const std::string kInputShape = "input_shape"; +const std::string kInputShapeRange = "input_shape_range"; const std::string kInputFormat = "input_format"; /** @@ -289,13 +290,20 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) { graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GELOGD("Enter Update Data Attr Process!"); - if (options_.find(kInputShape) == options_.end()) { - return GRAPH_SUCCESS; - } + std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape]; + std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange]; + map> shape_map; vector>> user_shape_map; - GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true), - return GRAPH_PARAM_INVALID, "parse input shape failed!"); + if (!input_shape.empty()) { + GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), + return GRAPH_PARAM_INVALID, "Parse input shape failed!"); + } + std::map>> shape_range_map; + if (!input_shape_range.empty()) { + GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), + return GRAPH_PARAM_INVALID, "Parse input shape range failed."); + } auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { @@ -303,21 +311,31 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { ge::OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); if (op->GetType() == DATA) { - auto tensor_input = op->MutableInputDesc(0); - auto tensor_output = op->MutableOutputDesc(0); - GE_CHECK_NOTNULL(tensor_input); - GE_CHECK_NOTNULL(tensor_output); - string data_op_name = op->GetName(); - auto iter = shape_map.find(data_op_name); - if (iter != shape_map.end()) { - tensor_input->SetShape(ge::GeShape(iter->second)); - tensor_output->SetShape(ge::GeShape(iter->second)); - GELOGD("update input [%s] shape info", data_op_name.c_str()); - } else { - GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); + if (UpdateDataOpShape(op, shape_map) != SUCCESS) { + GELOGE(GRAPH_FAILED, "Update data op [%s] shape failed.", op->GetName().c_str()); + return GRAPH_FAILED; + } + if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { + GELOGE(GRAPH_FAILED, "Update data op [%s] shape range failed.", op->GetName().c_str()); + return GRAPH_FAILED; + } + if (shape_range_map.empty()) { + auto tensor_input = op->MutableInputDesc(0); + GE_CHECK_NOTNULL(tensor_input); + GeShape shape = tensor_input->GetShape(); + std::vector> shape_range; + if (tensor_input->GetShapeRange(shape_range) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[%s] Get shape range failed.", op->GetName().c_str()); + return GRAPH_FAILED; + } + if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) { + GELOGE(GRAPH_FAILED, "[%s] Check shape by shape range failed.", op->GetName().c_str()); + return GRAPH_FAILED; + } } } } + return GRAPH_SUCCESS; } @@ -400,9 +418,11 @@ graphStatus Impl::Init(const Graph &graph, const std::map &options, std::string output } else { std::map atc_params; atc_params.insert(std::pair("input_shape", FLAGS_input_shape)); + atc_params.insert(std::pair(ge::INPUT_SHAPE_RANGE, FLAGS_input_shape_range)); atc_params.insert(std::pair("out_nodes", FLAGS_out_nodes)); atc_params.insert(std::pair("input_format", FLAGS_input_format)); atc_params.insert(std::pair("check_report", FLAGS_check_report)); diff --git a/ge/session/omg.cc b/ge/session/omg.cc index bd1fd67c..f7072c7d 100755 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -576,6 +576,7 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format, GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); return PARAM_INVALID; } + return SUCCESS; } @@ -788,6 +789,12 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, + INPUT_SHAPE_RANGE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index c1a61c67..6c9969f4 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -45,6 +45,7 @@ include_directories(${GE_CODE_DIR}/inc) include_directories(${GE_CODE_DIR}/metadef/inc) include_directories(${GE_CODE_DIR}/ge) include_directories(${GE_CODE_DIR}/ge/inc) +include_directories(${GE_CODE_DIR}/ge/ir_build) include_directories(${GE_CODE_DIR}/metadef) include_directories(${GE_CODE_DIR}/metadef/graph) include_directories(${GE_CODE_DIR}/inc/external) @@ -61,6 +62,7 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) include_directories(${GE_CODE_DIR}/tests/ut/ge) +include_directories(${GE_CODE_DIR}/tests/ut/common) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) @@ -732,6 +734,7 @@ set(KERNEL_TEST_FILES set(MULTI_PARTS_TEST_FILES "graph_ir/ge_operator_factory_unittest.cc" + "graph_ir/ge_ir_build_unittest.cc" "graph/transop_util_unittest.cc" "common/datatype_transfer_unittest.cc" "common/dump_manager_unittest.cc" diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc new file mode 100644 index 00000000..4b36cd34 --- /dev/null +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ir_build/atc_ir_common.h" +#include "graph/testcase/ge_graph/graph_builder_utils.h" + +#define protected public +#define private public + +#undef private +#undef protected + +const string DATA = "Data"; +const string AddNYes = "AddNYes"; +const string NETOUTPUT = "NetOutput"; + +using namespace ge; +class UtestIrCommon : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) { + OpDescPtr op_desc = std::make_shared(name, type); + ge::GeTensorDesc ge_tensor_desc; + op_desc->AddInputDesc("input", ge_tensor_desc); + op_desc->AddOutputDesc("output", ge_tensor_desc); + + return op_desc; +} + +static ComputeGraphPtr BuildComputeGraph() { + auto builder = ut::GraphBuilder("test"); + auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); + auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); + auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, addn1, 0); + builder.AddDataEdge(data2, 0, addn1, 1); + builder.AddDataEdge(addn1, 0,netoutput, 0); + + return builder.GetGraph(); +} + +TEST(UtestIrCommon, update_data_op_shape) { + ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); + map> shape_map; + shape_map["Data"] = {{1,2}}; + + Status ret = UpdateDataOpShape(op_desc, shape_map); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST(UtestIrCommon, update_dynamic_shape_range_success) { + ComputeGraphPtr graph = BuildComputeGraph(); + std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; + + Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST(UtestIrCommon, update_dynamic_shape_range_failed) { + ComputeGraphPtr graph = BuildComputeGraph(); + // 1 + std::string input_shape_range = "input1;[1, 2~3, -1]"; + Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); + EXPECT_EQ(ret, ge::PARAM_INVALID); + + // 2 + input_shape_range = "input1:[1, 2~3, -1)"; + ret = UpdateDynamicInputShapeRange(graph, input_shape_range); + EXPECT_EQ(ret, ge::PARAM_INVALID); + + //3 + input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]"; + ret = UpdateDynamicInputShapeRange(graph, input_shape_range); + EXPECT_EQ(ret, ge::FAILED); + + //4 + input_shape_range = "input1:[1, 2~-3, -1]"; + ret = UpdateDynamicInputShapeRange(graph, input_shape_range); + EXPECT_EQ(ret, ge::PARAM_INVALID); +}