Browse Source

add check input shape range node func

tags/v1.3.0
zhengyuanhua 3 years ago
parent
commit
4606e4b137
3 changed files with 48 additions and 7 deletions
  1. +32
    -1
      ge/ir_build/atc_ir_common.cc
  2. +6
    -6
      ge/session/omg.cc
  3. +10
    -0
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 32
- 1
ge/ir_build/atc_ir_common.cc View File

@@ -736,7 +736,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
}

auto tensor_input = op->MutableInputDesc(0);
auto tensor_output = op->MutableOutputDesc(0);
GE_CHECK_NOTNULL(tensor_input);
GE_CHECK_NOTNULL(tensor_output);
string data_op_name = op->GetName();
auto origin_shape = tensor_input->GetShape();
auto iter = shape_range_map.find(data_op_name);
@@ -755,6 +757,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
}
tensor_input->SetShape(origin_shape);
tensor_input->SetShapeRange(cur_shape_range);
tensor_output->SetShape(origin_shape);
tensor_output->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info", data_op_name.c_str());
} else {
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str());
@@ -763,6 +767,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
return SUCCESS;
}

static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph,
const map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) {
for (const auto &it : shape_range_map) {
std::string node_name = it.first;
ge::NodePtr node = compute_graph->FindNode(node_name);
if (node == nullptr) {
REPORT_INPUT_ERROR("E10016", std::vector<std::string>({"parameter", "opname"}),
std::vector<std::string>({"input_shape_range", node_name}));
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model",
node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != DATA) {
REPORT_INPUT_ERROR("E10017", std::vector<std::string>({"parameter", "opname"}),
std::vector<std::string>({"input_shape_range", node_name}));
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname",
node_name.c_str());
return PARAM_INVALID;
}
}
return SUCCESS;
}

Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) {
if (input_shape_range.empty()) {
return SUCCESS;
@@ -775,6 +802,11 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
return PARAM_INVALID;
}

if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Check][InputShapeRange]check input shape range:%s failed.", input_shape_range.c_str());
return PARAM_INVALID;
}

for (NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
OpDescPtr op = input_node->GetOpDesc();
@@ -788,5 +820,4 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co
}
return SUCCESS;
}

} // namespace ge

+ 6
- 6
ge/session/omg.cc View File

@@ -99,8 +99,9 @@ static void ParseAtcParms(const std::map<std::string, std::string> &atc_params,
}
}

static Status CheckInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input, RunMode run_mode) {
if (!is_dynamic_input && run_mode != MODEL_TO_JSON) {
static Status CheckInputShapeNode(const ComputeGraphPtr &graph, bool is_dynamic_input,
const std::string &input_shape_range, RunMode run_mode) {
if (!is_dynamic_input && run_mode != MODEL_TO_JSON && input_shape_range.empty()) {
for (auto node : graph->GetDirectNode()) {
if (node->GetType() == DATA) {
auto data_op_desc = node->GetOpDesc();
@@ -760,8 +761,9 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
ParseAtcParms(atc_params, "is_input_adjust_hw_layout", is_input_adjust_hw_layout);
compute_graph = GraphUtils::GetComputeGraph(graph);
GE_RETURN_IF_ERROR(CheckInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout));

GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode));
std::string input_shape_range;
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range);
GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, input_shape_range, run_mode));

// Verify the contents of the op_name_map
if (op_conf != nullptr && *op_conf != '\0') {
@@ -790,8 +792,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "ATC weights parse ret fail.");

// parser input shape range and update op shape range
std::string input_shape_range;
ParseAtcParms(atc_params, INPUT_SHAPE_RANGE, input_shape_range);
GE_RETURN_WITH_LOG_IF_ERROR(UpdateDynamicInputShapeRange(compute_graph, input_shape_range),
"Update input shape range failed");



+ 10
- 0
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -97,4 +97,14 @@ TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
input_shape_range = "input1:[1, 2~-3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);

//5
input_shape_range = "input:[1, 2~3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);

//6
input_shape_range = "addn1:[1, 2~3, -1]";
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}

Loading…
Cancel
Save