From 3e6b21f6c17b54a40cd0e59c7b321f71775a402f Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Thu, 31 Dec 2020 14:47:44 +0800 Subject: [PATCH] modified: ge/graph/preprocess/graph_preprocess.cc --- ge/graph/preprocess/graph_preprocess.cc | 114 ++++++++++++++++-------- 1 file changed, 75 insertions(+), 39 deletions(-) diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index c45f4db6..57c2542a 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -901,51 +901,74 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) { } /** * Parser shape_range from string to vector - * shape_range from option normally is "[1~20],[3],[3~6],[-1]" + * shape_range from option normally is "[1~20,3,3~6,-1],[1~20,3,3~6,-1]" * @param shape_range */ -void ParseDynamicInputShapeRange(const std::string &shape_range, - std::vector>> &range) { - if (shape_range.empty() || shape_range.size() < 2) { +Status ParseDynamicInputShapeRange(const std::string &shape_range, + std::vector>> &range) { + if (shape_range.size() < 2) { GELOGW("Shape range %s is invalid.", shape_range.c_str()); return; } - // different parameter sets are split by ';' - vector shape_set = ge::StringUtils::Split(shape_range, ']'); - if (shape_set.empty()) { - return; + // different shape_ragne of single input are split by ']' + vector shape_range_set = ge::StringUtils::Split(shape_range, ']'); + if (shape_range_set.empty()) { + GELOGE("Shape range %s is not valid. Correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", shape_range.c_str()); + return PARAM_INVALID; } - for (auto shape_str : shape_set) { - if (shape_str.empty()) { - continue; - } - if (ge::StringUtils::StartWith(shape_str, "[")) { - shape_str = shape_str.substr(1, shape_str.size()); + for (const auto &shape_range_str : shape_range_set) { + if (shape_range_str.empty()) { + GELOGE("Shape range of input is empty. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", + shape_range.c_str()); + return PARAM_INVALID; } - if (ge::StringUtils::StartWith(shape_str, ",")) { - shape_str = shape_str.substr(2, shape_str.size()); + // trim start bytes, after that, single input should be "1~20,3,3~6,-1" + if (ge::StringUtils::StartWith(shape_range_str, "[")) { + shape_range_str = shape_range_str.substr(1, shape_range_str.size()); + } else if (ge::StringUtils::StartWith(shape_range_str, ",")) { + shape_range_str = shape_range_str.substr(2, shape_range_str.size()); + } else { + GELOGE("Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", + shape_range.c_str()); + return PARAM_INVALID; } - std::vector> range_of_single; - vector range_set = ge::StringUtils::Split(shape_str, ','); - for (auto range_str : range_set) { - vector pair_set = ge::StringUtils::Split(range_str, '~'); + // parse shape_range of single input. eg. "1~20,3,3~6,-1" + std::vector> range_of_single_input; + vector dim_range_set = ge::StringUtils::Split(shape_range_str, ','); + for (const auto &range_pair_str : dim_range_set) { + vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); pair range_pair; - if (pair_set.size() == 1) { - auto range_value = atoi(pair_set.at(0).c_str()); + if (range_pair_set.size() == 1) { + // fix dim + auto range_value = stol(range_pair_set.at(0).c_str()); if (range_value < 0) { range_pair = std::make_pair(1, range_value); } else { range_pair = std::make_pair(range_value, range_value); } - } else if (pair_set.size() == 2) { - auto range_left = atoi(pair_set.at(0).c_str()); - auto range_right = atoi(pair_set.at(1).c_str()); - range_pair = std::make_pair(range_left, range_right); + } else if (range_pair_set.size() == 2) { + // unknown dim, should get range. + try { + auto range_left = stol(range_pair_set.at(0).c_str()); + auto range_right = stol(range_pair_set.at(1).c_str()); + range_pair = std::make_pair(range_left, range_right); + } catch (const std::invalid_argument) { + GELOGE( + "Parse shape range of input failed when transfer from string to int64. Given %s, while correct example: " + "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", + shape_range.c_str()); + return PARAM_INVALID; + } + } else { + GELOGE("Shape range of input is invalid. Given %s, while correct example: \"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", + shape_range.c_str()); + return PARAM_INVALID; } - range_of_single.emplace_back(range_pair); + range_of_single_input.emplace_back(range_pair); } - range.emplace_back(range_of_single); + range.emplace_back(range_of_single_input); } + return SUCCESS; } Status GetDynamicInputShapeRange(const std::vector &user_input, const std::map &graph_option, @@ -966,7 +989,8 @@ Status GetDynamicInputShapeRange(const std::vector &user_input, const return PARAM_INVALID; } GELOGD("GraphOption: dynamic_inputs_shape_range value is %s.", iter->second.c_str()); - ParseDynamicInputShapeRange(iter->second, range_vec); + auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); + GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); if (range_vec.size() != user_input.size()) { GELOGE(PARAM_INVALID, "Dynamic input shape range size is %zu, inputs size is %zu. Not match.", range_vec.size(), user_input.size()); @@ -978,18 +1002,30 @@ Status GetDynamicInputShapeRange(const std::vector &user_input, const Status UpdateDynamicInputShapeRange(const ge::GeAttrValue::INT index, const vector>> &range_vec, OpDescPtr &op, GeTensorDesc &desc) { - auto unkown_shape = desc.GetShape(); - auto shape_range = range_vec.at(index); - for (size_t i = 0; i < unkown_shape.GetDimNum(); ++i) { - if (shape_range.at(i).first == shape_range.at(i).second) { - unkown_shape.SetDim(i, shape_range.at(i).first); + auto origin_shape = desc.GetShape(); + auto current_shape_range_vec = range_vec.at(index); + if (current_shape_range_vec.size() != origin_shape.GetDimNum()) { + GELOGE(PARAM_INVALID, "Given shape_range dim num is %zu, current dim num is %zu, not match.Pleace Check.", + current_shape_range_vec.size(), origin_shape.GetDimNum()); + return PARAM_INVALID; + } + for (size_t i = 0; i < origin_shape.GetDimNum(); ++i) { + if (current_shape_range_vec.at(i).first == current_shape_range_vec.at(i).second) { + // given shape_range is known dim, check is same as origin or not + if (origin_shape.GetDim(i) != current_shape_range_vec.at(i).first) { + GELOGE(PARAM_INVALID, "Given shape range is %ld, current dim shape is %ld, not match.Pleace Check.", + current_shape_range_vec.at(i).first, origin_shape.GetDim(i)); + return PARAM_INVALID; + } + origin_shape.SetDim(i, current_shape_range_vec.at(i).first); } else { - unkown_shape.SetDim(i, -1); + origin_shape.SetDim(i, -1); } } - desc.SetShape(unkown_shape); - desc.SetShapeRange(shape_range); - int64_t dynamic_shape_size = 1; + desc.SetShape(origin_shape); + desc.SetShapeRange(current_shape_range_vec); + + /*int64_t dynamic_shape_size = 1; for (const auto range_pair : range_vec.at(index)) { FMK_INT64_MULCHECK(dynamic_shape_size, range_pair.second); dynamic_shape_size *= range_pair.second; @@ -1003,7 +1039,7 @@ Status UpdateDynamicInputShapeRange(const ge::GeAttrValue::INT index, FMK_INT64_MULCHECK(dynamic_shape_size, data_type_size); dynamic_shape_size *= data_type_size; GELOGI("In dynamic_execute mode ,set input %s shape range size %ld", op->GetName().c_str(), dynamic_shape_size); - ge::TensorUtils::SetSize(desc, dynamic_shape_size); + ge::TensorUtils::SetSize(desc, dynamic_shape_size);*/ graphStatus graph_ret = op->UpdateInputDesc(0, desc); GE_CHK_STATUS_RET(graph_ret, "UpdateInputDesc fail, graph ret: %u", graph_ret); graph_ret = op->UpdateOutputDesc(0, desc);