|
@@ -29,6 +29,39 @@ |
|
|
namespace ge { |
|
|
namespace ge { |
|
|
namespace formats { |
|
|
namespace formats { |
|
|
namespace { |
|
|
namespace { |
|
|
|
|
|
constexpr int64_t kCubeN = 16; |
|
|
|
|
|
constexpr int64_t kDim = 1; |
|
|
|
|
|
|
|
|
|
|
|
static int64_t Measure(int64_t x, int64_t y) { |
|
|
|
|
|
int64_t z = y; |
|
|
|
|
|
while (x % y != 0) { |
|
|
|
|
|
z = x % y; |
|
|
|
|
|
x = y; |
|
|
|
|
|
y = z; |
|
|
|
|
|
} |
|
|
|
|
|
return z; |
|
|
|
|
|
} |
|
|
|
|
|
// least common multiple |
|
|
|
|
|
static int64_t Lcm(int64_t a, int64_t b) { |
|
|
|
|
|
if (b == 0) { |
|
|
|
|
|
return -1; |
|
|
|
|
|
} |
|
|
|
|
|
int64_t temp = (a * b) / (Measure(a, b)); |
|
|
|
|
|
return temp; |
|
|
|
|
|
} |
|
|
|
|
|
// get the result of two number divisor and let result round up |
|
|
|
|
|
static int64_t DivCeil(int64_t a, int64_t b) { |
|
|
|
|
|
if (b == 0) { |
|
|
|
|
|
return -1; |
|
|
|
|
|
} else { |
|
|
|
|
|
int64_t ret = a / b; |
|
|
|
|
|
if ((a % b) != 0) { |
|
|
|
|
|
ret++; |
|
|
|
|
|
} |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } |
|
|
Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } |
|
|
|
|
|
|
|
|
/** |
|
|
/** |
|
@@ -61,6 +94,35 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape |
|
|
|
|
|
, int64_t groups) { |
|
|
|
|
|
auto c0 = GetCubeSizeByDataType(data_type); |
|
|
|
|
|
if (c0 < 0) { |
|
|
|
|
|
return ACL_ERROR_GE_DATATYPE_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
int64_t cin_ori = c; |
|
|
|
|
|
int64_t cout_ori = n / groups; |
|
|
|
|
|
int64_t cube_k = data_type == DT_INT8 ? 32 : 16; |
|
|
|
|
|
int64_t e_mult = std::min( |
|
|
|
|
|
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), |
|
|
|
|
|
groups); |
|
|
|
|
|
int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k; |
|
|
|
|
|
int64_t c1_dim = cin_opt / cube_k; |
|
|
|
|
|
int64_t g_dim = DivCeil(groups, e_mult); |
|
|
|
|
|
auto n1 = DivCeil(cout_ori * e_mult, kCubeN); |
|
|
|
|
|
dst_shape.clear(); |
|
|
|
|
|
dst_shape.push_back(g_dim * c1_dim * h * w); |
|
|
|
|
|
dst_shape.push_back(n1); |
|
|
|
|
|
dst_shape.push_back(16); |
|
|
|
|
|
dst_shape.push_back(cube_k); |
|
|
|
|
|
if (!IsShapeValid(dst_shape)) { |
|
|
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { |
|
|
Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { |
|
|
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { |
|
|
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { |
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
@@ -82,10 +144,24 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t |
|
|
auto w = src_shape.at(kHwcnW); |
|
|
auto w = src_shape.at(kHwcnW); |
|
|
auto c = src_shape.at(kHwcnC); |
|
|
auto c = src_shape.at(kHwcnC); |
|
|
auto n = src_shape.at(kHwcnN); |
|
|
auto n = src_shape.at(kHwcnN); |
|
|
|
|
|
|
|
|
return TransShapeToFz(n, c, h, w, data_type, dst_shape); |
|
|
return TransShapeToFz(n, c, h, w, data_type, dst_shape); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape |
|
|
|
|
|
, int64_t groups){ |
|
|
|
|
|
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { |
|
|
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto h = src_shape.at(kHwcnH); |
|
|
|
|
|
auto w = src_shape.at(kHwcnW); |
|
|
|
|
|
auto c = src_shape.at(kHwcnC); |
|
|
|
|
|
auto n = src_shape.at(kHwcnN); |
|
|
|
|
|
|
|
|
|
|
|
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { |
|
|
Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { |
|
|
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { |
|
|
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { |
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
@@ -127,8 +203,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
dst == nullptr, |
|
|
dst == nullptr, |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, |
|
|
|
|
|
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
@@ -174,8 +249,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (ret != EOK) { |
|
|
if (ret != EOK) { |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, |
|
|
|
|
|
"Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, |
|
|
ret, need_pad_zero); |
|
|
ret, need_pad_zero); |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
} |
|
|
} |
|
@@ -189,6 +263,85 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, int64_t groups){ |
|
|
|
|
|
int64_t h_dim = args.src_shape[kHwcnH]; |
|
|
|
|
|
int64_t w_dim = args.src_shape[kHwcnW]; |
|
|
|
|
|
int64_t c_dim = args.src_shape[kHwcnC]; |
|
|
|
|
|
int64_t n_dim = args.src_shape[kHwcnN]; |
|
|
|
|
|
int64_t cin_ori = c_dim; |
|
|
|
|
|
int64_t cout_ori = n_dim / groups; |
|
|
|
|
|
if (cin_ori == 0 || cout_ori == 0) { |
|
|
|
|
|
GELOGE(GRAPH_FAILED, |
|
|
|
|
|
"Cin_ori, cout_ori must not be equal 0, " |
|
|
|
|
|
"and current cin_ori, cout_ori, groups are %d %d %d", |
|
|
|
|
|
cin_ori, cout_ori, groups); |
|
|
|
|
|
return GRAPH_FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
const int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; |
|
|
|
|
|
int64_t e_mult = std::min( |
|
|
|
|
|
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), |
|
|
|
|
|
groups); |
|
|
|
|
|
int64_t cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k; |
|
|
|
|
|
int64_t cout_opt = DivCeil(e_mult * cout_ori, kCubeN) * kCubeN; |
|
|
|
|
|
int64_t c1_dim = cin_opt / cube_k; |
|
|
|
|
|
int64_t g_dim = DivCeil(groups, e_mult); |
|
|
|
|
|
int64_t dim_cin = cin_opt / cube_k; |
|
|
|
|
|
int64_t data_size = GetCubeSizeByDataType(args.src_data_type); |
|
|
|
|
|
int64_t size_output_data = |
|
|
|
|
|
g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; |
|
|
|
|
|
GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data); |
|
|
|
|
|
return SUCCESS;); |
|
|
|
|
|
errno_t ret = EOK; |
|
|
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>()); |
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
|
|
|
dst == nullptr, |
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
|
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
|
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); |
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
|
|
|
|
ret = memset_s(dst.get(), size_output_data, 0, size_output_data); |
|
|
|
|
|
if (ret != EOK) { |
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory , error-code %d, ret %d", |
|
|
|
|
|
ret); |
|
|
|
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
for (int64_t g = 0; g < groups; g++) { |
|
|
|
|
|
for (int64_t d = 0; d < kDim; d++) { |
|
|
|
|
|
for (int64_t c = 0; c < c_dim; c++) { |
|
|
|
|
|
for (int64_t h = 0; h < h_dim; h++) { |
|
|
|
|
|
for (int64_t w = 0; w < w_dim; w++) { |
|
|
|
|
|
for (int64_t n = 0; n < cout_ori; n++) { |
|
|
|
|
|
int64_t e_val = g % e_mult; |
|
|
|
|
|
int64_t dst_ci = e_val * cin_ori + c; |
|
|
|
|
|
int64_t dst_co = e_val * cout_ori + n; |
|
|
|
|
|
int64_t src_co = g * cout_ori + n; |
|
|
|
|
|
int64_t tempory = dst_ci % cube_k; |
|
|
|
|
|
int64_t srx_inx = 0; |
|
|
|
|
|
int64_t dst_inx = |
|
|
|
|
|
(g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt * |
|
|
|
|
|
cube_k + |
|
|
|
|
|
d * c1_dim * h_dim * w_dim * cout_opt * cube_k + |
|
|
|
|
|
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + |
|
|
|
|
|
h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k + |
|
|
|
|
|
dst_co * cube_k + tempory; |
|
|
|
|
|
srx_inx = d * h_dim * w_dim * c_dim * n_dim + |
|
|
|
|
|
h * w_dim * c_dim * n_dim + w * c_dim * n_dim + |
|
|
|
|
|
c * n_dim + src_co; |
|
|
|
|
|
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size); |
|
|
|
|
|
const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size); |
|
|
|
|
|
for (int64_t index = 0; index < data_size; index++) { |
|
|
|
|
|
*dst_data++ = *src_data++; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
result.data = dst; |
|
|
|
|
|
result.length = static_cast<size_t>(size_output_data); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { |
|
|
Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { |
|
|
int64_t h = args.src_shape[kHwcnH]; |
|
|
int64_t h = args.src_shape[kHwcnH]; |
|
|
int64_t w = args.src_shape[kHwcnW]; |
|
|
int64_t w = args.src_shape[kHwcnW]; |
|
@@ -215,8 +368,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
dst == nullptr, |
|
|
dst == nullptr, |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, |
|
|
|
|
|
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
@@ -238,8 +390,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { |
|
|
static_cast<size_t>(data_size)); |
|
|
static_cast<size_t>(data_size)); |
|
|
} else { |
|
|
} else { |
|
|
if (protected_size < data_size) { |
|
|
if (protected_size < data_size) { |
|
|
GELOGE(ACL_ERROR_GE_PARAM_INVALID, |
|
|
|
|
|
"Failed to operate the dst memory, protected_size is %ld and size is %ld", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", |
|
|
protected_size, data_size); |
|
|
protected_size, data_size); |
|
|
return ACL_ERROR_GE_PARAM_INVALID; |
|
|
return ACL_ERROR_GE_PARAM_INVALID; |
|
|
} |
|
|
} |
|
@@ -251,8 +402,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (ret != EOK) { |
|
|
if (ret != EOK) { |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, |
|
|
|
|
|
"Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", |
|
|
dst_offset, ret, pad_zero); |
|
|
dst_offset, ret, pad_zero); |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
} |
|
|
} |
|
@@ -293,8 +443,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( |
|
|
dst == nullptr, |
|
|
dst == nullptr, |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, |
|
|
|
|
|
"Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.src_format).c_str(), |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
|
return ACL_ERROR_GE_MEMORY_ALLOCATION;); |
|
@@ -316,8 +465,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { |
|
|
static_cast<size_t>(data_size)); |
|
|
static_cast<size_t>(data_size)); |
|
|
} else { |
|
|
} else { |
|
|
if (protected_size < data_size) { |
|
|
if (protected_size < data_size) { |
|
|
GELOGE(ACL_ERROR_GE_PARAM_INVALID, |
|
|
|
|
|
"Failed to operate the dst memory, protected_size is %ld and size is %ld", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld", |
|
|
protected_size, data_size); |
|
|
protected_size, data_size); |
|
|
return ACL_ERROR_GE_PARAM_INVALID; |
|
|
return ACL_ERROR_GE_PARAM_INVALID; |
|
|
} |
|
|
} |
|
@@ -329,8 +477,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (ret != EOK) { |
|
|
if (ret != EOK) { |
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, |
|
|
|
|
|
"Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", |
|
|
|
|
|
|
|
|
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", |
|
|
dst_offset, ret, pad_zero); |
|
|
dst_offset, ret, pad_zero); |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; |
|
|
} |
|
|
} |
|
@@ -363,15 +510,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r |
|
|
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
return TransFormatNhwcToFz(args, result); |
|
|
return TransFormatNhwcToFz(args, result); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
|
|
|
|
|
if ((args.src_format == FORMAT_HWCN) && (GetPrimaryFormat(args.dst_format) == FORMAT_FRACTAL_Z)) { |
|
|
|
|
|
if (GetSubFormat(args.dst_format) >= 1) { |
|
|
|
|
|
return TransFormatHwcnToFzWithGroups(args, result, GetSubFormat(args.dst_format)); |
|
|
|
|
|
} |
|
|
return TransFormatHwcnToFz(args, result); |
|
|
return TransFormatHwcnToFz(args, result); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { |
|
|
return TransFormatFromNchwToFz(args, result); |
|
|
return TransFormatFromNchwToFz(args, result); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return ACL_ERROR_GE_FORMAT_INVALID; |
|
|
return ACL_ERROR_GE_FORMAT_INVALID; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -384,7 +532,10 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i |
|
|
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { |
|
|
if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { |
|
|
return TransShapeNhwcToFz(src_shape, data_type, dst_shape); |
|
|
return TransShapeNhwcToFz(src_shape, data_type, dst_shape); |
|
|
} |
|
|
} |
|
|
if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { |
|
|
|
|
|
|
|
|
if ((src_format == FORMAT_HWCN) && (GetPrimaryFormat(dst_format) == FORMAT_FRACTAL_Z)) { |
|
|
|
|
|
if (GetSubFormat(dst_format) >= 1) { |
|
|
|
|
|
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, GetSubFormat(dst_format)); |
|
|
|
|
|
} |
|
|
return TransShapeHwcnToFz(src_shape, data_type, dst_shape); |
|
|
return TransShapeHwcnToFz(src_shape, data_type, dst_shape); |
|
|
} |
|
|
} |
|
|
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { |
|
|
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { |
|
|