|
|
|
@@ -186,18 +186,6 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t CubeSizeByType(const TypeId data_type) { |
|
|
|
const size_t default_error = 0; |
|
|
|
auto dt_size = abstract::TypeIdSize(data_type); |
|
|
|
if (dt_size < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return default_error; |
|
|
|
} else if (dt_size == 1) { |
|
|
|
return kCubeSize * 2; |
|
|
|
} |
|
|
|
return kCubeSize; |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
bool CheckDims(const std::vector<size_t> &shape) { |
|
|
|
if (shape.size() != kNchwDims) { |
|
|
|
@@ -780,12 +768,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { |
|
|
|
auto c = args.host_shape[kC]; |
|
|
|
auto h = args.host_shape[kH]; |
|
|
|
auto w = args.host_shape[kW]; |
|
|
|
|
|
|
|
auto c0 = CubeSizeByType(args.src_data_type); |
|
|
|
if (c0 < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
const size_t c0 = 16; |
|
|
|
auto c1 = DivCeil(c, c0); |
|
|
|
auto hw = h * w; |
|
|
|
auto chw = c * hw; |
|
|
|
@@ -1109,11 +1092,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { |
|
|
|
auto c = args.host_shape[kC]; |
|
|
|
auto h = args.host_shape[kH]; |
|
|
|
auto w = args.host_shape[kW]; |
|
|
|
auto c0 = CubeSizeByType(args.src_data_type); |
|
|
|
if (c0 < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t c0 = 16; |
|
|
|
if (args.device_format == kOpFormat_NC1HWC0_C04) { |
|
|
|
c0 = 4; |
|
|
|
} |
|
|
|
@@ -1412,7 +1391,7 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) { |
|
|
|
auto w = args.host_shape[4]; |
|
|
|
|
|
|
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; |
|
|
|
auto c0 = CubeSizeByType(args.src_data_type); |
|
|
|
const size_t c0 = 16; |
|
|
|
auto c1 = DivCeil(c, c0); |
|
|
|
auto hw = h * w; |
|
|
|
auto dhw = d * hw; |
|
|
|
|