@@ -403,7 +403,7 @@ set(TRAIN_SRC_LIST | |||||
"ir_build/attr_options/utils.cc" | "ir_build/attr_options/utils.cc" | ||||
"ir_build/attr_options/keep_dtype_option.cc" | "ir_build/attr_options/keep_dtype_option.cc" | ||||
"ir_build/attr_options/weight_compress_option.cc" | "ir_build/attr_options/weight_compress_option.cc" | ||||
"ir_build/atc_ir_common.cc" | |||||
"ir_build/option_utils.cc" | |||||
"graph/build/memory/memory_assigner.cc" | "graph/build/memory/memory_assigner.cc" | ||||
"graph/build/memory/graph_mem_assigner.cc" | "graph/build/memory/graph_mem_assigner.cc" | ||||
"graph/build/memory/binary_block_mem_assigner.cc" | "graph/build/memory/binary_block_mem_assigner.cc" | ||||
@@ -664,7 +664,7 @@ set(INFER_SRC_LIST | |||||
"ir_build/attr_options/utils.cc" | "ir_build/attr_options/utils.cc" | ||||
"ir_build/attr_options/keep_dtype_option.cc" | "ir_build/attr_options/keep_dtype_option.cc" | ||||
"ir_build/attr_options/weight_compress_option.cc" | "ir_build/attr_options/weight_compress_option.cc" | ||||
"ir_build/atc_ir_common.cc" | |||||
"ir_build/option_utils.cc" | |||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
"hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
@@ -73,7 +73,7 @@ BUILER_SRC_FILES := \ | |||||
ir_build/attr_options/utils.cc \ | ir_build/attr_options/utils.cc \ | ||||
ir_build/attr_options/keep_dtype_option.cc \ | ir_build/attr_options/keep_dtype_option.cc \ | ||||
ir_build/attr_options/weight_compress_option.cc \ | ir_build/attr_options/weight_compress_option.cc \ | ||||
ir_build/atc_ir_common.cc \ | |||||
ir_build/option_utils.cc \ | |||||
ANALYZER_SRC_FILES:= \ | ANALYZER_SRC_FILES:= \ | ||||
analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
@@ -316,7 +316,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
ir_build/attr_options/utils.cc \ | ir_build/attr_options/utils.cc \ | ||||
ir_build/attr_options/keep_dtype_option.cc \ | ir_build/attr_options/keep_dtype_option.cc \ | ||||
ir_build/attr_options/weight_compress_option.cc \ | ir_build/attr_options/weight_compress_option.cc \ | ||||
ir_build/atc_ir_common.cc \ | |||||
ir_build/option_utils.cc \ | |||||
LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
@@ -101,7 +101,7 @@ | |||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
@@ -27,6 +27,7 @@ | |||||
#include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
#include "ir_build/option_utils.h" | |||||
#include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
#include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
@@ -991,101 +992,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
long StringToLongNoThrow(const string &str) { | |||||
try { | |||||
return std::stol(str); | |||||
} catch (const std::invalid_argument) { | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" | |||||
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
str.c_str()); | |||||
return PARAM_INVALID; | |||||
} catch (const std::out_of_range) { | |||||
GELOGE(PARAM_INVALID, | |||||
"Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:" | |||||
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
str.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
/** | |||||
* Parser shape_range from string to vector | |||||
* shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" | |||||
* @param shape_range | |||||
*/ | |||||
Status ParseDynamicInputShapeRange(const std::string &shape_range, | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | |||||
if (shape_range.size() < 2) { | |||||
REPORT_INNER_ERROR("E19999", "shape_range.size:%zu < 2, check invalid", shape_range.size()); | |||||
GELOGE(PARAM_INVALID, "Shape range %s is invalid.", shape_range.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
// different shape_range of single input are split by ']' | |||||
vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']'); | |||||
if (shape_range_set.empty()) { | |||||
REPORT_INNER_ERROR("E19999", "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
shape_range.c_str()); | |||||
GELOGE(PARAM_INVALID, "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
shape_range.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (auto &shape_range_str : shape_range_set) { | |||||
if (shape_range_str.size() < 3) { | |||||
// shape_range_str should be "[2~3,1" | |||||
// or ",[2~3,1". because we should trim '[' or ',[' | |||||
// so shape_range_str.size() < 3 is invalid | |||||
continue; | |||||
} | |||||
// trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||||
if (ge::StringUtils::StartWith(shape_range_str, "[")) { | |||||
shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | |||||
} | |||||
if (ge::StringUtils::StartWith(shape_range_str, ",")) { | |||||
shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||||
} | |||||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||||
std::vector<std::pair<int64_t, int64_t>> range_of_single_input; | |||||
vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ','); | |||||
for (const auto &range_pair_str : dim_range_set) { | |||||
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||||
pair<int64_t, int64_t> range_pair; | |||||
if (range_pair_set.size() == 1) { | |||||
// fix dim | |||||
auto range_value = StringToLongNoThrow(range_pair_set.at(0).c_str()); | |||||
if (range_value < 0) { | |||||
range_pair = std::make_pair(1, range_value); | |||||
} else { | |||||
range_pair = std::make_pair(range_value, range_value); | |||||
} | |||||
} else if (range_pair_set.size() == 2) { | |||||
// unknown dim, should get range. | |||||
auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str()); | |||||
auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str()); | |||||
if (range_left < 0 || range_right < 0) { | |||||
REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given range pair [%ld,%ld], " | |||||
"while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", range_left, range_right); | |||||
GELOGE(PARAM_INVALID, | |||||
"Shape range of input is invalid. Given range pair [%ld,%ld], while correct example: " | |||||
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
range_left, range_right); | |||||
return PARAM_INVALID; | |||||
} | |||||
range_pair = std::make_pair(range_left, range_right); | |||||
} else { | |||||
REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given %s, " | |||||
"while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", shape_range.c_str()); | |||||
GELOGE(PARAM_INVALID, | |||||
"Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
shape_range.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
range_of_single_input.emplace_back(range_pair); | |||||
} | |||||
range.emplace_back(range_of_single_input); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | ||||
vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | ||||
@@ -1114,9 +1020,10 @@ Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const | |||||
OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); | |||||
GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); | |||||
if (ParseInputShapeRange(iter->second, range_vec) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] Parse dynamic input shape range failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (range_vec.size() != user_input.size()) { | if (range_vec.size() != user_input.size()) { | ||||
GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), | GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), | ||||
user_input.size()); | user_input.size()); | ||||
@@ -112,7 +112,12 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), | HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), | ||||
"[Execute][GraphInternal] Dump exception info failed."); | "[Execute][GraphInternal] Dump exception info failed."); | ||||
} | } | ||||
GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); | |||||
if (ret == ge::END_OF_SEQUENCE) { | |||||
GELOGD("Got end of sequence"); | |||||
} else { | |||||
GELOGE(ret, "[Execute][GraphInternal] Synchronize failed."); | |||||
} | |||||
return ret; | |||||
} | } | ||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); | ||||
} | } | ||||
@@ -32,7 +32,7 @@ | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/ge_global_options.h" | #include "graph/ge_global_options.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
#include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
#include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
@@ -299,10 +299,19 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | ||||
return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); | return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); | ||||
} | } | ||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> name_shape_range_map; | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map; | |||||
if (!input_shape_range.empty()) { | if (!input_shape_range.empty()) { | ||||
GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), | |||||
return GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] failed."); | |||||
Status ret = GRAPH_PARAM_INVALID; | |||||
if (input_shape_range.find(":") != string::npos) { | |||||
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); | |||||
} else { | |||||
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); | |||||
} | |||||
if (ret != SUCCESS) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
} | } | ||||
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -315,10 +324,14 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); | GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||||
if (UpdateDataOpShapeRange(op, name_shape_range_map) != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | |||||
} | |||||
if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -13,7 +13,7 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "atc_ir_common.h" | |||||
#include "option_utils.h" | |||||
#include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "framework/common/string_util.h" | #include "framework/common/string_util.h" | ||||
@@ -22,6 +22,7 @@ | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
using std::pair; | using std::pair; | ||||
using std::string; | using std::string; | ||||
@@ -58,10 +59,12 @@ const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \ | |||||
const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | ||||
const char *const kKeepDtypeError = "file not found"; | const char *const kKeepDtypeError = "file not found"; | ||||
const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | ||||
const char *const kInputShapeRangeSizeInvalid = " shape range size less than 2 is invalid"; | |||||
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | ||||
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | ||||
const char *const kInputShapeRangeSample2 = "\"[1~20]\""; | const char *const kInputShapeRangeSample2 = "\"[1~20]\""; | ||||
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | ||||
const char *const kInputShapeRangeSample4 = "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\""; | |||||
vector<string> SplitInputShape(const std::string &input_shape) { | vector<string> SplitInputShape(const std::string &input_shape) { | ||||
vector<string> shape_pair_vec; | vector<string> shape_pair_vec; | ||||
@@ -72,6 +75,67 @@ vector<string> SplitInputShape(const std::string &input_shape) { | |||||
} | } | ||||
return shape_pair_vec; | return shape_pair_vec; | ||||
} | } | ||||
static bool StringToLongNoThrow(const string &str, long &val) { | |||||
try { | |||||
val = std::stol(str); | |||||
return true; | |||||
} catch (const std::invalid_argument) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} catch (const std::out_of_range) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3})); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} | |||||
return false; | |||||
} | |||||
static bool ParseShapeRangePair(const string &shape_range, | |||||
const vector<string> &range_pair_set, | |||||
std::pair<int64_t, int64_t> &range_pair) { | |||||
if (range_pair_set.size() == 1) { | |||||
long range_value = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||||
return false; | |||||
} | |||||
if (range_value < 0) { | |||||
range_pair = std::make_pair(1, range_value); | |||||
} else { | |||||
range_pair = std::make_pair(range_value, range_value); | |||||
} | |||||
} else if (range_pair_set.size() == kRangePairSize) { | |||||
// unknown dim, should get range. | |||||
long range_left = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||||
return false; | |||||
} | |||||
long range_right = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||||
return false; | |||||
} | |||||
if ((range_left < 0) || (range_right < 0)) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); | |||||
GELOGE(PARAM_INVALID, | |||||
"[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," | |||||
"reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
return false; | |||||
} | |||||
range_pair = std::make_pair(range_left, range_right); | |||||
} else { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3})); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace | } // namespace | ||||
Status CheckInputFormat(const string &input_format) { | Status CheckInputFormat(const string &input_format) { | ||||
@@ -287,24 +351,6 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||||
return true; | return true; | ||||
} | } | ||||
bool StringToLongNoThrow(const string &str, long &val) { | |||||
try { | |||||
val = std::stol(str); | |||||
return true; | |||||
} catch (const std::invalid_argument) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} catch (const std::out_of_range) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.", | |||||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||||
} | |||||
return false; | |||||
} | |||||
bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | ||||
vector<char> square_brackets; | vector<char> square_brackets; | ||||
for (auto ch : shape_range) { | for (auto ch : shape_range) { | ||||
@@ -331,41 +377,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_ | |||||
for (const auto &range_pair_str : dim_range_set) { | for (const auto &range_pair_str : dim_range_set) { | ||||
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | ||||
pair<int64_t, int64_t> range_pair; | pair<int64_t, int64_t> range_pair; | ||||
if (range_pair_set.size() == 1) { | |||||
long range_value = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||||
return false; | |||||
} | |||||
if (range_value < 0) { | |||||
range_pair = std::make_pair(1, range_value); | |||||
} else { | |||||
range_pair = std::make_pair(range_value, range_value); | |||||
} | |||||
} else if (range_pair_set.size() == kRangePairSize) { | |||||
// unknown dim, should get range. | |||||
long range_left = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||||
return false; | |||||
} | |||||
long range_right = 0; | |||||
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||||
return false; | |||||
} | |||||
if (range_left < 0 || (range_right < 0)) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, | |||||
"[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed," | |||||
"reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
return false; | |||||
} | |||||
range_pair = std::make_pair(range_left, range_right); | |||||
} else { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||||
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||||
if (!ParseShapeRangePair(shape_range, range_pair_set, range_pair)) { | |||||
GELOGE(PARAM_INVALID, "[Parse][RangePair] parse range pair failed."); | |||||
return false; | return false; | ||||
} | } | ||||
shape_range_vec.emplace_back(range_pair); | shape_range_vec.emplace_back(range_pair); | ||||
@@ -373,8 +386,13 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_ | |||||
return true; | return true; | ||||
} | } | ||||
bool ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||||
/** | |||||
* Parser shape_range from string to map | |||||
* shape_range from option normally is "input1:[1~20,3,3~6,-1];input2:[1~20,3,3~6,-1]" | |||||
* @param shape_range | |||||
*/ | |||||
Status ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||||
GELOGD("Input shape range %s", shape_range.c_str()); | GELOGD("Input shape range %s", shape_range.c_str()); | ||||
vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | ||||
@@ -386,25 +404,82 @@ bool ParseInputShapeRange(const std::string &shape_range, | |||||
{shape_range, kSplitError1, kInputShapeRangeSample1}); | {shape_range, kSplitError1, kInputShapeRangeSample1}); | ||||
GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", | GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", | ||||
shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | ||||
return false; | |||||
return PARAM_INVALID; | |||||
} | } | ||||
if (shape_range_pair_vec[1].empty()) { | if (shape_range_pair_vec[1].empty()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | ||||
{shape_range, kEmptyError, kInputShapeRangeSample1}); | {shape_range, kEmptyError, kInputShapeRangeSample1}); | ||||
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", | GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", | ||||
shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | ||||
return false; | |||||
return PARAM_INVALID; | |||||
} | } | ||||
string shape_range_str = shape_range_pair_vec[1]; | string shape_range_str = shape_range_pair_vec[1]; | ||||
vector<pair<int64_t, int64_t>> shape_range_val; | vector<pair<int64_t, int64_t>> shape_range_val; | ||||
if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | ||||
GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); | GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); | ||||
return false; | |||||
return PARAM_INVALID; | |||||
} | } | ||||
shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | ||||
} | } | ||||
return true; | |||||
return SUCCESS; | |||||
} | |||||
/** | |||||
* Parser shape_range from string to vector | |||||
* shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" | |||||
* @param shape_range | |||||
*/ | |||||
Status ParseInputShapeRange(const std::string &shape_range, | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) { | |||||
GELOGD("Input shape range %s", shape_range.c_str()); | |||||
if (shape_range.size() < 2) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4})); | |||||
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeSizeInvalid, kInputShapeRangeSample4); | |||||
return PARAM_INVALID; | |||||
} | |||||
// different shape_range of single input are split by ']' | |||||
vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']'); | |||||
if (shape_range_set.empty()) { | |||||
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}), | |||||
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample4})); | |||||
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.", | |||||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample4); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (auto &shape_range_str : shape_range_set) { | |||||
if (shape_range_str.size() < 3) { | |||||
// shape_range_str should be "[2~3,1" | |||||
// or ",[2~3,1". because we should trim '[' or ',[' | |||||
// so shape_range_str.size() < 3 is invalid | |||||
continue; | |||||
} | |||||
// trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||||
if (ge::StringUtils::StartWith(shape_range_str, "[")) { | |||||
shape_range_str = shape_range_str.substr(1, shape_range_str.size()); | |||||
} | |||||
if (ge::StringUtils::StartWith(shape_range_str, ",")) { | |||||
shape_range_str = shape_range_str.substr(2, shape_range_str.size()); | |||||
} | |||||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||||
std::vector<std::pair<int64_t, int64_t>> range_of_single_input; | |||||
vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ','); | |||||
for (const auto &range_pair_str : dim_range_set) { | |||||
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||||
pair<int64_t, int64_t> range_pair; | |||||
if (!ParseShapeRangePair(shape_range_str, range_pair_set, range_pair)) { | |||||
GELOGE(PARAM_INVALID, "[Parse][RangePair] Parse range pair failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
range_of_single_input.emplace_back(range_pair); | |||||
} | |||||
range.emplace_back(range_of_single_input); | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | ||||
@@ -420,11 +495,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||||
} | } | ||||
if (param_size == 0) { | if (param_size == 0) { | ||||
if (!input_shape_range.empty()) { | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
if (input_shape_range.find(":") != string::npos) { | |||||
if (!input_shape_range.empty()) { | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||||
if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { | |||||
GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
} | } | ||||
} | } | ||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
@@ -446,9 +523,11 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||||
} | } | ||||
if (!dynamic_batch_size.empty()) { | if (!dynamic_batch_size.empty()) { | ||||
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { | |||||
GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
if (input_shape_range.find(":") != string::npos) { | |||||
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { | |||||
GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -744,10 +823,10 @@ Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shap | |||||
} | } | ||||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | Status UpdateDataOpShapeRange(const OpDescPtr &op, | ||||
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | |||||
const map<string, vector<pair<int64_t, int64_t>>> &name_shape_range_map) { | |||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
if (shape_range_map.empty()) { | |||||
GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); | |||||
if (name_shape_range_map.empty()) { | |||||
GELOGI("Shape range name map of data op [%s] is empty.", op->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -757,8 +836,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
GE_CHECK_NOTNULL(tensor_output); | GE_CHECK_NOTNULL(tensor_output); | ||||
string data_op_name = op->GetName(); | string data_op_name = op->GetName(); | ||||
auto origin_shape = tensor_input->GetShape(); | auto origin_shape = tensor_input->GetShape(); | ||||
auto iter = shape_range_map.find(data_op_name); | |||||
if (iter != shape_range_map.end()) { | |||||
auto iter = name_shape_range_map.find(data_op_name); | |||||
if (iter != name_shape_range_map.end()) { | |||||
auto cur_shape_range = iter->second; | auto cur_shape_range = iter->second; | ||||
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | ||||
GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | ||||
@@ -783,6 +862,56 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||||
const vector<vector<pair<int64_t, int64_t>>> &index_shape_range_map) { | |||||
GE_CHECK_NOTNULL(op); | |||||
if (index_shape_range_map.empty()) { | |||||
GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
GeAttrValue::INT index = 0; | |||||
if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { | |||||
GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
if ((index < 0) || (static_cast<size_t>(index) >= index_shape_range_map.size())) { | |||||
std::string situation = "data op index[" + std::to_string(index) + "]"; | |||||
std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]"; | |||||
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}), | |||||
std::vector<std::string>({situation, reason})); | |||||
GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", index_shape_range_map.size(), index); | |||||
return FAILED; | |||||
} | |||||
auto tensor_input = op->MutableInputDesc(0); | |||||
auto tensor_output = op->MutableOutputDesc(0); | |||||
GE_CHECK_NOTNULL(tensor_input); | |||||
GE_CHECK_NOTNULL(tensor_output); | |||||
string data_op_name = op->GetName(); | |||||
auto origin_shape = tensor_input->GetShape(); | |||||
auto cur_shape_range = index_shape_range_map[index]; | |||||
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (size_t idx = 0; idx < cur_shape_range.size(); ++idx) { | |||||
auto left_range = cur_shape_range[idx].first; | |||||
auto right_range = cur_shape_range[idx].second; | |||||
if (left_range != right_range) { | |||||
origin_shape.SetDim(idx, UNKNOWN_DIM); | |||||
} | |||||
} | |||||
tensor_input->SetShape(origin_shape); | |||||
tensor_input->SetShapeRange(cur_shape_range); | |||||
tensor_output->SetShape(origin_shape); | |||||
tensor_output->SetShapeRange(cur_shape_range); | |||||
GELOGI("Update input [%s] shape range info success.", data_op_name.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, | static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, | ||||
const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | ||||
for (const auto &it : shape_range_map) { | for (const auto &it : shape_range_map) { | ||||
@@ -813,7 +942,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co | |||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | ||||
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||||
if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); | GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } |
@@ -64,8 +64,10 @@ Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string | |||||
bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | ||||
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | ||||
bool ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
Status ParseInputShapeRange(const std::string &shape_range, | |||||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
Status ParseInputShapeRange(const std::string &shape_range, | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range); | |||||
Status CheckOutputTypeParamValid(const std::string output_type); | Status CheckOutputTypeParamValid(const std::string output_type); | ||||
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | ||||
@@ -80,8 +82,10 @@ Status CheckKeepTypeParamValid(const std::string &keep_dtype); | |||||
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | ||||
void EraseEndSemicolon(std::string ¶m); | void EraseEndSemicolon(std::string ¶m); | ||||
Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | ||||
Status UpdateDataOpShapeRange( | |||||
const OpDescPtr &op, const std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map); | |||||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | Status UpdateDataOpShapeRange(const OpDescPtr &op, | ||||
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||||
const std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map); | |||||
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | ||||
} | } | ||||
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ | #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ |
@@ -11,7 +11,7 @@ set(SRC_LIST | |||||
"main.cc" | "main.cc" | ||||
"single_op_parser.cc" | "single_op_parser.cc" | ||||
"../session/omg.cc" | "../session/omg.cc" | ||||
"../ir_build/atc_ir_common.cc" | |||||
"../ir_build/option_utils.cc" | |||||
) | ) | ||||
############ atc_atc.bin ############ | ############ atc_atc.bin ############ | ||||
@@ -36,7 +36,7 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "omg/omg.h" | #include "omg/omg.h" | ||||
#include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
#include "omg/parser/parser_inner_ctx.h" | #include "omg/parser/parser_inner_ctx.h" | ||||
@@ -12,7 +12,7 @@ LOCAL_SRC_FILES := \ | |||||
main.cc \ | main.cc \ | ||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | |||||
../ir_build/option_utils.cc \ | |||||
LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
$(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
@@ -65,7 +65,7 @@ LOCAL_SRC_FILES := \ | |||||
main.cc \ | main.cc \ | ||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | |||||
../ir_build/option_utils.cc \ | |||||
LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
$(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
@@ -118,7 +118,7 @@ LOCAL_SRC_FILES := \ | |||||
main.cc \ | main.cc \ | ||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | |||||
../ir_build/option_utils.cc \ | |||||
LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
$(LOCAL_PATH)/../ ./ \ | $(LOCAL_PATH)/../ ./ \ | ||||
@@ -38,7 +38,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
#include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
#include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
@@ -308,7 +308,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | ||||
"${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" | ||||
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/ir_build/atc_ir_common.cc" | |||||
"${GE_CODE_DIR}/ge/ir_build/option_utils.cc" | |||||
"${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" | "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" | "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" | ||||
"${GE_CODE_DIR}/ge/graph/build/model_builder.cc" | "${GE_CODE_DIR}/ge/graph/build/model_builder.cc" | ||||
@@ -109,7 +109,7 @@ | |||||
#include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
@@ -15,8 +15,9 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "ir_build/atc_ir_common.h" | |||||
#include "ir_build/option_utils.h" | |||||
#include "graph/testcase/ge_graph/graph_builder_utils.h" | #include "graph/testcase/ge_graph/graph_builder_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
@@ -68,6 +69,20 @@ TEST(UtestIrCommon, update_data_op_shape) { | |||||
EXPECT_EQ(ret, ge::SUCCESS); | EXPECT_EQ(ret, ge::SUCCESS); | ||||
} | } | ||||
TEST(UtestIrCommon, update_data_op_shape_range) { | |||||
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); | |||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map; | |||||
std::pair<int64_t, int64_t> range_pair(1, 2); | |||||
vector<pair<int64_t, int64_t>> range_pair_tmp = { range_pair }; | |||||
index_shape_range_map.push_back(range_pair_tmp); | |||||
AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0); | |||||
Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST(UtestIrCommon, update_dynamic_shape_range_success) { | TEST(UtestIrCommon, update_dynamic_shape_range_success) { | ||||
ComputeGraphPtr graph = BuildComputeGraph(); | ComputeGraphPtr graph = BuildComputeGraph(); | ||||
std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | ||||
@@ -117,4 +132,84 @@ TEST(UtestIrCommon, check_dynamic_image_size_fail) { | |||||
bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | ||||
EXPECT_EQ(ret, false); | EXPECT_EQ(ret, false); | ||||
} | |||||
} | |||||
TEST(UtestIrCommon, check_input_format_failed) { | |||||
std::string format = "invalid"; | |||||
Status ret = CheckInputFormat(format); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} | |||||
TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) { | |||||
map<string, vector<int64_t>> shape_map; | |||||
shape_map.insert(std::pair<string, vector<int64_t>>("data", {-1, 2, 3})); | |||||
std::string dynamic_batch_size = "11"; | |||||
bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size); | |||||
EXPECT_EQ(ret, true); | |||||
} | |||||
TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) { | |||||
map<string, vector<int64_t>> shape_map; | |||||
shape_map.insert(std::pair<string, vector<int64_t>>("data", {4, -1, -1, 5})); | |||||
std::string input_format = "NCHW"; | |||||
std::string dynamic_image_size = "4,5"; | |||||
Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST(UtestIrCommon, check_dynamic_input_param_succ) { | |||||
string dynamic_batch_size = "1"; | |||||
string dynamic_image_size; | |||||
string dynamic_dims; | |||||
string input_shape = "data:1,3,244,244"; | |||||
string input_shape_range; | |||||
string input_format = "NCHW"; | |||||
bool is_dynamic_input = false; | |||||
Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, | |||||
input_shape, input_shape_range, input_format,is_dynamic_input); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST(UtestIrCommon, check_compress_weight) { | |||||
std::string enable_compress_weight = "true"; | |||||
std::string compress_weight_conf="./"; | |||||
Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
enable_compress_weight = "yes"; | |||||
compress_weight_conf = "./"; | |||||
ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
} | |||||
TEST(UtestIrCommon, check_param_failed) { | |||||
std::string param_invalid = "invalid"; | |||||
Status ret = CheckOutputTypeParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckBufferOptimizeParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckKeepTypeParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckInsertOpConfParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckDisableReuseMemoryParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckEnableSingleStreamParamValid(param_invalid); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
std::string optypelist_for_implmode; | |||||
std::string op_select_implmode = "1"; | |||||
ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
ret = CheckLogParamValidAndSetLogLevel(param_invalid); | |||||
} |