Browse Source

Feature:single op support Huawei private format enter

tags/v1.1.0
wxl 3 years ago
parent
commit
e1a405b030
2 changed files with 59 additions and 10 deletions
  1. +57
    -10
      ge/offline/single_op_parser.cc
  2. +2
    -0
      ge/offline/single_op_parser.h

+ 57
- 10
ge/offline/single_op_parser.cc View File

@@ -44,9 +44,11 @@ constexpr char const *kKeyAttr = "attr";
constexpr char const *kKeyName = "name"; constexpr char const *kKeyName = "name";
constexpr char const *kKeyType = "type"; constexpr char const *kKeyType = "type";
constexpr char const *kKeyShape = "shape"; constexpr char const *kKeyShape = "shape";
constexpr char const *kKeyOriginShape = "origin_shape";
constexpr char const *kKeyShapeRange = "shape_range"; constexpr char const *kKeyShapeRange = "shape_range";
constexpr char const *kKeyValue = "value"; constexpr char const *kKeyValue = "value";
constexpr char const *kKeyFormat = "format"; constexpr char const *kKeyFormat = "format";
constexpr char const *kKeyOriginFormat = "origin_format";
constexpr char const *kFileSuffix = ".om"; constexpr char const *kFileSuffix = ".om";
constexpr char const *kKeyDynamicInput = "dynamic_input"; constexpr char const *kKeyDynamicInput = "dynamic_input";
constexpr char const *kKeyDynamicOutput = "dynamic_output"; constexpr char const *kKeyDynamicOutput = "dynamic_output";
@@ -90,9 +92,42 @@ map<string, Format> kFormatDict = {
{"nchw", FORMAT_NCHW}, {"nchw", FORMAT_NCHW},
{"nhwc", FORMAT_NHWC}, {"nhwc", FORMAT_NHWC},
{"nd", FORMAT_ND}, {"nd", FORMAT_ND},
{"fractal_nz", FORMAT_FRACTAL_NZ},
{"fractal_z", FORMAT_FRACTAL_Z},
{"nc1hwc0", FORMAT_NC1HWC0}, {"nc1hwc0", FORMAT_NC1HWC0},
{"fractal_z", FORMAT_FRACTAL_Z},
{"nc1c0hwpad", FORMAT_NC1C0HWPAD},
{"nhwc1c0", FORMAT_NHWC1C0},
{"fsr_nchw", FORMAT_FSR_NCHW},
{"fractal_deconv", FORMAT_FRACTAL_DECONV},
{"c1hwnc0", FORMAT_C1HWNC0},
{"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"nc1hwc0_c04", FORMAT_NC1HWC0_C04},
{"fractal_z_c04", FORMAT_FRACTAL_Z_C04},
{"chwn", FORMAT_CHWN},
{"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"nc1khkwhwc0", FORMAT_NC1KHKWHWC0},
{"bn_weight", FORMAT_BN_WEIGHT},
{"filter_hwck", FORMAT_FILTER_HWCK},
{"hwcn", FORMAT_HWCN},
{"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS},
{"md", FORMAT_MD},
{"c1hwncoc0", FORMAT_C1HWNCoC0},
{"fractal_nz", FORMAT_FRACTAL_NZ},
{"ndhwc", FORMAT_NDHWC},
{"ncdhw", FORMAT_NCDHW},
{"dhwcn", FORMAT_DHWCN},
{"dhwnc", FORMAT_DHWNC},
{"ndc1hwc0", FORMAT_NDC1HWC0},
{"fractal_z_3d", FORMAT_FRACTAL_Z_3D},
{"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
{"cn", FORMAT_CN},
{"nc", FORMAT_NC},
{"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM},
{"fractal_z_g", FORMAT_FRACTAL_Z_G}
}; };
} }


@@ -118,10 +153,19 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (it != j.end()) { if (it != j.end()) {
desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>(); desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>();
} }
it = j.find(kKeyOriginShape);
if (it != j.end()) {
desc.ori_dims = j.at(kKeyOriginShape).get<vector<int64_t>>();
}
string format_str = j.at(kKeyFormat).get<string>(); string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>(); string type_str = j.at(kKeyType).get<string>();
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
it = j.find(kKeyOriginFormat);
if (it != j.end()) {
string origin_format_str = j.at(kKeyOriginFormat).get<string>();
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
}
auto tensor_name = j.find(kKeyName); auto tensor_name = j.find(kKeyName);
if (tensor_name != j.end()) { if (tensor_name != j.end()) {
desc.name = tensor_name->get<string>(); desc.name = tensor_name->get<string>();
@@ -303,10 +347,7 @@ Status SingleOpParser::ConvertToBuildParam(int index,
const SingleOpDesc &single_op_desc, const SingleOpDesc &single_op_desc,
SingleOpBuildParam &build_param) { SingleOpBuildParam &build_param) {
auto op_desc = CreateOpDesc(single_op_desc.op); auto op_desc = CreateOpDesc(single_op_desc.op);
if (op_desc == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc");
return MEMALLOC_FAILED;
}
GE_CHECK_NOTNULL(op_desc);


std::stringstream file_name; std::stringstream file_name;
file_name << index; file_name << index;
@@ -319,9 +360,12 @@ Status SingleOpParser::ConvertToBuildParam(int index,
GeTensorDesc ge_tensor_desc(GeShape(desc.dims), GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format, desc.format,
desc.type); desc.type);
ge_tensor_desc.SetOriginFormat(desc.format);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size());
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, true); TensorUtils::SetInputTensor(ge_tensor_desc, true);
TensorUtils::SetOutputTensor(ge_tensor_desc, false); TensorUtils::SetOutputTensor(ge_tensor_desc, false);
if (desc.name.empty()) { if (desc.name.empty()) {
@@ -341,9 +385,12 @@ Status SingleOpParser::ConvertToBuildParam(int index,
GeTensorDesc ge_tensor_desc(GeShape(desc.dims), GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
desc.format, desc.format,
desc.type); desc.type);
ge_tensor_desc.SetOriginFormat(desc.format);
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
ge_tensor_desc.SetOriginFormat(ori_format_to_set);
ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size());
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
TensorUtils::SetInputTensor(ge_tensor_desc, false); TensorUtils::SetInputTensor(ge_tensor_desc, false);
TensorUtils::SetOutputTensor(ge_tensor_desc, true); TensorUtils::SetOutputTensor(ge_tensor_desc, true);
if (desc.name.empty()) { if (desc.name.empty()) {


+ 2
- 0
ge/offline/single_op_parser.h View File

@@ -30,8 +30,10 @@ namespace ge {
struct SingleOpTensorDesc { struct SingleOpTensorDesc {
std::string name; std::string name;
std::vector<int64_t> dims; std::vector<int64_t> dims;
std::vector<int64_t> ori_dims;
std::vector<std::vector<int64_t>> dim_ranges; std::vector<std::vector<int64_t>> dim_ranges;
ge::Format format = ge::FORMAT_RESERVED; ge::Format format = ge::FORMAT_RESERVED;
ge::Format ori_format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED; ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name; std::string dynamic_input_name;
}; };


Loading…
Cancel
Save