Browse Source

Transdata

tags/v1.3.0
zk 4 years ago
parent
commit
459c9a4ab2
2 changed files with 233 additions and 21 deletions
  1. +172
    -21
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  2. +61
    -0
      tests/ut/ge/common/format_transfer_hwcn_fractalz_unittest.cc

+ 172
- 21
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -29,6 +29,39 @@
namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {
constexpr int64_t kCubeN = 16;
constexpr int64_t kDim = 1;

static int64_t Measure(int64_t x, int64_t y) {
int64_t z = y;
while (x % y != 0) {
z = x % y;
x = y;
y = z;
}
return z;
}
// least common multiple
static int64_t Lcm(int64_t a, int64_t b) {
if (b == 0) {
return -1;
}
int64_t temp = (a * b) / (Measure(a, b));
return temp;
}
// get the result of two number divisor and let result round up
static int64_t DivCeil(int64_t a, int64_t b) {
if (b == 0) {
return -1;
} else {
int64_t ret = a / b;
if ((a % b) != 0) {
ret++;
}
return ret;
}
}

Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; }


/** /**
@@ -61,6 +94,35 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_
return SUCCESS; return SUCCESS;
} }


Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape
, int64_t groups) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_DATATYPE_INVALID;
}
int64_t cin_ori = c;
int64_t cout_ori = n / groups;
int64_t cube_k = data_type == DT_INT8 ? 32 : 16;
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
groups);
int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = DivCeil(groups, e_mult);
auto n1 = DivCeil(cout_ori * e_mult, kCubeN);
dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w);
dst_shape.push_back(n1);
dst_shape.push_back(16);
dst_shape.push_back(cube_k);
if (!IsShapeValid(dst_shape)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}

Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { if (!CheckShapeValid(src_shape, kNchwDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID; return ACL_ERROR_GE_SHAPE_INVALID;
@@ -82,10 +144,24 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
auto w = src_shape.at(kHwcnW); auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC); auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN); auto n = src_shape.at(kHwcnN);

return TransShapeToFz(n, c, h, w, data_type, dst_shape); return TransShapeToFz(n, c, h, w, data_type, dst_shape);
} }


Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, int64_t groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
}

auto h = src_shape.at(kHwcnH);
auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);

return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
}


Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID; return ACL_ERROR_GE_SHAPE_INVALID;
@@ -127,8 +203,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr, dst == nullptr,
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;); return ACL_ERROR_GE_MEMORY_ALLOCATION;);
@@ -174,8 +249,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
} }
} }
if (ret != EOK) { if (ret != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,
"Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
ret, need_pad_zero); ret, need_pad_zero);
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
} }
@@ -189,6 +263,85 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS; return SUCCESS;
} }


Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, int64_t groups){
int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC];
int64_t n_dim = args.src_shape[kHwcnN];
int64_t cin_ori = c_dim;
int64_t cout_ori = n_dim / groups;
if (cin_ori == 0 || cout_ori == 0) {
GELOGE(GRAPH_FAILED,
"Cin_ori, cout_ori must not be equal 0, "
"and current cin_ori, cout_ori, groups are %d %d %d",
cin_ori, cout_ori, groups);
return GRAPH_FAILED;
}
const int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16;
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
groups);
int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
int64_t cout_opt = DivCeil(e_mult * cout_ori, kCubeN) * kCubeN;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = DivCeil(groups, e_mult);
int64_t dim_cin = cin_opt / cube_k;
int64_t data_size = GetCubeSizeByDataType(args.src_data_type);
int64_t size_output_data =
g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size;
GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data);
return SUCCESS;);
errno_t ret = EOK;
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);
ret = memset_s(dst.get(), size_output_data, 0, size_output_data);
if (ret != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory , error-code %d, ret %d",
ret);
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
for (int64_t g = 0; g < groups; g++) {
for (int64_t d = 0; d < kDim; d++) {
for (int64_t c = 0; c < c_dim; c++) {
for (int64_t h = 0; h < h_dim; h++) {
for (int64_t w = 0; w < w_dim; w++) {
for (int64_t n = 0; n < cout_ori; n++) {
int64_t e_val = g % e_mult;
int64_t dst_ci = e_val * cin_ori + c;
int64_t dst_co = e_val * cout_ori + n;
int64_t src_co = g * cout_ori + n;
int64_t tempory = dst_ci % cube_k;
int64_t srx_inx = 0;
int64_t dst_inx =
(g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt *
cube_k +
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k +
h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k +
dst_co * cube_k + tempory;
srx_inx = d * h_dim * w_dim * c_dim * n_dim +
h * w_dim * c_dim * n_dim + w * c_dim * n_dim +
c * n_dim + src_co;
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size);
const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
}
}
}
}
}
result.data = dst;
result.length = static_cast<size_t>(size_output_data);
return SUCCESS;
}
Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
int64_t h = args.src_shape[kHwcnH]; int64_t h = args.src_shape[kHwcnH];
int64_t w = args.src_shape[kHwcnW]; int64_t w = args.src_shape[kHwcnW];
@@ -215,8 +368,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr, dst == nullptr,
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;); return ACL_ERROR_GE_MEMORY_ALLOCATION;);
@@ -238,8 +390,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
static_cast<size_t>(data_size)); static_cast<size_t>(data_size));
} else { } else {
if (protected_size < data_size) { if (protected_size < data_size) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"Failed to operate the dst memory, protected_size is %ld and size is %ld",
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size); protected_size, data_size);
return ACL_ERROR_GE_PARAM_INVALID; return ACL_ERROR_GE_PARAM_INVALID;
} }
@@ -251,8 +402,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
} }
} }
if (ret != EOK) { if (ret != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,
"Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
dst_offset, ret, pad_zero); dst_offset, ret, pad_zero);
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
} }
@@ -293,8 +443,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr, dst == nullptr,
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;); return ACL_ERROR_GE_MEMORY_ALLOCATION;);
@@ -316,8 +465,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
static_cast<size_t>(data_size)); static_cast<size_t>(data_size));
} else { } else {
if (protected_size < data_size) { if (protected_size < data_size) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"Failed to operate the dst memory, protected_size is %ld and size is %ld",
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size); protected_size, data_size);
return ACL_ERROR_GE_PARAM_INVALID; return ACL_ERROR_GE_PARAM_INVALID;
} }
@@ -329,8 +477,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
} }
} }
if (ret != EOK) { if (ret != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,
"Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
dst_offset, ret, pad_zero); dst_offset, ret, pad_zero);
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
} }
@@ -363,15 +510,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatNhwcToFz(args, result); return TransFormatNhwcToFz(args, result);
} }

if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) {
if ((args.src_format == FORMAT_HWCN) && (GetPrimaryFormat(args.dst_format) == FORMAT_FRACTAL_Z)) {
if (GetSubFormat(args.dst_format) >= 1) {
return TransFormatHwcnToFzWithGroups(args, result, GetSubFormat(args.dst_format));
}
return TransFormatHwcnToFz(args, result); return TransFormatHwcnToFz(args, result);
} }


if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result); return TransFormatFromNchwToFz(args, result);
} }

return ACL_ERROR_GE_FORMAT_INVALID; return ACL_ERROR_GE_FORMAT_INVALID;
} }


@@ -384,7 +532,10 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) {
return TransShapeNhwcToFz(src_shape, data_type, dst_shape); return TransShapeNhwcToFz(src_shape, data_type, dst_shape);
} }
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) {
if ((src_format == FORMAT_HWCN) && (GetPrimaryFormat(dst_format) == FORMAT_FRACTAL_Z)) {
if (GetSubFormat(dst_format) >= 1) {
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, GetSubFormat(dst_format));
}
return TransShapeHwcnToFz(src_shape, data_type, dst_shape); return TransShapeHwcnToFz(src_shape, data_type, dst_shape);
} }
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) {


+ 61
- 0
tests/ut/ge/common/format_transfer_hwcn_fractalz_unittest.cc View File

@@ -34427,6 +34427,40 @@ TEST_F(UtestFormatTransferHwcnFz, fp32_2c_2n_pad) {
} }
} }


TEST_F(UtestFormatTransferHwcnFz, fp16_1c_1n_with_groups) {
uint16_t data[1 * 1 * 1 * 2] = {19, 88};
uint16_t ret[1 * 1 * 16 * 16] ={19 , 0, 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 88, 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,
0 , 0 , 0, 0 ,0 , 0, 0, 0 , 0 , 0 , 0, 0, 0 , 0 , 0, 0,};
FormatTransferFractalZ transfer;
ge::Format old_format = FORMAT_FRACTAL_Z;
int32_t groups = 2;
ge::Format new_format = static_cast<ge::Foramt>(ge::GetFormatFromSub(old_format, groups));
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, new_format, std::vector<int64_t>({1, 1, 1, 2}),
std::vector<int64_t>({1, 1, 16, 16}), DT_FLOAT16};

TransResult result;
EXPECT_EQ(transfer.TransFormat(args, result), SUCCESS);
EXPECT_EQ(result.length, sizeof(ret) / sizeof(ret[0]) * 2);
for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(result.data.get()))[i], ret[i]);
}
}

TEST_F(UtestFormatTransferHwcnFz, build_transfer_fp32) { TEST_F(UtestFormatTransferHwcnFz, build_transfer_fp32) {
float data[5 * 5 * 31 * 17]; float data[5 * 5 * 31 * 17];
TransArgs args{ TransArgs args{
@@ -34454,6 +34488,24 @@ TEST_F(UtestFormatTransferHwcnFz, build_transfer_int8) {
EXPECT_NE(transfer, nullptr); EXPECT_NE(transfer, nullptr);
} }


TEST_F(UtestFormatTransferHwcnFz, build_transfer_int8) {
int8_t data[4 * 4 * 3 * 1];
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, std::vector<int64_t>({4, 4, 3, 1}),
std::vector<int64_t>({16, 1, 16, 32}), DT_INT8};
auto transfer = BuildFormatTransfer(args);
EXPECT_NE(transfer, nullptr);
}

TEST_F(UtestFormatTransferHwcnFz, build_transfer_int8) {
int8_t data[4 * 4 * 3 * 1];
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, std::vector<int64_t>({4, 4, 3, 1}),
std::vector<int64_t>({16, 1, 16, 32}), DT_INT8};
auto transfer = BuildFormatTransfer(args);
EXPECT_NE(transfer, nullptr);
}

TEST_F(UtestFormatTransferHwcnFz, build_transfer_not_support) { TEST_F(UtestFormatTransferHwcnFz, build_transfer_not_support) {
float data[50 * 2 * 16 * 16]; float data[50 * 2 * 16 * 16];
TransArgs args{ TransArgs args{
@@ -34462,5 +34514,14 @@ TEST_F(UtestFormatTransferHwcnFz, build_transfer_not_support) {
auto transfer = BuildFormatTransfer(args); auto transfer = BuildFormatTransfer(args);
EXPECT_EQ(transfer, nullptr); EXPECT_EQ(transfer, nullptr);
} }

TEST_F(UtestFormatTransferHwcnFz, build_transfer_int8_with_groups) {
int8_t data[4 * 4 * 3 * 1];
TransArgs args{
reinterpret_cast<uint8_t *>(data), FORMAT_HWCN, FORMAT_FRACTAL_Z, std::vector<int64_t>({4, 4, 3, 1}),
std::vector<int64_t>({16, 1, 16, 32}), DT_INT8};
auto transfer = BuildFormatTransfer(args);
EXPECT_NE(transfer, nullptr);
}
} // namespace formats } // namespace formats
} // namespace ge } // namespace ge

Loading…
Cancel
Save