From: @zhengyuanhua Reviewed-by: Signed-off-by:tags/v1.3.0
| @@ -402,7 +402,7 @@ set(TRAIN_SRC_LIST | |||||
| "ir_build/attr_options/utils.cc" | "ir_build/attr_options/utils.cc" | ||||
| "ir_build/attr_options/keep_dtype_option.cc" | "ir_build/attr_options/keep_dtype_option.cc" | ||||
| "ir_build/attr_options/weight_compress_option.cc" | "ir_build/attr_options/weight_compress_option.cc" | ||||
| "ir_build/atc_ir_common.cc" | |||||
| "ir_build/option_utils.cc" | |||||
| "graph/build/memory/memory_assigner.cc" | "graph/build/memory/memory_assigner.cc" | ||||
| "graph/build/memory/graph_mem_assigner.cc" | "graph/build/memory/graph_mem_assigner.cc" | ||||
| "graph/build/memory/binary_block_mem_assigner.cc" | "graph/build/memory/binary_block_mem_assigner.cc" | ||||
| @@ -662,7 +662,7 @@ set(INFER_SRC_LIST | |||||
| "ir_build/attr_options/utils.cc" | "ir_build/attr_options/utils.cc" | ||||
| "ir_build/attr_options/keep_dtype_option.cc" | "ir_build/attr_options/keep_dtype_option.cc" | ||||
| "ir_build/attr_options/weight_compress_option.cc" | "ir_build/attr_options/weight_compress_option.cc" | ||||
| "ir_build/atc_ir_common.cc" | |||||
| "ir_build/option_utils.cc" | |||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
| "hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
| @@ -72,7 +72,7 @@ BUILER_SRC_FILES := \ | |||||
| ir_build/attr_options/utils.cc \ | ir_build/attr_options/utils.cc \ | ||||
| ir_build/attr_options/keep_dtype_option.cc \ | ir_build/attr_options/keep_dtype_option.cc \ | ||||
| ir_build/attr_options/weight_compress_option.cc \ | ir_build/attr_options/weight_compress_option.cc \ | ||||
| ir_build/atc_ir_common.cc \ | |||||
| ir_build/option_utils.cc \ | |||||
| ANALYZER_SRC_FILES:= \ | ANALYZER_SRC_FILES:= \ | ||||
| analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
| @@ -315,7 +315,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| ir_build/attr_options/utils.cc \ | ir_build/attr_options/utils.cc \ | ||||
| ir_build/attr_options/keep_dtype_option.cc \ | ir_build/attr_options/keep_dtype_option.cc \ | ||||
| ir_build/attr_options/weight_compress_option.cc \ | ir_build/attr_options/weight_compress_option.cc \ | ||||
| ir_build/atc_ir_common.cc \ | |||||
| ir_build/option_utils.cc \ | |||||
| LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
| proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
| @@ -101,7 +101,7 @@ | |||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "ir_build/option_utils.h" | |||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| #include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
| @@ -991,101 +992,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| long StringToLongNoThrow(const string &str) { | |||||
| try { | |||||
| return std::stol(str); | |||||
| } catch (const std::invalid_argument) { | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" | |||||
| "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| str.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } catch (const std::out_of_range) { | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" | |||||
| "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| str.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| /** | |||||
| * Parser shape_range from string to vector | |||||
| * shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" | |||||
| * @param shape_range | |||||
| */ | |||||
| Status ParseDynamicInputShapeRange(const std::string &shape_range, | |||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | |||||
| if (shape_range.size() < 2) { | |||||
| REPORT_INNER_ERROR("E19999", "shape_range.size:%zu < 2, check invalid", shape_range.size()); | |||||
| GELOGE(PARAM_INVALID, "Shape range %s is invalid.", shape_range.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| // different shape_range of single input are split by ']' | |||||
| vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']'); | |||||
| if (shape_range_set.empty()) { | |||||
| REPORT_INNER_ERROR("E19999", "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| shape_range.c_str()); | |||||
| GELOGE(PARAM_INVALID, "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| shape_range.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| for (auto &shape_range_str : shape_range_set) { | |||||
| if (shape_range_str.size() < 3) { | |||||
| // shape_range_str should be "[2~3,1" | |||||
| // or ",[2~3,1". because we should trim '[' or ',[' | |||||
| // so shape_range_str.size() < 3 is invalid | |||||
| continue; | |||||
| } | |||||
| // trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||||
| if (ge::StringUtils::StartWith(shape_range_str, "[")) { | |||||
| shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | |||||
| } | |||||
| if (ge::StringUtils::StartWith(shape_range_str, ",")) { | |||||
| shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||||
| } | |||||
| // parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||||
| std::vector<std::pair<int64_t, int64_t>> range_of_single_input; | |||||
| vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ','); | |||||
| for (const auto &range_pair_str : dim_range_set) { | |||||
| vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||||
| pair<int64_t, int64_t> range_pair; | |||||
| if (range_pair_set.size() == 1) { | |||||
| // fix dim | |||||
| auto range_value = StringToLongNoThrow(range_pair_set.at(0).c_str()); | |||||
| if (range_value < 0) { | |||||
| range_pair = std::make_pair(1, range_value); | |||||
| } else { | |||||
| range_pair = std::make_pair(range_value, range_value); | |||||
| } | |||||
| } else if (range_pair_set.size() == 2) { | |||||
| // unknown dim, should get range. | |||||
| auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str()); | |||||
| auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str()); | |||||
| if (range_left < 0 || range_right < 0) { | |||||
| REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given range pair [%ld,%ld], " | |||||
| "while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", range_left, range_right); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Shape range of input is invalid. Given range pair [%ld,%ld], while correct example: " | |||||
| "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| range_left, range_right); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| range_pair = std::make_pair(range_left, range_right); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given %s, " | |||||
| "while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", shape_range.c_str()); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| shape_range.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| range_of_single_input.emplace_back(range_pair); | |||||
| } | |||||
| range.emplace_back(range_of_single_input); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | ||||
| vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | ||||
| @@ -1114,9 +1020,10 @@ Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const | |||||
| OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); | |||||
| GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); | |||||
| if (ParseInputShapeRange(iter->second, range_vec) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Parse][ShapeRange] Parse dynamic input shape range failed."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (range_vec.size() != user_input.size()) { | if (range_vec.size() != user_input.size()) { | ||||
| GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), | GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), | ||||
| user_input.size()); | user_input.size()); | ||||
| @@ -112,7 +112,12 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), | HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), | ||||
| "[Execute][GraphInternal] Dump exception info failed."); | "[Execute][GraphInternal] Dump exception info failed."); | ||||
| } | } | ||||
| GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); | |||||
| if (ret == ge::END_OF_SEQUENCE) { | |||||
| GELOGD("Got end of sequence"); | |||||
| } else { | |||||
| GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); | |||||
| } | |||||
| return ret; | |||||
| } | } | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/ge_global_options.h" | #include "graph/ge_global_options.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "model/ge_model.h" | #include "model/ge_model.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| @@ -299,10 +299,19 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
| GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | ||||
| return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); | return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); | ||||
| } | } | ||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> name_shape_range_map; | |||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map; | |||||
| if (!input_shape_range.empty()) { | if (!input_shape_range.empty()) { | ||||
| GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), | |||||
| return GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] failed."); | |||||
| Status ret = GRAPH_PARAM_INVALID; | |||||
| if (input_shape_range.find(":") != string::npos) { | |||||
| ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); | |||||
| } else { | |||||
| ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); | |||||
| } | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| } | } | ||||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| @@ -315,10 +324,14 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
| GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); | GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||||
| if (UpdateDataOpShapeRange(op, name_shape_range_map) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | |||||
| } | |||||
| if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "atc_ir_common.h" | |||||
| #include "option_utils.h" | |||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "framework/common/string_util.h" | #include "framework/common/string_util.h" | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| using std::pair; | using std::pair; | ||||
| using std::string; | using std::string; | ||||
| @@ -58,10 +59,12 @@ const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \ | |||||
| const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | ||||
| const char *const kKeepDtypeError = "file not found"; | const char *const kKeepDtypeError = "file not found"; | ||||
| const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | ||||
| const char *const kInputShapeRangeSizeInvalid = " shape range size less than 2 is invalid"; | |||||
| const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | ||||
| const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | ||||
| const char *const kInputShapeRangeSample2 = "\"[1~20]\""; | const char *const kInputShapeRangeSample2 = "\"[1~20]\""; | ||||
| const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | ||||
| const char *const kInputShapeRangeSample4 = "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\""; | |||||
| vector<string> SplitInputShape(const std::string &input_shape) { | vector<string> SplitInputShape(const std::string &input_shape) { | ||||
| vector<string> shape_pair_vec; | vector<string> shape_pair_vec; | ||||
| @@ -72,6 +75,67 @@ vector<string> SplitInputShape(const std::string &input_shape) { | |||||
| } | } | ||||
| return shape_pair_vec; | return shape_pair_vec; | ||||
| } | } | ||||
| static bool StringToLongNoThrow(const string &str, long &val) { | |||||
| try { | |||||
| val = std::stol(str); | |||||
| return true; | |||||
| } catch (const std::invalid_argument) { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", | |||||
| str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
| } catch (const std::out_of_range) { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", | |||||
| str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| static bool ParseShapeRangePair(const string &shape_range, | |||||
| const vector<string> &range_pair_set, | |||||
| std::pair<int64_t, int64_t> &range_pair) { | |||||
| if (range_pair_set.size() == 1) { | |||||
| long range_value = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||||
| return false; | |||||
| } | |||||
| if (range_value < 0) { | |||||
| range_pair = std::make_pair(1, range_value); | |||||
| } else { | |||||
| range_pair = std::make_pair(range_value, range_value); | |||||
| } | |||||
| } else if (range_pair_set.size() == kRangePairSize) { | |||||
| // unknown dim, should get range. | |||||
| long range_left = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||||
| return false; | |||||
| } | |||||
| long range_right = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||||
| return false; | |||||
| } | |||||
| if ((range_left < 0) || (range_right < 0)) { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," | |||||
| "reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
| return false; | |||||
| } | |||||
| range_pair = std::make_pair(range_left, range_right); | |||||
| } else { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| Status CheckInputFormat(const string &input_format) { | Status CheckInputFormat(const string &input_format) { | ||||
| @@ -287,24 +351,6 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool StringToLongNoThrow(const string &str, long &val) { | |||||
| try { | |||||
| val = std::stol(str); | |||||
| return true; | |||||
| } catch (const std::invalid_argument) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
| {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", | |||||
| str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
| } catch (const std::out_of_range) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
| {str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", | |||||
| str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | ||||
| vector<char> square_brackets; | vector<char> square_brackets; | ||||
| for (auto ch : shape_range) { | for (auto ch : shape_range) { | ||||
| @@ -331,41 +377,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_ | |||||
| for (const auto &range_pair_str : dim_range_set) { | for (const auto &range_pair_str : dim_range_set) { | ||||
| vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | ||||
| pair<int64_t, int64_t> range_pair; | pair<int64_t, int64_t> range_pair; | ||||
| if (range_pair_set.size() == 1) { | |||||
| long range_value = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||||
| return false; | |||||
| } | |||||
| if (range_value < 0) { | |||||
| range_pair = std::make_pair(1, range_value); | |||||
| } else { | |||||
| range_pair = std::make_pair(range_value, range_value); | |||||
| } | |||||
| } else if (range_pair_set.size() == kRangePairSize) { | |||||
| // unknown dim, should get range. | |||||
| long range_left = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||||
| return false; | |||||
| } | |||||
| long range_right = 0; | |||||
| if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||||
| return false; | |||||
| } | |||||
| if (range_left < 0 || (range_right < 0)) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
| {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," | |||||
| "reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
| return false; | |||||
| } | |||||
| range_pair = std::make_pair(range_left, range_right); | |||||
| } else { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
| {shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
| if (!ParseShapeRangePair(shape_range, range_pair_set, range_pair)) { | |||||
| GELOGE(PARAM_INVALID, "[Parse][RangePair] parse range pair failed."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| shape_range_vec.emplace_back(range_pair); | shape_range_vec.emplace_back(range_pair); | ||||
| @@ -373,8 +386,13 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_ | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ParseInputShapeRange(const std::string &shape_range, | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||||
| /** | |||||
| * Parser shape_range from string to map | |||||
| * shape_range from option normally is "input1:[1~20,3,3~6,-1];input2:[1~20,3,3~6,-1]" | |||||
| * @param shape_range | |||||
| */ | |||||
| Status ParseInputShapeRange(const std::string &shape_range, | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||||
| GELOGD("Input shape range %s", shape_range.c_str()); | GELOGD("Input shape range %s", shape_range.c_str()); | ||||
| vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | ||||
| @@ -386,25 +404,82 @@ bool ParseInputShapeRange(const std::string &shape_range, | |||||
| {shape_range, kSplitError1, kInputShapeRangeSample1}); | {shape_range, kSplitError1, kInputShapeRangeSample1}); | ||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", | GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", | ||||
| shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | ||||
| return false; | |||||
| return PARAM_INVALID; | |||||
| } | } | ||||
| if (shape_range_pair_vec[1].empty()) { | if (shape_range_pair_vec[1].empty()) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | ||||
| {shape_range, kEmptyError, kInputShapeRangeSample1}); | {shape_range, kEmptyError, kInputShapeRangeSample1}); | ||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", | GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", | ||||
| shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | ||||
| return false; | |||||
| return PARAM_INVALID; | |||||
| } | } | ||||
| string shape_range_str = shape_range_pair_vec[1]; | string shape_range_str = shape_range_pair_vec[1]; | ||||
| vector<pair<int64_t, int64_t>> shape_range_val; | vector<pair<int64_t, int64_t>> shape_range_val; | ||||
| if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | ||||
| GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); | GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); | ||||
| return false; | |||||
| return PARAM_INVALID; | |||||
| } | } | ||||
| shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | ||||
| } | } | ||||
| return true; | |||||
| return SUCCESS; | |||||
| } | |||||
| /** | |||||
| * Parser shape_range from string to vector | |||||
| * shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" | |||||
| * @param shape_range | |||||
| */ | |||||
| Status ParseInputShapeRange(const std::string &shape_range, | |||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | |||||
| GELOGD("Input shape range %s", shape_range.c_str()); | |||||
| if (shape_range.size() < 2) { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); | |||||
| GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeSizeInvalid, kInputShapeRangeSample4); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| // different shape_range of single input are split by ']' | |||||
| vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']'); | |||||
| if (shape_range_set.empty()) { | |||||
| REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
| std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample4})); | |||||
| GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | |||||
| shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample4); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| for (auto &shape_range_str : shape_range_set) { | |||||
| if (shape_range_str.size() < 3) { | |||||
| // shape_range_str should be "[2~3,1" | |||||
| // or ",[2~3,1". because we should trim '[' or ',[' | |||||
| // so shape_range_str.size() < 3 is invalid | |||||
| continue; | |||||
| } | |||||
| // trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||||
| if (ge::StringUtils::StartWith(shape_range_str, "[")) { | |||||
| shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | |||||
| } | |||||
| if (ge::StringUtils::StartWith(shape_range_str, ",")) { | |||||
| shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||||
| } | |||||
| // parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||||
| std::vector<std::pair<int64_t, int64_t>> range_of_single_input; | |||||
| vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ','); | |||||
| for (const auto &range_pair_str : dim_range_set) { | |||||
| vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||||
| pair<int64_t, int64_t> range_pair; | |||||
| if (!ParseShapeRangePair(shape_range_str, range_pair_set, range_pair)) { | |||||
| GELOGE(PARAM_INVALID, "[Parse][RangePair] Parse range pair failed."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| range_of_single_input.emplace_back(range_pair); | |||||
| } | |||||
| range.emplace_back(range_of_single_input); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | ||||
| @@ -420,11 +495,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||||
| } | } | ||||
| if (param_size == 0) { | if (param_size == 0) { | ||||
| if (!input_shape_range.empty()) { | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
| if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
| GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); | |||||
| return ge::PARAM_INVALID; | |||||
| if (input_shape_range.find(":") != string::npos) { | |||||
| if (!input_shape_range.empty()) { | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
| if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { | |||||
| GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| @@ -446,9 +523,11 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||||
| } | } | ||||
| if (!dynamic_batch_size.empty()) { | if (!dynamic_batch_size.empty()) { | ||||
| if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { | |||||
| GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); | |||||
| return ge::PARAM_INVALID; | |||||
| if (input_shape_range.find(":") != string::npos) { | |||||
| if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { | |||||
| GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -744,10 +823,10 @@ Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shap | |||||
| } | } | ||||
| Status UpdateDataOpShapeRange(const OpDescPtr &op, | Status UpdateDataOpShapeRange(const OpDescPtr &op, | ||||
| map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | |||||
| const map<string, vector<pair<int64_t, int64_t>>> &name_shape_range_map) { | |||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| if (shape_range_map.empty()) { | |||||
| GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); | |||||
| if (name_shape_range_map.empty()) { | |||||
| GELOGI("Shape range name map of data op [%s] is empty.", op->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -757,8 +836,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
| GE_CHECK_NOTNULL(tensor_output); | GE_CHECK_NOTNULL(tensor_output); | ||||
| string data_op_name = op->GetName(); | string data_op_name = op->GetName(); | ||||
| auto origin_shape = tensor_input->GetShape(); | auto origin_shape = tensor_input->GetShape(); | ||||
| auto iter = shape_range_map.find(data_op_name); | |||||
| if (iter != shape_range_map.end()) { | |||||
| auto iter = name_shape_range_map.find(data_op_name); | |||||
| if (iter != name_shape_range_map.end()) { | |||||
| auto cur_shape_range = iter->second; | auto cur_shape_range = iter->second; | ||||
| if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | ||||
| @@ -783,6 +862,56 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
| const vector<vector<pair<int64_t, int64_t>>> &index_shape_range_map) { | |||||
| GE_CHECK_NOTNULL(op); | |||||
| if (index_shape_range_map.empty()) { | |||||
| GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| GeAttrValue::INT index = 0; | |||||
| if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { | |||||
| GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| if ((index < 0) || (static_cast<size_t>(index) >= index_shape_range_map.size())) { | |||||
| std::string situation = "data op index[" + std::to_string(index) + "]"; | |||||
| std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]"; | |||||
| REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}), | |||||
| std::vector<std::string>({situation, reason})); | |||||
| GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", index_shape_range_map.size(), index); | |||||
| return FAILED; | |||||
| } | |||||
| auto tensor_input = op->MutableInputDesc(0); | |||||
| auto tensor_output = op->MutableOutputDesc(0); | |||||
| GE_CHECK_NOTNULL(tensor_input); | |||||
| GE_CHECK_NOTNULL(tensor_output); | |||||
| string data_op_name = op->GetName(); | |||||
| auto origin_shape = tensor_input->GetShape(); | |||||
| auto cur_shape_range = index_shape_range_map[index]; | |||||
| if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| for (size_t idx = 0; idx < cur_shape_range.size(); ++idx) { | |||||
| auto left_range = cur_shape_range[idx].first; | |||||
| auto right_range = cur_shape_range[idx].second; | |||||
| if (left_range != right_range) { | |||||
| origin_shape.SetDim(idx, UNKNOWN_DIM); | |||||
| } | |||||
| } | |||||
| tensor_input->SetShape(origin_shape); | |||||
| tensor_input->SetShapeRange(cur_shape_range); | |||||
| tensor_output->SetShape(origin_shape); | |||||
| tensor_output->SetShapeRange(cur_shape_range); | |||||
| GELOGI("Update input [%s] shape range info success.", data_op_name.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, | static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, | ||||
| const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | ||||
| for (const auto &it : shape_range_map) { | for (const auto &it : shape_range_map) { | ||||
| @@ -813,7 +942,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co | |||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | ||||
| if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
| if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); | GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -64,8 +64,10 @@ Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string | |||||
| bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | ||||
| std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | ||||
| bool ParseInputShapeRange(const std::string &shape_range, | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
| Status ParseInputShapeRange(const std::string &shape_range, | |||||
| std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
| Status ParseInputShapeRange(const std::string &shape_range, | |||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> &range); | |||||
| Status CheckOutputTypeParamValid(const std::string output_type); | Status CheckOutputTypeParamValid(const std::string output_type); | ||||
| Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | ||||
| @@ -80,8 +82,10 @@ Status CheckKeepTypeParamValid(const std::string &keep_dtype); | |||||
| void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | ||||
| void EraseEndSemicolon(std::string ¶m); | void EraseEndSemicolon(std::string ¶m); | ||||
| Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | ||||
| Status UpdateDataOpShapeRange( | |||||
| const OpDescPtr &op, const std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map); | |||||
| Status UpdateDataOpShapeRange(const OpDescPtr &op, | Status UpdateDataOpShapeRange(const OpDescPtr &op, | ||||
| std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
| const std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map); | |||||
| Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | ||||
| } | } | ||||
| #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ | #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ | ||||
| @@ -11,7 +11,7 @@ set(SRC_LIST | |||||
| "main.cc" | "main.cc" | ||||
| "single_op_parser.cc" | "single_op_parser.cc" | ||||
| "../session/omg.cc" | "../session/omg.cc" | ||||
| "../ir_build/atc_ir_common.cc" | |||||
| "../ir_build/option_utils.cc" | |||||
| ) | ) | ||||
| ############ atc_atc.bin ############ | ############ atc_atc.bin ############ | ||||
| @@ -36,7 +36,7 @@ | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "omg/omg.h" | #include "omg/omg.h" | ||||
| #include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
| #include "omg/parser/parser_inner_ctx.h" | #include "omg/parser/parser_inner_ctx.h" | ||||
| @@ -12,7 +12,7 @@ LOCAL_SRC_FILES := \ | |||||
| main.cc \ | main.cc \ | ||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | |||||
| ../ir_build/option_utils.cc \ | |||||
| LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
| $(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
| @@ -65,7 +65,7 @@ LOCAL_SRC_FILES := \ | |||||
| main.cc \ | main.cc \ | ||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | |||||
| ../ir_build/option_utils.cc \ | |||||
| LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
| $(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
| @@ -118,7 +118,7 @@ LOCAL_SRC_FILES := \ | |||||
| main.cc \ | main.cc \ | ||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | |||||
| ../ir_build/option_utils.cc \ | |||||
| LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
| $(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
| @@ -38,7 +38,7 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
| #include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
| @@ -307,7 +307,7 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | ||||
| "${GE_CODE_DIR}/ge/ir_build/atc_ir_common.cc" | |||||
| "${GE_CODE_DIR}/ge/ir_build/option_utils.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" | "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" | "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/build/model_builder.cc" | "${GE_CODE_DIR}/ge/graph/build/model_builder.cc" | ||||
| @@ -109,7 +109,7 @@ | |||||
| #include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| @@ -15,8 +15,9 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "ir_build/atc_ir_common.h" | |||||
| #include "ir_build/option_utils.h" | |||||
| #include "graph/testcase/ge_graph/graph_builder_utils.h" | #include "graph/testcase/ge_graph/graph_builder_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| @@ -68,6 +69,20 @@ TEST(UtestIrCommon, update_data_op_shape) { | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | EXPECT_EQ(ret, ge::SUCCESS); | ||||
| } | } | ||||
| TEST(UtestIrCommon, update_data_op_shape_range) { | |||||
| ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); | |||||
| std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map; | |||||
| std::pair<int64_t, int64_t> range_pair(1, 2); | |||||
| vector<pair<int64_t, int64_t>> range_pair_tmp = { range_pair }; | |||||
| index_shape_range_map.push_back(range_pair_tmp); | |||||
| AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0); | |||||
| Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map); | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | |||||
| } | |||||
| TEST(UtestIrCommon, update_dynamic_shape_range_success) { | TEST(UtestIrCommon, update_dynamic_shape_range_success) { | ||||
| ComputeGraphPtr graph = BuildComputeGraph(); | ComputeGraphPtr graph = BuildComputeGraph(); | ||||
| std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | ||||
| @@ -117,4 +132,84 @@ TEST(UtestIrCommon, check_dynamic_image_size_fail) { | |||||
| bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | ||||
| EXPECT_EQ(ret, false); | EXPECT_EQ(ret, false); | ||||
| } | |||||
| } | |||||
| TEST(UtestIrCommon, check_input_format_failed) { | |||||
| std::string format = "invalid"; | |||||
| Status ret = CheckInputFormat(format); | |||||
| EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
| } | |||||
| TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) { | |||||
| map<string, vector<int64_t>> shape_map; | |||||
| shape_map.insert(std::pair<string, vector<int64_t>>("data", {-1, 2, 3})); | |||||
| std::string dynamic_batch_size = "11"; | |||||
| bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size); | |||||
| EXPECT_EQ(ret, true); | |||||
| } | |||||
| TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) { | |||||
| map<string, vector<int64_t>> shape_map; | |||||
| shape_map.insert(std::pair<string, vector<int64_t>>("data", {4, -1, -1, 5})); | |||||
| std::string input_format = "NCHW"; | |||||
| std::string dynamic_image_size = "4,5"; | |||||
| Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | |||||
| } | |||||
| TEST(UtestIrCommon, check_dynamic_input_param_succ) { | |||||
| string dynamic_batch_size = "1"; | |||||
| string dynamic_image_size; | |||||
| string dynamic_dims; | |||||
| string input_shape = "data:1,3,244,244"; | |||||
| string input_shape_range; | |||||
| string input_format = "NCHW"; | |||||
| bool is_dynamic_input = false; | |||||
| Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, | |||||
| input_shape, input_shape_range, input_format,is_dynamic_input); | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | |||||
| } | |||||
| TEST(UtestIrCommon, check_compress_weight) { | |||||
| std::string enable_compress_weight = "true"; | |||||
| std::string compress_weight_conf="./"; | |||||
| Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| enable_compress_weight = "yes"; | |||||
| compress_weight_conf = "./"; | |||||
| ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| } | |||||
| TEST(UtestIrCommon, check_param_failed) { | |||||
| std::string param_invalid = "invalid"; | |||||
| Status ret = CheckOutputTypeParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckBufferOptimizeParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckKeepTypeParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckInsertOpConfParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckDisableReuseMemoryParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckEnableSingleStreamParamValid(param_invalid); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| std::string optypelist_for_implmode; | |||||
| std::string op_select_implmode = "1"; | |||||
| ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| ret = CheckLogParamValidAndSetLogLevel(param_invalid); | |||||
| } | |||||