@@ -111,7 +111,7 @@ Status CastKernel(const CastArgs &args, uint8_t *dst, const size_t data_size, co | |||||
}; | }; | ||||
auto it = transfer_handle.find(trans_mode); | auto it = transfer_handle.find(trans_mode); | ||||
if (it == transfer_handle.end()) { | if (it == transfer_handle.end()) { | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} else { | } else { | ||||
return (it->second)(args, dst, data_size); | return (it->second)(args, dst, data_size); | ||||
} | } | ||||
@@ -127,8 +127,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
std::string error = "Failed to trans data from datatype " + | std::string error = "Failed to trans data from datatype " + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
@@ -136,14 +136,14 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
if (size <= 0) { | if (size <= 0) { | ||||
std::string error = "Failed to calc size from data type" + | std::string error = "Failed to calc size from data type" + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | ||||
std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) + | std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) + | ||||
" or data type size" + FmtToStr(size) + " is too big"; | " or data type size" + FmtToStr(size) + " is too big"; | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
result.length = total_size; | result.length = total_size; | ||||
@@ -154,8 +154,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||||
return OUT_OF_MEMORY; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | ||||
@@ -163,8 +163,8 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | ||||
FmtToStr(std::to_string(args.src_data_size)); | FmtToStr(std::to_string(args.src_data_size)); | ||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_INTERNAL_ERROR, error.c_str()); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -39,22 +39,22 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | auto cube_size = GetCubeSizeByDataType(args.src_data_type); | ||||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | ||||
@@ -63,8 +63,8 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
src_shape.at(kC1hwncoc0C0) != cube_size) { | src_shape.at(kC1hwncoc0C0) != cube_size) { | ||||
std::string error = "Failed to check relationship between src and dst shape, src shape" + | std::string error = "Failed to check relationship between src and dst shape, src shape" + | ||||
FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -73,10 +73,10 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto h = args.src_shape.at(kC1hwncoc0H); | auto h = args.src_shape.at(kC1hwncoc0H); | ||||
@@ -114,12 +114,12 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld to " | "Failed to copy data from C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld to " | ||||
"HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | "HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, | c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, | ||||
ret); | ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -132,8 +132,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
} // namespace | } // namespace | ||||
Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForC1hwncoc0ToHwcn(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForC1hwncoc0ToHwcn(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -143,18 +144,19 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
result.length = static_cast<size_t>(total_size); | result.length = static_cast<size_t>(total_size); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from C1HWNCoC0 to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from C1HWNCoC0 to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -162,7 +164,7 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
Status FormatTransferC1hwncoc0Hwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferC1hwncoc0Hwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
GELOGD("The shape derivation from C1HWNCoC0 to HWCN is not unique. Trans shape in this direction is not supported"); | GELOGD("The shape derivation from C1HWNCoC0 to HWCN is not unique. Trans shape in this direction is not supported"); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferC1hwncoc0Hwcn, FORMAT_C1HWNCoC0, FORMAT_HWCN) | REGISTER_FORMAT_TRANSFER(FormatTransferC1hwncoc0Hwcn, FORMAT_C1HWNCoC0, FORMAT_HWCN) | ||||
@@ -32,7 +32,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
auto c1 = Ceil(c, c0); | auto c1 = Ceil(c, c0); | ||||
@@ -50,7 +50,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||||
Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType data_type, | Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType data_type, | ||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kDhwcnDimsNum)) { | if (!CheckShapeValid(src_shape, kDhwcnDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto d = src_shape.at(kDhwcnD); | auto d = src_shape.at(kDhwcnD); | ||||
auto h = src_shape.at(kDhwcnH); | auto h = src_shape.at(kDhwcnH); | ||||
@@ -62,7 +62,7 @@ Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType dat | |||||
} | } | ||||
Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | ||||
if (!CheckShapeValid(args.src_shape, kDhwcnDimsNum)) { | if (!CheckShapeValid(args.src_shape, kDhwcnDimsNum)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t d = args.src_shape[kDhwcnD]; | int64_t d = args.src_shape[kDhwcnD]; | ||||
int64_t h = args.src_shape[kDhwcnH]; | int64_t h = args.src_shape[kDhwcnH]; | ||||
@@ -94,10 +94,10 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
for (int64_t di = 0; di < d; di++) { | for (int64_t di = 0; di < d; di++) { | ||||
@@ -122,9 +122,9 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
args.data + src_idx * data_size, static_cast<size_t>(data_size)); | args.data + src_idx * data_size, static_cast<size_t>(data_size)); | ||||
} | } | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
dst_offset, ret, pad_zero); | dst_offset, ret, pad_zero); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -149,28 +149,28 @@ Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransRe | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | if (!IsTransShapeDstCorrect(args, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (args.src_format == FORMAT_DHWCN && args.dst_format == FORMAT_FRACTAL_Z_3D) { | if (args.src_format == FORMAT_DHWCN && args.dst_format == FORMAT_FRACTAL_Z_3D) { | ||||
return TransFormatDhwckToFz3D(args, result); | return TransFormatDhwckToFz3D(args, result); | ||||
} | } | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, | DataType data_type, Format dst_format, | ||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
if (CheckDataTypeSupport(data_type) != SUCCESS) { | if (CheckDataTypeSupport(data_type) != SUCCESS) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (src_format == FORMAT_DHWCN && dst_format == FORMAT_FRACTAL_Z_3D) { | if (src_format == FORMAT_DHWCN && dst_format == FORMAT_FRACTAL_Z_3D) { | ||||
return TransShapeDhwckToFz3D(src_shape, data_type, dst_shape); | return TransShapeDhwckToFz3D(src_shape, data_type, dst_shape); | ||||
} | } | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | ||||
@@ -32,7 +32,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
auto c1 = Ceil(c, c0); | auto c1 = Ceil(c, c0); | ||||
@@ -50,7 +50,7 @@ Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, Dat | |||||
Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, DataType data_type, | Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, DataType data_type, | ||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kDhwncDimsNum)) { | if (!CheckShapeValid(src_shape, kDhwncDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto d = src_shape.at(kDhwncD); | auto d = src_shape.at(kDhwncD); | ||||
auto h = src_shape.at(kDhwncH); | auto h = src_shape.at(kDhwncH); | ||||
@@ -62,7 +62,7 @@ Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, Dat | |||||
} | } | ||||
Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &result) { | Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &result) { | ||||
if (!CheckShapeValid(args.src_shape, kDhwncDimsNum)) { | if (!CheckShapeValid(args.src_shape, kDhwncDimsNum)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t d = args.src_shape[kDhwncD]; | int64_t d = args.src_shape[kDhwncD]; | ||||
int64_t h = args.src_shape[kDhwncH]; | int64_t h = args.src_shape[kDhwncH]; | ||||
@@ -95,10 +95,10 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
for (int64_t di = 0; di < d; di++) { | for (int64_t di = 0; di < d; di++) { | ||||
@@ -123,9 +123,9 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
args.data + src_idx * data_size, static_cast<size_t>(data_size)); | args.data + src_idx * data_size, static_cast<size_t>(data_size)); | ||||
} | } | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
dst_offset, ret, pad_zero); | dst_offset, ret, pad_zero); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -150,28 +150,28 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | if (!IsTransShapeDstCorrect(args, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (args.src_format == ge::FORMAT_DHWNC && args.dst_format == ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | if (args.src_format == ge::FORMAT_DHWNC && args.dst_format == ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | ||||
return TransFormatDhwncToFz3DTranspose(args, result); | return TransFormatDhwncToFz3DTranspose(args, result); | ||||
} | } | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, | DataType data_type, Format dst_format, | ||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
if (CheckDataTypeSupport(data_type) != SUCCESS) { | if (CheckDataTypeSupport(data_type) != SUCCESS) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (src_format == FORMAT_DHWNC && dst_format == FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | if (src_format == FORMAT_DHWNC && dst_format == FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | ||||
return TransShapeDhwncToFz3DTranspose(src_shape, data_type, dst_shape); | return TransShapeDhwncToFz3DTranspose(src_shape, data_type, dst_shape); | ||||
} | } | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | ||||
@@ -87,8 +87,8 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
hw_shape.push_back(DIM_DEFAULT_VALUE); | hw_shape.push_back(DIM_DEFAULT_VALUE); | ||||
hw_shape.push_back(src_shape[kNdDimIndexN]); | hw_shape.push_back(src_shape[kNdDimIndexN]); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
default: | default: | ||||
@@ -106,8 +106,8 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | ||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -117,14 +117,14 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeVector expect_src_shape; | ShapeVector expect_src_shape; | ||||
auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||||
GELOGE(ret, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -139,10 +139,10 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
// src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
@@ -175,8 +175,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
auto w1_head = num_w1 * w0; | auto w1_head = num_w1 * w0; | ||||
@@ -189,8 +189,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -210,10 +210,10 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto times = dst_hw_shape.at(kNdDimIndexN); | auto times = dst_hw_shape.at(kNdDimIndexN); | ||||
@@ -246,8 +246,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
auto w1_head = num_w1 * w0; | auto w1_head = num_w1 * w0; | ||||
@@ -260,8 +260,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -273,13 +273,19 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
} // namespace | } // namespace | ||||
Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (!IsDataTypeSupport(args.src_data_type) || !CheckShape(args.src_format, args.src_shape) || | |||||
!IsShapeValid(args.dst_shape)) { | |||||
GELOGE(PARAM_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
if (!IsDataTypeSupport(args.src_data_type)) { | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | |||||
if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
@@ -292,7 +298,7 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | if (!IsTransShapeDstCorrect(args, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return TransFormatFromNdToFracNz(args, result, hw_shape); | return TransFormatFromNdToFracNz(args, result, hw_shape); | ||||
} | } | ||||
@@ -300,31 +306,38 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||||
Status FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | Status FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | ||||
Format dst_format, ShapeVector &dst_shape) { | Format dst_format, ShapeVector &dst_shape) { | ||||
if (!IsDataTypeSupport(data_type)) { | if (!IsDataTypeSupport(data_type)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
"Trans format from %s to %s, src shape %s, data type %s is not supported", | "Trans format from %s to %s, src shape %s, data type %s is not supported", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | ||||
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShape(src_format, src_shape)) { | if (!CheckShape(src_format, src_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
"Trans format from %s to %s, src shape %s, data type %s is not supported", | "Trans format from %s to %s, src shape %s, data type %s is not supported", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | ||||
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
ShapeVector hw_shape; | ShapeVector hw_shape; | ||||
return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape); | return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape); | ||||
} | } | ||||
Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (!IsDataTypeSupport(args.src_data_type) || !IsShapeValid(args.src_shape) || | |||||
!CheckShape(args.dst_format, args.dst_shape)) { | |||||
GELOGE(PARAM_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
if (!IsDataTypeSupport(args.src_data_type)) { | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | |||||
if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Trans format from %s to %s, src shape %s, dst shape %s, data type %s is not supported", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
@@ -332,8 +345,9 @@ Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
ShapeVector hw_shape; | ShapeVector hw_shape; | ||||
if (CheckShapeRelation(args, hw_shape) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckShapeRelation(args, hw_shape); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
return TransFormatFromFracNzToNd(args, result, hw_shape); | return TransFormatFromFracNzToNd(args, result, hw_shape); | ||||
} | } | ||||
@@ -342,7 +356,7 @@ Status FormatTransferFractalNzND::TransShape(Format src_format, const ShapeVecto | |||||
Format dst_format, ShapeVector &dst_shape) { | Format dst_format, ShapeVector &dst_shape) { | ||||
GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_ND, FORMAT_FRACTAL_NZ) | REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_ND, FORMAT_FRACTAL_NZ) | ||||
@@ -42,7 +42,7 @@ Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_ | |||||
Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
auto c1 = Ceil(c, c0); | auto c1 = Ceil(c, c0); | ||||
@@ -54,16 +54,16 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ | |||||
dst_shape.push_back(kNiSize); | dst_shape.push_back(kNiSize); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto n = src_shape.at(kNchwN); | auto n = src_shape.at(kNchwN); | ||||
@@ -75,7 +75,7 @@ Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_t | |||||
Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto h = src_shape.at(kHwcnH); | auto h = src_shape.at(kHwcnH); | ||||
@@ -88,7 +88,7 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t | |||||
Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto n = src_shape.at(kNhwcN); | auto n = src_shape.at(kNhwcN); | ||||
@@ -127,10 +127,10 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
dst == nullptr, | dst == nullptr, | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY;); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { | for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { | ||||
// vertical fractal matrix base index | // vertical fractal matrix base index | ||||
@@ -163,8 +163,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
if (protected_size < size) { | if (protected_size < size) { | ||||
std::string error = "Failed to operate the dst memory, protected_size is " + | std::string error = "Failed to operate the dst memory, protected_size is " + | ||||
FmtToStr(protected_size) + " and size is " + FmtToStr(size); | FmtToStr(protected_size) + " and size is " + FmtToStr(size); | ||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | ||||
const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size); | const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size); | ||||
@@ -173,9 +173,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
} | } | ||||
} | } | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, | |||||
ret, need_pad_zero); | ret, need_pad_zero); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -213,10 +213,10 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
dst == nullptr, | dst == nullptr, | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY;); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
for (int64_t c1i = 0; c1i < c1; c1i++) { | for (int64_t c1i = 0; c1i < c1; c1i++) { | ||||
for (int64_t hi = 0; hi < h; hi++) { | for (int64_t hi = 0; hi < h; hi++) { | ||||
@@ -235,9 +235,9 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
static_cast<size_t>(data_size)); | static_cast<size_t>(data_size)); | ||||
} else { | } else { | ||||
if (protected_size < data_size) { | if (protected_size < data_size) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
protected_size, data_size); | protected_size, data_size); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | ||||
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | ||||
@@ -247,9 +247,9 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
} | } | ||||
} | } | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
dst_offset, ret, pad_zero); | dst_offset, ret, pad_zero); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -288,10 +288,10 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
dst == nullptr, | dst == nullptr, | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY;); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
for (int64_t c1i = 0; c1i < c1; c1i++) { | for (int64_t c1i = 0; c1i < c1; c1i++) { | ||||
for (int64_t hi = 0; hi < h; hi++) { | for (int64_t hi = 0; hi < h; hi++) { | ||||
@@ -310,9 +310,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
static_cast<size_t>(data_size)); | static_cast<size_t>(data_size)); | ||||
} else { | } else { | ||||
if (protected_size < data_size) { | if (protected_size < data_size) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
protected_size, data_size); | protected_size, data_size); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); | int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); | ||||
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | ||||
@@ -322,9 +322,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
} | } | ||||
} | } | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||||
dst_offset, ret, pad_zero); | dst_offset, ret, pad_zero); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -349,7 +349,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | if (!IsTransShapeDstCorrect(args, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | ||||
@@ -364,13 +364,13 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
return TransFormatFromNchwToFz(args, result); | return TransFormatFromNchwToFz(args, result); | ||||
} | } | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | ||||
Format dst_format, std::vector<int64_t> &dst_shape) { | Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
if (CheckDataTypeSupport(data_type) != SUCCESS) { | if (CheckDataTypeSupport(data_type) != SUCCESS) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | ||||
@@ -383,7 +383,7 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||||
return TransShapeNchwToFz(src_shape, data_type, dst_shape); | return TransShapeNchwToFz(src_shape, data_type, dst_shape); | ||||
} | } | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | ||||
@@ -86,9 +86,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
hw_shape.push_back(DIM_DEFAULT_VALUE); | hw_shape.push_back(DIM_DEFAULT_VALUE); | ||||
hw_shape.push_back(src_shape[kNdDimIndexN]); | hw_shape.push_back(src_shape[kNdDimIndexN]); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
default: | default: | ||||
@@ -106,9 +106,9 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | ||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -118,14 +118,14 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeVector expect_src_shape; | ShapeVector expect_src_shape; | ||||
auto ret = TransShapeToFracZz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | auto ret = TransShapeToFracZz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||||
GELOGE(ret, "Trans shape from %s to %s, shape %s to %s, data type %s failed", | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), ShapeToString(args.dst_shape).c_str(), | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -140,10 +140,10 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
auto times = hw_shape.at(kNdDimIndexN); | auto times = hw_shape.at(kNdDimIndexN); | ||||
@@ -179,8 +179,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
auto w1_head = num_w1 * w0; | auto w1_head = num_w1 * w0; | ||||
@@ -195,8 +195,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -217,10 +217,10 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
@@ -257,8 +257,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
auto w1_head = num_w1 * w0; | auto w1_head = num_w1 * w0; | ||||
@@ -273,8 +273,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d", dst_offset, ret); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -287,13 +287,19 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
} // namespace | } // namespace | ||||
Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (!IsDataTypeSupport(args.src_data_type) || !CheckShape(args.src_format, args.src_shape) || | |||||
!IsShapeValid(args.dst_shape)) { | |||||
GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
if (!IsDataTypeSupport(args.src_data_type)) { | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | |||||
if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
@@ -306,7 +312,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | if (!IsTransShapeDstCorrect(args, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return TransFormatFromNdToFracZz(args, result, hw_shape); | return TransFormatFromNdToFracZz(args, result, hw_shape); | ||||
} | } | ||||
@@ -314,31 +320,38 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||||
Status FormatTransferFractalZz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | Status FormatTransferFractalZz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | ||||
Format dst_format, ShapeVector &dst_shape) { | Format dst_format, ShapeVector &dst_shape) { | ||||
if (!IsDataTypeSupport(data_type)) { | if (!IsDataTypeSupport(data_type)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
"Not support trans format from %s to %s, src shape %s, data type %s", | "Not support trans format from %s to %s, src shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | ||||
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShape(src_format, src_shape)) { | if (!CheckShape(src_format, src_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
"Not support trans format from %s to %s, src shape %s, data type %s", | "Not support trans format from %s to %s, src shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | ||||
ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
ShapeVector hw_shape; | ShapeVector hw_shape; | ||||
return TransShapeToFracZz(src_shape, data_type, dst_shape, hw_shape); | return TransShapeToFracZz(src_shape, data_type, dst_shape, hw_shape); | ||||
} | } | ||||
Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (!IsDataTypeSupport(args.src_data_type) || !IsShapeValid(args.src_shape) || | |||||
!CheckShape(args.dst_format, args.dst_shape)) { | |||||
GELOGE(PARAM_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
if (!IsDataTypeSupport(args.src_data_type)) { | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | |||||
if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Not support trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
@@ -346,8 +359,9 @@ Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult | |||||
ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
ShapeVector hw_shape; | ShapeVector hw_shape; | ||||
if (CheckShapeRelation(args, hw_shape) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckShapeRelation(args, hw_shape); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
return TransFormatFromFracZzToNd(args, result, hw_shape); | return TransFormatFromFracZzToNd(args, result, hw_shape); | ||||
} | } | ||||
@@ -356,7 +370,7 @@ Status FormatTransferFractalZzND::TransShape(Format src_format, const ShapeVecto | |||||
Format dst_format, ShapeVector &dst_shape) { | Format dst_format, ShapeVector &dst_shape) { | ||||
GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | ||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_ND, FORMAT_FRACTAL_ZZ) | REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_ND, FORMAT_FRACTAL_ZZ) | ||||
@@ -37,25 +37,25 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from FORMAT_FRACTAL_Z to HWCN, invalid data type %s", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from FORMAT_FRACTAL_Z to HWCN, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
int64_t c1 = Ceil(dst_shape.at(kHwcnC), c0); | int64_t c1 = Ceil(dst_shape.at(kHwcnC), c0); | ||||
int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | ||||
@@ -64,8 +64,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
std::string error = "Failed to check relationship between src shape" + | std::string error = "Failed to check relationship between src shape" + | ||||
FmtToStr(ShapeToString(src_shape)) + " and dst shape" + | FmtToStr(ShapeToString(src_shape)) + " and dst shape" + | ||||
FmtToStr(ShapeToString(dst_shape)); | FmtToStr(ShapeToString(dst_shape)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return PARAM_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -74,10 +74,10 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto n0 = args.src_shape.at(kFracZN0); | auto n0 = args.src_shape.at(kFracZN0); | ||||
@@ -113,11 +113,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from FracZ offset %ld to HWCN[%ld, %ld, %ld, %ld] " | "Failed to copy data from FracZ offset %ld to HWCN[%ld, %ld, %ld, %ld] " | ||||
"offset %ld, err-code %d", | "offset %ld, err-code %d", | ||||
src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -130,8 +130,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForFracZToHwcn(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForFracZToHwcn(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -142,18 +143,19 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from FracZ to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -161,7 +163,7 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
Status FormatTransferFracZHwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | Status FormatTransferFracZHwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | ||||
Format dst_format, std::vector<int64_t> &dst_shape) { | Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
GELOGD("The shape derivation from FracZ to HWCN is not unique. Trans shape in this direction is not supported"); | GELOGD("The shape derivation from FracZ to HWCN is not unique. Trans shape in this direction is not supported"); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferFracZHwcn, FORMAT_FRACTAL_Z, FORMAT_HWCN) | REGISTER_FORMAT_TRANSFER(FormatTransferFracZHwcn, FORMAT_FRACTAL_Z, FORMAT_HWCN) | ||||
@@ -38,32 +38,32 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(dst_shape, kNchwDimsNum)) { | if (!CheckShapeValid(dst_shape, kNchwDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
int64_t c1 = Ceil(dst_shape.at(kNchwC), c0); | int64_t c1 = Ceil(dst_shape.at(kNchwC), c0); | ||||
int64_t n0 = Ceil(dst_shape.at(kNchwN), static_cast<int64_t>(kNiSize)); | int64_t n0 = Ceil(dst_shape.at(kNchwN), static_cast<int64_t>(kNiSize)); | ||||
if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || src_shape.at(kFracZC0) != c0 || | if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || src_shape.at(kFracZC0) != c0 || | ||||
src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | ||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -72,10 +72,10 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto n0 = args.src_shape.at(kFracZN0); | auto n0 = args.src_shape.at(kFracZN0); | ||||
@@ -111,11 +111,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from FracZ offset %ld to NCHW[%ld, %ld, %ld, %ld] offset %ld, " | "Failed to copy data from FracZ offset %ld to NCHW[%ld, %ld, %ld, %ld] offset %ld, " | ||||
"err-code %d", | "err-code %d", | ||||
src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -128,8 +128,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForFracZToNchw(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForFracZToNchw(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -140,19 +141,20 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -160,7 +162,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
Status FormatTransferFracZNchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | Status FormatTransferFracZNchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | ||||
Format dst_format, std::vector<int64_t> &dst_shape) { | Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
GELOGD("The shape derivation from FracZ to NCHW is not unique. Trans shape in this direction is not supported"); | GELOGD("The shape derivation from FracZ to NCHW is not unique. Trans shape in this direction is not supported"); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferFracZNchw, FORMAT_FRACTAL_Z, FORMAT_NCHW) | REGISTER_FORMAT_TRANSFER(FormatTransferFracZNchw, FORMAT_FRACTAL_Z, FORMAT_NCHW) | ||||
@@ -43,9 +43,9 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||||
dst_shape.push_back(cube_size); | dst_shape.push_back(cube_size); | ||||
dst_shape.push_back(cube_size); | dst_shape.push_back(cube_size); | ||||
if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -55,21 +55,21 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.dst_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(args.dst_shape, kC1hwncoc0DimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
std::vector<int64_t> expect_dst_shape; | std::vector<int64_t> expect_dst_shape; | ||||
auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); | auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); | ||||
@@ -77,12 +77,12 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_dst_shape) { | if (args.dst_shape != expect_dst_shape) { | ||||
GELOGE(PARAM_INVALID, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
"Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " | "Failed to trans format, src and dst shape are not compatible. src shape %s, dst shape %s, " | ||||
"expect dst shape %s", | "expect dst shape %s", | ||||
ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | ||||
ShapeToString(expect_dst_shape).c_str()); | ShapeToString(expect_dst_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -91,10 +91,10 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto h = args.src_shape.at(kHwcnH); | auto h = args.src_shape.at(kHwcnH); | ||||
@@ -135,22 +135,22 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from HWCN[%ld, %ld, %ld, %ld] offset %ld to " | "Failed to copy data from HWCN[%ld, %ld, %ld, %ld] offset %ld to " | ||||
"C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, | h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, | ||||
dst_offset, ret); | dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} else { | } else { | ||||
auto ret = | auto ret = | ||||
memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to set to 0 to C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, " | "Failed to set to 0 to C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, " | ||||
"err-code %d", | "err-code %d", | ||||
c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); | c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -166,8 +166,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForHwcnToC1hwncoc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForHwcnToC1hwncoc0(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -178,18 +179,20 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -198,15 +201,15 @@ Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vec | |||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { | if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { | ||||
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
ShapeToString(src_shape).c_str()); | ShapeToString(src_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); | return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); | ||||
} else if (src_format != FORMAT_HWCN) { | } else if (src_format != FORMAT_HWCN) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} else { | } else { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
} | } | ||||
@@ -37,33 +37,33 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWC0 to NHWC, invalid data type %s", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from NC1HWC0 to NHWC, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) { | if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.dst_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(args.dst_shape, kNhwcDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
if (c0 <= 0) { | if (c0 <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | ||||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | ||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -72,10 +72,10 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto h = args.src_shape.at(kNc1hwc0H); | auto h = args.src_shape.at(kNc1hwc0H); | ||||
@@ -109,11 +109,11 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld to NHWC[%ld, %ld, %ld, %ld]" | "Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld to NHWC[%ld, %ld, %ld, %ld]" | ||||
" offset %ld, err-code %d", | " offset %ld, err-code %d", | ||||
n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -126,8 +126,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForNc1hwc0ToNhwc(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForNc1hwc0ToNhwc(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -138,18 +139,20 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from NC1HWC0 to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from NC1HWC0 to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -157,7 +160,7 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||||
Status FormatTransferNc1hwc0Nhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferNc1hwc0Nhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
GELOGD("The shape derivation from NC1HWC0 to NHWC is not unique. Trans shape in this direction is not supported"); | GELOGD("The shape derivation from NC1HWC0 to NHWC is not unique. Trans shape in this direction is not supported"); | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nhwc, FORMAT_NC1HWC0, FORMAT_NHWC) | REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nhwc, FORMAT_NC1HWC0, FORMAT_NHWC) | ||||
@@ -45,7 +45,7 @@ Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_ | |||||
Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
auto chw = c * h * w; | auto chw = c * h * w; | ||||
@@ -59,9 +59,9 @@ Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type | |||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -69,7 +69,7 @@ Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type | |||||
Status TransShapeNchwToFzC04(const std::vector<int64_t> &src_shape, DataType data_type, | Status TransShapeNchwToFzC04(const std::vector<int64_t> &src_shape, DataType data_type, | ||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto n = src_shape.at(kNchwN); | auto n = src_shape.at(kNchwN); | ||||
@@ -94,8 +94,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
std::vector<int64_t> expect_shape = {n, h, w, c}; | std::vector<int64_t> expect_shape = {n, h, w, c}; | ||||
auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); | auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||||
return NOT_CHANGED; | |||||
GELOGE(ret, "Failed to Transpose from NCHW to HWCN"); | |||||
return ret; | |||||
} | } | ||||
TransArgs args_tmp = args; | TransArgs args_tmp = args; | ||||
@@ -104,8 +104,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
// check size it should be same with original | // check size it should be same with original | ||||
size_t expect_size = n * c * h * w * size; // before has do check about mul | size_t expect_size = n * c * h * w * size; // before has do check about mul | ||||
if (trans_result_1.length != expect_size) { | if (trans_result_1.length != expect_size) { | ||||
GELOGE(INTERNAL_ERROR, "size is not match after transpose!"); | |||||
return NOT_CHANGED; | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "size is not match after transpose!"); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
// prepare for padding in chw | // prepare for padding in chw | ||||
@@ -118,20 +118,20 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
// data overflow check totally | // data overflow check totally | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
auto t1 = h_o * w_o; | auto t1 = h_o * w_o; | ||||
auto t2 = n_o * c_o; | auto t2 = n_o * c_o; | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | ||||
return INTERNAL_ERROR); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | if (dst_size == 0) { | ||||
result.length = 0; | result.length = 0; | ||||
@@ -140,15 +140,15 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); | auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); | ||||
if (retMem != EOK) { | if (retMem != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "memst failed!"); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memst failed!"); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
// copy data | // copy data | ||||
auto block = c * h * w * size; | auto block = c * h * w * size; | ||||
@@ -159,8 +159,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
for (auto k = 0; k < n; k++) { | for (auto k = 0; k < n; k++) { | ||||
ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); | ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memcpy_s failed!"); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
protectSize = protectSize - block; | protectSize = protectSize - block; | ||||
} | } | ||||
@@ -169,8 +169,8 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
std::vector<int64_t> perm_arg_2 = {2, 0, 1, 3}; | std::vector<int64_t> perm_arg_2 = {2, 0, 1, 3}; | ||||
ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); | ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||||
return NOT_CHANGED; | |||||
GELOGE(ret, "Failed to Transpose from NCHW to HWCN"); | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -180,7 +180,7 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
args_tmp = args; | args_tmp = args; | ||||
auto src_shape = args_tmp.src_shape; | auto src_shape = args_tmp.src_shape; | ||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
@@ -190,8 +190,8 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
auto w = src_shape.at(kNchwW); | auto w = src_shape.at(kNchwW); | ||||
if (c > kMaxDimsNumC) { | if (c > kMaxDimsNumC) { | ||||
GELOGE(PARAM_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
auto n_o = Ceil(n, c0) * c0; | auto n_o = Ceil(n, c0) * c0; | ||||
@@ -205,21 +205,21 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
// data overflow check | // data overflow check | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
auto t1 = h_o * w_o; | auto t1 = h_o * w_o; | ||||
auto t2 = n_o * c_o; | auto t2 = n_o * c_o; | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||||
return INTERNAL_ERROR); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | ||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
return INTERNAL_ERROR); | |||||
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | if (dst_size == 0) { | ||||
@@ -228,15 +228,15 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto ret = memset_s(dst.get(), dst_size, 0, dst_size); | auto ret = memset_s(dst.get(), dst_size, 0, dst_size); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "memst failed!"); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memst failed!"); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
auto p_s = args.data; | auto p_s = args.data; | ||||
@@ -249,8 +249,8 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, | ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, | ||||
p_s + (i * c * h * w + j * h * w) * size, block); | p_s + (i * c * h * w + j * h * w) * size, block); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||||
return INTERNAL_ERROR; | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "memcpy_s failed!"); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
protectSize = protectSize - block; | protectSize = protectSize - block; | ||||
} | } | ||||
@@ -270,7 +270,7 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||||
std::shared_ptr<uint8_t> dst = nullptr; | std::shared_ptr<uint8_t> dst = nullptr; | ||||
auto ret = PaddingNC(args, args_tmp, dst); | auto ret = PaddingNC(args, args_tmp, dst); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Padding in NC axis failed!"); | |||||
GELOGE(ret, "Padding in NC axis failed!"); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -281,26 +281,26 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||||
} | } | ||||
if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { | if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { | ||||
return TransFormatFromNchwToFzC04(args_tmp, result); | return TransFormatFromNchwToFzC04(args_tmp, result); | ||||
} | } | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
if (CheckDataTypeSupport(data_type) != SUCCESS) { | if (CheckDataTypeSupport(data_type) != SUCCESS) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { | if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { | ||||
return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); | return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); | ||||
} | } | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | ||||
@@ -32,13 +32,13 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
int64_t c0 = GetCubeSizeByDataType(data_type); | int64_t c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 <= 0) { | if (c0 <= 0) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
ShapeToString(src_shape).c_str()); | ShapeToString(src_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNchwN)); | dst_shape.push_back(src_shape.at(kNchwN)); | ||||
@@ -47,9 +47,9 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
dst_shape.push_back(src_shape.at(kNchwW)); | dst_shape.push_back(src_shape.at(kNchwW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -59,8 +59,8 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
std::vector<int64_t> expect_5d_shape; | std::vector<int64_t> expect_5d_shape; | ||||
auto ret = TransShapeNchwToNc1hwc0(args.src_shape, args.src_data_type, expect_5d_shape); | auto ret = TransShapeNchwToNc1hwc0(args.src_shape, args.src_data_type, expect_5d_shape); | ||||
@@ -68,12 +68,12 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
return ret; | return ret; | ||||
} | } | ||||
if (expect_5d_shape != args.dst_shape) { | if (expect_5d_shape != args.dst_shape) { | ||||
GELOGE(PARAM_INVALID, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
"Failed to trans format, the src and dst shape are not compatible. data" | "Failed to trans format, the src and dst shape are not compatible. data" | ||||
" type %s, src shape %s, dst shape %s, expect dst shape %s", | " type %s, src shape %s, dst shape %s, expect dst shape %s", | ||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.src_shape).c_str(), | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.src_shape).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(expect_5d_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(expect_5d_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -82,12 +82,12 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
"Failed to trans format from %s to %s, can not alloc the memory for" | "Failed to trans format from %s to %s, can not alloc the memory for" | ||||
" dst buf %ld, shape %s", | " dst buf %ld, shape %s", | ||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto n = args.src_shape.at(kNchwN); | auto n = args.src_shape.at(kNchwN); | ||||
@@ -97,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | ||||
if (c0 <= 0) { | if (c0 <= 0) { | ||||
GELOGE(INTERNAL_ERROR, "The c0 is invalid %ld", c0); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "The c0 is invalid %ld", c0); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
int64_t c1 = (c - 1) / c0 + 1; | int64_t c1 = (c - 1) / c0 + 1; | ||||
int64_t hw = h * w; | int64_t hw = h * w; | ||||
@@ -129,21 +129,21 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from NCHW[%ld] offset %ld to " | "Failed to copy data from NCHW[%ld] offset %ld to " | ||||
"NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
srcIdx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | srcIdx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} else { | } else { | ||||
auto ret = | auto ret = | ||||
memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to set to 0 to " | "Failed to set to 0 to " | ||||
"NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | ||||
n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -159,8 +159,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForNchwToNc1hwc0(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
// Guarantee the validity of parameters in check function | // Guarantee the validity of parameters in check function | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
@@ -172,20 +173,21 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
GELOGD( | GELOGD( | ||||
"Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | ||||
"%s, dst shape %s memory size %ld", | "%s, dst shape %s memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -195,7 +197,7 @@ Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vecto | |||||
if (src_format == FORMAT_NCHW) { | if (src_format == FORMAT_NCHW) { | ||||
return TransShapeNchwToNc1hwc0(src_shape, data_type, dst_shape); | return TransShapeNchwToNc1hwc0(src_shape, data_type, dst_shape); | ||||
} else { | } else { | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
} | } | ||||
@@ -34,8 +34,8 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
int64_t c0 = GetCubeSizeByDataType(data_type); | int64_t c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 <= 0) { | if (c0 <= 0) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to get cube size, the data type is invalid"); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNhwcN)); | dst_shape.push_back(src_shape.at(kNhwcN)); | ||||
@@ -44,9 +44,9 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
dst_shape.push_back(src_shape.at(kNhwcW)); | dst_shape.push_back(src_shape.at(kNhwcW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", | |||||
ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -56,21 +56,21 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||||
std::string error = "Dose not support trans format from " + | std::string error = "Dose not support trans format from " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from NHWC to NC1HWC0, invalid data type %s", | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to trans shape from NHWC to NC1HWC0, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return UNSUPPORTED; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.src_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(args.src_shape, kNhwcDimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
if (!CheckShapeValid(args.dst_shape, kNc1hwc0DimsNum)) { | if (!CheckShapeValid(args.dst_shape, kNc1hwc0DimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
std::vector<int64_t> expect_dst_shape; | std::vector<int64_t> expect_dst_shape; | ||||
auto ret = TransShapeNhwcToNc1hwc0(args.src_shape, args.src_data_type, expect_dst_shape); | auto ret = TransShapeNhwcToNc1hwc0(args.src_shape, args.src_data_type, expect_dst_shape); | ||||
@@ -78,12 +78,12 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_dst_shape) { | if (args.dst_shape != expect_dst_shape) { | ||||
GELOGE(PARAM_INVALID, | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
"Failed to trans format, the src and dst shape are not compatible. src shape %s, dst shape %s, " | "Failed to trans format, the src and dst shape are not compatible. src shape %s, dst shape %s, " | ||||
"expect dst shape %s", | "expect dst shape %s", | ||||
ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | ||||
ShapeToString(expect_dst_shape).c_str()); | ShapeToString(expect_dst_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -92,10 +92,10 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | TypeUtils::FormatToSerialString(args.src_format).c_str(), | ||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | ||||
return OUT_OF_MEMORY; | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | } | ||||
auto n = args.src_shape.at(kNhwcN); | auto n = args.src_shape.at(kNhwcN); | ||||
@@ -131,19 +131,19 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
if (c_idx < c) { | if (c_idx < c) { | ||||
auto ret = memcpy_s(dst.get() + dst_offset, protected_size, args.data + src_offset, size); | auto ret = memcpy_s(dst.get() + dst_offset, protected_size, args.data + src_offset, size); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to copy data from NHWC[%ld, %ld, %ld, %ld] offset %ld to " | "Failed to copy data from NHWC[%ld, %ld, %ld, %ld] offset %ld to " | ||||
"NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld err-code %d", | "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld err-code %d", | ||||
n_idx, h_idx, w_idx, c_idx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | n_idx, h_idx, w_idx, c_idx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} else { | } else { | ||||
auto ret = memset_s(dst.get() + dst_offset, protected_size, 0, size); | auto ret = memset_s(dst.get() + dst_offset, protected_size, 0, size); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to set 0 to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld base err-code %d", n_idx, c1_idx, | "Failed to set 0 to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld base err-code %d", n_idx, c1_idx, | ||||
h_idx, w_idx, c0_idx, dst_offset, ret); | h_idx, w_idx, c0_idx, dst_offset, ret); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -158,8 +158,9 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
} // namespace | } // namespace | ||||
Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | ||||
if (CheckArgsForNhwcToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
Status ret = CheckArgsForNhwcToNc1hwc0(args); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | } | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
@@ -170,18 +171,20 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
GELOGD("Begin to trans format from NHWC to NC1HWC0, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from NHWC to NC1HWC0, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
return INTERNAL_ERROR; | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -190,15 +193,15 @@ Status FormatTransferNhwcNc1hwc0::TransShape(Format src_format, const std::vecto | |||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
if (src_format == FORMAT_NHWC && CheckDataTypeSupported(data_type)) { | if (src_format == FORMAT_NHWC && CheckDataTypeSupported(data_type)) { | ||||
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | ||||
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", | |||||
ShapeToString(src_shape).c_str()); | ShapeToString(src_shape).c_str()); | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return TransShapeNhwcToNc1hwc0(src_shape, data_type, dst_shape); | return TransShapeNhwcToNc1hwc0(src_shape, data_type, dst_shape); | ||||
} else if (src_format != FORMAT_NHWC) { | } else if (src_format != FORMAT_NHWC) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} else { | } else { | ||||
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
} | } | ||||
@@ -141,7 +141,7 @@ std::vector<int64_t> TransShapeByPerm(const std::vector<int64_t> &src_shape, con | |||||
Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | ||||
const std::vector<int64_t> &perm_arg, TransResult &result) { | const std::vector<int64_t> &perm_arg, TransResult &result) { | ||||
if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) { | if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
auto dst_shape = TransShapeByPerm(src_shape, perm_arg); | auto dst_shape = TransShapeByPerm(src_shape, perm_arg); | ||||
@@ -172,12 +172,12 @@ Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Data | |||||
auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | ||||
static_cast<size_t>(data_size)); | static_cast<size_t>(data_size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
"Failed to transpose, src shape %s, perm arg %s, dst shape %s, " | "Failed to transpose, src shape %s, perm arg %s, dst shape %s, " | ||||
"failed to write to dst offset %ld, current dim offset %s", | "failed to write to dst offset %ld, current dim offset %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), | ShapeToString(src_shape).c_str(), ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), | ||||
dst_offset_bytes, ShapeToString(dst_indexes).c_str()); | dst_offset_bytes, ShapeToString(dst_indexes).c_str()); | ||||
return INTERNAL_ERROR; | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | } | ||||
AddOne(dst_shape, dst_indexes); | AddOne(dst_shape, dst_indexes); | ||||
++dst_index; | ++dst_index; | ||||
@@ -192,14 +192,14 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
const std::vector<int64_t> &dst_shape, DataType src_data_type, | const std::vector<int64_t> &dst_shape, DataType src_data_type, | ||||
const std::vector<int64_t> &perm_arg, TransResult &result) { | const std::vector<int64_t> &perm_arg, TransResult &result) { | ||||
if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) { | if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | ||||
if (dst_shape != expected_shape) { | if (dst_shape != expected_shape) { | ||||
std::string error = "Failed to trans axis for perm_arg" + | std::string error = "Failed to trans axis for perm_arg" + | ||||
FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" + | FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" + | ||||
FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape)); | FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape)); | ||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
} | } | ||||
return Transpose(data, src_shape, src_data_type, perm_arg, result); | return Transpose(data, src_shape, src_data_type, perm_arg, result); | ||||
@@ -211,16 +211,16 @@ Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t | |||||
std::string error = "Failed to trans shape, do not support transpose from format " + | std::string error = "Failed to trans shape, do not support transpose from format " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
auto iter = dst_iter->second.find(dst_format); | auto iter = dst_iter->second.find(dst_format); | ||||
if (iter == dst_iter->second.end()) { | if (iter == dst_iter->second.end()) { | ||||
std::string error = "Failed to trans shape, do not support transpose from format " + | std::string error = "Failed to trans shape, do not support transpose from format " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
perm = iter->second; | perm = iter->second; | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -233,7 +233,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!IsTransShapeDstCorrect(args, expected_shape)) { | if (!IsTransShapeDstCorrect(args, expected_shape)) { | ||||
return PARAM_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
return Transpose(args.data, args.src_shape, args.src_data_type, perm_args[args.src_format][args.dst_format], result); | return Transpose(args.data, args.src_shape, args.src_data_type, perm_args[args.src_format][args.dst_format], result); | ||||
@@ -244,7 +244,7 @@ Status FormatTransferTranspose::TransShape(Format src_format, const std::vector< | |||||
std::vector<int64_t> perm_arg; | std::vector<int64_t> perm_arg; | ||||
GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); | GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); | ||||
if (!IsShapeArgValid(src_shape, perm_arg)) { | if (!IsShapeArgValid(src_shape, perm_arg)) { | ||||
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; | |||||
return ACL_ERROR_GE_SHAPE_INVALID; | |||||
} | } | ||||
dst_shape = TransShapeByPerm(src_shape, perm_arg); | dst_shape = TransShapeByPerm(src_shape, perm_arg); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -38,14 +38,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||||
std::string error = "Failed to trans data from format " + | std::string error = "Failed to trans data from format " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
auto src_shape_size = GetItemNumByShape(args.src_shape); | auto src_shape_size = GetItemNumByShape(args.src_shape); | ||||
if (args.data == nullptr && src_shape_size != 0) { | if (args.data == nullptr && src_shape_size != 0) { | ||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Invalid input null data"); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
return transfer->TransFormat(args, result); | return transfer->TransFormat(args, result); | ||||
@@ -64,8 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
std::string error = "Failed to trans data from format " + | std::string error = "Failed to trans data from format " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | ||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | ||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_FORMAT_INVALID; | |||||
} | } | ||||
return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); | return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); | ||||
@@ -77,13 +77,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr | |||||
std::string error = "Failed to trans data from datatype " + | std::string error = "Failed to trans data from datatype " + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | ||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | ||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | |||||
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
} | } | ||||
if (args.data == nullptr && args.src_data_size != 0) { | if (args.data == nullptr && args.src_data_size != 0) { | ||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | |||||
return PARAM_INVALID; | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Invalid input null data"); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
return transfer->TransDataType(args, result); | return transfer->TransDataType(args, result); | ||||
@@ -110,9 +110,9 @@ GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_AIPP_MODE_INVALID, "AIPP mode invalid."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "Task type invalid."); | GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "Task type invalid."); | ||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Kernel type invalid."); | GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Kernel type invalid."); | ||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_PLGMGR_PATH_INVALID, "Plugin path is invalid."); | GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_PLGMGR_PATH_INVALID, "Plugin path is invalid."); | ||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID, "Format is invalid when transferring shape."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Shape is invalid when transferring shape."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID, "Datatype is invalid when transferring shape."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_FORMAT_INVALID, "Format is invalid."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_SHAPE_INVALID, "Shape is invalid."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_DATATYPE_INVALID, "Datatype is invalid."); | |||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_ALLOCATION, "Memory allocation error."); | GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_ALLOCATION, "Memory allocation error."); | ||||
GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate memory."); | GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate memory."); | ||||
@@ -53,9 +53,9 @@ static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016; | |||||
static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017; | static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017; | ||||
static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018; | static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018; | ||||
static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019; | static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019; | ||||
static const uint32_t ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID = 145020; | |||||
static const uint32_t ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID = 145021; | |||||
static const uint32_t ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID = 145022; | |||||
static const uint32_t ACL_ERROR_GE_FORMAT_INVALID = 145020; | |||||
static const uint32_t ACL_ERROR_GE_SHAPE_INVALID = 145021; | |||||
static const uint32_t ACL_ERROR_GE_DATATYPE_INVALID = 145022; | |||||
static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000; | static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000; | ||||
static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001; | static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001; | ||||
static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000; | static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000; | ||||
@@ -1 +1 @@ | |||||
Subproject commit 7a51997cbd34e1869b9fb4ea5597a021e6427272 | |||||
Subproject commit 6b802ec3cf711e9942a7e2a74f04a53647aae473 |
@@ -1 +1 @@ | |||||
Subproject commit 227b10355427038785e95c81a41cda99893eba08 | |||||
Subproject commit 6a07f1a8b9b8b4630a5b60d9d8d02ec4a6314d68 |
@@ -365,7 +365,7 @@ TEST_F(UtestDataTypeTransfer, invalid_src_data_type) { | |||||
TransResult result; | TransResult result; | ||||
DataTypeTransfer transfer; | DataTypeTransfer transfer; | ||||
EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
/* | /* | ||||
@@ -386,8 +386,8 @@ TEST_F(UtestDataTypeTransfer, unsupprot_trans) { | |||||
TransResult result; | TransResult result; | ||||
DataTypeTransfer transfer; | DataTypeTransfer transfer; | ||||
EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(TransDataType(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
EXPECT_EQ(TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestDataTypeTransfer, unsupprot_trans2) { | TEST_F(UtestDataTypeTransfer, unsupprot_trans2) { | ||||
@@ -396,8 +396,8 @@ TEST_F(UtestDataTypeTransfer, unsupprot_trans2) { | |||||
TransResult result; | TransResult result; | ||||
DataTypeTransfer transfer; | DataTypeTransfer transfer; | ||||
EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(TransDataType(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(transfer.TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
EXPECT_EQ(TransDataType(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -679,7 +679,7 @@ TEST_F(UtestFormatTransfer5dNhwc, nc1hwc0_to_nhwc_float2) { | |||||
} | } | ||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_src_format) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_src_format) { | ||||
@@ -689,7 +689,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_format) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_src_shape1) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_src_shape1) { | ||||
@@ -699,7 +699,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_shape1) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, InvalidSrcShape2) { | TEST_F(UtestFormatTransfer5dNhwc, InvalidSrcShape2) { | ||||
@@ -709,7 +709,7 @@ TEST_F(UtestFormatTransfer5dNhwc, InvalidSrcShape2) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_src_data_type) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_src_data_type) { | ||||
@@ -719,7 +719,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_data_type) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_format) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_format) { | ||||
@@ -729,7 +729,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_format) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape1) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape1) { | ||||
@@ -739,7 +739,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape1) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape2) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape2) { | ||||
@@ -749,7 +749,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_dst_shape2) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | ||||
@@ -759,7 +759,7 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNc1hwc0Nhwc transfer; | FormatTransferNc1hwc0Nhwc transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_uint8) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferC1hwncoc0Hwcn transfer; | FormatTransferC1hwncoc0Hwcn transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_int32) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_int32) { | ||||
@@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_data_type_int32) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_INT32}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_INT32}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_format_nc1khkwhwc0) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_format_nc1khkwhwc0) { | ||||
@@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_format_nc1khkw | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NC1KHKWHWC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_NC1KHKWHWC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_format_nchw) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_format_nchw) { | ||||
@@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_format_nchw) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_NCHW, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_NCHW, {1, 4, 4, 1, 16, 16}, {4, 4, 3, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape) { | ||||
@@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16}, {4, 4, 3, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16}, {4, 4, 3, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape2) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape2) { | ||||
@@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, -16}, {4, 4, 3, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, -16}, {4, 4, 3, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invali_dst_shape) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invali_dst_shape) { | ||||
@@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invali_dst_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_shape2) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_shape2) { | ||||
@@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, -1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 3, -1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_dst_shape_relation) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_dst_shape_relation) { | ||||
@@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_invalid_src_dst_shape_rela | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_C1HWNCoC0, FORMAT_HWCN, {1, 4, 4, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp16_success_lt_cube) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp16_success_lt_cube) { | ||||
@@ -158,7 +158,7 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp16_success_lt_cube) { | |||||
} | } | ||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_gp16_success_eq_cube) { | TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_gp16_success_eq_cube) { | ||||
@@ -2332,7 +2332,7 @@ TEST_F(UtestFormatTransferNdFractNz, nd_shape4_fp16) { | |||||
} | } | ||||
EXPECT_EQ( | EXPECT_EQ( | ||||
transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, nd_shape5_fp16) { | TEST_F(UtestFormatTransferNdFractNz, nd_shape5_fp16) { | ||||
@@ -4785,7 +4785,7 @@ TEST_F(UtestFormatTransferNdFractNz, nd_shape4_fp32) { | |||||
EXPECT_EQ((reinterpret_cast<float *>(result2.data.get()))[i], data[i]); | EXPECT_EQ((reinterpret_cast<float *>(result2.data.get()))[i], data[i]); | ||||
} | } | ||||
EXPECT_EQ(transfer2.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer2.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, nchw_shape4_fp32) { | TEST_F(UtestFormatTransferNdFractNz, nchw_shape4_fp32) { | ||||
@@ -9058,9 +9058,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNz transfer; | FormatTransferFractalNz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type) { | ||||
@@ -9078,9 +9078,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type) { | |||||
DT_UNDEFINED}; | DT_UNDEFINED}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNz transfer; | FormatTransferFractalNz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||||
ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_format) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_format) { | ||||
@@ -9093,9 +9093,9 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_format) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNz transfer; | FormatTransferFractalNz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape) { | TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape) { | ||||
@@ -9104,7 +9104,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNz transfer; | FormatTransferFractalNz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
SUCCESS); | SUCCESS); | ||||
} | } | ||||
@@ -9115,7 +9115,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type2) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type2) { | ||||
@@ -9133,7 +9133,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type2) { | |||||
DT_UNDEFINED}; | DT_UNDEFINED}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type3) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type3) { | ||||
@@ -9151,7 +9151,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_data_type3) { | |||||
DT_VARIANT}; | DT_VARIANT}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_dst_format2) { | TEST_F(UtestFormatTransferNdFractNz, invalid_dst_format2) { | ||||
@@ -9164,8 +9164,8 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_dst_format2) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape2) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape2) { | ||||
@@ -9174,7 +9174,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_NZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | ||||
@@ -9187,7 +9187,7 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalNzND transfer; | FormatTransferFractalNzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -1894,7 +1894,7 @@ TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16_1) { | |||||
} | } | ||||
EXPECT_EQ( | EXPECT_EQ( | ||||
transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16) { | TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16) { | ||||
@@ -2071,7 +2071,7 @@ TEST_F(UtestFormatTransferNdFractZz, nd_shape4_fp16) { | |||||
} | } | ||||
EXPECT_EQ( | EXPECT_EQ( | ||||
transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, nd_shape5_fp16) { | TEST_F(UtestFormatTransferNdFractZz, nd_shape5_fp16) { | ||||
@@ -7877,9 +7877,9 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZz transfer; | FormatTransferFractalZz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type) { | TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type) { | ||||
@@ -7897,9 +7897,9 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type) { | |||||
DT_UNDEFINED}; | DT_UNDEFINED}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZz transfer; | FormatTransferFractalZz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||||
ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_src_format) { | TEST_F(UtestFormatTransferNdFractZz, invalid_src_format) { | ||||
@@ -7912,10 +7912,10 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_format) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZz transfer; | FormatTransferFractalZz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape) { | TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape) { | ||||
@@ -7924,7 +7924,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_ZZ, {1, 1, 4, 4}, {1, 1, 16, 16}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZz transfer; | FormatTransferFractalZz transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
SUCCESS); | SUCCESS); | ||||
} | } | ||||
@@ -7935,7 +7935,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 1, 1, 16, 16}, {1, 4, 4}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZzND transfer; | FormatTransferFractalZzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type2) { | TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type2) { | ||||
@@ -7953,7 +7953,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_data_type2) { | |||||
DT_UNDEFINED}; | DT_UNDEFINED}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZzND transfer; | FormatTransferFractalZzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_dst_format2) { | TEST_F(UtestFormatTransferNdFractZz, invalid_dst_format2) { | ||||
@@ -7966,8 +7966,8 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_dst_format2) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZzND transfer; | FormatTransferFractalZzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), UNSUPPORTED); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
EXPECT_EQ(TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape2) { | TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape2) { | ||||
@@ -7976,7 +7976,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_ZZ, FORMAT_NHWC, {1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZzND transfer; | FormatTransferFractalZzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | ||||
@@ -7989,7 +7989,7 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | |||||
DT_FLOAT16}; | DT_FLOAT16}; | ||||
TransResult result; | TransResult result; | ||||
FormatTransferFractalZzND transfer; | FormatTransferFractalZzND transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_data_type_invalid_dat | |||||
TransResult result; | TransResult result; | ||||
FormatTransferFracZHwcn transfer; | FormatTransferFracZHwcn transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_format_reserved) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_format_reserved) { | ||||
@@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_format_reserved) | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_format_reserved) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_format_reserved) { | ||||
@@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_format_reserved) | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape) { | ||||
@@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape2) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape2) { | ||||
@@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, -1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, -1, 16, 16}, {4, 4, 1, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape) { | ||||
@@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape2) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape2) { | ||||
@@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, -1, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, -1, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation1) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation1) { | ||||
@@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relatio | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 17, 1}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation2) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relation2) { | ||||
@@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_invalid_src_dst_shape_relatio | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 17}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_HWCN, {16, 1, 16, 16}, {4, 4, 1, 17}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_lt_cube) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_lt_cube) { | ||||
@@ -302,7 +302,7 @@ TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_eq_cube) { | |||||
} | } | ||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_gt_cube) { | TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_gt_cube) { | ||||
@@ -39,7 +39,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_data_type) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferFracZNchw transfer; | FormatTransferFracZNchw transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_Invalid_src_format_reserved) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_Invalid_src_format_reserved) { | ||||
@@ -50,7 +50,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_Invalid_src_format_reserved) | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_NCHW, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_RESERVED, FORMAT_NCHW, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_format_reserved) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_format_reserved) { | ||||
@@ -61,7 +61,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_format_reserved) | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_RESERVED, {16, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape) { | ||||
@@ -72,7 +72,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 1, 16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape2) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape2) { | ||||
@@ -83,7 +83,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, -16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, -16, 16}, {1, 1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape) { | ||||
@@ -94,7 +94,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape2) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape2) { | ||||
@@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, -1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, -1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation1) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation1) { | ||||
@@ -116,7 +116,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relatio | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 17, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {1, 17, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation2) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relation2) { | ||||
@@ -127,7 +127,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_invalid_src_dst_shape_relatio | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {17, 1, 4, 4}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_FRACTAL_Z, FORMAT_NCHW, {16, 1, 16, 16}, {17, 1, 4, 4}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_lt_cube) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_lt_cube) { | ||||
@@ -302,7 +302,7 @@ TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_eq_cube) { | |||||
} | } | ||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_gt_cube) { | TEST_F(UtestFormatTransferFraczNchw, fracz_to_nchw_fp16_success_gt_cube) { | ||||
@@ -42,7 +42,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_uint8) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferHwcnC1hwncoc0 transfer; | FormatTransferHwcnC1hwncoc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_int32) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_int32) { | ||||
@@ -57,7 +57,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_data_type_int32) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_INT32}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_INT32}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_format_nchw) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_format_nchw) { | ||||
@@ -72,10 +72,10 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_format_nchw) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format_nc1khkwhwc0) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format_nc1khkwhwc0) { | ||||
@@ -90,7 +90,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format_nc1khkwhw | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape) { | ||||
@@ -105,7 +105,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape2) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape2) { | ||||
@@ -120,7 +120,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4}, {1, 4, 4, 1, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape3) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape3) { | ||||
@@ -139,10 +139,10 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_src_shape3) { | |||||
DT_FLOAT}; | DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format) { | ||||
@@ -157,7 +157,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_format) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 1, 4, 4, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_NC1KHKWHWC0, {4, 4, 3, 1}, {1, 1, 4, 4, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_shape2) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_shape2) { | ||||
@@ -172,7 +172,7 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_invalid_dst_shape2) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {2, 4, 4, 1, 16, 16}, DT_FLOAT}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_C1HWNCoC0, {4, 4, 3, 1}, {2, 4, 4, 1, 16, 16}, DT_FLOAT}; | ||||
TransResult result; | TransResult result; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_fp16_success_lt_cube) { | TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_fp16_success_lt_cube) { | ||||
@@ -640,7 +640,7 @@ TEST_F(UtestFormatTransferNchw5d, invalid_data_format) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferNchwNc1hwc0 transfer; | FormatTransferNchwNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -691,7 +691,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_shape1) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
/* | /* | ||||
@@ -716,10 +716,10 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_format) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
Status status = | Status status = | ||||
transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); | ||||
EXPECT_EQ(status, ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
EXPECT_EQ(status, ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwc5d, invalid_dst_shape2) { | TEST_F(UtestFormatTransferNhwc5d, invalid_dst_shape2) { | ||||
@@ -729,7 +729,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_dst_shape2) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwc5d, invalid_src_data_type) { | TEST_F(UtestFormatTransferNhwc5d, invalid_src_data_type) { | ||||
@@ -739,7 +739,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_data_type) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | ||||
@@ -749,7 +749,7 @@ TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | |||||
TransResult result; | TransResult result; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||||
EXPECT_EQ(transfer.TransFormat(args, result), ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwc5d, invalid_data_shape) { | TEST_F(UtestFormatTransferNhwc5d, invalid_data_shape) { | ||||
@@ -758,13 +758,13 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_data_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferNhwcNc1hwc0 transfer; | FormatTransferNhwcNc1hwc0 transfer; | ||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
TransArgs args2{ | TransArgs args2{ | ||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_STRING}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_STRING}; | ||||
FormatTransferNhwcNc1hwc0 transfer2; | FormatTransferNhwcNc1hwc0 transfer2; | ||||
EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||||
ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -5360,7 +5360,7 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_type) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_VARIANT}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_VARIANT}; | ||||
FormatTransferFractalZ transfer; | FormatTransferFractalZ transfer; | ||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID); | |||||
ACL_ERROR_GE_DATATYPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwcFz, invalid_data_format) { | TEST_F(UtestFormatTransferNhwcFz, invalid_data_format) { | ||||
@@ -5369,7 +5369,7 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_format) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_CHWN, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_CHWN, FORMAT_FRACTAL_NZ, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferFractalZ transfer; | FormatTransferFractalZ transfer; | ||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTransferNhwcFz, invalid_data_shape) { | TEST_F(UtestFormatTransferNhwcFz, invalid_data_shape) { | ||||
@@ -5378,19 +5378,19 @@ TEST_F(UtestFormatTransferNhwcFz, invalid_data_shape) { | |||||
reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NHWC, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferFractalZ transfer; | FormatTransferFractalZ transfer; | ||||
EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | EXPECT_EQ(transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
TransArgs args2{ | TransArgs args2{ | ||||
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferFractalZ transfer2; | FormatTransferFractalZ transfer2; | ||||
EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | EXPECT_EQ(transfer2.TransShape(args2.src_format, args2.src_shape, args2.src_data_type, args2.dst_format, args2.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
TransArgs args3{ | TransArgs args3{ | ||||
reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | reinterpret_cast<uint8_t *>(data), FORMAT_NCHW, FORMAT_FRACTAL_Z, {1, 4, 4}, {1, 1, 1, 16, 16}, DT_FLOAT16}; | ||||
FormatTransferFractalZ transfer3; | FormatTransferFractalZ transfer3; | ||||
EXPECT_EQ(transfer3.TransShape(args3.src_format, args3.src_shape, args3.src_data_type, args3.dst_format, args3.dst_shape), | EXPECT_EQ(transfer3.TransShape(args3.src_format, args3.src_shape, args3.src_data_type, args3.dst_format, args3.dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -4659,14 +4659,14 @@ TEST_F(UtestFormatTranspose, invalid_data_shape) { | |||||
FormatTransferTranspose transfer; | FormatTransferTranspose transfer; | ||||
std::vector<int64_t> dst_shape; | std::vector<int64_t> dst_shape; | ||||
EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, std::vector<int64_t>({}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, std::vector<int64_t>({}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID); | |||||
ACL_ERROR_GE_SHAPE_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTranspose, invalid_src_format) { | TEST_F(UtestFormatTranspose, invalid_src_format) { | ||||
FormatTransferTranspose transfer; | FormatTransferTranspose transfer; | ||||
std::vector<int64_t> dst_shape; | std::vector<int64_t> dst_shape; | ||||
EXPECT_EQ(transfer.TransShape(FORMAT_NC1HWC0, std::vector<int64_t>({1, 3, 8, 8}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | EXPECT_EQ(transfer.TransShape(FORMAT_NC1HWC0, std::vector<int64_t>({1, 3, 8, 8}), DT_FLOAT16, FORMAT_HWCN, dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
TEST_F(UtestFormatTranspose, invalid_dst_format) { | TEST_F(UtestFormatTranspose, invalid_dst_format) { | ||||
@@ -4674,7 +4674,7 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) { | |||||
std::vector<int64_t> dst_shape; | std::vector<int64_t> dst_shape; | ||||
std::vector<int64_t> src_shape; | std::vector<int64_t> src_shape; | ||||
EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | ||||
ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID); | |||||
ACL_ERROR_GE_FORMAT_INVALID); | |||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |