Browse Source

!1218 inference dynamic input

From: @zhengyuanhua
Reviewed-by: @wan_xuelei,@wqtshg,@xchu42
Signed-off-by: @ljl0711
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
2ece5f3b63
12 changed files with 457 additions and 62 deletions
  1. +9
    -20
      ge/hybrid/executor/hybrid_model_async_executor.cc
  2. +39
    -6
      ge/hybrid/executor/node_state.cc
  3. +2
    -0
      ge/hybrid/executor/node_state.h
  4. +14
    -15
      ge/hybrid/model/hybrid_model.cc
  5. +226
    -1
      ge/ir_build/atc_ir_common.cc
  6. +8
    -1
      ge/ir_build/atc_ir_common.h
  7. +38
    -18
      ge/ir_build/ge_ir_build.cc
  8. +7
    -1
      ge/offline/main.cc
  9. +7
    -0
      ge/session/omg.cc
  10. +4
    -0
      inc/external/ge/ge_api_types.h
  11. +3
    -0
      tests/ut/ge/CMakeLists.txt
  12. +100
    -0
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 9
- 20
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -444,31 +444,20 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs,
TensorValue tensor_value(inputs[i].data, inputs[i].length);
args.inputs[i] = tensor_value;
}
for (size_t i = 0; i < outputs.size(); ++i) {
args.outputs.emplace_back(TensorValue(outputs[i].data, outputs[i].length));
}
// usr must designate input tensorDesc when input shape is dynamic in inference
for (size_t i = 0; i < input_desc.size(); ++i) {
ConstGeTensorDescPtr tensor_desc_ptr = MakeShared<GeTensorDesc>(input_desc[i]);
args.input_desc.emplace_back(tensor_desc_ptr);
}

GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model.");
for (const auto &output_tensor_desc : args.output_desc) {
output_desc.emplace_back(*output_tensor_desc);
}

for (size_t i = 0; i < args.outputs.size(); ++i) {
int64_t output_real_size = 0;
ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc[i], output_real_size);
if (graph_status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Get tensor size in bytes failed.");
return FAILED;
}
if (output_real_size > 0) {
if (outputs[i].length < static_cast<uint64_t>(output_real_size)) {
GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by "
"user should be greater than or equal to the real size of output[%ld]",
i, outputs[i].length, output_real_size);
return FAILED;
}
GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size,
RT_MEMCPY_DEVICE_TO_DEVICE));
}
outputs[i].length = output_real_size;
}

return SUCCESS;
}



+ 39
- 6
ge/hybrid/executor/node_state.cc View File

@@ -44,6 +44,27 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
}
}

Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc,
const GeTensorDesc &target_tensor_desc) const {
std::vector<std::pair<int64_t, int64_t>> shape_range;
if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "Get shape range failed.");
return PARAM_INVALID;
}
if (shape_range.empty()) {
GELOGD("Shape range is empty, no need to check input shape.");
return SUCCESS;
}

GeShape target_shape = target_tensor_desc.GetShape();
if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "Check shape by shape range failed.");
return PARAM_INVALID;
}

return SUCCESS;
}

Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) {
if (node_item.IsInputShapeStatic(idx)) {
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]",
@@ -54,19 +75,31 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target
return SUCCESS;
}

std::lock_guard<std::mutex> lk(mu_);
auto &input_desc = input_tensor_desc[idx];
if (CheckInputShapeByShapeRange(input_desc, target) != SUCCESS) {
GELOGE(FAILED, "[%s] Check input shape by shape range failed.", node_item.NodeName().c_str());
return FAILED;
}
GeShape shape = target.GetShape();
input_desc.SetShape(shape);
input_desc.SetOriginShape(target.GetOriginShape());
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(target, tensor_size);
if (tensor_size <= 0) {
Format format = input_desc.GetFormat();
DataType data_type = input_desc.GetDataType();
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) {
GELOGE(FAILED, "[%s] Calculate tensor memory size failed.", node_item.NodeName().c_str());
return FAILED;
}
}
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld",
node_item.NodeName().c_str(),
idx,
target.GetShape().ToString().c_str(),
shape.ToString().c_str(),
target.GetOriginShape().ToString().c_str(),
tensor_size);

