From 2abf8be62178511bf6c073ff821b309a8e2817ee Mon Sep 17 00:00:00 2001 From: wjm Date: Tue, 8 Jun 2021 04:28:41 +0800 Subject: [PATCH] fix sc --- ge/graph/preprocess/graph_preprocess.cc | 119 ++++++++++-------- ge/graph/preprocess/graph_preprocess.h | 3 +- ge/ir_build/ge_ir_build.cc | 2 +- .../preprocess/graph_preprocess_unittest.cc | 15 +++ 4 files changed, 84 insertions(+), 55 deletions(-) diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 0c4adeea..a73c6a96 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { return SUCCESS; } -Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { +Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) { auto format = desc.GetFormat(); auto origin_format = desc.GetOriginFormat(); + auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); if (need_check_internal_format) { bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); @@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens return SUCCESS; } +Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) { + auto data_type = desc.GetDataType(); + uint32_t length = 1; + bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); + if (!type_ret) { + std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + + std::to_string(index) + " input tensor is not support"; + REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); + GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + int64_t desc_shape = desc.GetShape().GetShapeSize(); + FMK_INT64_UINT32_MULCHECK(desc_shape, length); + int64_t shape_size = desc_shape * length; + GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); + int64_t size = 0; + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, + REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); + GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED); + bool size_check = (size != 0 && shape_size != size); + if (size_check) { + std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + + "] != shape_size[" + std::to_string(size) + "], check invalid"; + REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); + GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); + return FAILED; + } + ge::TensorUtils::SetSize(desc, shape_size); + + auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); + if (!tune_flag) { + graphStatus graph_ret = op->UpdateInputDesc(0, desc); + if (graph_ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + return graph_ret; + } + // Size will be recalculated in the build stage + ge::TensorUtils::SetSize(desc, 0); + graph_ret = op->UpdateOutputDesc(0, desc); + if (graph_ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + return graph_ret; + } + } else { + GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); + } + + return SUCCESS; +} + Status GraphPrepare::UpdateInput(const std::vector &user_input, const std::map &graph_option) { // Get shape range of input in dynamic_execute mode @@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input, } GeTensorDesc desc(user_input[index].GetTensorDesc()); // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. - auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); - ret = CheckInternalFormat(input_node, desc, tune_flag); + ret = CheckInternalFormat(input_node, desc); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); return ret; } - auto data_type = desc.GetDataType(); - uint32_t length = 1; - bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); - if (!type_ret) { - std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + - std::to_string(index) + " input tensor is not support"; - REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); - GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", - TypeUtils::DataTypeToSerialString(data_type).c_str()); - return FAILED; - } - int64_t desc_shape = desc.GetShape().GetShapeSize(); - FMK_INT64_UINT32_MULCHECK(desc_shape, length); - int64_t shape_size = desc_shape * length; - GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); - int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, - REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); - GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); - return FAILED); - bool size_check = (size != 0 && shape_size != size); - if (size_check) { - std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + - "] != shape_size[" + std::to_string(size) + "], check invalid"; - REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); - GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); - return FAILED; - } - ge::TensorUtils::SetSize(desc, shape_size); - if (!tune_flag) { - graphStatus graph_ret = op->UpdateInputDesc(0, desc); - if (graph_ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - return graph_ret; - } - // Size will be recalculated in the build stage - ge::TensorUtils::SetSize(desc, 0); - graph_ret = op->UpdateOutputDesc(0, desc); - if (graph_ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - return graph_ret; - } - } else { - GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); + + ret = UpdateDataInputOutputDesc(index, op, desc); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str()); + return ret; } + if (!dynamic_shape_range_vec.empty()) { ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h index 584f4d16..22bc566c 100755 --- a/ge/graph/preprocess/graph_preprocess.h +++ b/ge/graph/preprocess/graph_preprocess.h @@ -63,7 +63,8 @@ class GraphPrepare { Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); - Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); + Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc); + Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc); Status UpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckAndUpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckConstOp(); diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 21db83aa..befffa93 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -559,8 +559,8 @@ graphStatus Impl::Init(const Graph &graph, const std::map user_input = {input1}; + std::map graph_option; + auto ret = graph_prepare.UpdateInput(user_input, graph_option); + EXPECT_EQ(ret, ge::FAILED); +} + TEST_F(UtestGraphPreproces, test_check_user_input) { ge::GraphPrepare graph_prepare; graph_prepare.compute_graph_ = BuildGraph1();