@@ -123,21 +123,25 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | ||||
auto iter = trans_mode_map.find(trans_info); | auto iter = trans_mode_map.find(trans_info); | ||||
if (iter == trans_mode_map.end()) { | if (iter == trans_mode_map.end()) { | ||||
GELOGE(PARAM_INVALID, "Trans data type from %s to %s is not supported.", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to trans data from datatype [" + | |||||
TypeUtils::FormatToSerialString(args.src_data_type) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_data_type) + "], it is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to calc size from data type[" + | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type) + "], it is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | ||||
GELOGE(PARAM_INVALID, "args.src_data_size %zu or data type size %d too big.", args.src_data_size, size); | |||||
std::string error = "args.src_data_size[" + std::to_string(args.src_data_size) + | |||||
"] or data type size[" + std::to_string(size) + "is too big"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
@@ -154,9 +158,10 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
} | } | ||||
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to cast data from %s to %s, data size %zu", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str(), args.src_data_size); | |||||
std::string error = "Failed to cast data from datatype [" + | |||||
TypeUtils::FormatToSerialString(args.src_data_type) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_data_type) + "], data size is " + std::to_string(args.src_data_size); | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
@@ -35,14 +35,18 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | ||||
std::string error = "Dose not support trans format from " + | |||||
FmtEgStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtEgStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str()); | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type[" + | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str() + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | ||||
@@ -58,8 +62,9 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | ||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | ||||
src_shape.at(kC1hwncoc0C0) != cube_size) { | src_shape.at(kC1hwncoc0C0) != cube_size) { | ||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
std::string error = "Failed to check relationship between src and dst shape, src shape[" + | |||||
ShapeToString(src_shape) + "], dst shape[" + ShapeToString(dst_shape) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -69,9 +74,10 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||||
std::string error = "Failed to trans format from[" + TypeUtils::FormatToSerialString(args.src_format).c_str() + | |||||
"] to [" + TypeUtils::FormatToSerialString(args.dst_format) + "], can not alloc the memory for dst buf[" + | |||||
std::to_string(total_size) + "], shape[" + ShapeToString(args.dst_shape) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(OUT_OF_MEMORY, error.c_str()); | |||||
return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
} | } | ||||
@@ -148,11 +148,7 @@ Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransRe | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -149,11 +149,7 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -39,6 +39,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
default: | default: | ||||
std::string error = "Trans format between[" + TypeUtils::FormatToSerialString(format) + | |||||
"] and FORMAT_FRACTAL_NZ is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GELOGE(PARAM_INVALID, "Trans format between %s and FORMAT_FRACTAL_NZ is not supported.", | GELOGE(PARAM_INVALID, "Trans format between %s and FORMAT_FRACTAL_NZ is not supported.", | ||||
TypeUtils::FormatToSerialString(format).c_str()); | TypeUtils::FormatToSerialString(format).c_str()); | ||||
return false; | return false; | ||||
@@ -103,11 +106,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (args.src_shape != expect_src_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid relationship between src shape %s and dst %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str()); | |||||
if (!IsTransShapeSrcCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -275,11 +274,7 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return TransFormatFromNdToFracNz(args, result, hw_shape); | return TransFormatFromNdToFracNz(args, result, hw_shape); | ||||
@@ -159,8 +159,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
} else { | } else { | ||||
if (protected_size < size) { | if (protected_size < size) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
protected_size, size); | |||||
std::string error = "Failed to operate the dst memory, protected_size is[" + | |||||
std::to_string(protected_size) + "] and size is [" + std::to_string(size) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | ||||
@@ -345,11 +346,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -39,8 +39,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
default: | default: | ||||
GELOGE(PARAM_INVALID, "Not support trans format between %s and FORMAT_FRACTAL_ZZ.", | |||||
TypeUtils::FormatToSerialString(format).c_str()); | |||||
std::string error = "Trans format between[" + TypeUtils::FormatToSerialString(format) + | |||||
"] and FORMAT_FRACTAL_ZZ is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -103,12 +104,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (args.src_shape != expect_src_shape) { | |||||
GELOGE(PARAM_INVALID, | |||||
"Failed to trans format from %s to %s, invalid relationship between src shape %s and dst shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str()); | |||||
if (!IsTransShapeSrcCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -289,11 +285,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return TransFormatFromNdToFracZz(args, result, hw_shape); | return TransFormatFromNdToFracZz(args, result, hw_shape); | ||||
@@ -33,9 +33,9 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from [" + | |||||
TypeUtils::FormatToSerialString(args.src_format) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_format) + "]"; | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -280,11 +280,7 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeCorrect(args_tmp, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -223,11 +223,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expected_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid dst shape %s, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expected_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -34,9 +34,10 @@ namespace formats { | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | ||||
auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Failed to trans data from format [" + | |||||
TypeUtils::FormatToSerialString(args.src_format) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_format) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -59,9 +60,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
args.dst_format = dst_format; | args.dst_format = dst_format; | ||||
auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Failed to trans data from format [" + | |||||
TypeUtils::FormatToSerialString(args.src_format) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_format) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -71,9 +73,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | ||||
auto transfer = BuildDataTypeTransfer(args); | auto transfer = BuildDataTypeTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from datatype %s to %s, unsupport now", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to trans data from datatype [" + | |||||
TypeUtils::FormatToSerialString(args.src_data_type) + "] to " + "[" + | |||||
TypeUtils::FormatToSerialString(args.dst_data_type) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -92,5 +95,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransFormatSupport(const T | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransDataTypeSupport(const CastArgs &args) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransDataTypeSupport(const CastArgs &args) { | ||||
return DataTypeTransferExists(args); | return DataTypeTransferExists(args); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
if (!args.src_shape.empty() && args.src_shape != expect_shape) { | |||||
std::string error = "Failed to trans format from[" + TypeUtils::FormatToSerialString(args.src_format) + | |||||
"] to [" + TypeUtils::FormatToSerialString(args.dst_format) + "], the src shape[" + | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str() + "] is invalid, expect[" + | |||||
ShapeToString(expect_shape) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
std::stringstream error; | |||||
error << "Failed to trans format from[" + TypeUtils::FormatToSerialString(args.src_format) << | |||||
"] to [" << TypeUtils::FormatToSerialString(args.dst_format) << "], invalid relationship between src shape[" << | |||||
ShapeToString(args.src_shape) << "] and dst [" << ShapeToString(args.dst_shape) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -44,6 +44,10 @@ Status TransDataType(const CastArgs &args, TransResult &result); | |||||
bool IsTransFormatSupport(const TransArgs &args); | bool IsTransFormatSupport(const TransArgs &args); | ||||
bool IsTransDataTypeSupport(const CastArgs &args); | bool IsTransDataTypeSupport(const CastArgs &args); | ||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_COMMON_FORMATS_FORMATS_H_ | #endif // GE_COMMON_FORMATS_FORMATS_H_ |
@@ -29,8 +29,9 @@ int64_t GetCubeSizeByDataType(DataType data_type) { | |||||
// Current cube does not support 4 bytes and longer data | // Current cube does not support 4 bytes and longer data | ||||
auto size = GetSizeByDataType(data_type); | auto size = GetSizeByDataType(data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to get cube size, the data type %s is invalid", | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
std::string error = "Failed to get cube size, the data type [" + | |||||
TypeUtils::DataTypeToSerialString(data_type) + "] is invalid"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return -1; | return -1; | ||||
} else if (size == 1) { | } else if (size == 1) { | ||||
return kCubeSize * 2; // 32 bytes cube size | return kCubeSize * 2; // 32 bytes cube size | ||||
@@ -57,7 +58,10 @@ int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | |||||
bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | ||||
if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | ||||
GELOGE(PARAM_INVALID, "Invalid shape, dims num %zu, expect %ld", shape.size(), expect_dims); | |||||
std::string error = " Invalid shape, dims num [" + std::to_string(shape.size()) + | |||||
"], expect [" + std::to_string(expect_dims) + "]"; | |||||
TypeUtils::DataTypeToSerialString(data_type) + "] is invalid"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
return IsShapeValid(shape); | return IsShapeValid(shape); | ||||
@@ -70,11 +74,14 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : shape) { | for (auto dim : shape) { | ||||
if (dim < 0) { | if (dim < 0) { | ||||
GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
std::string error = "Invalid negative dims in the shape [" + ShapeToString(shape).c_str() + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
if (dim != 0 && kShapeItemNumMAX / dim < num) { | if (dim != 0 && kShapeItemNumMAX / dim < num) { | ||||
GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | |||||
std::string error = "Shape overflow, the total count should be less than [" + | |||||
std::to_string(kShapeItemNumMAX) + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
num *= dim; | num *= dim; | ||||
@@ -94,5 +101,7 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst) { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
bool | |||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -18,6 +18,7 @@ | |||||
#define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | ||||
#include <string> | #include <string> | ||||
#include <sstream> | |||||
#include "runtime/rt.h" | #include "runtime/rt.h" | ||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
@@ -269,4 +270,12 @@ | |||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
template <typename T> | |||||
std::string FmtEgStr(const T &t) { | |||||
std::string fmt; | |||||
std::stringstream st; | |||||
st << "[" << t << "]"; | |||||
return fmt; | |||||
} | |||||
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ |