|
|
@@ -47,47 +47,13 @@ using google::protobuf::io::ZeroCopyInputStream; |
|
|
using namespace ge::parser; |
|
|
using namespace ge::parser; |
|
|
|
|
|
|
|
|
namespace { |
|
|
namespace { |
|
|
static std::map<std::string, domiTensorFormat_t> kInputFormatStrToGeformat = { |
|
|
|
|
|
{"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, |
|
|
|
|
|
{"CHWN", domi::DOMI_TENSOR_CHWN}, {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0}, |
|
|
|
|
|
{"NCDHW", domi::DOMI_TENSOR_NCDHW}, {"NDHWC", domi::DOMI_TENSOR_NDHWC}}; |
|
|
|
|
|
|
|
|
|
|
|
// datatype/formats from user to GE, Unified to util interface file later |
|
|
|
|
|
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = { |
|
|
|
|
|
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; |
|
|
|
|
|
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; |
|
|
|
|
|
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; |
|
|
|
|
|
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; |
|
|
|
|
|
const char *const kSplitError1 = "size not equal to 2 split by \":\""; |
|
|
|
|
|
const char *const kEmptyError = "can not be empty"; |
|
|
|
|
|
const char *const kFloatNumError = "exist float number"; |
|
|
|
|
|
const char *const kDigitError = "is not digit"; |
|
|
|
|
|
const std::string kGraphDefaultName = "domi_default"; |
|
|
const std::string kGraphDefaultName = "domi_default"; |
|
|
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; |
|
|
|
|
|
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; |
|
|
|
|
|
static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"}; |
|
|
|
|
|
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; |
|
|
|
|
|
const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; |
|
|
|
|
|
const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; |
|
|
|
|
|
/// The maximum length of the file. |
|
|
/// The maximum length of the file. |
|
|
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 |
|
|
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 |
|
|
const int kMaxFileSizeLimit = INT_MAX; |
|
|
const int kMaxFileSizeLimit = INT_MAX; |
|
|
const int kMaxBuffSize = 256; |
|
|
const int kMaxBuffSize = 256; |
|
|
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. |
|
|
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. |
|
|
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M |
|
|
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M |
|
|
const int kOutputTypeNode = 0; |
|
|
|
|
|
const int kOutputTypeIndex = 1; |
|
|
|
|
|
const int kOutputTypeDataType = 2; |
|
|
|
|
|
|
|
|
|
|
|
vector<string> SplitInputShape(const std::string &input_shape) { |
|
|
|
|
|
vector<string> shape_pair_vec; |
|
|
|
|
|
size_t pos = input_shape.rfind(":"); |
|
|
|
|
|
if (pos != std::string::npos) { |
|
|
|
|
|
shape_pair_vec.emplace_back(input_shape.substr(0, pos)); |
|
|
|
|
|
shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); |
|
|
|
|
|
} |
|
|
|
|
|
return shape_pair_vec; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static string GetSoPath() { |
|
|
static string GetSoPath() { |
|
|
Dl_info dl_info; |
|
|
Dl_info dl_info; |
|
|
@@ -166,51 +132,6 @@ static bool CheckDigitStr(std::string &str) { |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Remove the space and tab before and after the string |
|
|
|
|
|
std::string TrimConf(const std::string &str) { |
|
|
|
|
|
if (str.empty()) { |
|
|
|
|
|
return str; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string::size_type start = str.find_first_not_of(" \t\r\n"); |
|
|
|
|
|
if (start == std::string::npos) { |
|
|
|
|
|
return str; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; |
|
|
|
|
|
return str.substr(start, end); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Parsing the command line |
|
|
|
|
|
bool ParseSingleLine(const std::string &line, std::map<std::string, std::string> &op_conf_map) { |
|
|
|
|
|
std::string temp = TrimConf(line); |
|
|
|
|
|
std::string delimiter = ":"; |
|
|
|
|
|
// Comment or newline returns true directly |
|
|
|
|
|
if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (!temp.empty()) { |
|
|
|
|
|
std::string::size_type pos = temp.find_first_of(delimiter); |
|
|
|
|
|
if (pos == std::string::npos) { |
|
|
|
|
|
GELOGE(ge::PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol", |
|
|
|
|
|
line.c_str(), delimiter.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string map_key = TrimConf(temp.substr(0, pos)); |
|
|
|
|
|
std::string value = TrimConf(temp.substr(pos + 1)); |
|
|
|
|
|
if (map_key.empty() || value.empty()) { |
|
|
|
|
|
GELOGE(ge::PARAM_INVALID, "Map_key or value empty. %s", line.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
op_conf_map[map_key] = value; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
@@ -224,38 +145,6 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static domi::Status CheckOutPutDataTypeSupport(const std::string &output_type) { |
|
|
|
|
|
auto it = kOutputTypeSupportDatatype.find(output_type); |
|
|
|
|
|
if (it == kOutputTypeSupportDatatype.end()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"output_type", output_type, kOutputTypeSupport}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); |
|
|
|
|
|
return domi::PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
return domi::SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static domi::Status StringToInt(std::string &str, int32_t &value) { |
|
|
|
|
|
try { |
|
|
|
|
|
if (!CheckDigitStr(str)) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"output_type", str, "is not positive integer"}); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
value = stoi(str); |
|
|
|
|
|
} catch (std::invalid_argument &) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} catch (std::out_of_range &) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { |
|
|
static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { |
|
|
int32_t out_size = op_desc->GetOutputsSize(); |
|
|
int32_t out_size = op_desc->GetOutputsSize(); |
|
|
if (index < 0 || index >= out_size) { |
|
|
if (index < 0 || index >= out_size) { |
|
|
@@ -272,25 +161,6 @@ static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { |
|
|
return domi::SUCCESS; |
|
|
return domi::SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) { |
|
|
|
|
|
std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes; |
|
|
|
|
|
std::set<std::string> out_nodes_info; |
|
|
|
|
|
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { |
|
|
|
|
|
// out_nodes set should include output_type and output_format |
|
|
|
|
|
std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); |
|
|
|
|
|
out_nodes_info.emplace(tmp); |
|
|
|
|
|
} |
|
|
|
|
|
for (uint32_t i = 0; i < out_type_vec.size(); ++i) { |
|
|
|
|
|
if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"output_type", out_type_vec[i], kOutputTypeError}); |
|
|
|
|
|
GELOGE(domi::FAILED, "Invalid value for output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); |
|
|
|
|
|
return domi::FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return domi::SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::LoadOpsProtoLib() { |
|
|
domi::Status AclGrphParseUtil::LoadOpsProtoLib() { |
|
|
string opsproto_path; |
|
|
string opsproto_path; |
|
|
GetOpsProtoPath(opsproto_path); |
|
|
GetOpsProtoPath(opsproto_path); |
|
|
@@ -367,148 +237,12 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) { |
|
|
|
|
|
if (input_format.empty()) { |
|
|
|
|
|
// Set default format |
|
|
|
|
|
if (ge::GetParserContext().type == domi::TENSORFLOW) { |
|
|
|
|
|
input_format = "NHWC"; |
|
|
|
|
|
} else { |
|
|
|
|
|
input_format = "NCHW"; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} else if (ge::GetParserContext().type == domi::CAFFE) { // caffe |
|
|
|
|
|
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
// only support NCHW ND |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"input_format", input_format, kCaffeFormatSupport}); |
|
|
|
|
|
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport); |
|
|
|
|
|
return false; |
|
|
|
|
|
} else if (ge::GetParserContext().type == domi::TENSORFLOW) { // tf |
|
|
|
|
|
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
// only support NCHW NHWC ND NCDHW NDHWC |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"input_format", input_format, kTFFormatSupport}); |
|
|
|
|
|
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kTFFormatSupport); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclFormat(string &input_format) { |
|
|
|
|
|
ge::GetParserContext().format = domi::DOMI_TENSOR_ND; |
|
|
|
|
|
if (!CheckAclInputFormat(input_format)) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Check input_format failed"); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
if (!input_format.empty()) { |
|
|
|
|
|
auto iter = kInputFormatStrToGeformat.find(input_format); |
|
|
|
|
|
if (iter != kInputFormatStrToGeformat.end()) { |
|
|
|
|
|
ge::GetParserContext().format = iter->second; |
|
|
|
|
|
} else { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", |
|
|
|
|
|
input_format.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool AclGrphParseUtil::ParseInputShape(const string &input_shape, |
|
|
|
|
|
std::unordered_map<string, vector<int64_t>> &shape_map, |
|
|
|
|
|
vector<pair<string, vector<int64_t>>> &user_shape_map) { |
|
|
|
|
|
vector<string> shape_vec = StringUtils::Split(input_shape, ';'); |
|
|
|
|
|
const int DEFAULT_SHAPE_PAIR_SIZE = 2; |
|
|
|
|
|
for (const auto &shape : shape_vec) { |
|
|
|
|
|
vector<string> shape_pair_vec = SplitInputShape(shape); |
|
|
|
|
|
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, |
|
|
|
|
|
{shape, kSplitError1, kInputShapeSample1}); |
|
|
|
|
|
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(), |
|
|
|
|
|
kSplitError1, kInputShapeSample1); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
if (shape_pair_vec[1].empty()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, |
|
|
|
|
|
{shape, kEmptyError, kInputShapeSample1}); |
|
|
|
|
|
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(), |
|
|
|
|
|
kEmptyError, kInputShapeSample1); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); |
|
|
|
|
|
vector<int64_t> shape_values; |
|
|
|
|
|
for (auto &shape_value_str : shape_value_strs) { |
|
|
|
|
|
// stoul: The method may throw an exception: invalid_argument/out_of_range |
|
|
|
|
|
if (std::string::npos != shape_value_str.find('.')) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, |
|
|
|
|
|
{shape, kFloatNumError, kInputShapeSample2}); |
|
|
|
|
|
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", |
|
|
|
|
|
shape.c_str(), kFloatNumError, kInputShapeSample2); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
long left_result = 0; |
|
|
|
|
|
try { |
|
|
|
|
|
left_result = stol(StringUtils::Trim(shape_value_str)); |
|
|
|
|
|
if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { |
|
|
|
|
|
// The value maybe dynamic shape [-1], need substr it and verify isdigit. |
|
|
|
|
|
shape_value_str = shape_value_str.substr(1); |
|
|
|
|
|
} |
|
|
|
|
|
for (char c : shape_value_str) { |
|
|
|
|
|
if (!isdigit(c)) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, |
|
|
|
|
|
{shape, kDigitError, kInputShapeSample2}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "input_shape's shape value[%s] is not digit", shape_value_str.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} catch (const std::out_of_range &) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, |
|
|
|
|
|
{"input_shape", shape_value_str}); |
|
|
|
|
|
GELOGW("Input parameter[input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} catch (const std::invalid_argument &) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, |
|
|
|
|
|
{"input_shape", shape_value_str}); |
|
|
|
|
|
GELOGW("Input parameter[input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} catch (...) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, |
|
|
|
|
|
{"input_shape", shape_value_str}); |
|
|
|
|
|
GELOGW("Input parameter[input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
int64_t result = left_result; |
|
|
|
|
|
shape_values.push_back(result); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); |
|
|
|
|
|
user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Parse user input shape info |
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape) { |
|
|
|
|
|
ge::GetParserContext().input_dims.clear(); |
|
|
|
|
|
ge::GetParserContext().user_input_dims.clear(); |
|
|
|
|
|
|
|
|
|
|
|
if (input_shape.empty()) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::unordered_map<string, vector<int64_t>> &shape_map = ge::GetParserContext().input_dims; |
|
|
|
|
|
if (!ParseInputShape(input_shape, ge::GetParserContext().input_dims, ge::GetParserContext().user_input_dims) || |
|
|
|
|
|
shape_map.empty()) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
|
|
|
void AclGrphParseUtil::SetDefaultFormat() { |
|
|
|
|
|
if (ge::GetParserContext().type == domi::TENSORFLOW) { |
|
|
|
|
|
ge::GetParserContext().format = "NHWC"; |
|
|
|
|
|
} else { |
|
|
|
|
|
ge::GetParserContext().format = "NCHW"; |
|
|
} |
|
|
} |
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { |
|
|
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { |
|
|
@@ -600,41 +334,6 @@ domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_ou |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclOpConf(const std::string &op_conf) { |
|
|
|
|
|
if (op_conf.empty()) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
// Normalize the path |
|
|
|
|
|
string resolved_file_path = ge::parser::RealPath(op_conf.c_str()); |
|
|
|
|
|
if (resolved_file_path.empty()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"op_map_conf", op_conf}); |
|
|
|
|
|
GELOGE(domi::FAILED, "Invalid input file path [%s], make sure that the file path is correct.", op_conf.c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
std::ifstream fs(resolved_file_path, std::ifstream::in); |
|
|
|
|
|
|
|
|
|
|
|
if (!fs.is_open()) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Open %s failed.", op_conf.c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string line; |
|
|
|
|
|
|
|
|
|
|
|
while (getline(fs, line)) { // line not with \n |
|
|
|
|
|
if (!ParseSingleLine(line, ge::GetParserContext().op_conf_map)) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, |
|
|
|
|
|
{"op_map_conf_line_info", line}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str()); |
|
|
|
|
|
fs.close(); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
fs.close(); // close the file |
|
|
|
|
|
|
|
|
|
|
|
GELOGI("LoadFileContent success."); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) { |
|
|
domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) { |
|
|
ge::GetParserContext().enable_scope_fusion_passes.clear(); |
|
|
ge::GetParserContext().enable_scope_fusion_passes.clear(); |
|
|
if (enable_scope_fusion_passes.empty()) { |
|
|
if (enable_scope_fusion_passes.empty()) { |
|
|
@@ -746,54 +445,6 @@ domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type, |
|
|
|
|
|
std::map<std::string, vector<std::string>> &output_node_dt_map) { |
|
|
|
|
|
if (output_type.find(':') == std::string::npos) { |
|
|
|
|
|
GELOGI("output_type is not multiple nodes, means all out nodes"); |
|
|
|
|
|
return CheckOutPutDataTypeSupport(output_type); |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<std::string> out_type_vec; |
|
|
|
|
|
vector<string> nodes_v = StringUtils::Split(output_type, ';'); |
|
|
|
|
|
for (const string &node : nodes_v) { |
|
|
|
|
|
vector<string> node_index_type_v = StringUtils::Split(node, ':'); |
|
|
|
|
|
if (node_index_type_v.size() != 3) { // The size must be 3. |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"output_type", node, kOutputTypeSample}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", node.c_str(), kOutputTypeSample); |
|
|
|
|
|
return domi::FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
ge::DataType tmp_dt; |
|
|
|
|
|
std::string node_name = StringUtils::Trim(node_index_type_v[kOutputTypeNode]); |
|
|
|
|
|
std::string index_str = StringUtils::Trim(node_index_type_v[kOutputTypeIndex]); |
|
|
|
|
|
int32_t index; |
|
|
|
|
|
if (StringToInt(index_str, index) != SUCCESS) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); |
|
|
|
|
|
return domi::FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
std::string dt_value = StringUtils::Trim(node_index_type_v[kOutputTypeDataType]); |
|
|
|
|
|
auto it = kOutputTypeSupportDatatype.find(dt_value); |
|
|
|
|
|
if (it == kOutputTypeSupportDatatype.end()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"output_type", dt_value, kOutputTypeSupport}); |
|
|
|
|
|
GELOGE(ge::PARAM_INVALID, "Invalid value for output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); |
|
|
|
|
|
return domi::FAILED; |
|
|
|
|
|
} else { |
|
|
|
|
|
tmp_dt = it->second; |
|
|
|
|
|
} |
|
|
|
|
|
out_type_vec.push_back(node_name + ":" + index_str); |
|
|
|
|
|
std::string index_dt_str = index_str + ":" + TypeUtils::DataTypeToSerialString(tmp_dt); |
|
|
|
|
|
auto it1 = output_node_dt_map.find(node_name); |
|
|
|
|
|
if (it1 == output_node_dt_map.end()) { |
|
|
|
|
|
vector<string> tmp_vec; |
|
|
|
|
|
tmp_vec.push_back(index_dt_str); |
|
|
|
|
|
output_node_dt_map.emplace(node_name, tmp_vec); |
|
|
|
|
|
} else { |
|
|
|
|
|
it1->second.push_back(index_dt_str); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return VerifyOutputTypeAndOutNodes(out_type_vec); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, |
|
|
void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, |
|
|
std::vector<std::string> &output_nodes_name) { |
|
|
std::vector<std::string> &output_nodes_name) { |
|
|
output_nodes_name.clear(); |
|
|
output_nodes_name.clear(); |
|
|
@@ -880,20 +531,10 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, |
|
|
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); |
|
|
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); |
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
|
string output_type; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::OUTPUT_TYPE, output_type); |
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes; |
|
|
std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes; |
|
|
std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats; |
|
|
std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats; |
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info; |
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info; |
|
|
std::vector<std::string> output_nodes_name; |
|
|
std::vector<std::string> output_nodes_name; |
|
|
std::map<std::string, vector<std::string>> output_node_dt_map; |
|
|
|
|
|
if (!output_type.empty()) { |
|
|
|
|
|
if (ParseAclOutputType(output_type, output_node_dt_map) != SUCCESS) { |
|
|
|
|
|
GELOGE(domi::FAILED, "Parse output_type failed."); |
|
|
|
|
|
return domi::FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// User declared outputs |
|
|
// User declared outputs |
|
|
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { |
|
|
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { |
|
|
@@ -923,11 +564,6 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, |
|
|
(void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); |
|
|
(void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
auto it = output_node_dt_map.find(user_out_nodes[i].first); |
|
|
|
|
|
if (it != output_node_dt_map.end()) { |
|
|
|
|
|
GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); |
|
|
|
|
|
(void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_data_type", it->second); |
|
|
|
|
|
} |
|
|
|
|
|
output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); |
|
|
output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); |
|
|
} |
|
|
} |
|
|
// default output node (leaf) |
|
|
// default output node (leaf) |
|
|
@@ -944,34 +580,6 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, |
|
|
return domi::SUCCESS; |
|
|
return domi::SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) { |
|
|
|
|
|
if (log.empty()) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
int ret = -1; |
|
|
|
|
|
if (log == "default") { |
|
|
|
|
|
ret = 0; |
|
|
|
|
|
} else if (log == "null") { |
|
|
|
|
|
ret = dlog_setlevel(-1, DLOG_NULL, 0); |
|
|
|
|
|
} else if (log == "debug") { |
|
|
|
|
|
ret = dlog_setlevel(-1, DLOG_DEBUG, 1); |
|
|
|
|
|
} else if (log == "info") { |
|
|
|
|
|
ret = dlog_setlevel(-1, DLOG_INFO, 1); |
|
|
|
|
|
} else if (log == "warning") { |
|
|
|
|
|
ret = dlog_setlevel(-1, DLOG_WARN, 1); |
|
|
|
|
|
} else if (log == "error") { |
|
|
|
|
|
ret = dlog_setlevel(-1, DLOG_ERROR, 1); |
|
|
|
|
|
} else { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"log", log}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid value for log:%s, only support debug, info, warning, error, null", log.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
if (ret != 0) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Log setlevel fail !"); |
|
|
|
|
|
} |
|
|
|
|
|
return domi::SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) { |
|
|
domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) { |
|
|
for (auto &ele : parser_params) { |
|
|
for (auto &ele : parser_params) { |
|
|
const char *key_ascend = ele.first.GetString(); |
|
|
const char *key_ascend = ele.first.GetString(); |
|
|
@@ -993,73 +601,13 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::CheckAclInputShapeNode(const ComputeGraphPtr &graph) { |
|
|
|
|
|
for (auto it : ge::GetParserContext().user_input_dims) { |
|
|
|
|
|
std::string node_name = it.first; |
|
|
|
|
|
ge::NodePtr node = graph->FindNode(node_name); |
|
|
|
|
|
if (node == nullptr) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not exist in model", node_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
if (node->GetType() != ge::parser::DATA) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not a input opname", node_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { |
|
|
|
|
|
GE_CHECK_NOTNULL(graph); |
|
|
|
|
|
unordered_map<string, string> graphNodeTypes; |
|
|
|
|
|
for (const NodePtr &node : graph->GetAllNodes()) { |
|
|
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
|
|
if (op_desc == nullptr) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid parameter for opDesc."); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
graphNodeTypes[op_desc->GetType()] = ""; |
|
|
|
|
|
} |
|
|
|
|
|
std::map<std::string, std::string> &propertiesMap = ge::GetParserContext().op_conf_map; |
|
|
|
|
|
if (propertiesMap.empty()) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"op_name_map", op_conf, "the file content is empty"}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!"); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
|
|
"E10003", {"parameter", "value", "reason"}, |
|
|
|
|
|
{"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, |
|
|
domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, |
|
|
string &graph_name) { |
|
|
string &graph_name) { |
|
|
GELOGI("Parse graph user options start."); |
|
|
GELOGI("Parse graph user options start."); |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID, |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID, |
|
|
"Parse paragrams invalid."); |
|
|
"Parse paragrams invalid."); |
|
|
// support paragrams: log, input_format, input_shape, out_nodes |
|
|
|
|
|
// is_output_adjust_hw_layout, output, op_name_map, enable_scope_fusion_passes |
|
|
|
|
|
string log_level; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::LOG_LEVEL, log_level); |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclLogLevel(log_level) != SUCCESS, return PARAM_INVALID, |
|
|
|
|
|
"Parse log_level failed"); |
|
|
|
|
|
|
|
|
|
|
|
string input_format; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::INPUT_FORMAT, input_format); |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclFormat(input_format) != SUCCESS, return PARAM_INVALID, |
|
|
|
|
|
"Parse input_format failed"); |
|
|
|
|
|
|
|
|
|
|
|
string input_shape; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::INPUT_SHAPE, input_shape); |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclShape(input_shape) != SUCCESS, return PARAM_INVALID, |
|
|
|
|
|
"Parse input_shape failed"); |
|
|
|
|
|
|
|
|
// support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes |
|
|
|
|
|
SetDefaultFormat(); |
|
|
|
|
|
|
|
|
string out_nodes; |
|
|
string out_nodes; |
|
|
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); |
|
|
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); |
|
|
@@ -1071,11 +619,6 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, |
|
|
return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed"); |
|
|
return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed"); |
|
|
|
|
|
|
|
|
string op_conf_str; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str); |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOpConf(op_conf_str) != SUCCESS, return PARAM_INVALID, |
|
|
|
|
|
"Parse op_name_map failed"); |
|
|
|
|
|
|
|
|
|
|
|
string tmp_name; |
|
|
string tmp_name; |
|
|
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); |
|
|
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); |
|
|
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name; |
|
|
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name; |
|
|
@@ -1103,19 +646,10 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, |
|
|
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, |
|
|
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, |
|
|
return PARAM_INVALID, "Parse input_fp16_nodes failed"); |
|
|
return PARAM_INVALID, "Parse input_fp16_nodes failed"); |
|
|
|
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclInputShapeNode(compute_graph) != SUCCESS, |
|
|
|
|
|
return PARAM_INVALID, "Check nodes input_shape info failed"); |
|
|
|
|
|
|
|
|
|
|
|
string compress_weight_conf; |
|
|
string compress_weight_conf; |
|
|
GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); |
|
|
GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, |
|
|
return PARAM_INVALID, "Parse compress_weight_conf failed"); |
|
|
return PARAM_INVALID, "Parse compress_weight_conf failed"); |
|
|
string op_conf_str; |
|
|
|
|
|
GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str); |
|
|
|
|
|
if (!op_conf_str.empty()) { |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclOpNameMap(compute_graph, op_conf_str) != SUCCESS, return PARAM_INVALID, |
|
|
|
|
|
"Check op_name_map info failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|