std::lock_guard<std::mutex> lk(mu_);
auto &input_desc = input_tensor_desc[idx];
input_desc.SetShape(target.GetShape());
input_desc.SetOriginShape(target.GetOriginShape());
(void) TensorUtils::SetSize(input_desc, tensor_size);
if (--num_pending_shapes_ <= 0) {
ready_cv_.notify_all();


+ 2
- 0
ge/hybrid/executor/node_state.h View File

@@ -58,6 +58,8 @@ struct ShapeInferenceState {

const vector<GeTensorDesc> &GetOutputTensorDesc() const;

Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const;

const NodeItem &node_item;

private:


+ 14
- 15
ge/hybrid/model/hybrid_model.cc View File

@@ -225,23 +225,19 @@ Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, st
GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0));

Format format = op_desc->GetInputDescPtr(0)->GetFormat();
input.data_type = op_desc->GetInputDescPtr(0)->GetDataType();
DataType data_type = op_desc->GetInputDescPtr(0)->GetDataType();
input.data_type = static_cast<uint32_t>(data_type);
input.name = op_desc->GetName();

int64_t input_size = 0;
GE_CHK_STATUS_RET(TensorUtils::GetSize(*op_desc->GetInputDescPtr(0), input_size), "get input size failed.");

// support dynamic shape
if (input_size < 0) {
GELOGD("dynamic shape scene, input size is unknown. "
"format=%d, data_type=%d, input_size=%ld",
format, input.data_type, input_size);
input_size = kMemSizeUnknownShape; // -1
GeShape shape = op_desc->GetInputDescPtr(0)->GetShape();
int64_t tensor_size = 0;
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Calculate tensor mem size failed.");
return FAILED;
}
// not support dynamic shape input for now, so input_size here will be not less than zero.
input.size = input_size;
if (tensor_size == kMemSizeUnknownShape) {
tensor_size = 0;
}
input.size = static_cast<uint64_t>(tensor_size);
CreateInputDimsInfo(op_desc, input);

formats.push_back(format);
@@ -284,6 +280,9 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc,
}
int64_t tensor_size = 0;
(void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size);
if (tensor_size == kMemSizeUnknownShape) {
tensor_size = 0;
}
output_desc_info.size = static_cast<uint64_t>(tensor_size);
output_desc_info.data_type = output_desc->GetDataType();
}


+ 226
- 1
ge/ir_build/atc_ir_common.cc View File

@@ -19,7 +19,9 @@
#include "framework/common/string_util.h"
#include "framework/common/types.h"
#include "framework/common/util.h"
#include "graph/compute_graph.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/tensor_utils.h"

using std::pair;
using std::string;
@@ -52,6 +54,11 @@ const char *const kCompressWeightError = "it must be appointed when appoint para
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kKeepDtypeError = "file not found";
const char *const kInputShapeRangeInvalid = "format of shape range is invalid";
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error";
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\"";
const char *const kInputShapeRangeSample2 = "\"[]\"";
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\"";

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
@@ -257,8 +264,132 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims
return true;
}

bool StringToLongNoThrow(const string &str, long &val) {
try {
val = std::stol(str);
return true;
} catch (const std::invalid_argument) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID,
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
} catch (const std::out_of_range) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID,
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.",
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3);
}
return false;
}

bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) {
vector<char> square_brackets;
for (auto ch : shape_range) {
if (ch == '[' || ch == ']') {
square_brackets.push_back(ch);
}
}

bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') && (square_brackets.size() == 2);
if (!is_square_brackets) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2});
GELOGE(PARAM_INVALID,
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample2);
return false;
}
// trim start bytes, after that, single input should be "1~20,3,3~6,-1"
if (ge::StringUtils::StartWith(shape_range, "[")) {
shape_range = shape_range.substr(1, shape_range.size() - 1);
}
// parse shape_range of single input. eg. "1~20,3,3~6,-1"
vector<string> dim_range_set = ge::StringUtils::Split(shape_range, ',');
for (const auto &range_pair_str : dim_range_set) {
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~');
pair<int64_t, int64_t> range_pair;
if (range_pair_set.size() == 1) {
long range_value = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) {
return false;
}
if (range_value < 0) {
range_pair = std::make_pair(1, range_value);
} else {
range_pair = std::make_pair(range_value, range_value);
}
} else if (range_pair_set.size() == 2) {
// unknown dim, should get range.
long range_left = 0;
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) {
return false;
}
long range_right = 0;
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) {
return false;
}
if (range_left < 0 || (range_right < 0)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID,
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
return false;
}
range_pair = std::make_pair(range_left, range_right);
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3});
GELOGE(PARAM_INVALID,
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.",
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3);
return false;
}
shape_range_vec.emplace_back(range_pair);
}
return true;
}

