Browse Source

ir parse input shape reange

tags/v1.3.0
zhengyuanhua 3 years ago
parent
commit
270de8ae3f
16 changed files with 350 additions and 197 deletions
  1. +2
    -2
      ge/CMakeLists.txt
  2. +1
    -1
      ge/ge_inference.mk
  3. +1
    -1
      ge/ge_runner.mk
  4. +1
    -1
      ge/graph/manager/graph_manager.cc
  5. +5
    -98
      ge/graph/preprocess/graph_preprocess.cc
  6. +6
    -1
      ge/hybrid/executor/hybrid_model_executor.cc
  7. +19
    -6
      ge/ir_build/ge_ir_build.cc
  8. +203
    -74
      ge/ir_build/option_utils.cc
  9. +7
    -3
      ge/ir_build/option_utils.h
  10. +1
    -1
      ge/offline/CMakeLists.txt
  11. +1
    -1
      ge/offline/main.cc
  12. +3
    -3
      ge/offline/module.mk
  13. +1
    -1
      ge/session/omg.cc
  14. +1
    -1
      tests/ut/ge/CMakeLists.txt
  15. +1
    -1
      tests/ut/ge/graph/manager/graph_manager_unittest.cc
  16. +97
    -2
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 2
- 2
ge/CMakeLists.txt View File

