| @@ -94,6 +94,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| break; | |||||
| default: | default: | ||||
| auto size = src_shape.size(); | auto size = src_shape.size(); | ||||
| int64_t times = 1; | int64_t times = 1; | ||||
| @@ -116,6 +117,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | return ACL_ERROR_GE_SHAPE_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| @@ -49,19 +49,24 @@ const size_t kFZzDimCountBackwardsW0H0W1H1 = 4; | |||||
| bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } | bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } | ||||
| using ShapeVector = std::vector<int64_t>; | using ShapeVector = std::vector<int64_t>; | ||||
| bool ret1 = false; | |||||
| bool CheckShape(Format format, const ShapeVector &shape) { | bool CheckShape(Format format, const ShapeVector &shape) { | ||||
| switch (format) { | switch (format) { | ||||
| case FORMAT_ND: | case FORMAT_ND: | ||||
| return IsShapeValid(shape); | |||||
| ret1 = IsShapeValid(shape); | |||||
| break; | |||||
| case FORMAT_NCHW: | case FORMAT_NCHW: | ||||
| case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
| return CheckShapeValid(shape, kDimSize4D); | |||||
| ret1 = CheckShapeValid(shape, kDimSize4D); | |||||
| break; | |||||
| default: | default: | ||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | ||||
| " and FORMAT_FRACTAL_ZZ is not supported."; | " and FORMAT_FRACTAL_ZZ is not supported."; | ||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | ||||
| return false; | |||||
| ret1 = false; | |||||
| break; | |||||
| } | } | ||||
| return ret1; | |||||
| } | } | ||||
| /** | /** | ||||
| @@ -76,6 +81,7 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| hw_shape.clear(); | hw_shape.clear(); | ||||
| auto w0 = GetCubeSizeByDataType(data_type); | auto w0 = GetCubeSizeByDataType(data_type); | ||||
| auto h0 = GetCubeSizeByDataType(data_type); | auto h0 = GetCubeSizeByDataType(data_type); | ||||
| auto ret2 = SUCCESS; | |||||
| switch (src_shape.size()) { | switch (src_shape.size()) { | ||||
| case kSingleDim: | case kSingleDim: | ||||
| dst_shape.push_back(DIM_DEFAULT_VALUE); | dst_shape.push_back(DIM_DEFAULT_VALUE); | ||||
| @@ -90,9 +96,10 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | ||||
| ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| ret2 = ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | } | ||||
| return SUCCESS; | |||||
| ret2 = SUCCESS; | |||||
| break; | |||||
| default: | default: | ||||
| auto size = src_shape.size(); | auto size = src_shape.size(); | ||||
| int64_t times = 1; | int64_t times = 1; | ||||
| @@ -112,10 +119,12 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | ||||
| ShapeToString(dst_shape).c_str()); | ShapeToString(dst_shape).c_str()); | ||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| ret2 = ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | } | ||||
| return SUCCESS; | |||||
| ret2 = SUCCESS; | |||||
| break; | |||||
| } | } | ||||
| return ret2; | |||||
| } | } | ||||
| Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | ||||
| @@ -576,10 +576,12 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||||
| uint16_t s_a, s_b; | uint16_t s_a, s_b; | ||||
| int16_t e_a, e_b; | int16_t e_a, e_b; | ||||
| uint32_t m_a, m_b; | uint32_t m_a, m_b; | ||||
| uint16_t s_ret, m_ret; | |||||
| uint16_t s_ret; | |||||
| uint16_t m_ret; | |||||
| int16_t e_ret; | int16_t e_ret; | ||||
| uint32_t mul_m; | uint32_t mul_m; | ||||
| uint16_t m_a_tmp, m_b_tmp; | |||||
| uint16_t m_a_tmp; | |||||
| uint16_t m_b_tmp; | |||||
| // 1.Extract | // 1.Extract | ||||
| ExtractFp16(v_1, s_a, e_a, m_a_tmp); | ExtractFp16(v_1, s_a, e_a, m_a_tmp); | ||||
| ExtractFp16(v_2, s_b, e_b, m_b_tmp); | ExtractFp16(v_2, s_b, e_b, m_b_tmp); | ||||
| @@ -635,7 +637,8 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { | |||||
| uint16_t ret; | uint16_t ret; | ||||
| if (FP16_IS_ZERO(v_2)) { // result is inf | if (FP16_IS_ZERO(v_2)) { // result is inf | ||||
| // throw "fp16_t division by zero."; | // throw "fp16_t division by zero."; | ||||
| uint16_t s_a, s_b; | |||||
| uint16_t s_a; | |||||
| uint16_t s_b; | |||||
| uint16_t s_ret; | uint16_t s_ret; | ||||
| s_a = FP16_EXTRAC_SIGN(v_1); | s_a = FP16_EXTRAC_SIGN(v_1); | ||||
| s_b = FP16_EXTRAC_SIGN(v_2); | s_b = FP16_EXTRAC_SIGN(v_2); | ||||
| @@ -644,11 +647,15 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { | |||||
| } else if (FP16_IS_ZERO(v_1)) { | } else if (FP16_IS_ZERO(v_1)) { | ||||
| ret = 0u; | ret = 0u; | ||||
| } else { | } else { | ||||
| uint16_t s_a, s_b; | |||||
| int16_t e_a, e_b; | |||||
| uint64_t m_a, m_b; | |||||
| uint16_t s_a; | |||||
| uint16_t s_b; | |||||
| int16_t e_a; | |||||
| int16_t e_b; | |||||
| uint64_t m_a; | |||||
| uint64_t m_b; | |||||
| float m_div; | float m_div; | ||||
| uint16_t m_a_tmp, m_b_tmp; | |||||
| uint16_t m_a_tmp; | |||||
| uint16_t m_b_tmp; | |||||
| // 1.Extract | // 1.Extract | ||||
| ExtractFp16(v_1, s_a, e_a, m_a_tmp); | ExtractFp16(v_1, s_a, e_a, m_a_tmp); | ||||
| ExtractFp16(v_2, s_b, e_b, m_b_tmp); | ExtractFp16(v_2, s_b, e_b, m_b_tmp); | ||||
| @@ -742,9 +749,12 @@ bool fp16_t::operator!=(const fp16_t &fp) const { | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool fp16_t::operator>(const fp16_t &fp) const { | bool fp16_t::operator>(const fp16_t &fp) const { | ||||
| uint16_t s_a, s_b; | |||||
| uint16_t e_a, e_b; | |||||
| uint16_t m_a, m_b; | |||||
| uint16_t s_a; | |||||
| uint16_t s_b; | |||||
| uint16_t e_a; | |||||
| uint16_t e_b; | |||||
| uint16_t m_a; | |||||
| uint16_t m_b; | |||||
| bool result = true; | bool result = true; | ||||
| // 1.Extract | // 1.Extract | ||||
| @@ -823,9 +833,11 @@ fp16_t &fp16_t::operator=(const fp16_t &fp) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| fp16_t &fp16_t::operator=(const float &f_val) { | fp16_t &fp16_t::operator=(const float &f_val) { | ||||
| uint16_t s_ret, m_ret; | |||||
| uint16_t s_ret; | |||||
| uint16_t m_ret; | |||||
| int16_t e_ret; | int16_t e_ret; | ||||
| uint32_t e_f, m_f; | |||||
| uint32_t e_f; | |||||
| uint32_t m_f; | |||||
| const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man | const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man | ||||
| uint32_t m_len_delta; | uint32_t m_len_delta; | ||||
| @@ -874,7 +886,9 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| fp16_t &fp16_t::operator=(const int8_t &i_val) { | fp16_t &fp16_t::operator=(const int8_t &i_val) { | ||||
| uint16_t s_ret, e_ret, m_ret; | |||||
| uint16_t s_ret; | |||||
| uint16_t e_ret; | |||||
| uint16_t m_ret; | |||||
| s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> kDim7); | s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> kDim7); | ||||
| m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max)); | m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max)); | ||||
| @@ -898,7 +912,9 @@ fp16_t &fp16_t::operator=(const int8_t &i_val) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| fp16_t &fp16_t::operator=(const uint8_t &ui_val) { | fp16_t &fp16_t::operator=(const uint8_t &ui_val) { | ||||
| uint16_t s_ret, e_ret, m_ret; | |||||
| uint16_t s_ret; | |||||
| uint16_t e_ret; | |||||
| uint16_t m_ret; | |||||
| s_ret = 0; | s_ret = 0; | ||||
| e_ret = 0; | e_ret = 0; | ||||
| m_ret = ui_val; | m_ret = ui_val; | ||||
| @@ -345,7 +345,8 @@ Status OpUtils::SetOutputSliceData(void *data, int64_t data_size, int32_t data_t | |||||
| break; | break; | ||||
| default: | default: | ||||
| GELOGW("Unsupported data type: %s", TypeUtils::DataTypeToSerialString(static_cast<DataType>(data_type)).c_str()); | GELOGW("Unsupported data type: %s", TypeUtils::DataTypeToSerialString(static_cast<DataType>(data_type)).c_str()); | ||||
| return PARAM_INVALID; | |||||
| ret = PARAM_INVALID; | |||||
| break; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -198,6 +198,7 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, | |||||
| GELOGW("data type %s not support.", | GELOGW("data type %s not support.", | ||||
| TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| @@ -423,6 +423,7 @@ Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDesc | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Get mem_type:%d for offset:%ld is unsupported, check invalid", | GELOGE(PARAM_INVALID, "[Check][Param] Get mem_type:%d for offset:%ld is unsupported, check invalid", | ||||
| mem_type, offset); | mem_type, offset); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| break; | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(var_addr); | GE_CHECK_NOTNULL(var_addr); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -247,10 +247,13 @@ MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) { | |||||
| switch (mem_type) { | switch (mem_type) { | ||||
| case RT_MEMORY_HBM: | case RT_MEMORY_HBM: | ||||
| return new (std::nothrow) HbmMemResource(); | return new (std::nothrow) HbmMemResource(); | ||||
| break; | |||||
| case RT_MEMORY_RDMA_HBM: | case RT_MEMORY_RDMA_HBM: | ||||
| return new (std::nothrow) RdmaMemResource(); | return new (std::nothrow) RdmaMemResource(); | ||||
| break; | |||||
| default: | default: | ||||
| return nullptr; | return nullptr; | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| @@ -189,6 +189,7 @@ Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGe | |||||
| default: | default: | ||||
| GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -124,6 +124,7 @@ Status EmptyKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<Const | |||||
| default: | default: | ||||
| GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -115,6 +115,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge | |||||
| default: | default: | ||||
| GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| @@ -244,6 +244,7 @@ Status FloorDivKernel::ComputeByDataType(DataType data_type, const std::vector<C | |||||
| default: | default: | ||||
| GELOGW("FloorDivKernel does not support Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGW("FloorDivKernel does not support Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -58,6 +58,7 @@ Status CheckYIsZero(T const &y, DataType &type) { | |||||
| break; | break; | ||||
| default: | default: | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| break; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -210,6 +210,7 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x | |||||
| default: | default: | ||||
| GELOGI("Only support 4 dims and below but input axis is %ld", axis); | GELOGI("Only support 4 dims and below but input axis is %ld", axis); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -269,6 +270,7 @@ Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPt | |||||
| default: | default: | ||||
| GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -96,6 +96,7 @@ Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr | |||||
| default: | default: | ||||
| GELOGW("Input data type must be FP16, FP32 and DOUBLE."); | GELOGW("Input data type must be FP16, FP32 and DOUBLE."); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_size) != GRAPH_SUCCESS, | GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_size) != GRAPH_SUCCESS, | ||||
| @@ -136,6 +137,7 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<Const | |||||
| default: | default: | ||||
| GELOGW("Input data type must be FP16, FP32 and DOUBLE."); | GELOGW("Input data type must be FP16, FP32 and DOUBLE."); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| break; | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("Rsqrt folding failed."); | GELOGW("Rsqrt folding failed."); | ||||
| @@ -221,7 +221,9 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr | |||||
| // handle new_axis_mask | // handle new_axis_mask | ||||
| ExpandDimsWithNewAxis(begin_tensor, x_dims_num, x_dims); | ExpandDimsWithNewAxis(begin_tensor, x_dims_num, x_dims); | ||||
| vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec; | |||||
| vector<int64_t> orig_begin_vec; | |||||
| vector<int64_t> orig_end_vec; | |||||
| vector<int64_t> orig_stride_vec; | |||||
| GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec); | GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec); | ||||
| // calculate begin_mask & end_mask by ellipsis_mask | // calculate begin_mask & end_mask by ellipsis_mask | ||||
| ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec); | ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec); | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "graph/runtime_inference_context.h" | #include "graph/runtime_inference_context.h" | ||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "hybrid/executor//worker//shape_inference_engine.h" | |||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | |||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -80,6 +80,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde | |||||
| default: { | default: { | ||||
| GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
| #include "graph//types.h" | |||||
| #include "graph/types.h" | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| namespace ge { | namespace ge { | ||||