bool ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) {
GELOGD("Input shape range %s", shape_range.c_str());

vector<string> shape_range_vec = StringUtils::Split(shape_range, ';');
const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2;
for (const auto &shape_range_item : shape_range_vec) {
vector<string> shape_range_pair_vec = SplitInputShape(shape_range_item);
if (shape_range_pair_vec.size() != DEFAULT_SHAPE_RANGE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kSplitError1, kInputShapeRangeSample1});
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed, "
"reason: %s, correct sample is %s.", shape_range.c_str(), kSplitError1, kInputShapeRangeSample1);
return false;
}
if (shape_range_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"},
{shape_range, kEmptyError, kInputShapeRangeSample1});
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed,"
"reason: %s, correct sample is %s.", shape_range.c_str(), kEmptyError, kInputShapeRangeSample1);
return false;
}

string shape_range_str = shape_range_pair_vec[1];
vector<pair<int64_t, int64_t>> shape_range_val;
if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) {
GELOGE(PARAM_INVALID, "Parse single shape range %s error.", shape_range_str.c_str());
return false;
}
shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val));
}
return true;
}

Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims,
const string input_shape, const string input_format, bool &is_dynamic_input) {
const string input_shape, const string input_shape_range, const string input_format,
bool &is_dynamic_input) {
int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) +
static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty());
if (param_size > 1) {
@@ -269,6 +400,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i
}

if (param_size == 0) {
if (!input_shape_range.empty()) {
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map;
if(!ParseInputShapeRange(input_shape_range, shape_range_map)) {
GELOGE(ge::PARAM_INVALID, "Failed to parse input shape range: %s", input_shape_range.c_str());
return ge::PARAM_INVALID;
}
}
return ge::SUCCESS;
}

@@ -546,4 +684,91 @@ void EraseEndSemicolon(string &param) {
param.erase(param.end() - 1);
}
}

Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shape_map) {
GE_CHECK_NOTNULL(op);
if (shape_map.empty()) {
GELOGI("Shape map of data op [%s] is empty, no need to update.", op->GetName().c_str());
return SUCCESS;
}

auto tensor_input = op->MutableInputDesc(0);
auto tensor_output = op->MutableOutputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName();
auto iter = shape_map.find(data_op_name);
if (iter != shape_map.end()) {
tensor_input->SetShape(ge::GeShape(iter->second));
tensor_output->SetShape(ge::GeShape(iter->second));
GELOGI("Update input [%s] shape info", data_op_name.c_str());
} else {
GELOGI("No need update input [%s] attr because not found from input_shape.", data_op_name.c_str());
}

return SUCCESS;
}

Status UpdateDataOpShapeRange(const OpDescPtr &op,
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
GE_CHECK_NOTNULL(op);
if (shape_range_map.empty()) {
GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str());
return SUCCESS;
}

auto tensor_input = op->MutableInputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
string data_op_name = op->GetName();
auto origin_shape = tensor_input->GetShape();
auto iter = shape_range_map.find(data_op_name);
if (iter != shape_range_map.end()) {
auto cur_shape_range = iter->second;
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "[%s] Check shape by shape range failed.", op->GetName().c_str());
return PARAM_INVALID;
}
for (size_t idx = 0; idx < cur_shape_range.size(); idx++) {
auto left_range = cur_shape_range[idx].first;
auto right_range = cur_shape_range[idx].second;
if (left_range != right_range) {
origin_shape.SetDim(idx, UNKNOWN_DIM);
}
}
tensor_input->SetShape(origin_shape);
tensor_input->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info", data_op_name.c_str());
} else {
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str());
}

return SUCCESS;
}

Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) {
if (input_shape_range.empty()) {
return SUCCESS;
}
GE_CHECK_NOTNULL(compute_graph);

map<string, vector<pair<int64_t, int64_t>>> shape_range_map;
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) {
GELOGE(PARAM_INVALID, "Parse input shape range failed.");
return PARAM_INVALID;
}

for (NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) {
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) {
GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str());
return FAILED;
}
}
}
return SUCCESS;
}

} // namespace ge

