| @@ -29,9 +29,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| constexpr int64_t kCubeN = 16; | |||||
| constexpr int64_t kDim = 1; | constexpr int64_t kDim = 1; | ||||
| constexpr int64_t kCubeN = 16; | |||||
| static int64_t Measure(int64_t x, int64_t y) { | static int64_t Measure(int64_t x, int64_t y) { | ||||
| int64_t z = y; | int64_t z = y; | ||||
| while (x % y != 0) { | while (x % y != 0) { | ||||
| @@ -266,7 +265,7 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, | |||||
| "groups are %ld %ld %ld",cin_ori, cout_ori, groups); | "groups are %ld %ld %ld",cin_ori, cout_ori, groups); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| const int64_t cube_k = GetCubeSizeByDataType(data_type); | |||||
| const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t e_mult = std::min( | int64_t e_mult = std::min( | ||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | ||||
| groups); | groups); | ||||
| @@ -277,16 +276,18 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, | |||||
| int64_t dim_cin = cin_opt / cube_k; | int64_t dim_cin = cin_opt / cube_k; | ||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | int64_t data_size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t size_output_data = g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; | int64_t size_output_data = g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; | ||||
| GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS;); | |||||
| if(size_output_data == 0){ | |||||
| result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS; | |||||
| } | |||||
| errno_t ret = EOK; | errno_t ret = EOK; | ||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>()); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| dst == nullptr, | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| ret = memset_s(dst.get(), static_cast<size_t>(size_output_data), 0, static_cast<size_t>(size_output_data)); | ret = memset_s(dst.get(), static_cast<size_t>(size_output_data), 0, static_cast<size_t>(size_output_data)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory, ret is %d", ret); | GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory, ret is %d", ret); | ||||