Browse Source

!15594 fix data trans error

From: @liubuyu
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
pull/15594/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
00ea69108c
4 changed files with 15 additions and 34 deletions
  1. +10
    -7
      mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
  2. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc
  3. +3
    -24
      mindspore/ccsrc/common/trans.cc
  4. +0
    -2
      mindspore/ccsrc/common/trans.h

+ 10
- 7
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc View File

@@ -92,10 +92,11 @@ void KernelQueryAll(const CNodePtr &kernel_node,
HostMetadataInfo(kernel_node, kernel_info_list);
}
if (kernel_info_list->empty()) {
MS_EXCEPTION(NotExistsError)
<< "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:"
<< kernel_node->fullname_with_scope() << "Node Type: " << op_name
<< ", Node DebugString: " << kernel_node->DebugString() << "\n trace: " << trace::DumpSourceLines(kernel_node);
MS_EXCEPTION(NotExistsError) << "Can not find any available operator info for op [" << op_name << ", "
<< kernel_node->fullname_with_scope()
<< "]. Node DebugString:" << kernel_node->DebugString()
<< ", maybe the operator can not supported on current platform. \n trace "
<< trace::DumpSourceLines(kernel_node);
}
}

@@ -121,9 +122,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
}

if (kernel_info_list->empty()) {
MS_EXCEPTION(NotExistsError)
<< "Failed to obtain operator info. Please check whether the operator info is registered, Op full name:"
<< kernel_node->fullname_with_scope() << ". Node DebugString: " << kernel_node->DebugString();
MS_EXCEPTION(NotExistsError) << "Can not find any available operator info for op ["
<< AnfAlgo::GetCNodeName(kernel_node) << ", " << kernel_node->fullname_with_scope()
<< "]. Node DebugString:" << kernel_node->DebugString()
<< ", maybe the operator can not supported on current platform. \n trace "
<< trace::DumpSourceLines(kernel_node);
}
// check output
FilterInvalidKernelInfo(kernel_node, kernel_info_list);


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc View File

@@ -537,7 +537,8 @@ string TbeKernelJsonCreator::GetSocVersion() {
}
if (soc_version_env != nullptr) {
if (std::strcmp(soc_version, soc_version_env) != 0) {
MS_LOG(WARNING) << "SocVerison will be change.";
MS_LOG(DEBUG) << "Detected the env SOC_VERSION, so the SocVersion will be changed to " << str_soc_version_env
<< ".";
ret = rtSetSocVersion(soc_version_env);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "SetSocVersion failed, errorno: " << ret;


+ 3
- 24
mindspore/ccsrc/common/trans.cc View File

@@ -186,18 +186,6 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const
}
}

size_t CubeSizeByType(const TypeId data_type) {
const size_t default_error = 0;
auto dt_size = abstract::TypeIdSize(data_type);
if (dt_size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return default_error;
} else if (dt_size == 1) {
return kCubeSize * 2;
}
return kCubeSize;
}

namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != kNchwDims) {
@@ -780,12 +768,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];

auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
const size_t c0 = 16;
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto chw = c * hw;
@@ -1109,11 +1092,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c0 = 16;
if (args.device_format == kOpFormat_NC1HWC0_C04) {
c0 = 4;
}
@@ -1412,7 +1391,7 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
auto w = args.host_shape[4];

auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
auto c0 = CubeSizeByType(args.src_data_type);
const size_t c0 = 16;
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto dhw = d * hw;


+ 0
- 2
mindspore/ccsrc/common/trans.h View File

@@ -55,8 +55,6 @@ struct FormatArgs {
TypeId src_data_type;
};

size_t CubeSizeByType(const TypeId data_type);

std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
const std::string &pad_index = {""});
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_axis = {""});


Loading…
Cancel
Save