From 16c5ba59006b40a1fd7f0c1fa9a23fa42bd78117 Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 11 Jun 2021 18:20:07 +0800 Subject: [PATCH] add --- .../formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc | 1 + .../formats/format_transfers/format_transfer_fractal_nz.cc | 2 +- .../formats/format_transfers/format_transfer_fractal_z.cc | 3 +++ .../formats/format_transfers/format_transfer_fracz_hwcn.cc | 1 + .../formats/format_transfers/format_transfer_fracz_nchw.cc | 1 + .../formats/format_transfers/format_transfer_fracz_nhwc.cc | 1 + .../formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc | 1 + .../formats/format_transfers/format_transfer_nc1hwc0_nchw.cc | 1 + .../formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc | 1 + .../formats/format_transfers/format_transfer_nchw_nc1hwc0.cc | 1 + .../formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc | 1 + .../formats/format_transfers/format_transfer_transpose.cc | 1 + 12 files changed, 14 insertions(+), 1 deletion(-) diff --git a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc index ce271c6d..aae95584 100644 --- a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc @@ -123,6 +123,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index 798ec55a..4f597e32 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { return CheckShapeValid(shape, kDimSize4D); default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + - " and FORMAT_FRACTAL_NZ is not supported."; + " and FORMAT_FRACTAL_NZ is not supported."; GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index 38125979..882a2a68 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -226,6 +226,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); errno_t ret = EOK; if (need_pad_zero) { ret = memset_s(dst.get() + offset, static_cast(protected_size), 0, static_cast(size)); @@ -390,6 +391,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); errno_t ret = EOK; if (pad_zero) { @@ -474,6 +476,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); errno_t ret = EOK; if (pad_zero) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc index f6af7534..abe6263b 100755 --- a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -128,6 +128,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc index aaeca490..58073397 100755 --- a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc @@ -130,6 +130,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc index 1e71ea09..3122f137 100755 --- a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc @@ -128,6 +128,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc index cb7f889b..c597cde0 100755 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc @@ -149,6 +149,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t c_idx = c0_idx + c1_idx * c0; int64_t src_idx = h_idx * wcn + w_idx * cn + c_idx * n + n_idx; auto src_offset = src_idx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc index 09ff45d9..c442bee9 100755 --- a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc +++ b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc @@ -129,6 +129,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc index e9e41cd1..603ddffa 100755 --- a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc +++ b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc @@ -129,6 +129,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc index ea2b1d7f..5cab311d 100755 --- a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc @@ -144,6 +144,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t cIdx = c0_idx + c1_idx * c0; int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; auto src_offset = srcIdx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index 518790b6..939c967c 100755 --- a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -149,6 +149,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t c_idx = c0_idx + c1_idx * c0; int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; auto src_offset = src_idx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc index 54c5444b..9a4d3fd6 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -171,6 +171,7 @@ Status Transpose(const uint8_t *src, const std::vector &src_shape, Data auto protected_size = dst_size - dst_offset_bytes < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset_bytes : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast(protected_size), src + src_offset, static_cast(data_size)); if (ret != EOK) {