@@ -403,7 +403,7 @@ set(TRAIN_SRC_LIST
"ir_build/attr_options/utils.cc" "ir_build/attr_options/utils.cc"
"ir_build/attr_options/keep_dtype_option.cc" "ir_build/attr_options/keep_dtype_option.cc"
"ir_build/attr_options/weight_compress_option.cc" "ir_build/attr_options/weight_compress_option.cc"
"ir_build/atc_ir_common.cc"
"ir_build/option_utils.cc"
"graph/build/memory/memory_assigner.cc" "graph/build/memory/memory_assigner.cc"
"graph/build/memory/graph_mem_assigner.cc" "graph/build/memory/graph_mem_assigner.cc"
"graph/build/memory/binary_block_mem_assigner.cc" "graph/build/memory/binary_block_mem_assigner.cc"
@@ -664,7 +664,7 @@ set(INFER_SRC_LIST
"ir_build/attr_options/utils.cc" "ir_build/attr_options/utils.cc"
"ir_build/attr_options/keep_dtype_option.cc" "ir_build/attr_options/keep_dtype_option.cc"
"ir_build/attr_options/weight_compress_option.cc" "ir_build/attr_options/weight_compress_option.cc"
"ir_build/atc_ir_common.cc"
"ir_build/option_utils.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/ge_aipp_op.cc"
"graph/preprocess/insert_op/util_insert_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc"
"hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_ext_info.cc"


+ 1
- 1
ge/ge_inference.mk View File

@@ -73,7 +73,7 @@ BUILER_SRC_FILES := \
ir_build/attr_options/utils.cc \ ir_build/attr_options/utils.cc \
ir_build/attr_options/keep_dtype_option.cc \ ir_build/attr_options/keep_dtype_option.cc \
ir_build/attr_options/weight_compress_option.cc \ ir_build/attr_options/weight_compress_option.cc \
ir_build/atc_ir_common.cc \
ir_build/option_utils.cc \


ANALYZER_SRC_FILES:= \ ANALYZER_SRC_FILES:= \
analyzer/analyzer.cc \ analyzer/analyzer.cc \


+ 1
- 1
ge/ge_runner.mk View File

@@ -316,7 +316,7 @@ LIBGE_LOCAL_SRC_FILES := \
ir_build/attr_options/utils.cc \ ir_build/attr_options/utils.cc \
ir_build/attr_options/keep_dtype_option.cc \ ir_build/attr_options/keep_dtype_option.cc \
ir_build/attr_options/weight_compress_option.cc \ ir_build/attr_options/weight_compress_option.cc \
ir_build/atc_ir_common.cc \
ir_build/option_utils.cc \


LIBCLIENT_LOCAL_SRC_FILES := \ LIBCLIENT_LOCAL_SRC_FILES := \
proto/ge_api.proto \ proto/ge_api.proto \


+ 1
- 1
ge/graph/manager/graph_manager.cc View File

@@ -101,7 +101,7 @@
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h" #include "inc/pass_manager.h"
#include "init/gelib.h" #include "init/gelib.h"
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"


+ 5
- 98
ge/graph/preprocess/graph_preprocess.cc View File

@@ -27,6 +27,7 @@
#include "common/helper/model_helper.h" #include "common/helper/model_helper.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "common/op/ge_op_utils.h" #include "common/op/ge_op_utils.h"
#include "ir_build/option_utils.h"
#include "graph/common/ge_call_wrapper.h" #include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
#include "graph/common/transop_util.h" #include "graph/common/transop_util.h"
@@ -991,101 +992,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) {
} }
return SUCCESS; return SUCCESS;
} }
long StringToLongNoThrow(const string &str) {
try {
return std::stol(str);
} catch (const std::invalid_argument) {
GELOGE(PARAM_INVALID,
"Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:"
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
str.c_str());
return PARAM_INVALID;
} catch (const std::out_of_range) {
GELOGE(PARAM_INVALID,
"Parse shape range of input failed when transfer from string to int64. Given %s, while correct example:"
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
str.c_str());
return PARAM_INVALID;
}
}
/**
* Parser shape_range from string to vector
* shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]"
* @param shape_range
*/
Status ParseDynamicInputShapeRange(const std::string &shape_range,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) {
if (shape_range.size() < 2) {
REPORT_INNER_ERROR("E19999", "shape_range.size:%zu < 2, check invalid", shape_range.size());
GELOGE(PARAM_INVALID, "Shape range %s is invalid.", shape_range.c_str());
return PARAM_INVALID;
}
// different shape_range of single input are split by ']'
vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']');
if (shape_range_set.empty()) {
REPORT_INNER_ERROR("E19999", "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
shape_range.c_str());
GELOGE(PARAM_INVALID, "Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
shape_range.c_str());
return PARAM_INVALID;
}
for (auto &shape_range_str : shape_range_set) {
if (shape_range_str.size() < 3) {
// shape_range_str should be "[2~3,1"
// or ",[2~3,1". because we should trim '[' or ',['
// so shape_range_str.size() < 3 is invalid
continue;
}
// trim start bytes, after that, single input should be "1~20,3,3~6,-1"
if (ge::StringUtils::StartWith(shape_range_str, "[")) {
shape_range_str = shape_range_str.substr(1, shape_range_str.size());
}
if (ge::StringUtils::StartWith(shape_range_str, ",")) {
shape_range_str = shape_range_str.substr(2, shape_range_str.size());
}

// parse shape_range of single input. eg. "1~20,3,3~6,-1"
std::vector<std::pair<int64_t, int64_t>> range_of_single_input;
vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ',');
for (const auto &range_pair_str : dim_range_set) {
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~');
pair<int64_t, int64_t> range_pair;
if (range_pair_set.size() == 1) {
// fix dim
auto range_value = StringToLongNoThrow(range_pair_set.at(0).c_str());
if (range_value < 0) {
range_pair = std::make_pair(1, range_value);
} else {
range_pair = std::make_pair(range_value, range_value);
}
} else if (range_pair_set.size() == 2) {
// unknown dim, should get range.
auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str());
auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str());
if (range_left < 0 || range_right < 0) {
REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given range pair [%ld,%ld], "
"while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", range_left, range_right);
GELOGE(PARAM_INVALID,
"Shape range of input is invalid. Given range pair [%ld,%ld], while correct example: "
"\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
range_left, range_right);
return PARAM_INVALID;
}
range_pair = std::make_pair(range_left, range_right);
} else {
REPORT_INNER_ERROR("E19999", "Shape range of input is invalid. Given %s, "
"while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", shape_range.c_str());
GELOGE(PARAM_INVALID,
"Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"",
shape_range.c_str());
return PARAM_INVALID;
}
range_of_single_input.emplace_back(range_pair);
}
range.emplace_back(range_of_single_input);
}
return SUCCESS;
}


Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option,
vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { vector<vector<std::pair<int64_t, int64_t>>> &range_vec) {
@@ -1114,9 +1020,10 @@ Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const
OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE);
return PARAM_INVALID; return PARAM_INVALID;
} }

auto ret = ParseDynamicInputShapeRange(iter->second, range_vec);
GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed.");
if (ParseInputShapeRange(iter->second, range_vec) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] Parse dynamic input shape range failed.");
return PARAM_INVALID;
}
if (range_vec.size() != user_input.size()) { if (range_vec.size() != user_input.size()) {
GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(),
user_input.size()); user_input.size());


