From 4606e4b1374474fb82d7e2a3d30dcb4a0e3edb16 Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Sat, 20 Mar 2021 11:16:03 +0800 Subject: [PATCH] add check input shape range node func --- ge/ir_build/atc_ir_common.cc | 33 +++++++++++++++++++- ge/session/omg.cc | 12 +++---- tests/ut/ge/graph_ir/ge_ir_build_unittest.cc | 10 ++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc index 0fe027df..5c18fa7a 100755 --- a/ge/ir_build/atc_ir_common.cc +++ b/ge/ir_build/atc_ir_common.cc @@ -736,7 +736,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, } auto tensor_input = op->MutableInputDesc(0); + auto tensor_output = op->MutableOutputDesc(0); GE_CHECK_NOTNULL(tensor_input); + GE_CHECK_NOTNULL(tensor_output); string data_op_name = op->GetName(); auto origin_shape = tensor_input->GetShape(); auto iter = shape_range_map.find(data_op_name); @@ -755,6 +757,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, } tensor_input->SetShape(origin_shape); tensor_input->SetShapeRange(cur_shape_range); + tensor_output->SetShape(origin_shape); + tensor_output->SetShapeRange(cur_shape_range); GELOGI("Update input [%s] shape range info", data_op_name.c_str()); } else { GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); @@ -763,6 +767,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, return SUCCESS; } +static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, + const map>> &shape_range_map) { + for (const auto &it : shape_range_map) { + std::string node_name = it.first; + ge::NodePtr node = compute_graph->FindNode(node_name); + if (node == nullptr) { + REPORT_INPUT_ERROR("E10016", std::vector({"parameter", "opname"}), + std::vector({"input_shape_range", node_name})); + GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model", + node_name.c_str()); + return PARAM_INVALID; + } + if (node->GetType() != DATA) { + REPORT_INPUT_ERROR("E10017", std::vector({"parameter", "opname"}), + std::vector({"input_shape_range", node_name})); + GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname", + node_name.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { if (input_shape_range.empty()) { return SUCCESS; @@ -775,6 +802,11 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co return PARAM_INVALID; } + if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Check][InputShapeRange]check input shape range:%s failed.", input_shape_range.c_str()); + return PARAM_INVALID; + } + for (NodePtr &input_node : compute_graph->GetDirectNode()) { GE_CHECK_NOTNULL(input_node); OpDescPtr op = input_node->GetOpDesc(); @@ -788,5 +820,4 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co } return SUCCESS; } - } // namespace ge diff --git a/ge/session/omg.cc b/ge/session/omg.cc index 63be4913..961bc8c7 100755 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -99,8 +99,9 @@ static void ParseAtcParms(const std::map &atc_params, } } -static Status CheckInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input, RunMode run_mode) { - if (!is_dynamic_input && run_mode != MODEL_TO_JSON) { +static Status CheckInputShapeNode(const ComputeGraphPtr &graph, bool is_dynamic_input, + const std::string &input_shape_range, RunMode run_mode) { + if (!is_dynamic_input && run_mode != MODEL_TO_JSON && input_shape_range.empty()) { for (auto node : graph->GetDirectNode()) { if (node->GetType() == DATA) { auto data_op_desc = node->GetOpDesc(); @@ -760,8 +761,9 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map