+ 8
- 1
ge/ir_build/atc_ir_common.h View File

@@ -59,10 +59,13 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims

Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size,
std::string &dynamic_dims, const std::string input_shape,
const std::string input_format, bool &is_dynamic_input);
const std::string input_shape_range, const std::string input_format,
bool &is_dynamic_input);

bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map,
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false);
bool ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map);

Status CheckOutputTypeParamValid(const std::string output_type);
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize);
@@ -76,5 +79,9 @@ Status CheckInputFormat(const string &input_format);
Status CheckKeepTypeParamValid(const std::string &keep_dtype);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param);
Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map);
Status UpdateDataOpShapeRange(const OpDescPtr &op,
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map);
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range);
}
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_

+ 38
- 18
ge/ir_build/ge_ir_build.cc View File

@@ -55,6 +55,7 @@ const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0";
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false";
const std::string KEEP_DTYPE_OPTION = "keep_dtype";
const std::string kInputShape = "input_shape";
const std::string kInputShapeRange = "input_shape_range";
const std::string kInputFormat = "input_format";

/**
@@ -289,13 +290,20 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) {

graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
GELOGD("Enter Update Data Attr Process!");
if (options_.find(kInputShape) == options_.end()) {
return GRAPH_SUCCESS;
}
std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape];
std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange];
map<string, vector<int64_t>> shape_map;
vector<pair<string, vector<int64_t>>> user_shape_map;
GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true),
return GRAPH_PARAM_INVALID, "parse input shape failed!");
if (!input_shape.empty()) {
GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true),
return GRAPH_PARAM_INVALID, "Parse input shape failed!");
}
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map;
if (!input_shape_range.empty()) {
GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map),
return GRAPH_PARAM_INVALID, "Parse input shape range failed.");
}
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) {
@@ -303,21 +311,31 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
ge::OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) {
auto tensor_input = op->MutableInputDesc(0);
auto tensor_output = op->MutableOutputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName();
auto iter = shape_map.find(data_op_name);
if (iter != shape_map.end()) {
tensor_input->SetShape(ge::GeShape(iter->second));
tensor_output->SetShape(ge::GeShape(iter->second));
GELOGD("update input [%s] shape info", data_op_name.c_str());
} else {
GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str());
if (UpdateDataOpShape(op, shape_map) != SUCCESS) {
GELOGE(GRAPH_FAILED, "Update data op [%s] shape failed.", op->GetName().c_str());
return GRAPH_FAILED;
}
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) {
GELOGE(GRAPH_FAILED, "Update data op [%s] shape range failed.", op->GetName().c_str());
return GRAPH_FAILED;
}
if (shape_range_map.empty()) {
auto tensor_input = op->MutableInputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
GeShape shape = tensor_input->GetShape();
std::vector<std::pair<int64_t, int64_t>> shape_range;
if (tensor_input->GetShapeRange(shape_range) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[%s] Get shape range failed.", op->GetName().c_str());
return GRAPH_FAILED;
}
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) {
GELOGE(GRAPH_FAILED, "[%s] Check shape by shape range failed.", op->GetName().c_str());
return GRAPH_FAILED;
}
}
}
}

return GRAPH_SUCCESS;
}

@@ -400,9 +418,11 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri
: options_[ge::ir_option::DYNAMIC_IMAGE_SIZE];
string dynamic_dims =
options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS];
string input_shape_range =
options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE];

auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape,
input_format, is_dynamic_input_);
input_shape_range, input_format, is_dynamic_input_);
if (status != ge::SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "Check dynamic input size failed!");
return GRAPH_PARAM_INVALID;


+ 7
- 1
ge/offline/main.cc View File

@@ -84,6 +84,10 @@ DEFINE_string(input_shape, "",
"Optional; shape of input data. Required when framework is caffe "
"or TensorFLow or MindSpore or Onnx. "
"Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"");
DEFINE_string(input_shape_range, "",
"Optional; shape range of input data. Required when framework is caffe "
"or TensorFLow or Onnx. "
"Format: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2~n3,c2,h2,w2]\"");
DEFINE_bool(h, false, "show this help message");
DEFINE_string(cal_conf, "", "Optional; the calibration config file.");

@@ -240,6 +244,7 @@ class GFlagUtils {
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n"
" --input_format Format of input data. E.g.: \"NCHW\"\n"
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). "
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)."
"Use double quotation marks (\") to enclose each argument.\n"
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n"
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n"
@@ -373,7 +378,7 @@ class GFlagUtils {

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size,
FLAGS_dynamic_dims, FLAGS_input_shape,
FLAGS_dynamic_dims, FLAGS_input_shape, FLAGS_input_shape_range,
FLAGS_input_format, is_dynamic_input) != ge::SUCCESS,
ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!");

@@ -985,6 +990,7 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
} else {
std::map<string, string> atc_params;
atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape));
atc_params.insert(std::pair<string, string>(ge::INPUT_SHAPE_RANGE, FLAGS_input_shape_range));
atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes));
atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));


+ 7
- 0
ge/session/omg.cc View File

@@ -576,6 +576,7 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format,
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str());
return PARAM_INVALID;
}

return SUCCESS;
}

@@ -788,6 +789,12 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri

GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail.");

// parser input shape range and update op shape range
std::string input_shape_range;
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range);
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range),
"Update input shape range failed");

GELOGI("ATC parser success.");

return SUCCESS;


+ 4
- 0
inc/external/ge/ge_api_types.h View File

@@ -311,6 +311,9 @@ const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update";
// 0: data multi; 1: model multi;
const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode";

// atc and ir option
const char *const INPUT_SHAPE_RANGE = "input_shape_range";

// Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN };

@@ -390,6 +393,7 @@ static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str();
#ifdef __GNUC__
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
INPUT_SHAPE,
INPUT_SHAPE_RANGE,
OP_NAME_MAP,
DYNAMIC_BATCH_SIZE,
DYNAMIC_IMAGE_SIZE,


+ 3
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -45,6 +45,7 @@ include_directories(${GE_CODE_DIR}/inc)
include_directories(${GE_CODE_DIR}/metadef/inc)
include_directories(${GE_CODE_DIR}/ge)
include_directories(${GE_CODE_DIR}/ge/inc)
include_directories(${GE_CODE_DIR}/ge/ir_build)
include_directories(${GE_CODE_DIR}/metadef)
include_directories(${GE_CODE_DIR}/metadef/graph)
include_directories(${GE_CODE_DIR}/inc/external)
@@ -61,6 +62,7 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce)
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops)
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain)
include_directories(${GE_CODE_DIR}/tests/ut/ge)
include_directories(${GE_CODE_DIR}/tests/ut/common)
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto)
@@ -732,6 +734,7 @@ set(KERNEL_TEST_FILES

set(MULTI_PARTS_TEST_FILES
"graph_ir/ge_operator_factory_unittest.cc"
"graph_ir/ge_ir_build_unittest.cc"
"graph/transop_util_unittest.cc"
"common/datatype_transfer_unittest.cc"
"common/dump_manager_unittest.cc"


+ 100
- 0
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -0,0 +1,100 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "ir_build/atc_ir_common.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"

#define protected public
#define private public

#undef private
#undef protected

const string DATA = "Data";
const string AddNYes = "AddNYes";
const string NETOUTPUT = "NetOutput";

using namespace ge;
class UtestIrCommon : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) {
OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type);
ge::GeTensorDesc ge_tensor_desc;
op_desc->AddInputDesc("input", ge_tensor_desc);
op_desc->AddOutputDesc("output", ge_tensor_desc);

return op_desc;
}

static ComputeGraphPtr BuildComputeGraph() {
auto builder = ut::GraphBuilder("test");
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

builder.AddDataEdge(data1, 0, addn1, 0);
builder.AddDataEdge(data2, 0, addn1, 1);
builder.AddDataEdge(addn1, 0,netoutput, 0);

return builder.GetGraph();
}

TEST(UtestIrCommon, update_data_op_shape) {
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
map<string, vector<int64_t>> shape_map;
shape_map["Data"] = {{1,2}};

Status ret = UpdateDataOpShape(op_desc, shape_map);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST(UtestIrCommon, update_dynamic_shape_range_success) {
ComputeGraphPtr graph = BuildComputeGraph();
std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]";

Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
ComputeGraphPtr graph = BuildComputeGraph();
// 1
std::string input_shape_range = "input1;[1, 2~3, -1]";
Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);

// 2
input_shape_range = "input1:[1, 2~3, -1)";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);

//3
input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::FAILED);

//4
input_shape_range = "input1:[1, 2~-3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}

Loading…
Cancel
Save