From 270de8ae3f5caee08b6c4fe4f03a1868d8659c16 Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Thu, 15 Apr 2021 14:18:21 +0800 Subject: [PATCH] ir parse input shape reange --- ge/CMakeLists.txt | 4 +- ge/ge_inference.mk | 2 +- ge/ge_runner.mk | 2 +- ge/graph/manager/graph_manager.cc | 2 +- ge/graph/preprocess/graph_preprocess.cc | 103 +------ ge/hybrid/executor/hybrid_model_executor.cc | 7 +- ge/ir_build/ge_ir_build.cc | 25 +- .../{atc_ir_common.cc => option_utils.cc} | 277 +++++++++++++----- .../{atc_ir_common.h => option_utils.h} | 10 +- ge/offline/CMakeLists.txt | 2 +- ge/offline/main.cc | 2 +- ge/offline/module.mk | 6 +- ge/session/omg.cc | 2 +- tests/ut/ge/CMakeLists.txt | 2 +- .../graph/manager/graph_manager_unittest.cc | 2 +- tests/ut/ge/graph_ir/ge_ir_build_unittest.cc | 99 ++++++- 16 files changed, 350 insertions(+), 197 deletions(-) rename ge/ir_build/{atc_ir_common.cc => option_utils.cc} (78%) rename ge/ir_build/{atc_ir_common.h => option_utils.h} (87%) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 6e0e9235..d28771ab 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -403,7 +403,7 @@ set(TRAIN_SRC_LIST "ir_build/attr_options/utils.cc" "ir_build/attr_options/keep_dtype_option.cc" "ir_build/attr_options/weight_compress_option.cc" - "ir_build/atc_ir_common.cc" + "ir_build/option_utils.cc" "graph/build/memory/memory_assigner.cc" "graph/build/memory/graph_mem_assigner.cc" "graph/build/memory/binary_block_mem_assigner.cc" @@ -664,7 +664,7 @@ set(INFER_SRC_LIST "ir_build/attr_options/utils.cc" "ir_build/attr_options/keep_dtype_option.cc" "ir_build/attr_options/weight_compress_option.cc" - "ir_build/atc_ir_common.cc" + "ir_build/option_utils.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" "hybrid/node_executor/aicpu/aicpu_ext_info.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 32fc206d..00dbc245 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -73,7 +73,7 @@ BUILER_SRC_FILES := \ ir_build/attr_options/utils.cc \ ir_build/attr_options/keep_dtype_option.cc \ ir_build/attr_options/weight_compress_option.cc \ - ir_build/atc_ir_common.cc \ + ir_build/option_utils.cc \ ANALYZER_SRC_FILES:= \ analyzer/analyzer.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 49515fe4..b89071aa 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -316,7 +316,7 @@ LIBGE_LOCAL_SRC_FILES := \ ir_build/attr_options/utils.cc \ ir_build/attr_options/keep_dtype_option.cc \ ir_build/attr_options/weight_compress_option.cc \ - ir_build/atc_ir_common.cc \ + ir_build/option_utils.cc \ LIBCLIENT_LOCAL_SRC_FILES := \ proto/ge_api.proto \ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 1315376c..202d0de4 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -101,7 +101,7 @@ #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "init/gelib.h" -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "graph/common/local_context.h" #include "graph/common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 4fb80646..2d06cd5d 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -27,6 +27,7 @@ #include "common/helper/model_helper.h" #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" +#include "ir_build/option_utils.h" #include "graph/common/ge_call_wrapper.h" #include "graph/common/local_context.h" #include "graph/common/transop_util.h" @@ -991,101 +992,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) { } return SUCCESS; } -long StringToLongNoThrow(const string &str) { - try { - return std::stol(str); - } catch (const std::invalid_argument) { - GELOGE(PARAM_INVALID, - "Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" - "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - str.c_str()); - return PARAM_INVALID; - } catch (const std::out_of_range) { - GELOGE(PARAM_INVALID, - "Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" - "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - str.c_str()); - return PARAM_INVALID; - } -} -/** - * Parser shape_range from string to vector - * shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" - * @param shape_range - */ -Status ParseDynamicInputShapeRange(const std::string &shape_range, - std::vector>> &range) { - if (shape_range.size() < 2) { - REPORT_INNER_ERROR("E19999", "shape_range.size:%zu < 2, check invalid", shape_range.size()); - GELOGE(PARAM_INVALID, "Shape range %s is invalid.", shape_range.c_str()); - return PARAM_INVALID; - } - // different shape_range of single input are split by ']' - vector shape_range_set = ge::StringUtils::Split(shape_range, ']'); - if (shape_range_set.empty()) { - REPORT_INNER_ERROR("E19999", "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - shape_range.c_str()); - GELOGE(PARAM_INVALID, "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - shape_range.c_str()); - return PARAM_INVALID; - } - for (auto &shape_range_str : shape_range_set) { - if (shape_range_str.size() < 3) { - // shape_range_str should be "[2~3,1" - // or ",[2~3,1". because we should trim '[' or ',[' - // so shape_range_str.size() < 3 is invalid - continue; - } - // trim start bytes, after that, single input should be "1~20,3,3~6,-1" - if (ge::StringUtils::StartWith(shape_range_str, "[")) { - shape_range_str = shape_range_str.substr(1, shape_range_str.size()); - } - if (ge::StringUtils::StartWith(shape_range_str, ",")) { - shape_range_str = shape_range_str.substr(2, shape_range_str.size()); - } - - // parse shape_range of single input. eg. "1~20,3,3~6,-1" - std::vector> range_of_single_input; - vector dim_range_set = ge::StringUtils::Split(shape_range_str, ','); - for (const auto &range_pair_str : dim_range_set) { - vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); - pair range_pair; - if (range_pair_set.size() == 1) { - // fix dim - auto range_value = StringToLongNoThrow(range_pair_set.at(0).c_str()); - if (range_value < 0) { - range_pair = std::make_pair(1, range_value); - } else { - range_pair = std::make_pair(range_value, range_value); - } - } else if (range_pair_set.size() == 2) { - // unknown dim, should get range. - auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str()); - auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str()); - if (range_left < 0 || range_right < 0) { - REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given range pair [%ld,%ld], " - "while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", range_left, range_right); - GELOGE(PARAM_INVALID, - "Shape range of input is invalid. Given range pair [%ld,%ld], while correct example: " - "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - range_left, range_right); - return PARAM_INVALID; - } - range_pair = std::make_pair(range_left, range_right); - } else { - REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given %s, " - "while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", shape_range.c_str()); - GELOGE(PARAM_INVALID, - "Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", - shape_range.c_str()); - return PARAM_INVALID; - } - range_of_single_input.emplace_back(range_pair); - } - range.emplace_back(range_of_single_input); - } - return SUCCESS; -} Status GetDynamicInputShapeRange(const std::vector &user_input, const std::map &graph_option, vector>> &range_vec) { @@ -1114,9 +1020,10 @@ Status GetDynamicInputShapeRange(const std::vector &user_input, const OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); return PARAM_INVALID; } - - auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); - GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); + if (ParseInputShapeRange(iter->second, range_vec) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Parse][ShapeRange] Parse dynamic input shape range failed."); + return PARAM_INVALID; + } if (range_vec.size() != user_input.size()) { GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), user_input.size()); diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index ea4e6912..2ab4ed5d 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -112,7 +112,12 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), "[Execute][GraphInternal] Dump exception info failed."); } - GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); + if (ret == ge::END_OF_SEQUENCE) { + GELOGD("Got end of sequence"); + } else { + GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); + } + return ret; } RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); } diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index af813a4a..b4a6c992 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -32,7 +32,7 @@ #include "graph/utils/type_utils.h" #include "graph/ge_global_options.h" #include "init/gelib.h" -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "model/ge_model.h" #include "graph/shape_refiner.h" #include "graph/opsproto_manager.h" @@ -299,10 +299,19 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); } - std::map>> shape_range_map; + std::map>> name_shape_range_map; + std::vector>> index_shape_range_map; if (!input_shape_range.empty()) { - GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), - return GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] failed."); + Status ret = GRAPH_PARAM_INVALID; + if (input_shape_range.find(":") != string::npos) { + ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); + } else { + ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); + } + if (ret != SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); + return GRAPH_PARAM_INVALID; + } } auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); @@ -315,10 +324,14 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); return GRAPH_FAILED; } - if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { + if (UpdateDataOpShapeRange(op, name_shape_range_map) != SUCCESS) { GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); return GRAPH_FAILED; - } + } + if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) { + GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); + return GRAPH_FAILED; + } } } diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/option_utils.cc similarity index 78% rename from ge/ir_build/atc_ir_common.cc rename to ge/ir_build/option_utils.cc index 4d4a67f0..c7b9b11f 100755 --- a/ge/ir_build/atc_ir_common.cc +++ b/ge/ir_build/option_utils.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "atc_ir_common.h" +#include "option_utils.h" #include "common/util/error_manager/error_manager.h" #include "external/ge/ge_api_types.h" #include "framework/common/string_util.h" @@ -22,6 +22,7 @@ #include "graph/compute_graph.h" #include "graph/utils/type_utils.h" #include "graph/utils/tensor_utils.h" +#include "graph/debug/ge_attr_define.h" using std::pair; using std::string; @@ -58,10 +59,12 @@ const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \ const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; const char *const kKeepDtypeError = "file not found"; const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; +const char *const kInputShapeRangeSizeInvalid = " shape range size less than 2 is invalid"; const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; const char *const kInputShapeRangeSample2 = "\"[1~20]\""; const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; +const char *const kInputShapeRangeSample4 = "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\""; vector SplitInputShape(const std::string &input_shape) { vector shape_pair_vec; @@ -72,6 +75,67 @@ vector SplitInputShape(const std::string &input_shape) { } return shape_pair_vec; } + +static bool StringToLongNoThrow(const string &str, long &val) { + try { + val = std::stol(str); + return true; + } catch (const std::invalid_argument) { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); + GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", + str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); + } catch (const std::out_of_range) { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); + GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", + str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); + } + return false; +} + +static bool ParseShapeRangePair(const string &shape_range, + const vector &range_pair_set, + std::pair &range_pair) { + if (range_pair_set.size() == 1) { + long range_value = 0; + if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { + return false; + } + if (range_value < 0) { + range_pair = std::make_pair(1, range_value); + } else { + range_pair = std::make_pair(range_value, range_value); + } + } else if (range_pair_set.size() == kRangePairSize) { + // unknown dim, should get range. + long range_left = 0; + if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { + return false; + } + long range_right = 0; + if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { + return false; + } + if ((range_left < 0) || (range_right < 0)) { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); + GELOGE(PARAM_INVALID, + "[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," + "reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); + return false; + } + range_pair = std::make_pair(range_left, range_right); + } else { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); + GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); + return false; + } + return true; +} } // namespace Status CheckInputFormat(const string &input_format) { @@ -287,24 +351,6 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims return true; } -bool StringToLongNoThrow(const string &str, long &val) { - try { - val = std::stol(str); - return true; - } catch (const std::invalid_argument) { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, - {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); - GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", - str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); - } catch (const std::out_of_range) { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, - {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); - GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", - str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); - } - return false; -} - bool ParseSingleShapeRange(std::string &shape_range, vector> &shape_range_vec) { vector square_brackets; for (auto ch : shape_range) { @@ -331,41 +377,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); pair range_pair; - if (range_pair_set.size() == 1) { - long range_value = 0; - if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { - return false; - } - if (range_value < 0) { - range_pair = std::make_pair(1, range_value); - } else { - range_pair = std::make_pair(range_value, range_value); - } - } else if (range_pair_set.size() == kRangePairSize) { - // unknown dim, should get range. - long range_left = 0; - if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { - return false; - } - long range_right = 0; - if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { - return false; - } - if (range_left < 0 || (range_right < 0)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, - {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); - GELOGE(PARAM_INVALID, - "[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," - "reason: %s, correct sample is %s.", - shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); - return false; - } - range_pair = std::make_pair(range_left, range_right); - } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, - {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); - GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", - shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); + if (!ParseShapeRangePair(shape_range, range_pair_set, range_pair)) { + GELOGE(PARAM_INVALID, "[Parse][RangePair] parse range pair failed."); return false; } shape_range_vec.emplace_back(range_pair); @@ -373,8 +386,13 @@ bool ParseSingleShapeRange(std::string &shape_range, vector>> &shape_range_map) { +/** + * Parser shape_range from string to map + * shape_range from option normally is "input1:[1~20,3,3~6,-1];input2:[1~20,3,3~6,-1]" + * @param shape_range + */ +Status ParseInputShapeRange(const std::string &shape_range, + std::map>> &shape_range_map) { GELOGD("Input shape range %s", shape_range.c_str()); vector shape_range_vec = StringUtils::Split(shape_range, ';'); @@ -386,25 +404,82 @@ bool ParseInputShapeRange(const std::string &shape_range, {shape_range, kSplitError1, kInputShapeRangeSample1}); GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); - return false; + return PARAM_INVALID; } if (shape_range_pair_vec[1].empty()) { ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, {shape_range, kEmptyError, kInputShapeRangeSample1}); GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); - return false; + return PARAM_INVALID; } string shape_range_str = shape_range_pair_vec[1]; vector> shape_range_val; if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); - return false; + return PARAM_INVALID; } shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); } - return true; + return SUCCESS; +} + +/** + * Parser shape_range from string to vector + * shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" + * @param shape_range + */ +Status ParseInputShapeRange(const std::string &shape_range, + std::vector>> &range) { + GELOGD("Input shape range %s", shape_range.c_str()); + + if (shape_range.size() < 2) { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); + GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeSizeInvalid, kInputShapeRangeSample4); + return PARAM_INVALID; + } + // different shape_range of single input are split by ']' + vector shape_range_set = ge::StringUtils::Split(shape_range, ']'); + if (shape_range_set.empty()) { + REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), + std::vector({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample4})); + GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", + shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample4); + return PARAM_INVALID; + } + for (auto &shape_range_str : shape_range_set) { + if (shape_range_str.size() < 3) { + // shape_range_str should be "[2~3,1" + // or ",[2~3,1". because we should trim '[' or ',[' + // so shape_range_str.size() < 3 is invalid + continue; + } + // trim start bytes, after that, single input should be "1~20,3,3~6,-1" + if (ge::StringUtils::StartWith(shape_range_str, "[")) { + shape_range_str = shape_range_str.substr(1, shape_range_str.size()); + } + if (ge::StringUtils::StartWith(shape_range_str, ",")) { + shape_range_str = shape_range_str.substr(2, shape_range_str.size()); + } + + // parse shape_range of single input. eg. "1~20,3,3~6,-1" + std::vector> range_of_single_input; + vector dim_range_set = ge::StringUtils::Split(shape_range_str, ','); + for (const auto &range_pair_str : dim_range_set) { + vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); + pair range_pair; + if (!ParseShapeRangePair(shape_range_str, range_pair_set, range_pair)) { + GELOGE(PARAM_INVALID, "[Parse][RangePair] Parse range pair failed."); + return PARAM_INVALID; + } + range_of_single_input.emplace_back(range_pair); + } + range.emplace_back(range_of_single_input); + } + return SUCCESS; } Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, @@ -420,11 +495,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i } if (param_size == 0) { - if (!input_shape_range.empty()) { - std::map>> shape_range_map; - if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { - GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); - return ge::PARAM_INVALID; + if (input_shape_range.find(":") != string::npos) { + if (!input_shape_range.empty()) { + std::map>> shape_range_map; + if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { + GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); + return ge::PARAM_INVALID; + } } } return ge::SUCCESS; @@ -446,9 +523,11 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i } if (!dynamic_batch_size.empty()) { - if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { - GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); - return ge::PARAM_INVALID; + if (input_shape_range.find(":") != string::npos) { + if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { + GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); + return ge::PARAM_INVALID; + } } } @@ -744,10 +823,10 @@ Status UpdateDataOpShape(const OpDescPtr &op, map> &shap } Status UpdateDataOpShapeRange(const OpDescPtr &op, - map>> &shape_range_map) { + const map>> &name_shape_range_map) { GE_CHECK_NOTNULL(op); - if (shape_range_map.empty()) { - GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); + if (name_shape_range_map.empty()) { + GELOGI("Shape range name map of data op [%s] is empty.", op->GetName().c_str()); return SUCCESS; } @@ -757,8 +836,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, GE_CHECK_NOTNULL(tensor_output); string data_op_name = op->GetName(); auto origin_shape = tensor_input->GetShape(); - auto iter = shape_range_map.find(data_op_name); - if (iter != shape_range_map.end()) { + auto iter = name_shape_range_map.find(data_op_name); + if (iter != name_shape_range_map.end()) { auto cur_shape_range = iter->second; if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); @@ -783,6 +862,56 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, return SUCCESS; } +Status UpdateDataOpShapeRange(const OpDescPtr &op, + const vector>> &index_shape_range_map) { + GE_CHECK_NOTNULL(op); + if (index_shape_range_map.empty()) { + GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str()); + return SUCCESS; + } + + GeAttrValue::INT index = 0; + if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { + GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str()); + return SUCCESS; + } + + if ((index < 0) || (static_cast(index) >= index_shape_range_map.size())) { + std::string situation = "data op index[" + std::to_string(index) + "]"; + std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]"; + REPORT_INPUT_ERROR("E19025", std::vector({"situation", "reason"}), + std::vector({situation, reason})); + GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", index_shape_range_map.size(), index); + return FAILED; + } + + auto tensor_input = op->MutableInputDesc(0); + auto tensor_output = op->MutableOutputDesc(0); + GE_CHECK_NOTNULL(tensor_input); + GE_CHECK_NOTNULL(tensor_output); + string data_op_name = op->GetName(); + auto origin_shape = tensor_input->GetShape(); + auto cur_shape_range = index_shape_range_map[index]; + if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); + return PARAM_INVALID; + } + for (size_t idx = 0; idx < cur_shape_range.size(); ++idx) { + auto left_range = cur_shape_range[idx].first; + auto right_range = cur_shape_range[idx].second; + if (left_range != right_range) { + origin_shape.SetDim(idx, UNKNOWN_DIM); + } + } + tensor_input->SetShape(origin_shape); + tensor_input->SetShapeRange(cur_shape_range); + tensor_output->SetShape(origin_shape); + tensor_output->SetShapeRange(cur_shape_range); + GELOGI("Update input [%s] shape range info success.", data_op_name.c_str()); + + return SUCCESS; +} + static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, const map>> &shape_range_map) { for (const auto &it : shape_range_map) { @@ -813,7 +942,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co GE_CHECK_NOTNULL(compute_graph); map>> shape_range_map; - if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { + if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); return PARAM_INVALID; } diff --git a/ge/ir_build/atc_ir_common.h b/ge/ir_build/option_utils.h similarity index 87% rename from ge/ir_build/atc_ir_common.h rename to ge/ir_build/option_utils.h index 6ff40547..44504e35 100644 --- a/ge/ir_build/atc_ir_common.h +++ b/ge/ir_build/option_utils.h @@ -64,8 +64,10 @@ Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string bool ParseInputShape(const std::string &input_shape, std::map> &shape_map, std::vector>> &user_shape_map, bool is_dynamic_input = false); -bool ParseInputShapeRange(const std::string &shape_range, - std::map>> &shape_range_map); +Status ParseInputShapeRange(const std::string &shape_range, + std::map>> &shape_range_map); +Status ParseInputShapeRange(const std::string &shape_range, + std::vector>> &range); Status CheckOutputTypeParamValid(const std::string output_type); Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); @@ -80,8 +82,10 @@ Status CheckKeepTypeParamValid(const std::string &keep_dtype); void PrintOptionMap(std::map &options, std::string tips); void EraseEndSemicolon(std::string ¶m); Status UpdateDataOpShape(const OpDescPtr &op, std::map> &shape_map); +Status UpdateDataOpShapeRange( + const OpDescPtr &op, const std::map>> &name_shape_range_map); Status UpdateDataOpShapeRange(const OpDescPtr &op, - std::map>> &shape_range_map); + const std::vector>> &index_shape_range_map); Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); } #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index 87589859..2a0f0ff0 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -11,7 +11,7 @@ set(SRC_LIST "main.cc" "single_op_parser.cc" "../session/omg.cc" - "../ir_build/atc_ir_common.cc" + "../ir_build/option_utils.cc" ) ############ atc_atc.bin ############ diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 54a1d8fb..9e14e498 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -36,7 +36,7 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "omg/omg.h" #include "omg/parser/parser_factory.h" #include "omg/parser/parser_inner_ctx.h" diff --git a/ge/offline/module.mk b/ge/offline/module.mk index 5c7a919c..27c5863a 100755 --- a/ge/offline/module.mk +++ b/ge/offline/module.mk @@ -12,7 +12,7 @@ LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ ../session/omg.cc \ - ../ir_build/atc_ir_common.cc \ + ../ir_build/option_utils.cc \ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../ ./ \ @@ -65,7 +65,7 @@ LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ ../session/omg.cc \ - ../ir_build/atc_ir_common.cc \ + ../ir_build/option_utils.cc \ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../ ./ \ @@ -118,7 +118,7 @@ LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ ../session/omg.cc \ - ../ir_build/atc_ir_common.cc \ + ../ir_build/option_utils.cc \ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../ ./ \ diff --git a/ge/session/omg.cc b/ge/session/omg.cc index 1aec2ed4..ca5043b1 100755 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -38,7 +38,7 @@ #include "graph/debug/ge_attr_define.h" #include "graph/optimize/common/params.h" #include "graph/utils/type_utils.h" -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "omg/omg_inner_types.h" #include "omg/parser/model_parser.h" #include "omg/parser/parser_factory.h" diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index f2f08106..65d7ad74 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -308,7 +308,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" - "${GE_CODE_DIR}/ge/ir_build/atc_ir_common.cc" + "${GE_CODE_DIR}/ge/ir_build/option_utils.cc" "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" "${GE_CODE_DIR}/ge/graph/build/model_builder.cc" diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index dad55f3d..14aece0c 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -109,7 +109,7 @@ #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "graph/common/local_context.h" #include "graph/common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index 0c7bf651..1ca375a8 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -15,8 +15,9 @@ */ #include -#include "ir_build/atc_ir_common.h" +#include "ir_build/option_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" +#include "graph/debug/ge_attr_define.h" #define protected public #define private public @@ -68,6 +69,20 @@ TEST(UtestIrCommon, update_data_op_shape) { EXPECT_EQ(ret, ge::SUCCESS); } +TEST(UtestIrCommon, update_data_op_shape_range) { + ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); + std::vector>> index_shape_range_map; + + std::pair range_pair(1, 2); + vector> range_pair_tmp = { range_pair }; + + index_shape_range_map.push_back(range_pair_tmp); + + AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0); + Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map); + EXPECT_EQ(ret, ge::SUCCESS); +} + TEST(UtestIrCommon, update_dynamic_shape_range_success) { ComputeGraphPtr graph = BuildComputeGraph(); std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; @@ -117,4 +132,84 @@ TEST(UtestIrCommon, check_dynamic_image_size_fail) { bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); EXPECT_EQ(ret, false); -} \ No newline at end of file +} + +TEST(UtestIrCommon, check_input_format_failed) { + std::string format = "invalid"; + Status ret = CheckInputFormat(format); + EXPECT_EQ(ret, ge::PARAM_INVALID); +} + +TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) { + map> shape_map; + shape_map.insert(std::pair>("data", {-1, 2, 3})); + std::string dynamic_batch_size = "11"; + + bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size); + EXPECT_EQ(ret, true); +} + +TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) { + map> shape_map; + shape_map.insert(std::pair>("data", {4, -1, -1, 5})); + std::string input_format = "NCHW"; + std::string dynamic_image_size = "4,5"; + + Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST(UtestIrCommon, check_dynamic_input_param_succ) { + string dynamic_batch_size = "1"; + string dynamic_image_size; + string dynamic_dims; + string input_shape = "data:1,3,244,244"; + string input_shape_range; + string input_format = "NCHW"; + bool is_dynamic_input = false; + + Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, + input_shape, input_shape_range, input_format,is_dynamic_input); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST(UtestIrCommon, check_compress_weight) { + std::string enable_compress_weight = "true"; + std::string compress_weight_conf="./"; + Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); + EXPECT_EQ(ret, PARAM_INVALID); + + enable_compress_weight = "yes"; + compress_weight_conf = "./"; + ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST(UtestIrCommon, check_param_failed) { + std::string param_invalid = "invalid"; + + Status ret = CheckOutputTypeParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckBufferOptimizeParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckKeepTypeParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckInsertOpConfParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckDisableReuseMemoryParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckEnableSingleStreamParamValid(param_invalid); + EXPECT_EQ(ret, PARAM_INVALID); + + std::string optypelist_for_implmode; + std::string op_select_implmode = "1"; + ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode); + EXPECT_EQ(ret, PARAM_INVALID); + + ret = CheckLogParamValidAndSetLogLevel(param_invalid); +}