From 6dbe0472bb3d8a03783cc7bf4407555be281b954 Mon Sep 17 00:00:00 2001 From: wjm Date: Wed, 2 Jun 2021 17:37:07 +0800 Subject: [PATCH] fix more sc --- ge/graph/preprocess/graph_preprocess.cc | 36 ++++++++++++++----------- ge/graph/preprocess/graph_preprocess.h | 1 + 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 2eae6023..5a491f19 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1423,6 +1423,25 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { return SUCCESS; } +Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { + auto format = desc.GetFormat(); + auto origin_format = desc.GetOriginFormat(); + bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); + if (need_check_internal_format) { + bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); + if (is_internal) { + ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, + {"Input format[" + TypeUtils::FormatToSerialString(format) + "] or origin_format[" + + TypeUtils::FormatToSerialString(origin_format) + "]", "it is not support"}); + GELOGE(PARAM_INVALID, "[Check][Param] Input format %s or origin_format %s is not support.", + TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(origin_format).c_str()); + return FAILED; + } + } + return SUCCESS; +} + Status GraphPrepare::UpdateInput(const std::vector &user_input, const std::map &graph_option) { // Get shape range of input in dynamic_execute mode @@ -1454,23 +1473,10 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input, continue; } GeTensorDesc desc(user_input[index].GetTensorDesc()); - auto format = desc.GetFormat(); - auto origin_format = desc.GetOriginFormat(); // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); - bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); - if (need_check_internal_format) { - bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); - if (is_internal) { - ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, - {"Input format[" + TypeUtils::FormatToSerialString(format) + "] or origin_format[" + - TypeUtils::FormatToSerialString(origin_format) + "]", "it is not support"}); - GELOGE(PARAM_INVALID, "[Check][Param] Input format %s or origin_format %s is not support.", - TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(origin_format).c_str()); - return FAILED; - } - } + GE_CHK_STATUS_RET(CheckInternalFormat(input_node, desc, tune_flag), "[Check][InternalFormat] on %s failed.", + op->GetName().c_str()); auto data_type = desc.GetDataType(); uint32_t length = 1; diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h index 3eb5e03a..584f4d16 100755 --- a/ge/graph/preprocess/graph_preprocess.h +++ b/ge/graph/preprocess/graph_preprocess.h @@ -63,6 +63,7 @@ class GraphPrepare { Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); + Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); Status UpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckAndUpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckConstOp();