Browse Source

Feature: delete several para of aclgrphParse interface

pull/150/head
l00444296 5 years ago
parent
commit
c9c671ca21
2 changed files with 8 additions and 485 deletions
  1. +7
    -473
      parser/common/acl_graph_parser_util.cc
  2. +1
    -12
      parser/common/acl_graph_parser_util.h

+ 7
- 473
parser/common/acl_graph_parser_util.cc View File

@@ -47,47 +47,13 @@ using google::protobuf::io::ZeroCopyInputStream;
using namespace ge::parser; using namespace ge::parser;


namespace { namespace {
static std::map<std::string, domiTensorFormat_t> kInputFormatStrToGeformat = {
{"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC},
{"CHWN", domi::DOMI_TENSOR_CHWN}, {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0},
{"NCDHW", domi::DOMI_TENSOR_NCDHW}, {"NDHWC", domi::DOMI_TENSOR_NDHWC}};

// datatype/formats from user to GE, Unified to util interface file later
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = {
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}};
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\"";
const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit";
const std::string kGraphDefaultName = "domi_default"; const std::string kGraphDefaultName = "domi_default";
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"";
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";
static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"};
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
/// The maximum length of the file. /// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX; const int kMaxFileSizeLimit = INT_MAX;
const int kMaxBuffSize = 256; const int kMaxBuffSize = 256;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
const int kOutputTypeNode = 0;
const int kOutputTypeIndex = 1;
const int kOutputTypeDataType = 2;

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
size_t pos = input_shape.rfind(":");
if (pos != std::string::npos) {
shape_pair_vec.emplace_back(input_shape.substr(0, pos));
shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos));
}
return shape_pair_vec;
}


static string GetSoPath() { static string GetSoPath() {
Dl_info dl_info; Dl_info dl_info;
@@ -166,51 +132,6 @@ static bool CheckDigitStr(std::string &str) {
} }
return true; return true;
} }

// Remove the space and tab before and after the string
std::string TrimConf(const std::string &str) {
if (str.empty()) {
return str;
}

std::string::size_type start = str.find_first_not_of(" \t\r\n");
if (start == std::string::npos) {
return str;
}

std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
return str.substr(start, end);
}

// Parsing the command line
bool ParseSingleLine(const std::string &line, std::map<std::string, std::string> &op_conf_map) {
std::string temp = TrimConf(line);
std::string delimiter = ":";
// Comment or newline returns true directly
if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') {
return true;
}

if (!temp.empty()) {
std::string::size_type pos = temp.find_first_of(delimiter);
if (pos == std::string::npos) {
GELOGE(ge::PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol",
line.c_str(), delimiter.c_str());
return false;
}

std::string map_key = TrimConf(temp.substr(0, pos));
std::string value = TrimConf(temp.substr(pos + 1));
if (map_key.empty() || value.empty()) {
GELOGE(ge::PARAM_INVALID, "Map_key or value empty. %s", line.c_str());
return false;
}

op_conf_map[map_key] = value;
}
return true;
}

} // namespace } // namespace


namespace ge { namespace ge {
@@ -224,38 +145,6 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p
} }
} }


static domi::Status CheckOutPutDataTypeSupport(const std::string &output_type) {
auto it = kOutputTypeSupportDatatype.find(output_type);
if (it == kOutputTypeSupportDatatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"output_type", output_type, kOutputTypeSupport});
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
return domi::PARAM_INVALID;
}
return domi::SUCCESS;
}

static domi::Status StringToInt(std::string &str, int32_t &value) {
try {
if (!CheckDigitStr(str)) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"output_type", str, "is not positive integer"});
return PARAM_INVALID;
}
value = stoi(str);
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
}
return SUCCESS;
}

static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
int32_t out_size = op_desc->GetOutputsSize(); int32_t out_size = op_desc->GetOutputsSize();
if (index < 0 || index >= out_size) { if (index < 0 || index >= out_size) {
@@ -272,25 +161,6 @@ static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
return domi::SUCCESS; return domi::SUCCESS;
} }


domi::Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) {
std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes;
std::set<std::string> out_nodes_info;
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
// out_nodes set should include output_type and output_format
std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second);
out_nodes_info.emplace(tmp);
}
for (uint32_t i = 0; i < out_type_vec.size(); ++i) {
if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"output_type", out_type_vec[i], kOutputTypeError});
GELOGE(domi::FAILED, "Invalid value for output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
return domi::FAILED;
}
}
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::LoadOpsProtoLib() { domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
string opsproto_path; string opsproto_path;
GetOpsProtoPath(opsproto_path); GetOpsProtoPath(opsproto_path);
@@ -367,148 +237,12 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
return SUCCESS; return SUCCESS;
} }


bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
if (input_format.empty()) {
// Set default format
if (ge::GetParserContext().type == domi::TENSORFLOW) {
input_format = "NHWC";
} else {
input_format = "NCHW";
}
return true;
} else if (ge::GetParserContext().type == domi::CAFFE) { // caffe
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) {
return true;
}
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"input_format", input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
return false;
} else if (ge::GetParserContext().type == domi::TENSORFLOW) { // tf
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) {
return true;
}
// only support NCHW NHWC ND NCDHW NDHWC
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"input_format", input_format, kTFFormatSupport});
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kTFFormatSupport);
return false;
}
return true;
}

domi::Status AclGrphParseUtil::ParseAclFormat(string &input_format) {
ge::GetParserContext().format = domi::DOMI_TENSOR_ND;
if (!CheckAclInputFormat(input_format)) {
GELOGE(PARAM_INVALID, "Check input_format failed");
return PARAM_INVALID;
}
if (!input_format.empty()) {
auto iter = kInputFormatStrToGeformat.find(input_format);
if (iter != kInputFormatStrToGeformat.end()) {
ge::GetParserContext().format = iter->second;
} else {
GELOGE(PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.",
input_format.c_str());
return PARAM_INVALID;
}
}
return SUCCESS;
}

bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
std::unordered_map<string, vector<int64_t>> &shape_map,
vector<pair<string, vector<int64_t>>> &user_shape_map) {
vector<string> shape_vec = StringUtils::Split(input_shape, ';');
const int DEFAULT_SHAPE_PAIR_SIZE = 2;
for (const auto &shape : shape_vec) {
vector<string> shape_pair_vec = SplitInputShape(shape);
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kSplitError1, kInputShapeSample1});
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(),
kSplitError1, kInputShapeSample1);
return false;
}
if (shape_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kEmptyError, kInputShapeSample1});
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(),
kEmptyError, kInputShapeSample1);
return false;
}

vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ',');
vector<int64_t> shape_values;
for (auto &shape_value_str : shape_value_strs) {
// stoul: The method may throw an exception: invalid_argument/out_of_range
if (std::string::npos != shape_value_str.find('.')) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kFloatNumError, kInputShapeSample2});
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kFloatNumError, kInputShapeSample2);
return false;
}

long left_result = 0;
try {
left_result = stol(StringUtils::Trim(shape_value_str));
if (!shape_value_str.empty() && (shape_value_str.front() == '-')) {
// The value maybe dynamic shape [-1], need substr it and verify isdigit.
shape_value_str = shape_value_str.substr(1);
}
for (char c : shape_value_str) {
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kDigitError, kInputShapeSample2});
GELOGE(PARAM_INVALID, "input_shape's shape value[%s] is not digit", shape_value_str.c_str());
return false;
}
}
} catch (const std::out_of_range &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
return false;
} catch (const std::invalid_argument &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
return false;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
return false;
}
int64_t result = left_result;
shape_values.push_back(result);
}

shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
}

return true;
}

// Parse user input shape info
domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape) {
ge::GetParserContext().input_dims.clear();
ge::GetParserContext().user_input_dims.clear();

if (input_shape.empty()) {
return SUCCESS;
}

std::unordered_map<string, vector<int64_t>> &shape_map = ge::GetParserContext().input_dims;
if (!ParseInputShape(input_shape, ge::GetParserContext().input_dims, ge::GetParserContext().user_input_dims) ||
shape_map.empty()) {
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str());
return PARAM_INVALID;
void AclGrphParseUtil::SetDefaultFormat() {
if (ge::GetParserContext().type == domi::TENSORFLOW) {
ge::GetParserContext().format = "NHWC";
} else {
ge::GetParserContext().format = "NCHW";
} }
return SUCCESS;
} }


domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
@@ -600,41 +334,6 @@ domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_ou
return SUCCESS; return SUCCESS;
} }


domi::Status AclGrphParseUtil::ParseAclOpConf(const std::string &op_conf) {
if (op_conf.empty()) {
return SUCCESS;
}
// Normalize the path
string resolved_file_path = ge::parser::RealPath(op_conf.c_str());
if (resolved_file_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"op_map_conf", op_conf});
GELOGE(domi::FAILED, "Invalid input file path [%s], make sure that the file path is correct.", op_conf.c_str());
return FAILED;
}
std::ifstream fs(resolved_file_path, std::ifstream::in);

if (!fs.is_open()) {
GELOGE(PARAM_INVALID, "Open %s failed.", op_conf.c_str());
return FAILED;
}

std::string line;

while (getline(fs, line)) { // line not with \n
if (!ParseSingleLine(line, ge::GetParserContext().op_conf_map)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"op_map_conf_line_info", line});
GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str());
fs.close();
return FAILED;
}
}
fs.close(); // close the file

