/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! * \file axis_util.h * \brief get the axis value */ #ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ #define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ #include #include #include #include "external/graph/ge_error_codes.h" #include "external/graph/types.h" #include "framework/common/debug/ge_log.h" namespace common { namespace transformer { const int32_t DIM_DEFAULT_SIZE = 4; const uint32_t NCHW_DIMENSION_NUM = 4; const int32_t AXIS_NCHW_DIM_N = 0; const int32_t AXIS_NCHW_DIM_C = 1; const int32_t AXIS_NCHW_DIM_H = 2; const int32_t AXIS_NCHW_DIM_W = 3; const int32_t AXIS_NHWC_DIM_N = 0; const int32_t AXIS_NHWC_DIM_H = 1; const int32_t AXIS_NHWC_DIM_W = 2; const int32_t AXIS_NHWC_DIM_C = 3; const int32_t AXIS_NC1HWC0_DIM_N = 0; const int32_t AXIS_NC1HWC0_DIM_C1 = 1; const int32_t AXIS_NC1HWC0_DIM_C0 = 4; const int32_t AXIS_NC1HWC0_DIM_H = 2; const int32_t AXIS_NC1HWC0_DIM_W = 3; const int32_t AXIS_HWCN_DIM_H = 0; const int32_t AXIS_HWCN_DIM_W = 1; const int32_t AXIS_HWCN_DIM_C = 2; const int32_t AXIS_HWCN_DIM_N = 3; const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; const int32_t AXIS_C1HWNCoC0_DIM_H = 1; const int32_t AXIS_C1HWNCoC0_DIM_W = 2; const int32_t AXIS_C1HWNCoC0_DIM_N = 3; const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; const int32_t NDHWC_DIM_N = 0; const int32_t NDHWC_DIM_D = 1; const int32_t NDHWC_DIM_H = 2; const int32_t NDHWC_DIM_W = 3; const int32_t NDHWC_DIM_C = 4; const int32_t NCDHW_DIM_N = 0; const int32_t NCDHW_DIM_C = 1; const int32_t NCDHW_DIM_D = 2; const int32_t NCDHW_DIM_H = 3; const int32_t NCDHW_DIM_W = 4; const int32_t DHWCN_DIM_D = 0; const int32_t DHWCN_DIM_H = 1; const int32_t DHWCN_DIM_W = 2; const int32_t DHWCN_DIM_C = 3; const int32_t DHWCN_DIM_N = 4; const int32_t DHWNC_DIM_D = 0; const int32_t DHWNC_DIM_H = 1; const int32_t DHWNC_DIM_W = 2; const int32_t DHWNC_DIM_N = 3; const int32_t DHWNC_DIM_C = 4; #define CHECK_NOTNULL(val) \ do { \ if ((val) == nullptr) { \ GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \ return false; \ } \ } while (0) #define CHECK(cond, log_func, return_expr) \ do { \ if (cond) { \ log_func; \ return_expr; \ } \ } while (0) enum AxisValueType { AXIS_N = 0, AXIS_C = 1, AXIS_H = 2, AXIS_W = 3, AXIS_C1 = 4, AXIS_C0 = 5, AXIS_Co = 6, AXIS_D = 7, AXIS_BOTTOM = 8 }; int64_t DivisionCeiling(int64_t dividend, int64_t divisor); /* Axis value is arranged as {N,C,H,W,C1,C0,...} */ /* The first parameter is old shape's dimension, * second is c0 and third is axis value. */ using GetAxisValueInfoByFormat = std::function&, const uint32_t&, std::vector&, std::vector&)>; using GetAxisValueInfoByFormatPtr = std::shared_ptr; class AxisUtil { public: AxisUtil(); ~AxisUtil(){}; bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector& dimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); bool HasAxisValueFunc(const ge::Format& format); private: static bool CheckParams(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByNCHW(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByNHWC(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByNC1HWC0(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByFz(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByHWCN(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByND(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByC1HWNCoC0(const std::vector& originalDimVec, const uint32_t& c0, std::vector& axisValue, std::vector& ndValue); static bool GetAxisValueByNDHWC(const std::vector& original_dim_vec, const uint32_t& c0, std::vector& axis_value, std::vector& nd_value); static bool GetAxisValueByNCDHW(const std::vector& original_dim_vec, const uint32_t& c0, std::vector& axis_value, std::vector& nd_value); static bool GetAxisValueByDHWCN(const std::vector& original_dim_vec, const uint32_t& c0, std::vector& axis_value, std::vector& nd_value); static bool GetAxisValueByDHWNC(const std::vector& original_dim_vec, const uint32_t& c0, std::vector& axis_value, std::vector& nd_value); /* map of GetAxisValueInfoByFormat, get axis value by different original * formats. */ std::map getAxisValueFuncMap; }; } // namespace transformer } // namespace common #endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_