|
|
@@ -251,17 +251,24 @@ class Impl { |
|
|
|
omg_context_.dynamic_batch_size.clear(); |
|
|
|
omg_context_.dynamic_image_size.clear(); |
|
|
|
omg_context_.dynamic_dims.clear(); |
|
|
|
omg_context_.user_attr_index_valid = false; |
|
|
|
}; |
|
|
|
~Impl() { (void)generator_.Finalize(); }; |
|
|
|
graphStatus CheckOptions(const std::map<std::string, std::string> &options); |
|
|
|
graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs); |
|
|
|
graphStatus UpdateDataOpAttr(const Graph &graph); |
|
|
|
graphStatus CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range); |
|
|
|
graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options); |
|
|
|
graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options, |
|
|
|
ModelBufferData &ge_models); |
|
|
|
graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, |
|
|
|
bool is_dynamic_input); |
|
|
|
graphStatus GetInputShapeRange(const string &input_shape_range, |
|
|
|
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map, |
|
|
|
std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map); |
|
|
|
static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph); |
|
|
|
bool GetUsrAttrIndexValidFlag(); |
|
|
|
bool IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph, size_t &data_num, vector<int64_t> &attr_index); |
|
|
|
void SetRtSocVersion(); |
|
|
|
void UpdateThreadContext(); |
|
|
|
void LoadOpsProto(); |
|
|
@@ -288,11 +295,113 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) { |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool Impl::GetUsrAttrIndexValidFlag() { |
|
|
|
return omg_context_.user_attr_index_valid; |
|
|
|
} |
|
|
|
|
|
|
|
bool Impl::IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph, |
|
|
|
size_t &data_num, |
|
|
|
vector<int64_t> &attr_index) { |
|
|
|
bool all_zero_flag = true; |
|
|
|
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(input_node); |
|
|
|
ge::OpDescPtr op = input_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op); |
|
|
|
if (op->GetType() == DATA) { |
|
|
|
data_num++; |
|
|
|
GeAttrValue::INT index = 0; |
|
|
|
if (AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { |
|
|
|
if (index != 0) { |
|
|
|
all_zero_flag = false; |
|
|
|
} |
|
|
|
attr_index.push_back(index); |
|
|
|
} else { |
|
|
|
GELOGW("[Get][AttrIndex] Get index[%ld] failed for op[%s].", index, op->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (data_num > 1 && attr_index.size() == data_num && all_zero_flag) { |
|
|
|
GELOGI("Attr indexes are not set by user."); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus Impl::GetInputShapeRange(const string &input_shape_range, |
|
|
|
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map, |
|
|
|
std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map) { |
|
|
|
if (input_shape_range.empty()) { |
|
|
|
GELOGI("Input shape range is empty."); |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
Status ret = GRAPH_PARAM_INVALID; |
|
|
|
if (input_shape_range.find(":") != string::npos) { |
|
|
|
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); |
|
|
|
} else { |
|
|
|
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); |
|
|
|
} |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); |
|
|
|
return GRAPH_PARAM_INVALID; |
|
|
|
} |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus Impl::CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range) { |
|
|
|
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
// when set input shape range by index, user must set data attr index, eg. "[1, 3, 3, -1],[1, 3~5, 6, -1]" |
|
|
|
bool index_input_shape_range_flag = !input_shape_range.empty() && (input_shape_range.find(":") == string::npos); |
|
|
|
size_t data_num = 0; |
|
|
|
vector<int64_t> attr_index; |
|
|
|
if (!IsAttrIndexSetByUser(compute_graph, data_num, attr_index)) { |
|
|
|
if (index_input_shape_range_flag) { |
|
|
|
std::string situation = "Data op index"; |
|
|
|
std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], " |
|
|
|
"when set input shape range by index."; |
|
|
|
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}), |
|
|
|
std::vector<std::string>({situation, reason})); |
|
|
|
GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Data op index is not set, total data op num[%ld], " |
|
|
|
"when set input shape range by index.", data_num); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
omg_context_.user_attr_index_valid = true; |
|
|
|
for (size_t i = 0; i < data_num; ++i) { |
|
|
|
if (std::find(attr_index.begin(), attr_index.end(), i) == attr_index.end()) { |
|
|
|
omg_context_.user_attr_index_valid = false; |
|
|
|
if (index_input_shape_range_flag) { |
|
|
|
std::string situation = "Data op index[" + std::to_string(i) + "]"; |
|
|
|
std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], " |
|
|
|
"when set input shape range by index"; |
|
|
|
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}), |
|
|
|
std::vector<std::string>({situation, reason})); |
|
|
|
GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld], " |
|
|
|
"when set input shape range by index", i, data_num); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} else { |
|
|
|
GELOGW("[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld].", i, data_num); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGI("Data op attr indexes are set by user and valid."); |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { |
|
|
|
GELOGD("Enter Update Data Attr Process!"); |
|
|
|
std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape]; |
|
|
|
std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange]; |
|
|
|
|
|
|
|
graphStatus ret = CheckDataOpAttrIndexValid(graph, input_shape_range); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "[Check][DataOpAttrIndex] fail, shape range[%s].", input_shape_range.c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
map<string, vector<int64_t>> shape_map; |
|
|
|
vector<pair<string, vector<int64_t>>> user_shape_map; |
|
|
|
if (!input_shape.empty()) { |
|
|
@@ -301,20 +410,13 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { |
|
|
|
} |
|
|
|
std::map<string, std::vector<std::pair<int64_t, int64_t>>> name_shape_range_map; |
|
|
|
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map; |
|
|
|
if (!input_shape_range.empty()) { |
|
|
|
Status ret = GRAPH_PARAM_INVALID; |
|
|
|
if (input_shape_range.find(":") != string::npos) { |
|
|
|
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); |
|
|
|
} else { |
|
|
|
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); |
|
|
|
} |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); |
|
|
|
return GRAPH_PARAM_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
ret = GetInputShapeRange(input_shape_range, name_shape_range_map, index_shape_range_map); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "[Get][InputShapeRange] fail, shape range[%s].", input_shape_range.c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(input_node); |
|
|
|
ge::OpDescPtr op = input_node->GetOpDesc(); |
|
|
@@ -495,7 +597,9 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTe |
|
|
|
ge::OpDescPtr op = input_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op); |
|
|
|
if (op->GetType() == DATA) { |
|
|
|
(void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); |
|
|
|
if (!GetUsrAttrIndexValidFlag()) { |
|
|
|
(void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); |
|
|
|
} |
|
|
|
GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size()); |
|
|
|
ge::GeTensorDesc tensor = op->GetInputDesc(0); |
|
|
|
string data_op_name = op->GetName(); |
|
|
|