From: @zhengyuanhua Reviewed-by: @wan_xuelei,@wqtshg,@xchu42 Signed-off-by: @ljl0711tags/v1.2.0
@@ -444,31 +444,20 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | |||
TensorValue tensor_value(inputs[i].data, inputs[i].length); | |||
args.inputs[i] = tensor_value; | |||
} | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
args.outputs.emplace_back(TensorValue(outputs[i].data, outputs[i].length)); | |||
} | |||
// usr must designate input tensorDesc when input shape is dynamic in inference | |||
for (size_t i = 0; i < input_desc.size(); ++i) { | |||
ConstGeTensorDescPtr tensor_desc_ptr = MakeShared<GeTensorDesc>(input_desc[i]); | |||
args.input_desc.emplace_back(tensor_desc_ptr); | |||
} | |||
GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); | |||
for (const auto &output_tensor_desc : args.output_desc) { | |||
output_desc.emplace_back(*output_tensor_desc); | |||
} | |||
for (size_t i = 0; i < args.outputs.size(); ++i) { | |||
int64_t output_real_size = 0; | |||
ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc[i], output_real_size); | |||
if (graph_status != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||
return FAILED; | |||
} | |||
if (output_real_size > 0) { | |||
if (outputs[i].length < static_cast<uint64_t>(output_real_size)) { | |||
GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by " | |||
"user should be greater than or equal to the real size of output[%ld]", | |||
i, outputs[i].length, output_real_size); | |||
return FAILED; | |||
} | |||
GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, | |||
RT_MEMCPY_DEVICE_TO_DEVICE)); | |||
} | |||
outputs[i].length = output_real_size; | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -44,6 +44,27 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||
} | |||
} | |||
Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, | |||
const GeTensorDesc &target_tensor_desc) const { | |||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) { | |||
GELOGE(PARAM_INVALID, "Get shape range failed."); | |||
return PARAM_INVALID; | |||
} | |||
if (shape_range.empty()) { | |||
GELOGD("Shape range is empty, no need to check input shape."); | |||
return SUCCESS; | |||
} | |||
GeShape target_shape = target_tensor_desc.GetShape(); | |||
if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) { | |||
GELOGE(PARAM_INVALID, "Check shape by shape range failed."); | |||
return PARAM_INVALID; | |||
} | |||
return SUCCESS; | |||
} | |||
Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | |||
if (node_item.IsInputShapeStatic(idx)) { | |||
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | |||
@@ -54,19 +75,31 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target | |||
return SUCCESS; | |||
} | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto &input_desc = input_tensor_desc[idx]; | |||
if (CheckInputShapeByShapeRange(input_desc, target) != SUCCESS) { | |||
GELOGE(FAILED, "[%s] Check input shape by shape range failed.", node_item.NodeName().c_str()); | |||
return FAILED; | |||
} | |||
GeShape shape = target.GetShape(); | |||
input_desc.SetShape(shape); | |||
input_desc.SetOriginShape(target.GetOriginShape()); | |||
int64_t tensor_size = -1; | |||
(void) TensorUtils::GetSize(target, tensor_size); | |||
if (tensor_size <= 0) { | |||
Format format = input_desc.GetFormat(); | |||
DataType data_type = input_desc.GetDataType(); | |||
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "[%s] Calculate tensor memory size failed.", node_item.NodeName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", | |||
node_item.NodeName().c_str(), | |||
idx, | |||
target.GetShape().ToString().c_str(), | |||
shape.ToString().c_str(), | |||
target.GetOriginShape().ToString().c_str(), | |||
tensor_size); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto &input_desc = input_tensor_desc[idx]; | |||
input_desc.SetShape(target.GetShape()); | |||
input_desc.SetOriginShape(target.GetOriginShape()); | |||
(void) TensorUtils::SetSize(input_desc, tensor_size); | |||
if (--num_pending_shapes_ <= 0) { | |||
ready_cv_.notify_all(); | |||
@@ -58,6 +58,8 @@ struct ShapeInferenceState { | |||
const vector<GeTensorDesc> &GetOutputTensorDesc() const; | |||
Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const; | |||
const NodeItem &node_item; | |||
private: | |||
@@ -225,23 +225,19 @@ Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, st | |||
GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0)); | |||
Format format = op_desc->GetInputDescPtr(0)->GetFormat(); | |||
input.data_type = op_desc->GetInputDescPtr(0)->GetDataType(); | |||
DataType data_type = op_desc->GetInputDescPtr(0)->GetDataType(); | |||
input.data_type = static_cast<uint32_t>(data_type); | |||
input.name = op_desc->GetName(); | |||
int64_t input_size = 0; | |||
GE_CHK_STATUS_RET(TensorUtils::GetSize(*op_desc->GetInputDescPtr(0), input_size), "get input size failed."); | |||
// support dynamic shape | |||
if (input_size < 0) { | |||
GELOGD("dynamic shape scene, input size is unknown. " | |||
"format=%d, data_type=%d, input_size=%ld", | |||
format, input.data_type, input_size); | |||
input_size = kMemSizeUnknownShape; // -1 | |||
GeShape shape = op_desc->GetInputDescPtr(0)->GetShape(); | |||
int64_t tensor_size = 0; | |||
if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Calculate tensor mem size failed."); | |||
return FAILED; | |||
} | |||
// not support dynamic shape input for now, so input_size here will be not less than zero. | |||
input.size = input_size; | |||
if (tensor_size == kMemSizeUnknownShape) { | |||
tensor_size = 0; | |||
} | |||
input.size = static_cast<uint64_t>(tensor_size); | |||
CreateInputDimsInfo(op_desc, input); | |||
formats.push_back(format); | |||
@@ -284,6 +280,9 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, | |||
} | |||
int64_t tensor_size = 0; | |||
(void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); | |||
if (tensor_size == kMemSizeUnknownShape) { | |||
tensor_size = 0; | |||
} | |||
output_desc_info.size = static_cast<uint64_t>(tensor_size); | |||
output_desc_info.data_type = output_desc->GetDataType(); | |||
} | |||
@@ -19,7 +19,9 @@ | |||
#include "framework/common/string_util.h" | |||
#include "framework/common/types.h" | |||
#include "framework/common/util.h" | |||
#include "graph/compute_graph.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
using std::pair; | |||
using std::string; | |||
@@ -52,6 +54,11 @@ const char *const kCompressWeightError = "it must be appointed when appoint para | |||
const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | |||
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | |||
const char *const kKeepDtypeError = "file not found"; | |||
const char *const kInputShapeRangeInvalid = "format of shape range is invalid"; | |||
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error"; | |||
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\""; | |||
const char *const kInputShapeRangeSample2 = "\"[]\""; | |||
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\""; | |||
vector<string> SplitInputShape(const std::string &input_shape) { | |||
vector<string> shape_pair_vec; | |||
@@ -257,8 +264,132 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||
return true; | |||
} | |||
bool StringToLongNoThrow(const string &str, long &val) { | |||
try { | |||
val = std::stol(str); | |||
return true; | |||
} catch (const std::invalid_argument) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||
GELOGE(PARAM_INVALID, | |||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||
} catch (const std::out_of_range) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{str, kShapeRangeValueConvertError, kInputShapeRangeSample3}); | |||
GELOGE(PARAM_INVALID, | |||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||
str.c_str(), kShapeRangeValueConvertError, kInputShapeRangeSample3); | |||
} | |||
return false; | |||
} | |||
bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_t>> &shape_range_vec) { | |||
vector<char> square_brackets; | |||
for (auto ch : shape_range) { | |||
if (ch == '[' || ch == ']') { | |||
square_brackets.push_back(ch); | |||
} | |||
} | |||
bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') && (square_brackets.size() == 2); | |||
if (!is_square_brackets) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2}); | |||
GELOGE(PARAM_INVALID, | |||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample2); | |||
return false; | |||
} | |||
// trim start bytes, after that, single input should be "1~20,3,3~6,-1" | |||
if (ge::StringUtils::StartWith(shape_range, "[")) { | |||
shape_range = shape_range.substr(1, shape_range.size() - 1); | |||
} | |||
// parse shape_range of single input. eg. "1~20,3,3~6,-1" | |||
vector<string> dim_range_set = ge::StringUtils::Split(shape_range, ','); | |||
for (const auto &range_pair_str : dim_range_set) { | |||
vector<string> range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); | |||
pair<int64_t, int64_t> range_pair; | |||
if (range_pair_set.size() == 1) { | |||
long range_value = 0; | |||
if (!StringToLongNoThrow(range_pair_set.at(0), range_value)) { | |||
return false; | |||
} | |||
if (range_value < 0) { | |||
range_pair = std::make_pair(1, range_value); | |||
} else { | |||
range_pair = std::make_pair(range_value, range_value); | |||
} | |||
} else if (range_pair_set.size() == 2) { | |||
// unknown dim, should get range. | |||
long range_left = 0; | |||
if (!StringToLongNoThrow(range_pair_set.at(0), range_left)) { | |||
return false; | |||
} | |||
long range_right = 0; | |||
if (!StringToLongNoThrow(range_pair_set.at(1), range_right)) { | |||
return false; | |||
} | |||
if (range_left < 0 || (range_right < 0)) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||
GELOGE(PARAM_INVALID, | |||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||
return false; | |||
} | |||
range_pair = std::make_pair(range_left, range_right); | |||
} else { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample3}); | |||
GELOGE(PARAM_INVALID, | |||
"Parse input parameter [--input_shape_range]'s shape range[%s] failed, reason: %s, correct sample is %s.", | |||
shape_range.c_str(), kInputShapeRangeInvalid, kInputShapeRangeSample3); | |||
return false; | |||
} | |||
shape_range_vec.emplace_back(range_pair); | |||
} | |||
return true; | |||
} | |||
bool ParseInputShapeRange(const std::string &shape_range, | |||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) { | |||
GELOGD("Input shape range %s", shape_range.c_str()); | |||
vector<string> shape_range_vec = StringUtils::Split(shape_range, ';'); | |||
const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2; | |||
for (const auto &shape_range_item : shape_range_vec) { | |||
vector<string> shape_range_pair_vec = SplitInputShape(shape_range_item); | |||
if (shape_range_pair_vec.size() != DEFAULT_SHAPE_RANGE_PAIR_SIZE) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"}, | |||
{shape_range, kSplitError1, kInputShapeRangeSample1}); | |||
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed, " | |||
"reason: %s, correct sample is %s.", shape_range.c_str(), kSplitError1, kInputShapeRangeSample1); | |||
return false; | |||
} | |||
if (shape_range_pair_vec[1].empty()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape", "reason", "sample"}, | |||
{shape_range, kEmptyError, kInputShapeRangeSample1}); | |||
GELOGE(PARAM_INVALID, "Parse input parameter [--input_shape_range]'s shape range[%s] failed," | |||
"reason: %s, correct sample is %s.", shape_range.c_str(), kEmptyError, kInputShapeRangeSample1); | |||
return false; | |||
} | |||
string shape_range_str = shape_range_pair_vec[1]; | |||
vector<pair<int64_t, int64_t>> shape_range_val; | |||
if (!ParseSingleShapeRange(shape_range_str, shape_range_val)) { | |||
GELOGE(PARAM_INVALID, "Parse single shape range %s error.", shape_range_str.c_str()); | |||
return false; | |||
} | |||
shape_range_map.emplace(make_pair(StringUtils::Trim(shape_range_pair_vec[0]), shape_range_val)); | |||
} | |||
return true; | |||
} | |||
Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, | |||
const string input_shape, const string input_format, bool &is_dynamic_input) { | |||
const string input_shape, const string input_shape_range, const string input_format, | |||
bool &is_dynamic_input) { | |||
int32_t param_size = static_cast<int32_t>(!dynamic_batch_size.empty()) + | |||
static_cast<int32_t>(!dynamic_image_size.empty()) + static_cast<int32_t>(!dynamic_dims.empty()); | |||
if (param_size > 1) { | |||
@@ -269,6 +400,13 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i | |||
} | |||
if (param_size == 0) { | |||
if (!input_shape_range.empty()) { | |||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||
if(!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||
GELOGE(ge::PARAM_INVALID, "Failed to parse input shape range: %s", input_shape_range.c_str()); | |||
return ge::PARAM_INVALID; | |||
} | |||
} | |||
return ge::SUCCESS; | |||
} | |||
@@ -546,4 +684,91 @@ void EraseEndSemicolon(string ¶m) { | |||
param.erase(param.end() - 1); | |||
} | |||
} | |||
Status UpdateDataOpShape(const OpDescPtr &op, map<string, vector<int64_t>> &shape_map) { | |||
GE_CHECK_NOTNULL(op); | |||
if (shape_map.empty()) { | |||
GELOGI("Shape map of data op [%s] is empty, no need to update.", op->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
auto tensor_input = op->MutableInputDesc(0); | |||
auto tensor_output = op->MutableOutputDesc(0); | |||
GE_CHECK_NOTNULL(tensor_input); | |||
GE_CHECK_NOTNULL(tensor_output); | |||
string data_op_name = op->GetName(); | |||
auto iter = shape_map.find(data_op_name); | |||
if (iter != shape_map.end()) { | |||
tensor_input->SetShape(ge::GeShape(iter->second)); | |||
tensor_output->SetShape(ge::GeShape(iter->second)); | |||
GELOGI("Update input [%s] shape info", data_op_name.c_str()); | |||
} else { | |||
GELOGI("No need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { | |||
GE_CHECK_NOTNULL(op); | |||
if (shape_range_map.empty()) { | |||
GELOGI("Shape range map of data op [%s] is empty.", op->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
auto tensor_input = op->MutableInputDesc(0); | |||
GE_CHECK_NOTNULL(tensor_input); | |||
string data_op_name = op->GetName(); | |||
auto origin_shape = tensor_input->GetShape(); | |||
auto iter = shape_range_map.find(data_op_name); | |||
if (iter != shape_range_map.end()) { | |||
auto cur_shape_range = iter->second; | |||
if (TensorUtils::CheckShapeByShapeRange(origin_shape, cur_shape_range) != SUCCESS) { | |||
GELOGE(PARAM_INVALID, "[%s] Check shape by shape range failed.", op->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
for (size_t idx = 0; idx < cur_shape_range.size(); idx++) { | |||
auto left_range = cur_shape_range[idx].first; | |||
auto right_range = cur_shape_range[idx].second; | |||
if (left_range != right_range) { | |||
origin_shape.SetDim(idx, UNKNOWN_DIM); | |||
} | |||
} | |||
tensor_input->SetShape(origin_shape); | |||
tensor_input->SetShapeRange(cur_shape_range); | |||
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); | |||
} else { | |||
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { | |||
if (input_shape_range.empty()) { | |||
return SUCCESS; | |||
} | |||
GE_CHECK_NOTNULL(compute_graph); | |||
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; | |||
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { | |||
GELOGE(PARAM_INVALID, "Parse input shape range failed."); | |||
return PARAM_INVALID; | |||
} | |||
for (NodePtr &input_node : compute_graph->GetDirectNode()) { | |||
GE_CHECK_NOTNULL(input_node); | |||
OpDescPtr op = input_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op); | |||
if (op->GetType() == DATA) { | |||
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||
GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -59,10 +59,13 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims | |||
Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, | |||
std::string &dynamic_dims, const std::string input_shape, | |||
const std::string input_format, bool &is_dynamic_input); | |||
const std::string input_shape_range, const std::string input_format, | |||
bool &is_dynamic_input); | |||
bool ParseInputShape(const std::string &input_shape, std::map<string, std::vector<int64_t>> &shape_map, | |||
std::vector<std::pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input = false); | |||
bool ParseInputShapeRange(const std::string &shape_range, | |||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||
Status CheckOutputTypeParamValid(const std::string output_type); | |||
Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); | |||
@@ -76,5 +79,9 @@ Status CheckInputFormat(const string &input_format); | |||
Status CheckKeepTypeParamValid(const std::string &keep_dtype); | |||
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | |||
void EraseEndSemicolon(std::string ¶m); | |||
Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<int64_t>> &shape_map); | |||
Status UpdateDataOpShapeRange(const OpDescPtr &op, | |||
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map); | |||
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); | |||
} | |||
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ |
@@ -55,6 +55,7 @@ const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | |||
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | |||
const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | |||
const std::string kInputShape = "input_shape"; | |||
const std::string kInputShapeRange = "input_shape_range"; | |||
const std::string kInputFormat = "input_format"; | |||
/** | |||
@@ -289,13 +290,20 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) { | |||
graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||
GELOGD("Enter Update Data Attr Process!"); | |||
if (options_.find(kInputShape) == options_.end()) { | |||
return GRAPH_SUCCESS; | |||
} | |||
std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape]; | |||
std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange]; | |||
map<string, vector<int64_t>> shape_map; | |||
vector<pair<string, vector<int64_t>>> user_shape_map; | |||
GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true), | |||
return GRAPH_PARAM_INVALID, "parse input shape failed!"); | |||
if (!input_shape.empty()) { | |||
GE_CHK_BOOL_EXEC(ParseInputShape(input_shape, shape_map, user_shape_map, true), | |||
return GRAPH_PARAM_INVALID, "Parse input shape failed!"); | |||
} | |||
std::map<string, std::vector<std::pair<int64_t, int64_t>>> shape_range_map; | |||
if (!input_shape_range.empty()) { | |||
GE_CHK_BOOL_EXEC(ParseInputShapeRange(input_shape_range, shape_range_map), | |||
return GRAPH_PARAM_INVALID, "Parse input shape range failed."); | |||
} | |||
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
GE_CHECK_NOTNULL(compute_graph); | |||
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { | |||
@@ -303,21 +311,31 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||
ge::OpDescPtr op = input_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op); | |||
if (op->GetType() == DATA) { | |||
auto tensor_input = op->MutableInputDesc(0); | |||
auto tensor_output = op->MutableOutputDesc(0); | |||
GE_CHECK_NOTNULL(tensor_input); | |||
GE_CHECK_NOTNULL(tensor_output); | |||
string data_op_name = op->GetName(); | |||
auto iter = shape_map.find(data_op_name); | |||
if (iter != shape_map.end()) { | |||
tensor_input->SetShape(ge::GeShape(iter->second)); | |||
tensor_output->SetShape(ge::GeShape(iter->second)); | |||
GELOGD("update input [%s] shape info", data_op_name.c_str()); | |||
} else { | |||
GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); | |||
if (UpdateDataOpShape(op, shape_map) != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Update data op [%s] shape failed.", op->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Update data op [%s] shape range failed.", op->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (shape_range_map.empty()) { | |||
auto tensor_input = op->MutableInputDesc(0); | |||
GE_CHECK_NOTNULL(tensor_input); | |||
GeShape shape = tensor_input->GetShape(); | |||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
if (tensor_input->GetShapeRange(shape_range) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "[%s] Get shape range failed.", op->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "[%s] Check shape by shape range failed.", op->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -400,9 +418,11 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||
: options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; | |||
string dynamic_dims = | |||
options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; | |||
string input_shape_range = | |||
options_.find(ge::INPUT_SHAPE_RANGE) == options_.end() ? "" : options_[ge::INPUT_SHAPE_RANGE]; | |||
auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, | |||
input_format, is_dynamic_input_); | |||
input_shape_range, input_format, is_dynamic_input_); | |||
if (status != ge::SUCCESS) { | |||
GELOGE(GRAPH_PARAM_INVALID, "Check dynamic input size failed!"); | |||
return GRAPH_PARAM_INVALID; | |||
@@ -84,6 +84,10 @@ DEFINE_string(input_shape, "", | |||
"Optional; shape of input data. Required when framework is caffe " | |||
"or TensorFLow or MindSpore or Onnx. " | |||
"Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); | |||
DEFINE_string(input_shape_range, "", | |||
"Optional; shape range of input data. Required when framework is caffe " | |||
"or TensorFLow or Onnx. " | |||
"Format: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2~n3,c2,h2,w2]\""); | |||
DEFINE_bool(h, false, "show this help message"); | |||
DEFINE_string(cal_conf, "", "Optional; the calibration config file."); | |||
@@ -240,6 +244,7 @@ class GFlagUtils { | |||
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | |||
" --input_format Format of input data. E.g.: \"NCHW\"\n" | |||
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | |||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||
"Use double quotation marks (\") to enclose each argument.\n" | |||
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | |||
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | |||
@@ -373,7 +378,7 @@ class GFlagUtils { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, | |||
FLAGS_dynamic_dims, FLAGS_input_shape, | |||
FLAGS_dynamic_dims, FLAGS_input_shape, FLAGS_input_shape_range, | |||
FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, | |||
ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); | |||
@@ -985,6 +990,7 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||
} else { | |||
std::map<string, string> atc_params; | |||
atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape)); | |||
atc_params.insert(std::pair<string, string>(ge::INPUT_SHAPE_RANGE, FLAGS_input_shape_range)); | |||
atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes)); | |||
atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); | |||
atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report)); | |||
@@ -576,6 +576,7 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format, | |||
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -788,6 +789,12 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri | |||
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail."); | |||
// parser input shape range and update op shape range | |||
std::string input_shape_range; | |||
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range); | |||
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range), | |||
"Update input shape range failed"); | |||
GELOGI("ATC parser success."); | |||
return SUCCESS; | |||
@@ -311,6 +311,9 @@ const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; | |||
// 0: data multi; 1: model multi; | |||
const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode"; | |||
// atc and ir option | |||
const char *const INPUT_SHAPE_RANGE = "input_shape_range"; | |||
// Graph run mode | |||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | |||
@@ -390,6 +393,7 @@ static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | |||
#ifdef __GNUC__ | |||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||
INPUT_SHAPE, | |||
INPUT_SHAPE_RANGE, | |||
OP_NAME_MAP, | |||
DYNAMIC_BATCH_SIZE, | |||
DYNAMIC_IMAGE_SIZE, | |||
@@ -45,6 +45,7 @@ include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/ge) | |||
include_directories(${GE_CODE_DIR}/ge/inc) | |||
include_directories(${GE_CODE_DIR}/ge/ir_build) | |||
include_directories(${GE_CODE_DIR}/metadef) | |||
include_directories(${GE_CODE_DIR}/metadef/graph) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
@@ -61,6 +62,7 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) | |||
include_directories(${GE_CODE_DIR}/tests/ut/ge) | |||
include_directories(${GE_CODE_DIR}/tests/ut/common) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | |||
@@ -732,6 +734,7 @@ set(KERNEL_TEST_FILES | |||
set(MULTI_PARTS_TEST_FILES | |||
"graph_ir/ge_operator_factory_unittest.cc" | |||
"graph_ir/ge_ir_build_unittest.cc" | |||
"graph/transop_util_unittest.cc" | |||
"common/datatype_transfer_unittest.cc" | |||
"common/dump_manager_unittest.cc" | |||
@@ -0,0 +1,100 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "ir_build/atc_ir_common.h" | |||
#include "graph/testcase/ge_graph/graph_builder_utils.h" | |||
#define protected public | |||
#define private public | |||
#undef private | |||
#undef protected | |||
const string DATA = "Data"; | |||
const string AddNYes = "AddNYes"; | |||
const string NETOUTPUT = "NetOutput"; | |||
using namespace ge; | |||
class UtestIrCommon : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) { | |||
OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type); | |||
ge::GeTensorDesc ge_tensor_desc; | |||
op_desc->AddInputDesc("input", ge_tensor_desc); | |||
op_desc->AddOutputDesc("output", ge_tensor_desc); | |||
return op_desc; | |||
} | |||
static ComputeGraphPtr BuildComputeGraph() { | |||
auto builder = ut::GraphBuilder("test"); | |||
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); | |||
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); | |||
auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); | |||
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(data1, 0, addn1, 0); | |||
builder.AddDataEdge(data2, 0, addn1, 1); | |||
builder.AddDataEdge(addn1, 0,netoutput, 0); | |||
return builder.GetGraph(); | |||
} | |||
TEST(UtestIrCommon, update_data_op_shape) { | |||
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); | |||
map<string, vector<int64_t>> shape_map; | |||
shape_map["Data"] = {{1,2}}; | |||
Status ret = UpdateDataOpShape(op_desc, shape_map); | |||
EXPECT_EQ(ret, ge::SUCCESS); | |||
} | |||
TEST(UtestIrCommon, update_dynamic_shape_range_success) { | |||
ComputeGraphPtr graph = BuildComputeGraph(); | |||
std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]"; | |||
Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||
EXPECT_EQ(ret, ge::SUCCESS); | |||
} | |||
TEST(UtestIrCommon, update_dynamic_shape_range_failed) { | |||
ComputeGraphPtr graph = BuildComputeGraph(); | |||
// 1 | |||
std::string input_shape_range = "input1;[1, 2~3, -1]"; | |||
Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||
// 2 | |||
input_shape_range = "input1:[1, 2~3, -1)"; | |||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||
//3 | |||
input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]"; | |||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||
EXPECT_EQ(ret, ge::FAILED); | |||
//4 | |||
input_shape_range = "input1:[1, 2~-3, -1]"; | |||
ret = UpdateDynamicInputShapeRange(graph, input_shape_range); | |||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||
} |