+ 6
- 1
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -112,7 +112,12 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos), HYBRID_CHK_STATUS_RET(context_.DumpExceptionInfo(exception_infos),
"[Execute][GraphInternal] Dump exception info failed."); "[Execute][GraphInternal] Dump exception info failed.");
} }
GELOGE(ret, "[Execute][GraphInternal] Synchronize failed.");
if (ret == ge::END_OF_SEQUENCE) {
GELOGD("Got end of sequence");
} else {
GELOGE(ret, "[Execute][GraphInternal] Synchronize failed.");
}
return ret;
} }
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End");
} }


+ 19
- 6
ge/ir_build/ge_ir_build.cc View File

@@ -32,7 +32,7 @@
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "graph/ge_global_options.h" #include "graph/ge_global_options.h"
#include "init/gelib.h" #include "init/gelib.h"
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "model/ge_model.h" #include "model/ge_model.h"
#include "graph/shape_refiner.h" #include "graph/shape_refiner.h"
#include "graph/opsproto_manager.h" #include "graph/opsproto_manager.h"
@@ -299,10 +299,19 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true),
return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!"); return GRAPH_PARAM_INVALID, "[Parse][InputShape] failed!");
} }
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map;
std::map<string, std::vector<std::pair<int64_t, int64_t>>> name_shape_range_map;
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map;
if (!input_shape_range.empty()) { if (!input_shape_range.empty()) {
GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map),
return GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] failed.");
Status ret = GRAPH_PARAM_INVALID;
if (input_shape_range.find(":") != string::npos) {
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map);
} else {
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map);
}
if (ret != SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str());
return GRAPH_PARAM_INVALID;
}
} }
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);
@@ -315,10 +324,14 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Update][DataOpShape] fail for op:%s.", op->GetName().c_str());
return GRAPH_FAILED; return GRAPH_FAILED;
} }
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) {
if (UpdateDataOpShapeRange(op, name_shape_range_map) != SUCCESS) {
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str());
return GRAPH_FAILED; return GRAPH_FAILED;
}
}
if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) {
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str());
return GRAPH_FAILED;
}
} }
} }




ge/ir_build/atc_ir_common.cc → ge/ir_build/option_utils.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "atc_ir_common.h"
#include "option_utils.h"
#include "common/util/error_manager/error_manager.h" #include "common/util/error_manager/error_manager.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
#include "framework/common/string_util.h" #include "framework/common/string_util.h"
@@ -22,6 +22,7 @@
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "graph/debug/ge_attr_define.h"


using std::pair; using std::pair;
using std::string; using std::string;
@@ -58,10 +59,12 @@ const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \
const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\"";
const char *const kKeepDtypeError = "file not found"; const char *const kKeepDtypeError = "file not found";
const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; const char *const kInputShapeRangeInvalid = "format of shape range is invalid";
const char *const kInputShapeRangeSizeInvalid = " shape range size less than 2 is invalid";
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; const char *const kShapeRangeValueConvertError = "transfer from string to int64 error";
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\"";
const char *const kInputShapeRangeSample2 = "\"[1~20]\""; const char *const kInputShapeRangeSample2 = "\"[1~20]\"";
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\"";
const char *const kInputShapeRangeSample4 = "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"";


vector<string> SplitInputShape(const std::string &input_shape) { vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec; vector<string> shape_pair_vec;
@@ -72,6 +75,67 @@ vector<string> SplitInputShape(const std::string &input_shape) {
} }
return shape_pair_vec; return shape_pair_vec;
} }

static bool StringToLongNoThrow(const string &str, long &val) {
try {
val = std::stol(str);
return true;
} catch (const std::invalid_argument) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3}));
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
} catch (const std::out_of_range) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<string>({str, kShapeRangeValueConvertError, kInputShapeRangeSample3}));
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s to long failed, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
}
return false;
}

static bool ParseShapeRangePair(const string &shape_range,
const vector<string> &range_pair_set,
std::pair<int64_t, int64_t> &range_pair) {
if (range_pair_set.size() == 1) {
long range_value = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) {
return false;
}
if (range_value < 0) {
range_pair = std::make_pair(1, range_value);
} else {
range_pair = std::make_pair(range_value, range_value);
}
} else if (range_pair_set.size() == kRangePairSize) {
// unknown dim, should get range.
long range_left = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) {
return false;
}
long range_right = 0;
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) {
return false;
}
if ((range_left < 0) || (range_right < 0)) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}));
GELOGE(PARAM_INVALID,
"[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed,"
"reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
return false;
}
range_pair = std::make_pair(range_left, range_right);
} else {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}));
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
return false;
}
return true;
}
} // namespace } // namespace


Status CheckInputFormat(const string &input_format) { Status CheckInputFormat(const string &input_format) {
@@ -287,24 +351,6 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims
return true; return true;
} }


bool StringToLongNoThrow(const string &str, long &val) {
try {
val = std::stol(str);
return true;
} catch (const std::invalid_argument) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
} catch (const std::out_of_range) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID, "[Parse][Parameter] str:%s invalid, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
}
return false;
}

bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) {
vector<char> square_brackets; vector<char> square_brackets;
for (auto ch : shape_range) { for (auto ch : shape_range) {
@@ -331,41 +377,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_
for (const auto &range_pair_str : dim_range_set) { for (const auto &range_pair_str : dim_range_set) {
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~');
pair<int64_t, int64_t> range_pair; pair<int64_t, int64_t> range_pair;
if (range_pair_set.size() == 1) {
long range_value = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) {
return false;
}
if (range_value < 0) {
range_pair = std::make_pair(1, range_value);
} else {
range_pair = std::make_pair(range_value, range_value);
}
} else if (range_pair_set.size() == kRangePairSize) {
// unknown dim, should get range.
long range_left = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) {
return false;
}
long range_right = 0;
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) {
return false;
}
if (range_left < 0 || (range_right < 0)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID,
"[Parse][InputParameter] [--input_shape_range]'s shape range[%s] failed,"
"reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
return false;
}
range_pair = std::make_pair(range_left, range_right);
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
if (!ParseShapeRangePair(shape_range, range_pair_set, range_pair)) {
GELOGE(PARAM_INVALID, "[Parse][RangePair] parse range pair failed.");
return false; return false;
} }
shape_range_vec.emplace_back(range_pair); shape_range_vec.emplace_back(range_pair);
@@ -373,8 +386,13 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_
return true; return true;
} }


bool ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) {
/**
* Parser shape_range from string to map
* shape_range from option normally is "input1:[1~20,3,3~6,-1];input2:[1~20,3,3~6,-1]"
* @param shape_range
*/
Status ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) {
GELOGD("Input shape range %s", shape_range.c_str()); GELOGD("Input shape range %s", shape_range.c_str());


vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); vector<string> shape_range_vec = StringUtils::Split(shape_range, ';');
@@ -386,25 +404,82 @@ bool ParseInputShapeRange(const std::string &shape_range,
{shape_range, kSplitError1, kInputShapeRangeSample1}); {shape_range, kSplitError1, kInputShapeRangeSample1});
GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.", GELOGE(PARAM_INVALID, "[Parse][Parameter]--input shape_range:%s invalid, reason: %s, correct sample is %s.",
shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); shape_range.c_str(), kSplitError1, kInputShapeRangeSample1);
return false;
return PARAM_INVALID;
} }
if (shape_range_pair_vec[1].empty()) { if (shape_range_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"},
{shape_range, kEmptyError, kInputShapeRangeSample1}); {shape_range, kEmptyError, kInputShapeRangeSample1});
GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.", GELOGE(PARAM_INVALID, "[Parse][Parameter]shape_range:%s invalid,reason: %s, correct sample is %s.",
shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); shape_range.c_str(), kEmptyError, kInputShapeRangeSample1);
return false;
return PARAM_INVALID;
} }


string shape_range_str = shape_range_pair_vec[1]; string shape_range_str = shape_range_pair_vec[1];
vector<pair<int64_t, int64_t>> shape_range_val; vector<pair<int64_t, int64_t>> shape_range_val;
if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) {
GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str()); GELOGE(PARAM_INVALID, "[Parse][Parameter] shape_range_str: %s invalid.", shape_range_str.c_str());
return false;
return PARAM_INVALID;
} }
shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val));
} }
return true;
return SUCCESS;
}

