|
|
@@ -736,7 +736,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
|
} |
|
|
|
|
|
|
|
auto tensor_input = op->MutableInputDesc(0); |
|
|
|
auto tensor_output = op->MutableOutputDesc(0); |
|
|
|
GE_CHECK_NOTNULL(tensor_input); |
|
|
|
GE_CHECK_NOTNULL(tensor_output); |
|
|
|
string data_op_name = op->GetName(); |
|
|
|
auto origin_shape = tensor_input->GetShape(); |
|
|
|
auto iter = shape_range_map.find(data_op_name); |
|
|
@@ -755,6 +757,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
|
} |
|
|
|
tensor_input->SetShape(origin_shape); |
|
|
|
tensor_input->SetShapeRange(cur_shape_range); |
|
|
|
tensor_output->SetShape(origin_shape); |
|
|
|
tensor_output->SetShapeRange(cur_shape_range); |
|
|
|
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); |
|
|
|
} else { |
|
|
|
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); |
|
|
@@ -763,6 +767,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, |
|
|
|
const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { |
|
|
|
for (const auto &it : shape_range_map) { |
|
|
|
std::string node_name = it.first; |
|
|
|
ge::NodePtr node = compute_graph->FindNode(node_name); |
|
|
|
if (node == nullptr) { |
|
|
|
REPORT_INPUT_ERROR("E10016", std::vector<std::string>({"parameter", "opname"}), |
|
|
|
std::vector<std::string>({"input_shape_range", node_name})); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model", |
|
|
|
node_name.c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (node->GetType() != DATA) { |
|
|
|
REPORT_INPUT_ERROR("E10017", std::vector<std::string>({"parameter", "opname"}), |
|
|
|
std::vector<std::string>({"input_shape_range", node_name})); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname", |
|
|
|
node_name.c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { |
|
|
|
if (input_shape_range.empty()) { |
|
|
|
return SUCCESS; |
|
|
@@ -775,6 +802,11 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputShapeRange]check input shape range:%s failed.", input_shape_range.c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
for (NodePtr &input_node : compute_graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(input_node); |
|
|
|
OpDescPtr op = input_node->GetOpDesc(); |
|
|
@@ -788,5 +820,4 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace ge |