From: @wangwenhua1 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -24,6 +24,7 @@ | |||||
| #include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "securec.h" | #include "securec.h" | ||||
| @@ -123,21 +124,25 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
| std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | ||||
| auto iter = trans_mode_map.find(trans_info); | auto iter = trans_mode_map.find(trans_info); | ||||
| if (iter == trans_mode_map.end()) { | if (iter == trans_mode_map.end()) { | ||||
| GELOGE(PARAM_INVALID, "Trans data type from %s to %s is not supported.", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
| std::string error = "Failed to trans data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
| int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
| if (size <= 0) { | if (size <= 0) { | ||||
| GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
| std::string error = "Failed to calc size from data type" + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | ||||
| GELOGE(PARAM_INVALID, "args.src_data_size %zu or data type size %d too big.", args.src_data_size, size); | |||||
| std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) + | |||||
| " or data type size" + FmtToStr(size) + " is too big"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
| @@ -154,9 +159,11 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
| } | } | ||||
| if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Failed to cast data from %s to %s, data size %zu", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str(), args.src_data_size); | |||||
| std::string error = "Failed to cast data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | |||||
| FmtToStr(std::to_string(args.src_data_size)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| result.data = dst; | result.data = dst; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -35,14 +36,16 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | ||||
| @@ -58,8 +61,9 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
| src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | ||||
| src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | ||||
| src_shape.at(kC1hwncoc0C0) != cube_size) { | src_shape.at(kC1hwncoc0C0) != cube_size) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| std::string error = "Failed to check relationship between src and dst shape, src shape" + | |||||
| FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -148,11 +148,7 @@ Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransRe | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -149,11 +149,7 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -39,8 +40,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
| case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
| return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
| default: | default: | ||||
| GELOGE(PARAM_INVALID, "Trans format between %s and FORMAT_FRACTAL_NZ is not supported.", | |||||
| TypeUtils::FormatToSerialString(format).c_str()); | |||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_NZ is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -103,11 +105,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (args.src_shape != expect_src_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid relationship between src shape %s and dst %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -275,11 +273,7 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (args.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return TransFormatFromNdToFracNz(args, result, hw_shape); | return TransFormatFromNdToFracNz(args, result, hw_shape); | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -159,8 +160,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
| ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
| } else { | } else { | ||||
| if (protected_size < size) { | if (protected_size < size) { | ||||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
| protected_size, size); | |||||
| std::string error = "Failed to operate the dst memory, protected_size is " + | |||||
| FmtToStr(protected_size) + " and size is " + FmtToStr(size); | |||||
| GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | ||||
| @@ -345,11 +347,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -39,8 +40,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
| case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
| return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
| default: | default: | ||||
| GELOGE(PARAM_INVALID, "Not support trans format between %s and FORMAT_FRACTAL_ZZ.", | |||||
| TypeUtils::FormatToSerialString(format).c_str()); | |||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_ZZ is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -103,12 +105,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (args.src_shape != expect_src_shape) { | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Failed to trans format from %s to %s, invalid relationship between src shape %s and dst shape %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -289,11 +286,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (args.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return TransFormatFromNdToFracZz(args, result, hw_shape); | return TransFormatFromNdToFracZz(args, result, hw_shape); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -33,9 +34,10 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -59,10 +61,12 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
| int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | ||||
| if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || | if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || | ||||
| src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| std::string error = "Failed to check relationship between src shape" + | |||||
| FmtToStr(ShapeToString(src_shape)) + " and dst shape" + | |||||
| FmtToStr(ShapeToString(dst_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -33,9 +34,10 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NCHW) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NCHW) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -33,9 +34,10 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NHWC) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NHWC) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -50,9 +51,10 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||||
| Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | ||||
| if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_C1HWNCoC0) { | if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_C1HWNCoC0) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -33,9 +34,10 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NCHW) { | if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NCHW) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -33,9 +34,10 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
| auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
| if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NHWC) { | if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NHWC) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -280,11 +280,7 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
| TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), | |||||
| ShapeToString(expect_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -53,9 +54,10 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
| Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | ||||
| if (args.src_format != FORMAT_NCHW || args.dst_format != FORMAT_NC1HWC0) { | if (args.src_format != FORMAT_NCHW || args.dst_format != FORMAT_NC1HWC0) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| std::vector<int64_t> expect_5d_shape; | std::vector<int64_t> expect_5d_shape; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -51,9 +52,10 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
| Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | ||||
| if (args.src_format != FORMAT_NHWC || args.dst_format != FORMAT_NC1HWC0) { | if (args.src_format != FORMAT_NHWC || args.dst_format != FORMAT_NC1HWC0) { | ||||
| GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
| @@ -48,28 +48,31 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | |||||
| bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
| if (src_shape.empty()) { | if (src_shape.empty()) { | ||||
| std::string error = "Failed to transpose, empty src shape"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
| if (dim < 0) { | if (dim < 0) { | ||||
| GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
| std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| if (perm_arg.size() != src_shape.size()) { | if (perm_arg.size() != src_shape.size()) { | ||||
| GELOGE(PARAM_INVALID, | |||||
| "Failed to transpose, the size of src shape(%zu) and" | |||||
| " perm arg(%zu) are different", | |||||
| src_shape.size(), perm_arg.size()); | |||||
| std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + | |||||
| " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::vector<int64_t> exists(perm_arg.size()); | std::vector<int64_t> exists(perm_arg.size()); | ||||
| for (auto perm : perm_arg) { | for (auto perm : perm_arg) { | ||||
| if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | ||||
| GELOGE(PARAM_INVALID, "Failed to transpose, duplicated perm arg %ld, perm arg %s", perm, | |||||
| JoinToString(perm_arg).c_str()); | |||||
| std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + | |||||
| ", perm arg " + FmtToStr(JoinToString(perm_arg)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -192,9 +195,10 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
| } | } | ||||
| auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | ||||
| if (dst_shape != expected_shape) { | if (dst_shape != expected_shape) { | ||||
| GELOGE(PARAM_INVALID, "Failed to trans axis for perm_arg %s, invalid dst shape %s, expect %s", | |||||
| ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), ShapeToString(expected_shape).c_str()); | |||||
| return PARAM_INVALID; | |||||
| std::string error = "Failed to trans axis for perm_arg" + | |||||
| FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" + | |||||
| FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| } | } | ||||
| return Transpose(data, src_shape, src_data_type, perm_arg, result); | return Transpose(data, src_shape, src_data_type, perm_arg, result); | ||||
| @@ -203,14 +207,18 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
| Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | ||||
| auto dst_iter = perm_args.find(src_format); | auto dst_iter = perm_args.find(src_format); | ||||
| if (dst_iter == perm_args.end()) { | if (dst_iter == perm_args.end()) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| auto iter = dst_iter->second.find(dst_format); | auto iter = dst_iter->second.find(dst_format); | ||||
| if (iter == dst_iter->second.end()) { | if (iter == dst_iter->second.end()) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| perm = iter->second; | perm = iter->second; | ||||
| @@ -223,11 +231,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (args.dst_shape != expected_shape) { | |||||
| GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid dst shape %s, expect %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expected_shape).c_str()); | |||||
| if (!IsTransShapeDstCorrect(args, expected_shape)) { | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -34,9 +35,10 @@ namespace formats { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | ||||
| auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
| if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Failed to trans data from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| @@ -59,9 +61,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
| args.dst_format = dst_format; | args.dst_format = dst_format; | ||||
| auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
| if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| std::string error = "Failed to trans data from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| @@ -71,9 +74,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | ||||
| auto transfer = BuildDataTypeTransfer(args); | auto transfer = BuildDataTypeTransfer(args); | ||||
| if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
| GELOGE(UNSUPPORTED, "Failed to trans data from datatype %s to %s, unsupport now", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
| std::string error = "Failed to trans data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -29,8 +30,9 @@ int64_t GetCubeSizeByDataType(DataType data_type) { | |||||
| // Current cube does not support 4 bytes and longer data | // Current cube does not support 4 bytes and longer data | ||||
| auto size = GetSizeByDataType(data_type); | auto size = GetSizeByDataType(data_type); | ||||
| if (size <= 0) { | if (size <= 0) { | ||||
| GELOGE(PARAM_INVALID, "Failed to get cube size, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| std::string error = "Failed to get cube size, the data type " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(data_type)) + " is invalid"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return -1; | return -1; | ||||
| } else if (size == 1) { | } else if (size == 1) { | ||||
| return kCubeSize * 2; // 32 bytes cube size | return kCubeSize * 2; // 32 bytes cube size | ||||
| @@ -57,7 +59,9 @@ int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | |||||
| bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | ||||
| if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | ||||
| GELOGE(PARAM_INVALID, "Invalid shape, dims num %zu, expect %ld", shape.size(), expect_dims); | |||||
| std::string error = "Invalid shape, dims num " + FmtToStr(shape.size()) + | |||||
| ", expect " + FmtToStr(expect_dims); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| return IsShapeValid(shape); | return IsShapeValid(shape); | ||||
| @@ -70,11 +74,13 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
| int64_t num = 1; | int64_t num = 1; | ||||
| for (auto dim : shape) { | for (auto dim : shape) { | ||||
| if (dim < 0) { | if (dim < 0) { | ||||
| GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
| std::string error = "Invalid negative dims in the shape " + FmtToStr(ShapeToString(shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (dim != 0 && kShapeItemNumMAX / dim < num) { | if (dim != 0 && kShapeItemNumMAX / dim < num) { | ||||
| GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | |||||
| std::string error = "Shape overflow, the total count should be less than " + FmtToStr(kShapeItemNumMAX); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| num *= dim; | num *= dim; | ||||
| @@ -94,5 +100,31 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst) { | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
| if (args.src_shape != expect_shape) { | |||||
| std::string error = "Failed to trans format from" + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + ", invalid relationship between src shape " + | |||||
| FmtToStr(ShapeToString(args.src_shape)) + " and dst " + | |||||
| FmtToStr(ShapeToString(args.dst_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
| if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
| std::string error = "Failed to trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + ", the dst shape" + | |||||
| FmtToStr(ShapeToString(args.dst_shape)) + " is invalid, expect" + | |||||
| FmtToStr(ShapeToString(expect_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace formats | } // namespace formats | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| @@ -61,6 +62,10 @@ bool IsShapeValid(const std::vector<int64_t> &shape); | |||||
| bool IsShapeEqual(const GeShape &src, const GeShape &dst); | bool IsShapeEqual(const GeShape &src, const GeShape &dst); | ||||
| bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
| bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
| template <typename T> | template <typename T> | ||||
| T Ceil(T n1, T n2) { | T Ceil(T n1, T n2) { | ||||
| if (n1 == 0) { | if (n1 == 0) { | ||||
| @@ -18,10 +18,12 @@ | |||||
| #define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <sstream> | |||||
| #include "runtime/rt.h" | #include "runtime/rt.h" | ||||
| #include "common/string_util.h" | #include "common/string_util.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
| @@ -269,4 +271,13 @@ | |||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| template <typename T> | |||||
| std::string FmtToStr(const T &t) { | |||||
| std::string fmt; | |||||
| std::stringstream st; | |||||
| st << "[" << t << "]"; | |||||
| fmt = st.str(); | |||||
| return fmt; | |||||
| } | |||||
| #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | ||||