/**
* Parser shape_range from string to vector
* shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]"
* @param shape_range
*/
Status ParseInputShapeRange(const std::string &shape_range,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) {
GELOGD("Input shape range %s", shape_range.c_str());

if (shape_range.size() < 2) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<std::string>({shape_range, kInputShapeRangeSizeInvalid, kInputShapeRangeSample4}));
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeSizeInvalid, kInputShapeRangeSample4);
return PARAM_INVALID;
}
// different shape_range of single input are split by ']'
vector<string> shape_range_set = ge::StringUtils::Split(shape_range, ']');
if (shape_range_set.empty()) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
std::vector<string>({shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample4}));
GELOGE(PARAM_INVALID, "[Parse][ShapeRange] str:%s invalid, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample4);
return PARAM_INVALID;
}
for (auto &shape_range_str : shape_range_set) {
if (shape_range_str.size() < 3) {
// shape_range_str should be "[2~3,1"
// or ",[2~3,1". because we should trim '[' or ',['
// so shape_range_str.size() < 3 is invalid
continue;
}
// trim start bytes, after that, single input should be "1~20,3,3~6,-1"
if (ge::StringUtils::StartWith(shape_range_str, "[")) {
shape_range_str = shape_range_str.substr(1, shape_range_str.size());
}
if (ge::StringUtils::StartWith(shape_range_str, ",")) {
shape_range_str = shape_range_str.substr(2, shape_range_str.size());
}

// parse shape_range of single input. eg. "1~20,3,3~6,-1"
std::vector<std::pair<int64_t, int64_t>> range_of_single_input;
vector<string> dim_range_set = ge::StringUtils::Split(shape_range_str, ',');
for (const auto &range_pair_str : dim_range_set) {
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~');
pair<int64_t, int64_t> range_pair;
if (!ParseShapeRangePair(shape_range_str, range_pair_set, range_pair)) {
GELOGE(PARAM_INVALID, "[Parse][RangePair] Parse range pair failed.");
return PARAM_INVALID;
}
range_of_single_input.emplace_back(range_pair);
}
range.emplace_back(range_of_single_input);
}
return SUCCESS;
} }


Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims,
@@ -420,11 +495,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i
} }


if (param_size == 0) { if (param_size == 0) {
if (!input_shape_range.empty()) {
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map;
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) {
GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str());
return ge::PARAM_INVALID;
if (input_shape_range.find(":") != string::npos) {
if (!input_shape_range.empty()) {
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map;
if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) {
GELOGE(ge::PARAM_INVALID, "[Parse][InputShapeRange] failed, range: %s", input_shape_range.c_str());
return ge::PARAM_INVALID;
}
} }
} }
return ge::SUCCESS; return ge::SUCCESS;
@@ -446,9 +523,11 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i
} }


if (!dynamic_batch_size.empty()) { if (!dynamic_batch_size.empty()) {
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) {
GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str());
return ge::PARAM_INVALID;
if (input_shape_range.find(":") != string::npos) {
if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) {
GELOGE(ge::PARAM_INVALID, "[Check][DynamicBatchSizeInputShape] input_shape: %s invalid.", input_shape.c_str());
return ge::PARAM_INVALID;
}
} }
} }


@@ -744,10 +823,10 @@ Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shap
} }


Status UpdateDataOpShapeRange(const OpDescPtr &op, Status UpdateDataOpShapeRange(const OpDescPtr &op,
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
const map<string, vector<pair<int64_t, int64_t>>> &name_shape_range_map) {
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);
if (shape_range_map.empty()) {
GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str());
if (name_shape_range_map.empty()) {
GELOGI("Shape range name map of data op [%s] is empty.", op->GetName().c_str());
return SUCCESS; return SUCCESS;
} }


@@ -757,8 +836,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
GE_CHECK_NOTNULL(tensor_output); GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName(); string data_op_name = op->GetName();
auto origin_shape = tensor_input->GetShape(); auto origin_shape = tensor_input->GetShape();
auto iter = shape_range_map.find(data_op_name);
if (iter != shape_range_map.end()) {
auto iter = name_shape_range_map.find(data_op_name);
if (iter != name_shape_range_map.end()) {
auto cur_shape_range = iter->second; auto cur_shape_range = iter->second;
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str()); GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str());
@@ -783,6 +862,56 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
return SUCCESS; return SUCCESS;
} }


Status UpdateDataOpShapeRange(const OpDescPtr &op,
const vector<vector<pair<int64_t, int64_t>>> &index_shape_range_map) {
GE_CHECK_NOTNULL(op);
if (index_shape_range_map.empty()) {
GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str());
return SUCCESS;
}