GELOGI("LoadFileContent success.");
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) { domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) {
ge::GetParserContext().enable_scope_fusion_passes.clear(); ge::GetParserContext().enable_scope_fusion_passes.clear();
if (enable_scope_fusion_passes.empty()) { if (enable_scope_fusion_passes.empty()) {
@@ -746,54 +445,6 @@ domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr
return SUCCESS; return SUCCESS;
} }


domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type,
std::map<std::string, vector<std::string>> &output_node_dt_map) {
if (output_type.find(':') == std::string::npos) {
GELOGI("output_type is not multiple nodes, means all out nodes");
return CheckOutPutDataTypeSupport(output_type);
}
std::vector<std::string> out_type_vec;
vector<string> nodes_v = StringUtils::Split(output_type, ';');
for (const string &node : nodes_v) {
vector<string> node_index_type_v = StringUtils::Split(node, ':');
if (node_index_type_v.size() != 3) { // The size must be 3.
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"output_type", node, kOutputTypeSample});
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", node.c_str(), kOutputTypeSample);
return domi::FAILED;
}
ge::DataType tmp_dt;
std::string node_name = StringUtils::Trim(node_index_type_v[kOutputTypeNode]);
std::string index_str = StringUtils::Trim(node_index_type_v[kOutputTypeIndex]);
int32_t index;
if (StringToInt(index_str, index) != SUCCESS) {
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
return domi::FAILED;
}
std::string dt_value = StringUtils::Trim(node_index_type_v[kOutputTypeDataType]);
auto it = kOutputTypeSupportDatatype.find(dt_value);
if (it == kOutputTypeSupportDatatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"output_type", dt_value, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID, "Invalid value for output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
return domi::FAILED;
} else {
tmp_dt = it->second;
}
out_type_vec.push_back(node_name + ":" + index_str);
std::string index_dt_str = index_str + ":" + TypeUtils::DataTypeToSerialString(tmp_dt);
auto it1 = output_node_dt_map.find(node_name);
if (it1 == output_node_dt_map.end()) {
vector<string> tmp_vec;
tmp_vec.push_back(index_dt_str);
output_node_dt_map.emplace(node_name, tmp_vec);
} else {
it1->second.push_back(index_dt_str);
}
}
return VerifyOutputTypeAndOutNodes(out_type_vec);
}

void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) { std::vector<std::string> &output_nodes_name) {
output_nodes_name.clear(); output_nodes_name.clear();
@@ -880,20 +531,10 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);


string output_type;
GetAclParams(parser_params, ge::ir_option::OUTPUT_TYPE, output_type);

std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes; std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes;
std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats; std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats;
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info; std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
std::vector<std::string> output_nodes_name; std::vector<std::string> output_nodes_name;
std::map<std::string, vector<std::string>> output_node_dt_map;
if (!output_type.empty()) {
if (ParseAclOutputType(output_type, output_node_dt_map) != SUCCESS) {
GELOGE(domi::FAILED, "Parse output_type failed.");
return domi::FAILED;
}
}


// User declared outputs // User declared outputs
for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
@@ -923,11 +564,6 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
(void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec);
} }
} }
auto it = output_node_dt_map.find(user_out_nodes[i].first);
if (it != output_node_dt_map.end()) {
GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str());
(void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_data_type", it->second);
}
output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second));
} }
// default output node (leaf) // default output node (leaf)
@@ -944,34 +580,6 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
return domi::SUCCESS; return domi::SUCCESS;
} }


domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) {
if (log.empty()) {
return SUCCESS;
}
int ret = -1;
if (log == "default") {
ret = 0;
} else if (log == "null") {
ret = dlog_setlevel(-1, DLOG_NULL, 0);
} else if (log == "debug") {
ret = dlog_setlevel(-1, DLOG_DEBUG, 1);
} else if (log == "info") {
ret = dlog_setlevel(-1, DLOG_INFO, 1);
} else if (log == "warning") {
ret = dlog_setlevel(-1, DLOG_WARN, 1);
} else if (log == "error") {
ret = dlog_setlevel(-1, DLOG_ERROR, 1);
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"log", log});
GELOGE(PARAM_INVALID, "Invalid value for log:%s, only support debug, info, warning, error, null", log.c_str());
return PARAM_INVALID;
}
if (ret != 0) {
GELOGE(PARAM_INVALID, "Log setlevel fail !");
}
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) { domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) {
for (auto &ele : parser_params) { for (auto &ele : parser_params) {
const char *key_ascend = ele.first.GetString(); const char *key_ascend = ele.first.GetString();
@@ -993,73 +601,13 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS
return SUCCESS; return SUCCESS;
} }


domi::Status AclGrphParseUtil::CheckAclInputShapeNode(const ComputeGraphPtr &graph) {
for (auto it : ge::GetParserContext().user_input_dims) {
std::string node_name = it.first;
ge::NodePtr node = graph->FindNode(node_name);
if (node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not exist in model", node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != ge::parser::DATA) {
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not a input opname", node_name.c_str());
return PARAM_INVALID;
}
}
return SUCCESS;
}

domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) {
GE_CHECK_NOTNULL(graph);
unordered_map<string, string> graphNodeTypes;
for (const NodePtr &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(PARAM_INVALID, "Invalid parameter for opDesc.");
return PARAM_INVALID;
}
graphNodeTypes[op_desc->GetType()] = "";
}
std::map<std::string, std::string> &propertiesMap = ge::GetParserContext().op_conf_map;
if (propertiesMap.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
{"op_name_map", op_conf, "the file content is empty"});
GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!");
return PARAM_INVALID;
}
for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) {
GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(),
ErrorManager::GetInstance().ATCReportErrMessage(
"E10003", {"parameter", "value", "reason"},
{"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"});
GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;);
}
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
string &graph_name) { string &graph_name) {
GELOGI("Parse graph user options start."); GELOGI("Parse graph user options start.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID,
"Parse paragrams invalid."); "Parse paragrams invalid.");
// support paragrams: log, input_format, input_shape, out_nodes
// is_output_adjust_hw_layout, output, op_name_map, enable_scope_fusion_passes
string log_level;
GetAclParams(parser_params, ge::ir_option::LOG_LEVEL, log_level);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclLogLevel(log_level) != SUCCESS, return PARAM_INVALID,
"Parse log_level failed");

string input_format;
GetAclParams(parser_params, ge::ir_option::INPUT_FORMAT, input_format);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclFormat(input_format) != SUCCESS, return PARAM_INVALID,
"Parse input_format failed");

string input_shape;
GetAclParams(parser_params, ge::ir_option::INPUT_SHAPE, input_shape);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclShape(input_shape) != SUCCESS, return PARAM_INVALID,
"Parse input_shape failed");
// support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes
SetDefaultFormat();


string out_nodes; string out_nodes;
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes);
@@ -1071,11 +619,6 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS,
return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed"); return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed");


string op_conf_str;
GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOpConf(op_conf_str) != SUCCESS, return PARAM_INVALID,
"Parse op_name_map failed");

string tmp_name; string tmp_name;
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name);
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name; graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name;
@@ -1103,19 +646,10 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS,
return PARAM_INVALID, "Parse input_fp16_nodes failed"); return PARAM_INVALID, "Parse input_fp16_nodes failed");


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclInputShapeNode(compute_graph) != SUCCESS,
return PARAM_INVALID, "Check nodes input_shape info failed");

string compress_weight_conf; string compress_weight_conf;
GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS,
return PARAM_INVALID, "Parse compress_weight_conf failed"); return PARAM_INVALID, "Parse compress_weight_conf failed");
string op_conf_str;
GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str);
if (!op_conf_str.empty()) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclOpNameMap(compute_graph, op_conf_str) != SUCCESS, return PARAM_INVALID,
"Check op_name_map info failed");
}


return SUCCESS; return SUCCESS;
} }


+ 1
- 12
parser/common/acl_graph_parser_util.h View File

@@ -45,7 +45,6 @@ class AclGrphParseUtil {
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
std::string &graph_name); std::string &graph_name);
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseOutputInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);


private: private:
bool parser_initialized = false; bool parser_initialized = false;
@@ -53,27 +52,17 @@ class AclGrphParseUtil {
domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name); std::vector<std::string> &output_nodes_name);
domi::Status ParseAclLogLevel(const std::string &log);
bool CheckAclInputFormat(string &input_format);
domi::Status ParseAclFormat(std::string &input_format);
bool ParseInputShape(const std::string &input_shape, std::unordered_map<std::string, vector<int64_t>> &shape_map,
vector<pair<std::string, vector<int64_t>>> &user_shape_map);
domi::Status ParseAclShape(const std::string &input_shape);
void SetDefaultFormat();
domi::Status ParseAclOutputNodes(const std::string &out_nodes); domi::Status ParseAclOutputNodes(const std::string &out_nodes);
domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16);
domi::Status ParseAclOpConf(const std::string &op_conf);
domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes);
static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name, static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name,
uint32_t index, OpDescPtr &op_desc); uint32_t index, OpDescPtr &op_desc);
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout); const string &is_input_adjust_hw_layout);
domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf);
domi::Status ParseAclOutputType(const std::string &output_type,
std::map<std::string, vector<std::string>> &output_node_dt_map);
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph);
domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf);
}; };


namespace parser { namespace parser {


Loading…
Cancel
Save