|
|
@@ -44,9 +44,11 @@ constexpr char const *kKeyAttr = "attr"; |
|
|
|
constexpr char const *kKeyName = "name"; |
|
|
|
constexpr char const *kKeyType = "type"; |
|
|
|
constexpr char const *kKeyShape = "shape"; |
|
|
|
constexpr char const *kKeyOriginShape = "origin_shape"; |
|
|
|
constexpr char const *kKeyShapeRange = "shape_range"; |
|
|
|
constexpr char const *kKeyValue = "value"; |
|
|
|
constexpr char const *kKeyFormat = "format"; |
|
|
|
constexpr char const *kKeyOriginFormat = "origin_format"; |
|
|
|
constexpr char const *kFileSuffix = ".om"; |
|
|
|
constexpr char const *kKeyDynamicInput = "dynamic_input"; |
|
|
|
constexpr char const *kKeyDynamicOutput = "dynamic_output"; |
|
|
@@ -90,9 +92,42 @@ map<string, Format> kFormatDict = { |
|
|
|
{"nchw", FORMAT_NCHW}, |
|
|
|
{"nhwc", FORMAT_NHWC}, |
|
|
|
{"nd", FORMAT_ND}, |
|
|
|
{"fractal_nz", FORMAT_FRACTAL_NZ}, |
|
|
|
{"fractal_z", FORMAT_FRACTAL_Z}, |
|
|
|
{"nc1hwc0", FORMAT_NC1HWC0}, |
|
|
|
{"fractal_z", FORMAT_FRACTAL_Z}, |
|
|
|
{"nc1c0hwpad", FORMAT_NC1C0HWPAD}, |
|
|
|
{"nhwc1c0", FORMAT_NHWC1C0}, |
|
|
|
{"fsr_nchw", FORMAT_FSR_NCHW}, |
|
|
|
{"fractal_deconv", FORMAT_FRACTAL_DECONV}, |
|
|
|
{"c1hwnc0", FORMAT_C1HWNC0}, |
|
|
|
{"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE}, |
|
|
|
{"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, |
|
|
|
{"nc1hwc0_c04", FORMAT_NC1HWC0_C04}, |
|
|
|
{"fractal_z_c04", FORMAT_FRACTAL_Z_C04}, |
|
|
|
{"chwn", FORMAT_CHWN}, |
|
|
|
{"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, |
|
|
|
{"nc1khkwhwc0", FORMAT_NC1KHKWHWC0}, |
|
|
|
{"bn_weight", FORMAT_BN_WEIGHT}, |
|
|
|
{"filter_hwck", FORMAT_FILTER_HWCK}, |
|
|
|
{"hwcn", FORMAT_HWCN}, |
|
|
|
{"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, |
|
|
|
{"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS}, |
|
|
|
{"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE}, |
|
|
|
{"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, |
|
|
|
{"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS}, |
|
|
|
{"md", FORMAT_MD}, |
|
|
|
{"c1hwncoc0", FORMAT_C1HWNCoC0}, |
|
|
|
{"fractal_nz", FORMAT_FRACTAL_NZ}, |
|
|
|
{"ndhwc", FORMAT_NDHWC}, |
|
|
|
{"ncdhw", FORMAT_NCDHW}, |
|
|
|
{"dhwcn", FORMAT_DHWCN}, |
|
|
|
{"dhwnc", FORMAT_DHWNC}, |
|
|
|
{"ndc1hwc0", FORMAT_NDC1HWC0}, |
|
|
|
{"fractal_z_3d", FORMAT_FRACTAL_Z_3D}, |
|
|
|
{"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, |
|
|
|
{"cn", FORMAT_CN}, |
|
|
|
{"nc", FORMAT_NC}, |
|
|
|
{"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM}, |
|
|
|
{"fractal_z_g", FORMAT_FRACTAL_Z_G} |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
@@ -118,10 +153,19 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { |
|
|
|
if (it != j.end()) { |
|
|
|
desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>(); |
|
|
|
} |
|
|
|
it = j.find(kKeyOriginShape); |
|
|
|
if (it != j.end()) { |
|
|
|
desc.ori_dims = j.at(kKeyOriginShape).get<vector<int64_t>>(); |
|
|
|
} |
|
|
|
string format_str = j.at(kKeyFormat).get<string>(); |
|
|
|
string type_str = j.at(kKeyType).get<string>(); |
|
|
|
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); |
|
|
|
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); |
|
|
|
it = j.find(kKeyOriginFormat); |
|
|
|
if (it != j.end()) { |
|
|
|
string origin_format_str = j.at(kKeyOriginFormat).get<string>(); |
|
|
|
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); |
|
|
|
} |
|
|
|
auto tensor_name = j.find(kKeyName); |
|
|
|
if (tensor_name != j.end()) { |
|
|
|
desc.name = tensor_name->get<string>(); |
|
|
@@ -303,10 +347,7 @@ Status SingleOpParser::ConvertToBuildParam(int index, |
|
|
|
const SingleOpDesc &single_op_desc, |
|
|
|
SingleOpBuildParam &build_param) { |
|
|
|
auto op_desc = CreateOpDesc(single_op_desc.op); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); |
|
|
|
return MEMALLOC_FAILED; |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
|
|
|
|
std::stringstream file_name; |
|
|
|
file_name << index; |
|
|
@@ -319,9 +360,12 @@ Status SingleOpParser::ConvertToBuildParam(int index, |
|
|
|
GeTensorDesc ge_tensor_desc(GeShape(desc.dims), |
|
|
|
desc.format, |
|
|
|
desc.type); |
|
|
|
ge_tensor_desc.SetOriginFormat(desc.format); |
|
|
|
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; |
|
|
|
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; |
|
|
|
ge_tensor_desc.SetOriginFormat(ori_format_to_set); |
|
|
|
ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); |
|
|
|
TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); |
|
|
|
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); |
|
|
|
TensorUtils::SetInputTensor(ge_tensor_desc, true); |
|
|
|
TensorUtils::SetOutputTensor(ge_tensor_desc, false); |
|
|
|
if (desc.name.empty()) { |
|
|
@@ -341,9 +385,12 @@ Status SingleOpParser::ConvertToBuildParam(int index, |
|
|
|
GeTensorDesc ge_tensor_desc(GeShape(desc.dims), |
|
|
|
desc.format, |
|
|
|
desc.type); |
|
|
|
ge_tensor_desc.SetOriginFormat(desc.format); |
|
|
|
auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format; |
|
|
|
auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims; |
|
|
|
ge_tensor_desc.SetOriginFormat(ori_format_to_set); |
|
|
|
ge_tensor_desc.SetOriginShape(GeShape(ori_dims)); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); |
|
|
|
TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); |
|
|
|
TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size()); |
|
|
|
TensorUtils::SetInputTensor(ge_tensor_desc, false); |
|
|
|
TensorUtils::SetOutputTensor(ge_tensor_desc, true); |
|
|
|
if (desc.name.empty()) { |
|
|
|