GeAttrValue::INT index = 0;
if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) {
GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str());
return SUCCESS;
}

if ((index < 0) || (static_cast<size_t>(index) >= index_shape_range_map.size())) {
std::string situation = "data op index[" + std::to_string(index) + "]";
std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}),
std::vector<std::string>({situation, reason}));
GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", index_shape_range_map.size(), index);
return FAILED;
}

auto tensor_input = op->MutableInputDesc(0);
auto tensor_output = op->MutableOutputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName();
auto origin_shape = tensor_input->GetShape();
auto cur_shape_range = index_shape_range_map[index];
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Check][OpDescPtr] Check shape by shape range failed for op:%s.", data_op_name.c_str());
return PARAM_INVALID;
}
for (size_t idx = 0; idx < cur_shape_range.size(); ++idx) {
auto left_range = cur_shape_range[idx].first;
auto right_range = cur_shape_range[idx].second;
if (left_range != right_range) {
origin_shape.SetDim(idx, UNKNOWN_DIM);
}
}
tensor_input->SetShape(origin_shape);
tensor_input->SetShapeRange(cur_shape_range);
tensor_output->SetShape(origin_shape);
tensor_output->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info success.", data_op_name.c_str());

return SUCCESS;
}

static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph,
const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
for (const auto &it : shape_range_map) { for (const auto &it : shape_range_map) {
@@ -813,7 +942,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);


map<string, vector<pair<int64_t, int64_t>>> shape_range_map; map<string, vector<pair<int64_t, int64_t>>> shape_range_map;
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) {
if (ParseInputShapeRange(input_shape_range, shape_range_map) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str()); GELOGE(PARAM_INVALID, "[Parse][InputShapeRange] input_shape_range:%s invalid.", input_shape_range.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }

ge/ir_build/atc_ir_common.h → ge/ir_build/option_utils.h View File

@@ -64,8 +64,10 @@ Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string


bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map,
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false);
bool ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map);
Status ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map);
Status ParseInputShapeRange(const std::string &shape_range,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range);


Status CheckOutputTypeParamValid(const std::string output_type); Status CheckOutputTypeParamValid(const std::string output_type);
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); Status CheckBufferOptimizeParamValid(const std::string buffer_optimize);
@@ -80,8 +82,10 @@ Status CheckKeepTypeParamValid(const std::string &keep_dtype);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param); void EraseEndSemicolon(std::string &param);
Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map);
Status UpdateDataOpShapeRange(
const OpDescPtr &op, const std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map);
Status UpdateDataOpShapeRange(const OpDescPtr &op, Status UpdateDataOpShapeRange(const OpDescPtr &op,
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map);
const std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map);
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range);
} }
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_

+ 1
- 1
ge/offline/CMakeLists.txt View File

@@ -11,7 +11,7 @@ set(SRC_LIST
"main.cc" "main.cc"
"single_op_parser.cc" "single_op_parser.cc"
"../session/omg.cc" "../session/omg.cc"
"../ir_build/atc_ir_common.cc"
"../ir_build/option_utils.cc"
) )


############ atc_atc.bin ############ ############ atc_atc.bin ############


+ 1
- 1
ge/offline/main.cc View File

@@ -36,7 +36,7 @@
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "init/gelib.h" #include "init/gelib.h"
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "omg/omg.h" #include "omg/omg.h"
#include "omg/parser/parser_factory.h" #include "omg/parser/parser_factory.h"
#include "omg/parser/parser_inner_ctx.h" #include "omg/parser/parser_inner_ctx.h"


+ 3
- 3
ge/offline/module.mk View File

@@ -12,7 +12,7 @@ LOCAL_SRC_FILES := \
main.cc \ main.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \
../ir_build/option_utils.cc \


LOCAL_C_INCLUDES := \ LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \ $(LOCAL_PATH)/../ ./ \
@@ -65,7 +65,7 @@ LOCAL_SRC_FILES := \
main.cc \ main.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \
../ir_build/option_utils.cc \


LOCAL_C_INCLUDES := \ LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \ $(LOCAL_PATH)/../ ./ \
@@ -118,7 +118,7 @@ LOCAL_SRC_FILES := \
main.cc \ main.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \
../ir_build/option_utils.cc \


LOCAL_C_INCLUDES := \ LOCAL_C_INCLUDES := \
$(LOCAL_PATH)/../ ./ \ $(LOCAL_PATH)/../ ./ \


