From e1a405b030987cf030b0a174ba1e26137fef8ae6 Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 17 Nov 2020 11:05:42 +0800 Subject: [PATCH] Feature:single op support Huawei private format enter --- ge/offline/single_op_parser.cc | 67 +++++++++++++++++++++++++++++----- ge/offline/single_op_parser.h | 2 + 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index 72f742e9..d30e2e8f 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -44,9 +44,11 @@ constexpr char const *kKeyAttr = "attr"; constexpr char const *kKeyName = "name"; constexpr char const *kKeyType = "type"; constexpr char const *kKeyShape = "shape"; +constexpr char const *kKeyOriginShape = "origin_shape"; constexpr char const *kKeyShapeRange = "shape_range"; constexpr char const *kKeyValue = "value"; constexpr char const *kKeyFormat = "format"; +constexpr char const *kKeyOriginFormat = "origin_format"; constexpr char const *kFileSuffix = ".om"; constexpr char const *kKeyDynamicInput = "dynamic_input"; constexpr char const *kKeyDynamicOutput = "dynamic_output"; @@ -90,9 +92,42 @@ map kFormatDict = { {"nchw", FORMAT_NCHW}, {"nhwc", FORMAT_NHWC}, {"nd", FORMAT_ND}, - {"fractal_nz", FORMAT_FRACTAL_NZ}, - {"fractal_z", FORMAT_FRACTAL_Z}, {"nc1hwc0", FORMAT_NC1HWC0}, + {"fractal_z", FORMAT_FRACTAL_Z}, + {"nc1c0hwpad", FORMAT_NC1C0HWPAD}, + {"nhwc1c0", FORMAT_NHWC1C0}, + {"fsr_nchw", FORMAT_FSR_NCHW}, + {"fractal_deconv", FORMAT_FRACTAL_DECONV}, + {"c1hwnc0", FORMAT_C1HWNC0}, + {"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE}, + {"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, + {"nc1hwc0_c04", FORMAT_NC1HWC0_C04}, + {"fractal_z_c04", FORMAT_FRACTAL_Z_C04}, + {"chwn", FORMAT_CHWN}, + {"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {"nc1khkwhwc0", FORMAT_NC1KHKWHWC0}, + {"bn_weight", FORMAT_BN_WEIGHT}, + {"filter_hwck", FORMAT_FILTER_HWCK}, + {"hwcn", FORMAT_HWCN}, + {"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, + {"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS}, + {"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE}, + {"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, + {"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS}, + {"md", FORMAT_MD}, + {"c1hwncoc0", FORMAT_C1HWNCoC0}, + {"fractal_nz", FORMAT_FRACTAL_NZ}, + {"ndhwc", FORMAT_NDHWC}, + {"ncdhw", FORMAT_NCDHW}, + {"dhwcn", FORMAT_DHWCN}, + {"dhwnc", FORMAT_DHWNC}, + {"ndc1hwc0", FORMAT_NDC1HWC0}, + {"fractal_z_3d", FORMAT_FRACTAL_Z_3D}, + {"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, + {"cn", FORMAT_CN}, + {"nc", FORMAT_NC}, + {"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM}, + {"fractal_z_g", FORMAT_FRACTAL_Z_G} }; } @@ -118,10 +153,19 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { if (it != j.end()) { desc.dim_ranges = j.at(kKeyShapeRange).get>>(); } + it = j.find(kKeyOriginShape); + if (it != j.end()) { + desc.ori_dims = j.at(kKeyOriginShape).get>(); + } string format_str = j.at(kKeyFormat).get(); string type_str = j.at(kKeyType).get(); desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); + it = j.find(kKeyOriginFormat); + if (it != j.end()) { + string origin_format_str = j.at(kKeyOriginFormat).get(); + desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); + } auto tensor_name = j.find(kKeyName); if (tensor_name != j.end()) { desc.name = tensor_name->get(); @@ -303,10 +347,7 @@ Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param) { auto op_desc = CreateOpDesc(single_op_desc.op); - if (op_desc == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); - return MEMALLOC_FAILED; - } + GE_CHECK_NOTNULL(op_desc); std::stringstream file_name; file_name << index; @@ -319,9 +360,12 @@ Status SingleOpParser::ConvertToBuildParam(int index, GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); + auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; + auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; + ge_tensor_desc.SetOriginFormat(ori_format_to_set); + ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); + TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); TensorUtils::SetInputTensor(ge_tensor_desc, true); TensorUtils::SetOutputTensor(ge_tensor_desc, false); if (desc.name.empty()) { @@ -341,9 +385,12 @@ Status SingleOpParser::ConvertToBuildParam(int index, GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); + auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; + auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; + ge_tensor_desc.SetOriginFormat(ori_format_to_set); + ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); + TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); TensorUtils::SetInputTensor(ge_tensor_desc, false); TensorUtils::SetOutputTensor(ge_tensor_desc, true); if (desc.name.empty()) { diff --git a/ge/offline/single_op_parser.h b/ge/offline/single_op_parser.h index 7e30fa4b..19879a32 100644 --- a/ge/offline/single_op_parser.h +++ b/ge/offline/single_op_parser.h @@ -30,8 +30,10 @@ namespace ge { struct SingleOpTensorDesc { std::string name; std::vector dims; + std::vector ori_dims; std::vector> dim_ranges; ge::Format format = ge::FORMAT_RESERVED; + ge::Format ori_format = ge::FORMAT_RESERVED; ge::DataType type = ge::DT_UNDEFINED; std::string dynamic_input_name; };