+ 1
- 1
ge/session/omg.cc View File

@@ -38,7 +38,7 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/optimize/common/params.h" #include "graph/optimize/common/params.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "omg/omg_inner_types.h" #include "omg/omg_inner_types.h"
#include "omg/parser/model_parser.h" #include "omg/parser/model_parser.h"
#include "omg/parser/parser_factory.h" #include "omg/parser/parser_factory.h"


+ 1
- 1
tests/ut/ge/CMakeLists.txt View File

@@ -308,7 +308,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc"
"${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc"
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc"
"${GE_CODE_DIR}/ge/ir_build/atc_ir_common.cc"
"${GE_CODE_DIR}/ge/ir_build/option_utils.cc"
"${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc"
"${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc"
"${GE_CODE_DIR}/ge/graph/build/model_builder.cc" "${GE_CODE_DIR}/ge/graph/build/model_builder.cc"


+ 1
- 1
tests/ut/ge/graph/manager/graph_manager_unittest.cc View File

@@ -109,7 +109,7 @@
#include "graph/build/label_allocator.h" #include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h" #include "inc/pass_manager.h"
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"


+ 97
- 2
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -15,8 +15,9 @@
*/ */


#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "ir_build/atc_ir_common.h"
#include "ir_build/option_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h"
#include "graph/debug/ge_attr_define.h"


#define protected public #define protected public
#define private public #define private public
@@ -68,6 +69,20 @@ TEST(UtestIrCommon, update_data_op_shape) {
EXPECT_EQ(ret, ge::SUCCESS); EXPECT_EQ(ret, ge::SUCCESS);
} }


TEST(UtestIrCommon, update_data_op_shape_range) {
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map;

std::pair<int64_t, int64_t> range_pair(1, 2);
vector<pair<int64_t, int64_t>> range_pair_tmp = { range_pair };

index_shape_range_map.push_back(range_pair_tmp);

AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0);
Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST(UtestIrCommon, update_dynamic_shape_range_success) { TEST(UtestIrCommon, update_dynamic_shape_range_success) {
ComputeGraphPtr graph = BuildComputeGraph(); ComputeGraphPtr graph = BuildComputeGraph();
std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]";
@@ -117,4 +132,84 @@ TEST(UtestIrCommon, check_dynamic_image_size_fail) {


bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size); bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
EXPECT_EQ(ret, false); EXPECT_EQ(ret, false);
}
}

TEST(UtestIrCommon, check_input_format_failed) {
std::string format = "invalid";
Status ret = CheckInputFormat(format);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}

TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) {
map<string, vector<int64_t>> shape_map;
shape_map.insert(std::pair<string, vector<int64_t>>("data", {-1, 2, 3}));
std::string dynamic_batch_size = "11";

bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size);
EXPECT_EQ(ret, true);
}

TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) {
map<string, vector<int64_t>> shape_map;
shape_map.insert(std::pair<string, vector<int64_t>>("data", {4, -1, -1, 5}));
std::string input_format = "NCHW";
std::string dynamic_image_size = "4,5";

Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST(UtestIrCommon, check_dynamic_input_param_succ) {
string dynamic_batch_size = "1";
string dynamic_image_size;
string dynamic_dims;
string input_shape = "data:1,3,244,244";
string input_shape_range;
string input_format = "NCHW";
bool is_dynamic_input = false;

Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims,
input_shape, input_shape_range, input_format,is_dynamic_input);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST(UtestIrCommon, check_compress_weight) {
std::string enable_compress_weight = "true";
std::string compress_weight_conf="./";
Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
EXPECT_EQ(ret, PARAM_INVALID);

enable_compress_weight = "yes";
compress_weight_conf = "./";
ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
EXPECT_EQ(ret, PARAM_INVALID);
}

TEST(UtestIrCommon, check_param_failed) {
std::string param_invalid = "invalid";

Status ret = CheckOutputTypeParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckBufferOptimizeParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckKeepTypeParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckInsertOpConfParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckDisableReuseMemoryParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckEnableSingleStreamParamValid(param_invalid);
EXPECT_EQ(ret, PARAM_INVALID);

std::string optypelist_for_implmode;
std::string op_select_implmode = "1";
ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode);
EXPECT_EQ(ret, PARAM_INVALID);

ret = CheckLogParamValidAndSetLogLevel(param_invalid);
}

Loading…
Cancel
Save