| @@ -1,26 +0,0 @@ | |||||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| **What type of PR is this?** | |||||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > /kind bug | |||||
| > /kind task | |||||
| > /kind feature | |||||
| **What does this PR do / why do we need it**: | |||||
| **Which issue(s) this PR fixes**: | |||||
| <!-- | |||||
| *Automatically closes linked issue when PR is merged. | |||||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||||
| --> | |||||
| Fixes # | |||||
| **Special notes for your reviewers**: | |||||
| @@ -1,19 +0,0 @@ | |||||
| --- | |||||
| name: RFC | |||||
| about: Use this template for the new feature or enhancement | |||||
| labels: kind/feature or kind/enhancement | |||||
| --- | |||||
| ## Background | |||||
| - Describe the status of the problem you wish to solve | |||||
| - Attach the relevant issue if have | |||||
| ## Introduction | |||||
| - Describe the general solution, design and/or pseudo-code | |||||
| ## Trail | |||||
| | No. | Task Description | Related Issue(URL) | | |||||
| | --- | ---------------- | ------------------ | | |||||
| | 1 | | | | |||||
| | 2 | | | | |||||
| @@ -1,43 +0,0 @@ | |||||
| --- | |||||
| name: Bug Report | |||||
| about: Use this template for reporting a bug | |||||
| labels: kind/bug | |||||
| --- | |||||
| <!-- Thanks for sending an issue! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| ## Environment | |||||
| ### Hardware Environment(`Ascend`/`GPU`/`CPU`): | |||||
| > Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > `/device ascend`</br> | |||||
| > `/device gpu`</br> | |||||
| > `/device cpu`</br> | |||||
| ### Software Environment: | |||||
| - **MindSpore version (source or binary)**: | |||||
| - **Python version (e.g., Python 3.7.5)**: | |||||
| - **OS platform and distribution (e.g., Linux Ubuntu 16.04)**: | |||||
| - **GCC/Compiler version (if compiled from source)**: | |||||
| ## Describe the current behavior | |||||
| ## Describe the expected behavior | |||||
| ## Steps to reproduce the issue | |||||
| 1. | |||||
| 2. | |||||
| 3. | |||||
| ## Related log / screenshot | |||||
| ## Special notes for this issue | |||||
| @@ -1,19 +0,0 @@ | |||||
| --- | |||||
| name: Task | |||||
| about: Use this template for task tracking | |||||
| labels: kind/task | |||||
| --- | |||||
| ## Task Description | |||||
| ## Task Goal | |||||
| ## Sub Task | |||||
| | No. | Task Description | Issue ID | | |||||
| | --- | ---------------- | -------- | | |||||
| | 1 | | | | |||||
| | 2 | | | | |||||
| @@ -1,24 +0,0 @@ | |||||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| **What type of PR is this?** | |||||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > `/kind bug`</br> | |||||
| > `/kind task`</br> | |||||
| > `/kind feature`</br> | |||||
| **What does this PR do / why do we need it**: | |||||
| **Which issue(s) this PR fixes**: | |||||
| <!-- | |||||
| *Automatically closes linked issue when PR is merged. | |||||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||||
| --> | |||||
| Fixes # | |||||
| **Special notes for your reviewers**: | |||||
| @@ -1,30 +1,29 @@ | |||||
| # GraphEngine | |||||
| /build | |||||
| /output | |||||
| /prebuilts | |||||
| /cov | |||||
| *.ir | |||||
| *.out | |||||
| # Dynamic libraries | |||||
| # *.so | |||||
| *.dylib | |||||
| # Static libraries | |||||
| *.la | |||||
| *.lai | |||||
| *.a | |||||
| *.lib | |||||
| # Protocol buffers | |||||
| *_pb2.py | |||||
| *.pb.h | |||||
| *.pb.cc | |||||
| # Object files | |||||
| *.o | |||||
| # Editor | |||||
| .vscode | |||||
| .idea/ | |||||
| cmake-build-* | |||||
| # GraphEngine | |||||
| /build | |||||
| /output | |||||
| /prebuilts | |||||
| *.ir | |||||
| *.out | |||||
| # Dynamic libraries | |||||
| # *.so | |||||
| *.dylib | |||||
| # Static libraries | |||||
| *.la | |||||
| *.lai | |||||
| *.a | |||||
| *.lib | |||||
| # Protocol buffers | |||||
| *_pb2.py | |||||
| *.pb.h | |||||
| *.pb.cc | |||||
| # Object files | |||||
| *.o | |||||
| # Editor | |||||
| .vscode | |||||
| .idea/ | |||||
| cmake-build-* | |||||
| @@ -1,8 +0,0 @@ | |||||
| [submodule "parser"] | |||||
| path = parser | |||||
| url = https://gitee.com/ascend/parser.git | |||||
| branch = master | |||||
| [submodule "metadef"] | |||||
| path = metadef | |||||
| url = https://gitee.com/ascend/metadef.git | |||||
| branch = master | |||||
| @@ -1,168 +1,136 @@ | |||||
| # Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| cmake_minimum_required(VERSION 3.14) | cmake_minimum_required(VERSION 3.14) | ||||
| project (GraphEngine[CXX]) | project (GraphEngine[CXX]) | ||||
| set(CMAKE_CXX_STANDARD 14) | |||||
| set(GE_CODE_DIR ${CMAKE_CURRENT_LIST_DIR}) | |||||
| set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) | |||||
| set(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}) | |||||
| set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) | |||||
| set(GE_PROTO_DIR ${GE_SOURCE_DIR}/src) | |||||
| if (NOT BUILD_PATH) | if (NOT BUILD_PATH) | ||||
| set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | ||||
| endif() | endif() | ||||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||||
| else() | |||||
| set(ASCEND_DIR /usr/local/Ascend) | |||||
| # architecture: aarch64 or x86_64 | |||||
| message(STATUS "System architecture: ${CMAKE_HOST_SYSTEM_PROCESSOR}") | |||||
| # system: euleros or ubuntu | |||||
| if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") | |||||
| execute_process( | |||||
| COMMAND bash "-c" "cat /etc/os-release | grep ^ID= | awk -F '=' '{print $2}'" | |||||
| OUTPUT_VARIABLE SYSTEM_TYPE | |||||
| ) | |||||
| MESSAGE(STATUS "System type: ${SYSTEM_TYPE}.") | |||||
| endif() | endif() | ||||
| if(DEFINED ENV{D_PKG_SERVER}) | |||||
| set(GE_PB_PKG $ENV{D_PKG_SERVER}) | |||||
| message("Download packages from DPKG server") | |||||
| elseif(DEFINED ENV{MSLIBS_SERVER}) | |||||
| set(GE_PB_PKG "http://$ENV{MSLIBS_SERVER}:8081") | |||||
| message("Download packages from MSPKG server") | |||||
| endif () | |||||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||||
| set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||||
| set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||||
| set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) | |||||
| set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) | |||||
| set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) | |||||
| set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) | |||||
| set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) | |||||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | |||||
| if (ENABLE_OPEN_SRC) | |||||
| set(HI_PYTHON python3) | |||||
| include(cmake/external_libs/protobuf_shared.cmake) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/gflags.cmake) | |||||
| include(cmake/external_libs/gtest.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
| if(DEFINED ENV{D_LINK_PATH}) | |||||
| # D_LINK_PATH is set | |||||
| set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
| set(GE_SYS_ARCH "") | |||||
| if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
| # x86 ubuntu | |||||
| set(GE_SYS_ARCH "x86_64") | |||||
| elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
| # arm euleros | |||||
| set(GE_SYS_ARCH "aarch64") | |||||
| else() | |||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
| endif() | |||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
| find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||||
| find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||||
| find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||||
| find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||||
| find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
| find_module(resource libresource.so ${GE_LIB_PATH}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
| elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| add_subdirectory(tests) | |||||
| # download json headers, rather than whole repository | |||||
| include(${GE_SOURCE_DIR}/cmake/ge_utils.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/json.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) | |||||
| set(CMAKE_SKIP_RPATH TRUE) | |||||
| # for CPU/GPU mode, find c_sec and slog from local prebuild | |||||
| if(NOT ENABLE_D AND NOT GE_ONLY) | |||||
| set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) | |||||
| find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH}) | |||||
| find_library(slog libslog.so ${GE_PREBUILD_PATH}) | |||||
| # if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
| elseif(DEFINED ENV{D_LINK_PATH}) | |||||
| # D_LINK_PATH is set | |||||
| set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
| set(GE_SYS_ARCH "") | |||||
| if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
| # x86 ubuntu | |||||
| set(GE_SYS_ARCH "x86_64") | |||||
| elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
| # arm euleros | |||||
| set(GE_SYS_ARCH "aarch64") | |||||
| else() | else() | ||||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||||
| if(PLATFORM STREQUAL "train") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| if(PRODUCT STREQUAL "flr3") | |||||
| message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||||
| endif() | |||||
| elseif(PLATFORM STREQUAL "inference") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| if(PRODUCT STREQUAL "flr3") | |||||
| elseif(PRODUCT STREQUAL "flr1") | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| elseif(PRODUCT STREQUAL "flr2") | |||||
| # flr2 ascend_hal_stub limsprof ? | |||||
| else() | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| endif() | |||||
| elseif(PLATFORM STREQUAL "all") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| else() | |||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| endif() | |||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
| endif() | endif() | ||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
| find_library(c_sec libc_sec.so ${GE_LIB_PATH}) | |||||
| find_library(slog libslog.so ${GE_LIB_PATH}) | |||||
| find_library(mmpa libmmpa.so ${GE_LIB_PATH}) | |||||
| find_library(runtime libruntime.so ${GE_LIB_PATH}) | |||||
| find_library(msprof libmsprof.so ${GE_LIB_PATH}) | |||||
| find_library(register libregister.so ${GE_LIB_PATH}) | |||||
| find_library(hccl libhccl.so ${GE_LIB_PATH}) | |||||
| find_library(cce libcce.so ${GE_LIB_PATH}) | |||||
| find_library(resource libresource.so ${GE_LIB_PATH}) | |||||
| else() | |||||
| # Ascend mode | |||||
| set(HIAI_INSTALLED_DIR /usr/local/HiAI) | |||||
| # set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64/common) | |||||
| # set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/fwkacllib/lib64) | |||||
| set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64) | |||||
| set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/runtime/lib64) | |||||
| find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR}) | |||||
| find_library(slog libslog.so ${HIAI_DRIVER_DIR}) | |||||
| find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR}) | |||||
| find_library(cce libcce.so ${HIAI_RUNTIME_DIR}) | |||||
| find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR}) | |||||
| find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR}) | |||||
| find_library(msprof libmsprof.so ${HIAI_RUNTIME_DIR}) | |||||
| find_library(register libregister.so ${HIAI_RUNTIME_DIR}) | |||||
| find_library(resource libresource.so ${HIAI_RUNTIME_DIR}) | |||||
| endif() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| add_subdirectory(metadef) | |||||
| add_subdirectory(parser) | |||||
| #add_subdirectory(metadef/graph) | |||||
| #add_subdirectory(metadef/register) | |||||
| elseif (ENABLE_D OR ENABLE_ACL) | |||||
| # compiling with MindSpore | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| if (ENABLE_D) | |||||
| # training | |||||
| find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| endif () | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| elseif(ENABLE_MS_TESTCASES) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # add compile flags | |||||
| include(CheckCXXCompilerFlag) | |||||
| check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) | |||||
| if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") | |||||
| message("Build in Debug mode") | |||||
| set(CMAKE_C_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_C_FLAGS}") | |||||
| set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_CXX_FLAGS}") | |||||
| if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") | |||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic") | |||||
| endif() | |||||
| else() | |||||
| set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_C_FLAGS}") | |||||
| set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_CXX_FLAGS}") | |||||
| endif () | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| # force __FILE__ to show relative path of file, from source directory, as cmake project makes __FILE__ absolute directory | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") | |||||
| # compile libraries from following directories | |||||
| # libgraph is compiled in any situation | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) | |||||
| if(ENABLE_D) | |||||
| # if MindSpore compiles in D mode, compile the following libraries | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) | |||||
| elseif(GE_ONLY) | |||||
| # standalone GraphEngine compiles all following libraries | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_local_engine) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/graph/build/memory) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/executor) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/client) | |||||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) | |||||
| endif() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| else() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||||
| add_subdirectory(tests) | |||||
| endif() | endif() | ||||
| add_subdirectory(ge) | |||||
| @@ -1,5 +1,3 @@ | |||||
| [查看中文](./README_CN.md) | |||||
| GraphEngine(GE) is a sub-module of MindSpore connecting the front end and devices which was designed by the researches and engineers within Huawei Technologies Co.,Ltd. GE is implemented via C++. It takes the graph of front end as its input and a series of graph operations are carried out to adapt the graph to a certain form which can be effectively operated on devices. GE is specifically designed for an efficient operation on Ascend Chips. GE is automatically called without any exposure to the users. GE mainly consists of two parts, i.e. GE API and GE Core. The architecture diagram of GE is illustrated as follows | GraphEngine(GE) is a sub-module of MindSpore connecting the front end and devices which was designed by the researches and engineers within Huawei Technologies Co.,Ltd. GE is implemented via C++. It takes the graph of front end as its input and a series of graph operations are carried out to adapt the graph to a certain form which can be effectively operated on devices. GE is specifically designed for an efficient operation on Ascend Chips. GE is automatically called without any exposure to the users. GE mainly consists of two parts, i.e. GE API and GE Core. The architecture diagram of GE is illustrated as follows | ||||
|  |  | ||||
| @@ -12,21 +10,25 @@ | |||||
| GE Core acts as the core module of GE and is responsible for graph processing operations. It consist of six parts, i.e. graph preparation, graph partition, graph optimization, graph compilation, graph loading and graph execution. These six parts are performed in series and all together complete the complicated graph processing operations. | GE Core acts as the core module of GE and is responsible for graph processing operations. It consist of six parts, i.e. graph preparation, graph partition, graph optimization, graph compilation, graph loading and graph execution. These six parts are performed in series and all together complete the complicated graph processing operations. | ||||
| - Graph preparation & Whole graph optimization | |||||
| - Graph preparation | |||||
| All the shapes of feature maps and variables in the graph are inferred in this stage for memory allocation later. Some aggregations of operators like allreduce are performed as well. | |||||
| All the shapes of feature maps and variables in the graph are inferred in this stage for memory allocation later. Some aggregations of operators like allreduce are performed as well. Ascend Chips are heterogeneous chips including CPUs and vector calculation units, i.e. AICORE. Each operator in the graph is assigned to a certain operating cores according to the costs and supports. These two cores correspond to two different abstract engines in software. | |||||
| - Graph partition | - Graph partition | ||||
| Ascend Chips are heterogeneous chips including CPUs and vector calculation units, i.e. AICORE. Each operator in the graph is assigned to a certain operating cores according to the costs and supports. These two cores correspond to two different abstract engines in software. The whole graph is split into several sub-graphs based on the assigned engine in previous stage. Certain operators are added to the sub-graphs as the marks for graph edges. Such a partition enables an efficient optimization, compilation in next stages. | |||||
| The whole graph is split into several sub-graphs based on the assigned engine in previous stage. Certain operators are added to the sub-graphs as the marks for graph edges. Such a partition enables an efficient optimization, compilation in next stages. | |||||
| - Subgraph optimization | |||||
| - Graph optimization | |||||
| Different optimizer interfaces are called due to different engines that each sub-graph belongs to. To thoroughly utilize the calculation ability of the CUBE module in AICORE, A novel data layout format for faster hardware fetch is applied and the transition between normal 4D to this special format is performed in this stage. Such an operation guarantees less data handling between RAMs and CUBEs. Certain combination of operators is fused into a single big operator to further reduce the computation costs. This fusion is carried out in this stage as well. | Different optimizer interfaces are called due to different engines that each sub-graph belongs to. To thoroughly utilize the calculation ability of the CUBE module in AICORE, A novel data layout format for faster hardware fetch is applied and the transition between normal 4D to this special format is performed in this stage. Such an operation guarantees less data handling between RAMs and CUBEs. Certain combination of operators is fused into a single big operator to further reduce the computation costs. This fusion is carried out in this stage as well. | ||||
| - Graph compilation & Graph loading | |||||
| - Graph compilation | |||||
| This stage can be divided into two parts, i.e. resources allocation and graph compilation. Memory allocation is completed considering memory reuse strategy in resources allocation stage. According to the graph information, the queue, event, stream resources are allocated. Each operator is compiled to a task bound to a certain stream. Tasks on the same stream are performed in series and task on different streams can be executed in parallel. This stream partition is completed in this stage. | |||||
| - Graph loading | |||||
| GraphEngine uses real-time operator compilation technology, i.e. the operator executable program is generated at real time according to the network structure. Meanwhile, Memory allocation is completed considering memory reuse strategy in resources allocation stage. According to the graph information, the queue, event, stream resources are allocated. Each operator is compiled to a task bound to a certain stream. Tasks on the same stream are performed in series and task on different streams can be executed in parallel. In the Graph Loading stage, the operators of graph are assigned to different engines according to the engine information, and the graph is loaded on the devices for running. | |||||
| According to the engine information, the operators of graph are assigned to different engines and in this stage, the graph is loaded on the devices for running. | |||||
| - Graph execution | - Graph execution | ||||
| @@ -44,7 +46,7 @@ | |||||
| ## Installing GraphEngine | ## Installing GraphEngine | ||||
| GE is automatically installed and compiled once you finish installing MindSpore. There are three dynamic link libraries corresponding to GE. | |||||
| GE is automatically installed and compiled once you finish installing MindSpore. A subset of dynamic libraries of GraphEngine could be found inside MindSpore package. | |||||
| ## Installing Using the Source Code | ## Installing Using the Source Code | ||||
| @@ -1,111 +0,0 @@ | |||||
| [View English](./README.md) | |||||
| 图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示。 | |||||
|  | |||||
| - GE API | |||||
| GE API是连接前端模块ME和GE Core的接口,负责GE Core中初始化、Session管理模块的接口,支持运行环境初始化,Session创建、销毁,图添加执行。 | |||||
| - GE Core | |||||
| GE Core是GE的核心模块,负责整个训练过程中的图管理。GE Core中的图处理可细分为六大步骤,分别是图准备、图拆分、图优化、图编译、图加载和图执行,对于ME下发的每一张图都会经过这六个步骤的操作,最终得到可以直接在底层硬件上高效执行的图。 | |||||
| - 图准备 & 整图优化 | |||||
| 完成整图级别的数据准备和优化,涉及到IR库及算子库。使用IR库中算子的InferShape函数,完成整图的Shape推导,以便后续申请内存;同时根据算子的聚合属性,完成某些算子的聚合优化,如allreduce算子,会按照聚合参数,将若干各参数对应梯度的allreduce算子聚合为一个,以此减少通讯耗时。 | |||||
| - 图拆分 | |||||
| 昇腾AI处理器是一种异构芯片,含有CPU(AICPU)和向量计算部件AICORE,图中每个算子会按照开销模型选择执行的核心,此阶段会对算子进行最优的核心分配,每种核心对应软件上的一个抽象引擎;按照之前对各算子的引擎分配,以引擎为边界,将整图拆分为若干子图,在图边界算子上插入相应的Placeholder算子以做标识,之后的优化、编译、加载操作均会以子图为单位进行,这样可以有效减少优化过程的耗时。 | |||||
| - 子图优化 | |||||
| 根据子图所属引擎,调用不同的优化器接口执行优化。为了充分发挥昇腾AI处理器中AICORE模块的算力,在AICORE内CUBE单元进行计算的算子会采用一种5D的数据格式,图优化阶段会对相应算子进行4D/5D的类型转换;为了进一步发挥CUBE单元的算力,减少数据搬运次数,GE会对某种范式的算子连接进行融合操作,此步骤也在图优化阶段进行;对所有子图优化之后,需进行算子运行属性计算,以计算输入输出内存大小。 | |||||
| - 图编译 & 图加载 | |||||
| GE采用即时算子编译技术,即按照实际网络结构即时编译生成算子可执行程序,同时完成内存复用与内存分配、流分配、算子可执行程序加载等。每个算子执行任务绑定到特定的流上,同一个流的任务是串行执行的,不同流上的任务可以并行执行。图加载阶段按照引擎归属的runtime,将子图加载到硬件上准备执行。 | |||||
| - 图执行 | |||||
| 最终在硬件上执行子图,并返回相应的输出值。为了提高运行效率,图执行阶段提供了一种下沉模式,可以在底层硬件上连续运行多轮再返回输出值,以此减少从底层硬件拷贝数据的次数。 | |||||
| 在训练/推理过程中,上述过程会自动执行,通过上述图操作,GE可以将前端下发的图转换为一种可以在昇腾AI处理器上高效运行的图模式。 | |||||
| <!-- TOC --> | |||||
| - [安装说明](#安装说明) | |||||
| - [安装GE](#安装ge) | |||||
| - [源码安装](#源码安装) | |||||
| - [社区](#社区) | |||||
| - [贡献](#贡献) | |||||
| - [Release Notes](#release-notes) | |||||
| - [License](#license) | |||||
| <!-- /TOC --> | |||||
| # 安装说明 | |||||
| ## 安装GE | |||||
| GE内嵌在MindSpore安装包中,MindSpore安装完毕后,GE以三个动态库的方式被调用。 | |||||
| ## 源码安装 | |||||
| GE也支持由源码编译,进行源码编译前,首先确保你有昇腾910 AI处理器的环境,同时系统满足以下要求: | |||||
| - GCC >= 7.3.0 | |||||
| - CMake >= 3.14.0 | |||||
| - Autoconf >= 2.64 | |||||
| - Libtool >= 2.4.6 | |||||
| - Automake >= 1.15.1 | |||||
| 编译完成后会生成几个动态库,他们会链接到MindSpore中执行,无法单独运行。 | |||||
| 1. 下载GE源码。 | |||||
| GE源码托管在码云平台,可由此下载。 | |||||
| ``` | |||||
| git clone https://gitee.com/mindspore/graphengine.git | |||||
| cd graphengine | |||||
| ``` | |||||
| 2. 在GE根目录下执行下列命令即可进行编译。 | |||||
| ``` | |||||
| bash build.sh | |||||
| ``` | |||||
| > - 开始编译之前,请确保正确设置相关的环境变量。 | |||||
| > - 在`build.sh`的脚本中,会进行`git clone`操作,请确保网络连接正常且git配置正确。 | |||||
| > - 在`build.sh`的脚本中,默认会8线程编译,如果机器性能较差,可能会编译失败。可以通过`-j{线程数}`来控制线程数,如`bash build.sh –j4`。 | |||||
| 3. 完成编译后,相应的动态库文件会生成在output文件夹中。 | |||||
| 更多指令帮助,可以使用: | |||||
| ``` | |||||
| bash build.sh –h | |||||
| ``` | |||||
| 如果想清除历史编译记录,可以如下操作: | |||||
| ``` | |||||
| rm -rf build/ output/ | |||||
| bash build.sh | |||||
| ``` | |||||
| ## 社区 | |||||
| - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - 可以提问和找答案。 | |||||
| ## 贡献 | |||||
| 欢迎参与贡献,更多信息详见[Contributor Wiki](https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md)。 | |||||
| ## Release Notes | |||||
| Release Notes请参考[RELEASE](RELEASE.md). | |||||
| ## License | |||||
| [Apache License 2.0](LICENSE) | |||||
| @@ -1,123 +1,3 @@ | |||||
| # Release 1.0.0 | |||||
| ## Major Features and Improvements | |||||
| * Automatically dump the input and output of the abnormal operator when the network execution is abnormal; | |||||
| * Realize dynamic multi-batch based on GotoLabel; | |||||
| * Optimize the performance of dynamic shape; | |||||
| * The dynamic resolution feature supports new scene that the network has multiple inputs and the shape of each input is different. | |||||
| ## Bugfixes | |||||
| * Fixed the issue that the input and output data of the AICPU operator cannot be dumped in the single-operator execution scenario. | |||||
| * Fixed the execution fails in the custom AICPU operator cascading scenario. | |||||
| * Fixed the issue that in the dynamic batch+dynamic AIPP scenario, the getinputformat and getinputdims parameters are inconsistent. | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: wuweikang,wangcong,weiyang,yanghaorang,xutianchun,shibeiji,zhouchao, tanghuikang, zhoulili, liujunzhu, zhengyuanhua, taoxiangdong Contributions of any kind are welcome! | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.7.0-beta | |||||
| ## Major Features and Improvements | |||||
| * Conditional operator memory supports separate allocation of 4G memory space; | |||||
| * In the zero-copy scenario, atomic_clean supports cleaning the memory of each part of the output when the network is multi-output; | |||||
| * Support profiling of multiple levels of data in inference scenarios; | |||||
| * In the online compilation scenarios, GE compilation time optimization. | |||||
| ## Bugfixes | |||||
| * Fix the issue that calculation result is wrong when the unknown subgraph contains conditional operations; | |||||
| * Fix the issue that the hccl executor fails to load the task when the input of hccl operator is unkown shape; | |||||
| * Fix the issue that allgather output is wrong when it exists in the unknown subgraph and its input is unkown shape; | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: wuweikang,wangcong,weiyang,yanghaorang,xutianchun,shibeiji,zhouchao, tanghuikang, zhoulili, liujunzhu, zhengyuanhua, taoxiangdong Contributions of any kind are welcome! | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.6.0-beta | |||||
| ## Major Features and Improvements | |||||
| - GE supports function control operators such as If/Case/While/For. | |||||
| - In a single operator call scenario, GE supports recording the correspondence between operators and tasks for performance commissioning. | |||||
| - GE supports new operator overflow positioning solution. | |||||
| ## Bugfixes | |||||
| - Fix the problem that the aclmdlGetCurOutputDims interface failed to query output Dims in dynamic batch scenarios. | |||||
| - Fix the problem that the operator compilation options (advanced and advanced) cannot be selected. | |||||
| - Fix the problem that zero copy function cannot be performed in the scene of converging conditional operators after Data operators. | |||||
| - Fix the problem that the empty graph cannot be handled. | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: | |||||
| wangcong,weiyang,yanghaorang,xutianchun,shibeiji,zhouchao, tanghuikang, zhoulili, liujunzhu, zhengyuanhua, taoxiangdong | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.5.0-beta | |||||
| ## Major Features and Improvements | |||||
| - Optimize Allreduce trailing parallelism, rebuild the calculation graph dependencies, adjust the calculation order, and maximize the efficiency of calculation and gradient aggregation communication in parallel, especially in large data volume gradient aggregation and low bandwidth/large cluster scenarios You can get a bigger income. | |||||
| - Advance constant folding, variable fusion, conversion operator related optimization pass to the end of the graph preparation. | |||||
| - Modify memory allocation algorithm, optimize GE memory allocation, and reduce memory usage in training multi-PCS scenarios. | |||||
| - Support IR composition, model compilation, inference execution in the same process. | |||||
| ## Bugfixes | |||||
| - Fix the bug that the graphic attribute "output_name_idx_" is not serialized to the GEIR model file, resulting in the failure of the Fast-RCNN network offline inference model generation。 | |||||
| - Introduce timestamp in the dump data storage directory, to ensure that the dump file generated is in a different directory each time it is executed. | |||||
| - Reinforce the ParserJsonFile interface to fix the program coredump bug caused by the injection of abnormal json files. | |||||
| - Fix the bug that Stream binding failure scenario and sream resource leakage. | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: | |||||
| wangcong,weiyang,yanghaorang,xutianchun,shibeiji | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.3.0-alpha | |||||
| ## Major Features and Improvements | |||||
| - It supports dynamic batches and shapes with certain fixed levels.([!22](https://gitee.com/mindspore/graphengine/pulls/22)) | |||||
| - Scope fusion interfaces are opened allowing user defined scope fusion rules.([!24](https://gitee.com/mindspore/graphengine/pulls/22)) | |||||
| - Enhance the maintenance and measurement capability.([!28](https://gitee.com/mindspore/graphengine/pulls/24)) | |||||
| - A package of compiled libraries is generated after compilation to facilitate code deployment.([!21](https://gitee.com/mindspore/graphengine/pulls/21)) | |||||
| ## Bugfixes | |||||
| - Fix the bug that the interface of GE IR construction operator does not support dynamic input in the middle of the ordinary input port.([!24](https://gitee.com/mindspore/graphengine/pulls/24)) | |||||
| - Fix checkpoint subgraph validation and data callback process to resolve the problem that checkpoint could not be generated in some scenarios.([!28](https://gitee.com/mindspore/graphengine/pulls/28)) | |||||
| - When MindSpore is compiled in a directory involving symbolic links, GE log records real path of which code was built when executing testcases using installed whl package.([!16](https://gitee.com/mindspore/graphengine/pulls/16), [!489](https://gitee.com/mindspore/mindspore/pulls/489)) | |||||
| - Find third-party software in specified directories only.([!18](https://gitee.com/mindspore/graphengine/pulls/18)) | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: | |||||
| wangcong,weiyang,yanghaorang,xutianchun,shibeiji | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.2.0-alpha | |||||
| ## Major Features and Improvements | |||||
| - Provides a common graph-level option, and multiple requirements can also share this mechanism in the future. | |||||
| - Improve graph compilation performance. | |||||
| - Optimize memory allocation. | |||||
| - Optimize serveral operators e.g., Slice, StridedSlice, ScatterMax etc. | |||||
| ## Bugfixes | |||||
| - Delete redudant codes.([#I1EU2Z](https://gitee.com/mindspore/graphengine/issues/I1EU2Z)) | |||||
| - Fix HCCL initilization bugs under train and eval scenarios.([#I1DIBJ](https://gitee.com/mindspore/graphengine/issues/I1DIBJ)) | |||||
| - Optimize compilation and linking process, enhancing efficiency and performance of concurrent compilation of GraphEngine and MindSpore. ([#I1DFIY](https://gitee.com/mindspore/mindspore/issues/I1DFIY)) | |||||
| - Fix the bug that GE checkpoint cannot save variable names correctly.([#I1DIBJ](https://gitee.com/mindspore/graphengine/issues/I1DIBJ)) | |||||
| - Save dump files on every iteration instead of every execution.([#I1DIBJ](https://gitee.com/mindspore/graphengine/issues/I1DIBJ)) | |||||
| ## Thanks to our Contributors | |||||
| Thanks goes to these wonderful people: Wang Cong, Tianchun Xu, Haoran Yang. | |||||
| Contributions of any kind are welcome! | |||||
| # Release 0.1.0-alpha | # Release 0.1.0-alpha | ||||
| This is the initial release of GraphEngine(GE) which was designed by the researchers and engineers in Huawei Technologies Co.,Ltd. GE is implemented via C++ and acts as a powerful backing force for MindSpore. GE is a linked up module between MindSpore front end and Ascend Chips. | This is the initial release of GraphEngine(GE) which was designed by the researchers and engineers in Huawei Technologies Co.,Ltd. GE is implemented via C++ and acts as a powerful backing force for MindSpore. GE is a linked up module between MindSpore front end and Ascend Chips. | ||||
| @@ -458,76 +458,3 @@ Copyright (c) Facebook Inc. and Microsoft Corporation. | |||||
| License: MIT License | License: MIT License | ||||
| Please see above. | Please see above. | ||||
| Software: caffe 1.0 | |||||
| License: BSD 2-Clause License | |||||
| Open Source Software Licensed Under the BSD 2-Clause License | |||||
| GraphEngine uses source code files from caffe so as to support model format conversion from caffe model to GraphEngine model. | |||||
| Please see below for the full list of source code files from caffe that are used by GraphEngine. | |||||
| The below software in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| ---------------------------------------------------------------------------------------- | |||||
| 1. caffe.proto master | |||||
| All contributions by the University of California: | |||||
| Copyright (c) 2014-2017 The Regents of the University of California (Regents) | |||||
| All rights reserved. | |||||
| Terms of the BSD 2-Clause License: | |||||
| -------------------------------------------------------------------- | |||||
| Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: | |||||
| Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. | |||||
| Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. | |||||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| Software: tensorflow 1.15.0 | |||||
| License: Apache-2.0 License | |||||
| Open Source Software Licensed Under the Apache-2.0 License | |||||
| GraphEngine uses source code files from tensorflow so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| Please see below for the full list of source code files from tensorflow that are used by GraphEngine. | |||||
| The below software in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| ---------------------------------------------------------------------------------------- | |||||
| 1. attr_value.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 2. function.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 3. graph.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 4. node_def.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 5. op_def.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 6. resource_handle.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 7. tensor.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 8. tensor_shape.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 9. types.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| 10. versions.proto master | |||||
| Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |||||
| Terms of the Apache-2.0 License: | |||||
| Please see above. | |||||
| @@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/" | |||||
| usage() | usage() | ||||
| { | { | ||||
| echo "Usage:" | echo "Usage:" | ||||
| echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off] [-M]" | |||||
| echo "sh build.sh [-j[n]] [-A] [-h] [-v] [-s] [-t] [-u] [-c]" | |||||
| echo "" | echo "" | ||||
| echo "Options:" | echo "Options:" | ||||
| echo " -h Print usage" | echo " -h Print usage" | ||||
| @@ -32,52 +32,35 @@ usage() | |||||
| echo " -j[n] Set the number of threads used for building GraphEngine, default is 8" | echo " -j[n] Set the number of threads used for building GraphEngine, default is 8" | ||||
| echo " -t Build and execute ut" | echo " -t Build and execute ut" | ||||
| echo " -c Build ut with coverage tag" | echo " -c Build ut with coverage tag" | ||||
| echo " -p Build inference or train" | |||||
| echo " -v Display build command" | echo " -v Display build command" | ||||
| echo " -S Enable enable download cmake compile dependency from gitee , default off" | |||||
| echo " -M build MindSpore mode" | |||||
| echo "to be continued ..." | echo "to be continued ..." | ||||
| } | } | ||||
| # check value of input is 'on' or 'off' | |||||
| # usage: check_on_off arg_value arg_name | |||||
| check_on_off() | |||||
| { | |||||
| if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then | |||||
| echo "Invalid value $1 for option -$2" | |||||
| usage | |||||
| exit 1 | |||||
| fi | |||||
| } | |||||
| # parse and set options | |||||
| # parse and set optionss | |||||
| checkopts() | checkopts() | ||||
| { | { | ||||
| VERBOSE="" | VERBOSE="" | ||||
| THREAD_NUM=8 | THREAD_NUM=8 | ||||
| # ENABLE_GE_UT_ONLY_COMPILE="off" | |||||
| ENABLE_GE_UT_ONLY_COMPILE="off" | |||||
| ENABLE_GE_UT="off" | ENABLE_GE_UT="off" | ||||
| ENABLE_GE_ST="off" | ENABLE_GE_ST="off" | ||||
| ENABLE_GE_COV="off" | ENABLE_GE_COV="off" | ||||
| PLATFORM="" | |||||
| PRODUCT="normal" | |||||
| ENABLE_GITEE="off" | |||||
| MINDSPORE_MODE="off" | |||||
| GE_ONLY="on" | |||||
| # Process the options | # Process the options | ||||
| while getopts 'ustchj:p:g:vS:M' opt | |||||
| while getopts 'ustchj:vA' opt | |||||
| do | do | ||||
| OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | ||||
| case "${opt}" in | case "${opt}" in | ||||
| u) | u) | ||||
| # ENABLE_GE_UT_ONLY_COMPILE="on" | |||||
| ENABLE_GE_UT_ONLY_COMPILE="on" | |||||
| ENABLE_GE_UT="on" | ENABLE_GE_UT="on" | ||||
| ;; | ;; | ||||
| s) | s) | ||||
| ENABLE_GE_ST="on" | ENABLE_GE_ST="on" | ||||
| ;; | ;; | ||||
| t) | t) | ||||
| ENABLE_GE_UT="on" | |||||
| ;; | |||||
| ENABLE_GE_UT="on" | |||||
| ;; | |||||
| c) | c) | ||||
| ENABLE_GE_COV="on" | ENABLE_GE_COV="on" | ||||
| ;; | ;; | ||||
| @@ -91,19 +74,8 @@ checkopts() | |||||
| v) | v) | ||||
| VERBOSE="VERBOSE=1" | VERBOSE="VERBOSE=1" | ||||
| ;; | ;; | ||||
| p) | |||||
| PLATFORM=$OPTARG | |||||
| ;; | |||||
| g) | |||||
| PRODUCT=$OPTARG | |||||
| ;; | |||||
| S) | |||||
| check_on_off $OPTARG S | |||||
| ENABLE_GITEE="$OPTARG" | |||||
| echo "enable download from gitee" | |||||
| ;; | |||||
| M) | |||||
| MINDSPORE_MODE="on" | |||||
| A) | |||||
| usage | |||||
| ;; | ;; | ||||
| *) | *) | ||||
| echo "Undefined option: ${opt}" | echo "Undefined option: ${opt}" | ||||
| @@ -114,9 +86,6 @@ checkopts() | |||||
| } | } | ||||
| checkopts "$@" | checkopts "$@" | ||||
| git submodule update --init metadef | |||||
| git submodule update --init parser | |||||
| mk_dir() { | mk_dir() { | ||||
| local create_dir="$1" # the target to make | local create_dir="$1" # the target to make | ||||
| @@ -131,10 +100,9 @@ echo "---------------- GraphEngine build start ----------------" | |||||
| build_graphengine() | build_graphengine() | ||||
| { | { | ||||
| echo "create build directory and build GraphEngine"; | echo "create build directory and build GraphEngine"; | ||||
| mk_dir "${BUILD_PATH}" | |||||
| cd "${BUILD_PATH}" | |||||
| CMAKE_ARGS="-DBUILD_PATH=$BUILD_PATH" | |||||
| mk_dir "${BUILD_PATH}/graphengine" | |||||
| cd "${BUILD_PATH}/graphengine" | |||||
| CMAKE_ARGS="-DBUILD_PATH=$BUILD_PATH -DGE_ONLY=$GE_ONLY" | |||||
| if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | ||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_COV=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_COV=ON" | ||||
| @@ -149,55 +117,17 @@ build_graphengine() | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_GITEE" = "Xon" ]]; then | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" | |||||
| fi | |||||
| if [[ "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH} -DPLATFORM=${PLATFORM} -DPRODUCT=${PRODUCT}" | |||||
| else | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}" | |||||
| fi | |||||
| echo "${CMAKE_ARGS}" | echo "${CMAKE_ARGS}" | ||||
| cmake ${CMAKE_ARGS} .. | |||||
| if [ $? -ne 0 ] | |||||
| then | |||||
| echo "execute command: cmake ${CMAKE_ARGS} .. failed." | |||||
| return 1 | |||||
| fi | |||||
| COMMON_TARGET="ge_local_engine ge_local_opskernel_builder host_cpu_engine host_cpu_opskernel_builder ge_common engine fmk_parser parser_common _caffe_parser fmk_onnx_parser graph register engine_conf.json optimizer_priority.pbtxt " | |||||
| TARGET=${COMMON_TARGET} | |||||
| if [ "x${PLATFORM}" = "xtrain" ] | |||||
| then | |||||
| TARGET="ge_runner fwk_atc.bin ${TARGET}" | |||||
| elif [ "x${PLATFORM}" = "xinference" ] | |||||
| then | |||||
| TARGET="ge_compiler atc_atc.bin ge_executor_shared ${TARGET}" | |||||
| elif [ "X$ENABLE_GE_UT" = "Xon" ] | |||||
| then | |||||
| TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | |||||
| elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||||
| then | |||||
| TARGET="ge_common graph" | |||||
| elif [ "x${PLATFORM}" = "xall" ] | |||||
| then | |||||
| # build all the target | |||||
| TARGET="ge_runner ge_compiler fwk_atc.bin atc_atc.bin ge_executor_shared ${TARGET}" | |||||
| fi | |||||
| make ${VERBOSE} ${TARGET} -j${THREAD_NUM} && make install | |||||
| if [ $? -ne 0 ] | |||||
| then | |||||
| echo "execute command: make ${VERBOSE} -j${THREAD_NUM} && make install failed." | |||||
| return 1 | |||||
| fi | |||||
| cmake ${CMAKE_ARGS} ../.. | |||||
| make ${VERBOSE} -j${THREAD_NUM} | |||||
| echo "GraphEngine build success!" | echo "GraphEngine build success!" | ||||
| } | } | ||||
| g++ -v | g++ -v | ||||
| mk_dir ${OUTPUT_PATH} | |||||
| build_graphengine || { echo "GraphEngine build failed."; return; } | |||||
| build_graphengine | |||||
| echo "---------------- GraphEngine build finished ----------------" | echo "---------------- GraphEngine build finished ----------------" | ||||
| mk_dir ${OUTPUT_PATH} | |||||
| cp -rf "${BUILD_PATH}/graphengine/"*.so "${OUTPUT_PATH}" | |||||
| rm -rf "${OUTPUT_PATH}/"libproto* | |||||
| rm -f ${OUTPUT_PATH}/libgmock*.so | rm -f ${OUTPUT_PATH}/libgmock*.so | ||||
| rm -f ${OUTPUT_PATH}/libgtest*.so | rm -f ${OUTPUT_PATH}/libgtest*.so | ||||
| rm -f ${OUTPUT_PATH}/lib*_stub.so | rm -f ${OUTPUT_PATH}/lib*_stub.so | ||||
| @@ -205,144 +135,38 @@ rm -f ${OUTPUT_PATH}/lib*_stub.so | |||||
| chmod -R 750 ${OUTPUT_PATH} | chmod -R 750 ${OUTPUT_PATH} | ||||
| find ${OUTPUT_PATH} -name "*.so*" -print0 | xargs -0 chmod 500 | find ${OUTPUT_PATH} -name "*.so*" -print0 | xargs -0 chmod 500 | ||||
| echo "---------------- GraphEngine output generated ----------------" | |||||
| echo "---------------- GraphEngine output package generated ----------------" | |||||
| if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| cp ${BUILD_PATH}/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/ut_libgraph && ${RUN_TEST_CASE} && | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_multiparts_utest && ${RUN_TEST_CASE} && | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_distinct_load_utest && ${RUN_TEST_CASE} && | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_others_utest && ${RUN_TEST_CASE} && | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_kernel_utest && ${RUN_TEST_CASE} | |||||
| if [[ "$?" -ne 0 ]]; then | |||||
| echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
| echo -e "\033[31m${RUN_TEST_CASE}\033[0m" | |||||
| exit 1; | |||||
| fi | |||||
| echo "Generating coverage statistics, please wait..." | |||||
| cd ${BASEPATH} | |||||
| rm -rf ${BASEPATH}/cov | |||||
| mkdir ${BASEPATH}/cov | |||||
| lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info | |||||
| lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o cov/coverage.info | |||||
| cd ${BASEPATH}/cov | |||||
| genhtml coverage.info | |||||
| if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
| cp ${BUILD_PATH}/graphengine/tests/st/st_resnet50_train ${OUTPUT_PATH} | |||||
| fi | fi | ||||
| # generate output package in tar form, including ut/st libraries/executables | |||||
| generate_package() | |||||
| { | |||||
| cd "${BASEPATH}" | |||||
| GRAPHENGINE_LIB_PATH="lib" | |||||
| ACL_PATH="acllib/lib64" | |||||
| FWK_PATH="fwkacllib/lib64" | |||||
| ATC_PATH="atc/lib64" | |||||
| ATC_BIN_PATH="atc/bin" | |||||
| FWK_BIN_PATH="fwkacllib/bin" | |||||
| FWK_INCLUDE_PATH="fwkacllib/include" | |||||
| ATC_INCLUDE_PATH="atc/include" | |||||
| NNENGINE_PATH="plugin/nnengine/ge_config" | |||||
| OPSKERNEL_PATH="plugin/opskernel" | |||||
| ACL_LIB=("libge_common.so" "libgraph.so" "libregister.so" "liberror_manager.so" "libge_executor.so") | |||||
| ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so" "libregister.so" "liberror_manager.so") | |||||
| FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so" "libregister.so" "liberror_manager.so") | |||||
| PLUGIN_OPSKERNEL=("libge_local_engine.so" "libge_local_opskernel_builder.so" "libhost_cpu_engine.so" "libhost_cpu_opskernel_builder.so" "optimizer_priority.pbtxt") | |||||
| PARSER_LIB=("lib_caffe_parser.so" "libfmk_onnx_parser.so" "libfmk_parser.so" "libparser_common.so") | |||||
| rm -rf ${OUTPUT_PATH:?}/${FWK_PATH}/ | |||||
| rm -rf ${OUTPUT_PATH:?}/${ACL_PATH}/ | |||||
| rm -rf ${OUTPUT_PATH:?}/${ATC_PATH}/ | |||||
| rm -rf ${OUTPUT_PATH:?}/${ATC_BIN_PATH}/ | |||||
| rm -rf ${OUTPUT_PATH:?}/${FWK_BIN_PATH}/ | |||||
| mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${ACL_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${ATC_BIN_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${FWK_BIN_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${FWK_INCLUDE_PATH}" | |||||
| mk_dir "${OUTPUT_PATH}/${ATC_INCLUDE_PATH}" | |||||
| cd "${OUTPUT_PATH}" | |||||
| find ./ -name graphengine_lib.tar -exec rm {} \; | |||||
| cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH} | |||||
| cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH} | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}/../ \; | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}/../ \; | |||||
| MAX_DEPTH=1 | |||||
| # if [ "x${PLATFORM}" = "xall" ] || [ "x${PLATFORM}" = "xinference" ] | |||||
| # then | |||||
| # MAX_DEPTH=2 | |||||
| # fi | |||||
| for lib in "${PLUGIN_OPSKERNEL[@]}"; | |||||
| do | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; | |||||
| done | |||||
| for lib in "${PARSER_LIB[@]}"; | |||||
| do | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; | |||||
| done | |||||
| for lib in "${FWK_LIB[@]}"; | |||||
| do | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; | |||||
| done | |||||
| for lib in "${ACL_LIB[@]}"; | |||||
| do | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \; | |||||
| done | |||||
| for lib in "${ATC_LIB[@]}"; | |||||
| do | |||||
| find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; | |||||
| done | |||||
| find ./lib/atclib -name atc.bin -exec cp {} "${OUTPUT_PATH}/${ATC_BIN_PATH}" \; | |||||
| find ./lib/fwkacl -name atc.bin -exec cp {} "${OUTPUT_PATH}/${FWK_BIN_PATH}" \; | |||||
| cp -r ${OUTPUT_PATH}/../metadef/inc/external/* ${ATC_INCLUDE_PATH} | |||||
| cp -r ${OUTPUT_PATH}/../parser/inc/external/* ${ATC_INCLUDE_PATH} | |||||
| cp -r ${OUTPUT_PATH}/../inc/external/* ${ATC_INCLUDE_PATH} | |||||
| cp -r ${OUTPUT_PATH}/../metadef/inc/external/* ${FWK_INCLUDE_PATH} | |||||
| cp -r ${OUTPUT_PATH}/../parser/inc/external/* ${FWK_INCLUDE_PATH} | |||||
| cp -r ${OUTPUT_PATH}/../inc/external/* ${FWK_INCLUDE_PATH} | |||||
| if [ "x${PLATFORM}" = "xtrain" ] | |||||
| then | |||||
| tar -cf graphengine_lib.tar fwkacllib | |||||
| elif [ "x${PLATFORM}" = "xinference" ] | |||||
| then | |||||
| tar -cf graphengine_lib.tar acllib atc | |||||
| elif [ "x${PLATFORM}" = "xall" ] | |||||
| then | |||||
| tar -cf graphengine_lib.tar fwkacllib acllib atc | |||||
| fi | |||||
| } | |||||
| if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| cp ${BUILD_PATH}/graphengine/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} | |||||
| cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} | |||||
| if [[ "X${ENABLE_GE_UT_ONLY_COMPILE}" != "Xon" ]]; then | |||||
| export LD_LIBRARY_PATH=${D_LINK_PATH}/x86_64/:${BUILD_PATH}/graphengine/:/usr/local/HiAI/driver/lib64:/usr/local/HiAI/runtime/lib64:${LD_LIBRARY_PATH} | |||||
| echo ${LD_LIBRARY_PATH} | |||||
| ${OUTPUT_PATH}/ut_libgraph && | |||||
| ${OUTPUT_PATH}/ut_libge_multiparts_utest && | |||||
| ${OUTPUT_PATH}/ut_libge_distinct_load_utest && | |||||
| ${OUTPUT_PATH}/ut_libge_others_utest && | |||||
| ${OUTPUT_PATH}/ut_libge_kernel_utest | |||||
| if [[ "$?" -ne 0 ]]; then | |||||
| echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
| exit 1; | |||||
| fi | |||||
| fi | |||||
| if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
| generate_package | |||||
| elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||||
| then | |||||
| cd "${OUTPUT_PATH}" | |||||
| find ./ -name graphengine_lib.tar -exec rm {} \; | |||||
| tar -cf graphengine_lib.tar lib | |||||
| if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| echo "Generating coverage statistics, please wait..." | |||||
| cd ${BASEPATH} | |||||
| rm -rf ${BASEPATH}/cov | |||||
| mkdir ${BASEPATH}/cov | |||||
| gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html | |||||
| fi | |||||
| fi | fi | ||||
| echo "---------------- GraphEngine package archive generated ----------------" | |||||
| @@ -1,5 +0,0 @@ | |||||
| [graphengine] | |||||
| ge | |||||
| inc | |||||
| metadef | |||||
| parser | |||||
| @@ -1,29 +0,0 @@ | |||||
| #[[ | |||||
| module - the name of export imported target | |||||
| name - find the library name | |||||
| path - find the library path | |||||
| #]] | |||||
| function(find_module module name) | |||||
| if (TARGET ${module}) | |||||
| return() | |||||
| endif() | |||||
| set(options) | |||||
| set(oneValueArgs) | |||||
| set(multiValueArgs) | |||||
| cmake_parse_arguments(MODULE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) | |||||
| set(path ${MODULE_UNPARSED_ARGUMENTS}) | |||||
| find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} | |||||
| PATH_SUFFIXES lib | |||||
| ) | |||||
| message(STATUS "find ${name} location ${${module}_LIBRARY_DIR}") | |||||
| if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | |||||
| message(FATAL_ERROR "${name} not found in ${path}") | |||||
| endif() | |||||
| add_library(${module} SHARED IMPORTED) | |||||
| set_target_properties(${module} PROPERTIES | |||||
| IMPORTED_LOCATION ${${module}_LIBRARY_DIR} | |||||
| ) | |||||
| endfunction() | |||||
| @@ -0,0 +1,13 @@ | |||||
| set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| set(Eigen3_NS "ge_") | |||||
| graphengine_add_pkg(Eigen3 | |||||
| VER 3.3.7 | |||||
| URL https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.gz | |||||
| MD5 9e30f67e8531477de4117506fe44669b | |||||
| CMAKE_OPTION -DBUILD_TESTING=OFF) | |||||
| find_package(Eigen3 3.3.7 REQUIRED ${GE_FIND_NO_DEFAULT_PATH}) | |||||
| set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE) | |||||
| add_library(graphengine::eigen ALIAS Eigen3::Eigen) | |||||
| include_directories(${EIGEN3_INCLUDE_DIRS}) | |||||
| @@ -1,48 +0,0 @@ | |||||
| if (HAVE_GFLAGS) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") | |||||
| ExternalProject_Add(gflags_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(GFLAGS_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gflags) | |||||
| add_library(gflags_static STATIC IMPORTED) | |||||
| set_target_properties(gflags_static PROPERTIES | |||||
| IMPORTED_LOCATION ${GFLAGS_PKG_DIR}/lib/libgflags.a | |||||
| ) | |||||
| add_library(gflags INTERFACE) | |||||
| target_include_directories(gflags INTERFACE ${GFLAGS_PKG_DIR}/include) | |||||
| target_link_libraries(gflags INTERFACE gflags_static) | |||||
| add_dependencies(gflags gflags_build) | |||||
| #set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") | |||||
| set(HAVE_GFLAGS TRUE) | |||||
| @@ -1,81 +1,16 @@ | |||||
| if (HAVE_GTEST) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/ge_gtest/release-1.8.1.tar.gz") | |||||
| set(MD5 "") | |||||
| elseif (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.1.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.1.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| ExternalProject_Add(gtest_build | |||||
| URL ${REQ_URL} | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | |||||
| -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(GTEST_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gtest) | |||||
| file(MAKE_DIRECTORY ${GTEST_PKG_DIR}/include) | |||||
| add_library(gtest SHARED IMPORTED) | |||||
| set_target_properties(gtest PROPERTIES | |||||
| IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest.so | |||||
| ) | |||||
| add_library(gtest_main SHARED IMPORTED) | |||||
| set_target_properties(gtest_main PROPERTIES | |||||
| IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest_main.so | |||||
| ) | |||||
| target_include_directories(gtest INTERFACE ${GTEST_PKG_DIR}/include) | |||||
| target_include_directories(gtest_main INTERFACE ${GTEST_PKG_DIR}/include) | |||||
| add_library(gmock SHARED IMPORTED) | |||||
| set_target_properties(gmock PROPERTIES | |||||
| IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgmock.so | |||||
| ) | |||||
| add_library(gmock_main SHARED IMPORTED) | |||||
| set_target_properties(gmock_main PROPERTIES | |||||
| IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgmock_main.so | |||||
| ) | |||||
| target_include_directories(gmock INTERFACE ${GTEST_PKG_DIR}/include) | |||||
| target_include_directories(gmock_main INTERFACE ${GTEST_PKG_DIR}/include) | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(FILES ${GTEST_PKG_DIR}/lib/libgtest.so ${GTEST_PKG_DIR}/lib/libgtest_main.so ${GTEST_PKG_DIR}/lib/libgmock.so ${GTEST_PKG_DIR}/lib/libgmock_main.so OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | |||||
| add_dependencies(gtest gtest_build) | |||||
| #set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") | |||||
| set(HAVE_GTEST TRUE) | |||||
| set(ge_gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| set(ge_gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| graphengine_add_pkg(ge_gtest | |||||
| VER 1.8.0 | |||||
| LIBS gtest gtest_main | |||||
| URL https://github.com/google/googletest/archive/release-1.8.0.tar.gz | |||||
| MD5 16877098823401d1bf2ed7891d7dce36 | |||||
| CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON | |||||
| -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON) | |||||
| add_library(graphengine::gtest ALIAS ge_gtest::gtest) | |||||
| add_library(graphengine::gtest_main ALIAS ge_gtest::gtest_main) | |||||
| include_directories(${ge_gtest_INC}) | |||||
| file(COPY ${ge_gtest_INC}/../lib/libgtest.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) | |||||
| file(COPY ${ge_gtest_INC}/../lib/libgtest_main.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) | |||||
| @@ -1,38 +1,9 @@ | |||||
| if (HAVE_JSON) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| #elseif (ENABLE_GITEE) | |||||
| # set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | |||||
| # set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | |||||
| #set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| endif () | |||||
| ExternalProject_Add(json_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/cloud_code/pkg/include.zip | |||||
| SOURCE_DIR ${JSON_SRC_DIR} | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | |||||
| BUILD_COMMAND "" | |||||
| INSTALL_COMMAND "" | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| add_library(json INTERFACE) | |||||
| target_include_directories(json INTERFACE ${JSON_INCLUDE_DIR}) | |||||
| add_dependencies(json json_build) | |||||
| #set(HAVE_JSON TRUE CACHE BOOL "json build add") | |||||
| set(HAVE_JSON TRUE) | |||||
| set(nlohmann_json_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| set(nlohmann_json_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||||
| graphengine_add_pkg(ge_nlohmann_json | |||||
| VER 3.6.1 | |||||
| HEAD_ONLY ./ | |||||
| URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip | |||||
| MD5 0dc903888211db3a0f170304cd9f3a89) | |||||
| include_directories(${ge_nlohmann_json_INC}) | |||||
| add_library(graphengine::json ALIAS ge_nlohmann_json) | |||||
| @@ -1,41 +1,5 @@ | |||||
| include(ExternalProject) | |||||
| #set(ONNX_SRC_DIR /home/txd/workspace/cloud_code/graphengine/build/graphengine/open_source/onnx) | |||||
| #set(ONNX_PROTO ${ONNX_SRC_DIR}/onnx/onnx.proto) | |||||
| set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) | |||||
| set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | |||||
| file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz") | |||||
| set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") | |||||
| elseif (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | |||||
| set(MD5 "1bdbcecdd68ea8392630467646776e02") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz") | |||||
| set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") | |||||
| endif () | |||||
| ExternalProject_Add(onnx | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | |||||
| #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | |||||
| #SOURCE_DIR ${ONNX_SRC_DIR} | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | |||||
| BUILD_COMMAND "" | |||||
| #INSTALL_COMMAND "" | |||||
| INSTALL_COMMAND ${CMAKE_COMMAND} -E copy <SOURCE_DIR>/onnx/onnx.proto ${ONNX_PROTO_FILE} | |||||
| #BUILD_ALWAYS TRUE | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| macro(onnx_protobuf_generate comp c_var h_var) | |||||
| add_custom_command(OUTPUT ${ONNX_PROTO_FILE} | |||||
| DEPENDS onnx | |||||
| ) | |||||
| ge_protobuf_generate(${comp} ${c_var} ${h_var} ${ONNX_PROTO_FILE}) | |||||
| endmacro() | |||||
| graphengine_add_pkg(onnx | |||||
| VER 1.6.0 | |||||
| HEAD_ONLY ./ | |||||
| URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz | |||||
| MD5 512f2779d6215d4a36f366b6b9acdf1e) | |||||
| @@ -0,0 +1,57 @@ | |||||
| if (NOT TARGET protobuf::libprotobuf) | |||||
| graphengine_add_pkg(protobuf | |||||
| VER 3.8.0 | |||||
| HEAD_ONLY ./ | |||||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
| MD5 3d9e32700639618a4d2d342c99d4507a) | |||||
| set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disahble protobuf test") | |||||
| set(protobuf_BUILD_SHARED_LIBS ON CACHE BOOL "Gen shared library") | |||||
| set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) | |||||
| string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | |||||
| string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | |||||
| set(PROTOBUF_CMAKE_FILE "${protobuf_DIRPATH}/cmake/libprotobuf.cmake" ) | |||||
| FILE(READ ${PROTOBUF_CMAKE_FILE} GE_MR_PROTOBUF_CMAKE) | |||||
| STRING(REPLACE "VERSION \${protobuf_VERSION}" "VERSION 19" GE_MR_PROTOBUF_CMAKE_V19 "${GE_MR_PROTOBUF_CMAKE}" ) | |||||
| FILE(WRITE ${PROTOBUF_CMAKE_FILE} "${GE_MR_PROTOBUF_CMAKE_V19}") | |||||
| add_subdirectory(${protobuf_DIRPATH}/cmake ${protobuf_DIRPATH}/build) | |||||
| set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) | |||||
| endif() | |||||
| set(PROTOBUF_LIBRARY protobuf::libprotobuf) | |||||
| include_directories(${protobuf_DIRPATH}/src) | |||||
| add_library(ge_protobuf::protobuf ALIAS libprotobuf) | |||||
| function(ge_protobuf_generate comp c_var h_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: ge_protobuf_generate() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${c_var}) | |||||
| set(${h_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc") | |||||
| list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc" | |||||
| "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/proto/${comp}/proto" | |||||
| COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/proto/${comp}/proto ${abs_file} | |||||
| DEPENDS protobuf::protoc ${abs_file} | |||||
| COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${c_var} ${${c_var}} PARENT_SCOPE) | |||||
| set(${h_var} ${${h_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| @@ -1,69 +0,0 @@ | |||||
| if (HAVE_PROTOBUF) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| include(GNUInstallDirs) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| else() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| endif() | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| ExternalProject_Add(protobuf_build | |||||
| URL ${REQ_URL} | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||||
| -Dprotobuf_WITH_ZLIB=OFF | |||||
| -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | |||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | |||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | |||||
| -DCMAKE_LINKER=${CMAKE_LINKER} | |||||
| -DCMAKE_AR=${CMAKE_AR} | |||||
| -DCMAKE_RANLIB=${CMAKE_RANLIB} | |||||
| -DLIB_PREFIX=ascend_ | |||||
| -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protobuf <SOURCE_DIR>/cmake | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| include(GNUInstallDirs) | |||||
| set(PROTOBUF_SHARED_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf) | |||||
| add_library(ascend_protobuf SHARED IMPORTED) | |||||
| file(MAKE_DIRECTORY ${PROTOBUF_SHARED_PKG_DIR}/include) | |||||
| set_target_properties(ascend_protobuf PROPERTIES | |||||
| IMPORTED_LOCATION ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.so | |||||
| ) | |||||
| target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/include) | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | |||||
| install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | |||||
| add_dependencies(ascend_protobuf protobuf_build) | |||||
| #set(HAVE_PROTOBUF TRUE CACHE BOOL "protobuf build add") | |||||
| set(HAVE_PROTOBUF TRUE) | |||||
| @@ -1,66 +0,0 @@ | |||||
| if (HAVE_PROTOBUF_STATIC) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| include(GNUInstallDirs) | |||||
| #set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if(GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| else() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| endif() | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | |||||
| ExternalProject_Add(protobuf_static_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | |||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | |||||
| -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | |||||
| -DCMAKE_LINKER=${CMAKE_LINKER} | |||||
| -DCMAKE_AR=${CMAKE_AR} | |||||
| -DCMAKE_RANLIB=${CMAKE_RANLIB} | |||||
| -Dprotobuf_WITH_ZLIB=OFF | |||||
| -DLIB_PREFIX=ascend_ | |||||
| -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_STATIC_PKG_DIR} <SOURCE_DIR>/cmake | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| include(GNUInstallDirs) | |||||
| add_library(ascend_protobuf_static_lib STATIC IMPORTED) | |||||
| set_target_properties(ascend_protobuf_static_lib PROPERTIES | |||||
| IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.a | |||||
| ) | |||||
| add_library(ascend_protobuf_static INTERFACE) | |||||
| target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | |||||
| target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | |||||
| if (ENABLE_D OR ENABLE_ACL OR ENABLE_MS_TESTCASES) | |||||
| include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) | |||||
| endif () | |||||
| add_dependencies(ascend_protobuf_static protobuf_static_build) | |||||
| set(HAVE_PROTOBUF_STATIC TRUE) | |||||
| @@ -1,116 +0,0 @@ | |||||
| if (HAVE_PROTOC) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| include(GNUInstallDirs) | |||||
| #set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if(GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| else() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| endif() | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| ExternalProject_Add(protoc_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc) | |||||
| set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc) | |||||
| function(protobuf_generate comp c_var h_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${c_var}) | |||||
| set(${h_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc") | |||||
| list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${c_var} ${${c_var}} PARENT_SCOPE) | |||||
| set(${h_var} ${${h_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| function(protobuf_generate_py comp py_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${py_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}_pb2.py" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${py_var} ${${py_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| #set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add") | |||||
| set(HAVE_PROTOC TRUE) | |||||
| @@ -1,71 +0,0 @@ | |||||
| if (HAVE_C_SEC) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| ExternalProject_Add(c_sec_build | |||||
| URL ${REQ_URL} | |||||
| #URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec | |||||
| PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | |||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | |||||
| -DCMAKE_LINKER=${CMAKE_LINKER} | |||||
| -DCMAKE_AR=${CMAKE_AR} | |||||
| -DCMAKE_RANLIB=${CMAKE_RANLIB} | |||||
| -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/c_sec <SOURCE_DIR> | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(C_SEC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/c_sec) | |||||
| add_library(c_sec SHARED IMPORTED) | |||||
| file(MAKE_DIRECTORY ${C_SEC_PKG_DIR}/include) | |||||
| set_target_properties(c_sec PROPERTIES | |||||
| IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.so | |||||
| ) | |||||
| target_include_directories(c_sec INTERFACE ${C_SEC_PKG_DIR}/include) | |||||
| add_dependencies(c_sec c_sec_build) | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(FILES ${C_SEC_PKG_DIR}/lib/libc_sec.so OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | |||||
| add_library(c_sec_static_lib STATIC IMPORTED) | |||||
| set_target_properties(c_sec_static_lib PROPERTIES | |||||
| IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.a | |||||
| ) | |||||
| add_library(c_sec_static INTERFACE) | |||||
| target_include_directories(c_sec_static INTERFACE ${C_SEC_PKG_DIR}/include) | |||||
| target_link_libraries(c_sec_static INTERFACE c_sec_static_lib) | |||||
| add_dependencies(c_sec_static c_sec_build) | |||||
| #set(HAVE_C_SEC TRUE CACHE BOOL "c_sec build add") | |||||
| set(HAVE_C_SEC TRUE) | |||||
| @@ -0,0 +1,349 @@ | |||||
| include(FetchContent) | |||||
| set(FETCHCONTENT_QUIET OFF) | |||||
| function(graphengine_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj) | |||||
| add_subdirectory(${sub_dir}) | |||||
| if(NOT TARGET ${submodule_name_obj}) | |||||
| message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}") | |||||
| endif() | |||||
| if("$<TARGET_OBJECTS:${submodule_name_obj}>" IN_LIST ${des_submodule_objs}) | |||||
| message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}") | |||||
| endif() | |||||
| set(${des_submodule_objs} ${${des_submodule_objs}} $<TARGET_OBJECTS:${submodule_name_obj}> PARENT_SCOPE) | |||||
| endfunction() | |||||
| get_filename_component(_MS_LIB_CACHE ~/.mslib REALPATH) | |||||
| if (NOT EXISTS ${_MS_LIB_CACHE}) | |||||
| file(MAKE_DIRECTORY ${_MS_LIB_CACHE}) | |||||
| endif () | |||||
| # set(FETCHCONTENT_BASE_DIR ${_MS_LIB_CACHE}) | |||||
| # set(CMAKE_PREFIX_PATH ${_MS_LIB_CACHE}) | |||||
| if (DEFINED ENV{MSLIBS_SERVER}) | |||||
| set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER}) | |||||
| message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}") | |||||
| endif () | |||||
| if(LOCAL_LIBS_SERVER) | |||||
| if (NOT ENV{no_proxy}) | |||||
| set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}") | |||||
| else() | |||||
| string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS) | |||||
| if (${IP_POS} EQUAL -1) | |||||
| set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}") | |||||
| endif () | |||||
| endif () | |||||
| endif() | |||||
| function(__download_pkg pkg_name pkg_url pkg_md5) | |||||
| if(LOCAL_LIBS_SERVER) | |||||
| get_filename_component(_URL_FILE_NAME ${pkg_url} NAME) | |||||
| set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) | |||||
| endif() | |||||
| FetchContent_Declare( | |||||
| ${pkg_name} | |||||
| URL ${pkg_url} | |||||
| URL_HASH MD5=${pkg_md5} | |||||
| ) | |||||
| FetchContent_GetProperties(${pkg_name}) | |||||
| message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") | |||||
| if(NOT ${pkg_name}_POPULATED) | |||||
| FetchContent_Populate(${pkg_name}) | |||||
| set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) | |||||
| endif() | |||||
| endfunction() | |||||
| function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_md5) | |||||
| if(LOCAL_LIBS_SERVER) | |||||
| set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}") | |||||
| FetchContent_Declare( | |||||
| ${pkg_name} | |||||
| URL ${pkg_url} | |||||
| URL_HASH MD5=${pkg_md5} | |||||
| ) | |||||
| else() | |||||
| FetchContent_Declare( | |||||
| ${pkg_name} | |||||
| GIT_REPOSITORY ${pkg_url} | |||||
| GIT_TAG ${pkg_git_commit}) | |||||
| endif() | |||||
| FetchContent_GetProperties(${pkg_name}) | |||||
| message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") | |||||
| if(NOT ${pkg_name}_POPULATED) | |||||
| FetchContent_Populate(${pkg_name}) | |||||
| set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) | |||||
| endif() | |||||
| endfunction() | |||||
| function(__find_pkg_then_add_target pkg_name pkg_exe) | |||||
| unset(${pkg_name}_LIBS) | |||||
| message("_FIND:${${pkg_name}_BASE_DIR}") | |||||
| if(pkg_exe) | |||||
| find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH) | |||||
| if(NOT ${pkg_exe}_EXE) | |||||
| return() | |||||
| endif() | |||||
| add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL) | |||||
| set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES | |||||
| IMPORTED_LOCATION ${${pkg_exe}_EXE} | |||||
| ) | |||||
| message("found ${${pkg_exe}_EXE}") | |||||
| endif() | |||||
| foreach(_LIB_NAME ${ARGN}) | |||||
| set(_LIB_SEARCH_NAME ${_LIB_NAME}) | |||||
| set(_LIB_TYPE SHARED) | |||||
| if (${pkg_name}_USE_STATIC_LIBS) | |||||
| set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") | |||||
| set(_LIB_TYPE STATIC) | |||||
| endif () | |||||
| set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND) | |||||
| find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/lib NO_DEFAULT_PATH) | |||||
| if(NOT ${_LIB_NAME}_LIB) | |||||
| return() | |||||
| endif() | |||||
| add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL) | |||||
| set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES | |||||
| INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include" | |||||
| IMPORTED_LOCATION ${${_LIB_NAME}_LIB} | |||||
| ) | |||||
| list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME}) | |||||
| message("found ${${_LIB_NAME}_LIB}") | |||||
| STRING( REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB}) | |||||
| set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL) | |||||
| endforeach(_LIB_NAME) | |||||
| set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE) | |||||
| endfunction() | |||||
| function(__exec_cmd) | |||||
| set(options ) | |||||
| set(oneValueArgs WORKING_DIRECTORY) | |||||
| set(multiValueArgs COMMAND) | |||||
| cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) | |||||
| execute_process(COMMAND ${EXEC_COMMAND} | |||||
| WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY} | |||||
| RESULT_VARIABLE RESULT) | |||||
| if(NOT RESULT EQUAL "0") | |||||
| message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}") | |||||
| endif() | |||||
| endfunction() | |||||
| function(__check_patches pkg_patches) | |||||
| # check patches | |||||
| if (PKG_PATCHES) | |||||
| file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.md5) | |||||
| file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${pkg_name}_PATCHES_MD5) | |||||
| message("patches md5:${${pkg_name}_PATCHES_MD5}") | |||||
| set(${pkg_name}_PATCHES_NEW_MD5 ) | |||||
| foreach(_PATCH ${PKG_PATCHES}) | |||||
| file(MD5 ${_PATCH} _PF_MD5) | |||||
| set(${pkg_name}_PATCHES_NEW_MD5 "${${pkg_name}_PATCHES_NEW_MD5},${_PF_MD5}") | |||||
| endforeach(_PATCH) | |||||
| if (NOT ${pkg_name}_PATCHES_MD5 STREQUAL ${pkg_name}_PATCHES_NEW_MD5) | |||||
| set(${pkg_name}_PATCHES ${PKG_PATCHES}) | |||||
| file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild") | |||||
| file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${${pkg_name}_PATCHES_NEW_MD5}) | |||||
| message("patches changed : ${${pkg_name}_PATCHES_NEW_MD5}") | |||||
| endif () | |||||
| endif () | |||||
| endfunction() | |||||
| set(GE_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH | |||||
| NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH | |||||
| NO_CMAKE_SYSTEM_PACKAGE_REGISTRY) | |||||
| set(GE_FIND_NO_DEFAULT_PATH ${GE_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) | |||||
| function(graphengine_add_pkg pkg_name ) | |||||
| set(options ) | |||||
| set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY) | |||||
| set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES) | |||||
| cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) | |||||
| set(__FIND_PKG_NAME ${pkg_name}) | |||||
| string(TOLOWER ${pkg_name} pkg_name) | |||||
| message("pkg name:${__FIND_PKG_NAME},${pkg_name}") | |||||
| set(${pkg_name}_PATCHES_HASH ) | |||||
| foreach(_PATCH ${PKG_PATCHES}) | |||||
| file(MD5 ${_PATCH} _PF_MD5) | |||||
| set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_MD5}") | |||||
| endforeach(_PATCH) | |||||
| # check options | |||||
| set(${pkg_name}_CONFIG_TXT | |||||
| "${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION} | |||||
| ${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH} | |||||
| ${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}") | |||||
| string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT}) | |||||
| string(MD5 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT}) | |||||
| message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}") | |||||
| set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH}) | |||||
| set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL) | |||||
| if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) | |||||
| add_library(${pkg_name} INTERFACE) | |||||
| target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) | |||||
| return() | |||||
| endif () | |||||
| if(NOT PKG_EXE) | |||||
| set(PKG_EXE 0) | |||||
| endif() | |||||
| set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR}) | |||||
| set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE) | |||||
| if (PKG_LIBS) | |||||
| __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) | |||||
| if(${pkg_name}_LIBS) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) | |||||
| message("Found libs: ${${pkg_name}_LIBS}") | |||||
| return() | |||||
| endif() | |||||
| elseif(NOT PKG_HEAD_ONLY) | |||||
| find_package(${__FIND_PKG_NAME} ${PKG_VER} NO_CMAKE_SYSTEM_PATH NO_SYSTEM_ENVIRONMENT_PATH) | |||||
| if (${__FIND_PKG_NAME}_FOUND) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) | |||||
| message("Found pkg: ${__FIND_PKG_NAME}") | |||||
| return() | |||||
| endif () | |||||
| endif () | |||||
| if (NOT PKG_DIR) | |||||
| if (PKG_GIT_REPOSITORY) | |||||
| __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) | |||||
| else() | |||||
| __download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5}) | |||||
| endif() | |||||
| else() | |||||
| set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) | |||||
| endif () | |||||
| file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT}) | |||||
| message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}") | |||||
| foreach(_PATCH_FILE ${PKG_PATCHES}) | |||||
| message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_PATCH_FILE}") | |||||
| execute_process(COMMAND patch -p1 INPUT_FILE ${_PATCH_FILE} | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR} | |||||
| RESULT_VARIABLE Result) | |||||
| if(NOT Result EQUAL "0") | |||||
| message(FATAL_ERROR "Failed patch: ${_PATCH_FILE}") | |||||
| endif() | |||||
| endforeach(_PATCH_FILE) | |||||
| file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) | |||||
| if(NOT ${pkg_name}_LOCK_RET EQUAL "0") | |||||
| message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") | |||||
| endif() | |||||
| if(${pkg_name}_SOURCE_DIR) | |||||
| if (PKG_HEAD_ONLY) | |||||
| file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) | |||||
| file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) | |||||
| add_library(${pkg_name} INTERFACE) | |||||
| target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) | |||||
| elseif (PKG_CMAKE_OPTION) | |||||
| # in cmake | |||||
| file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) | |||||
| if (${pkg_name}_CFLAGS) | |||||
| set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}") | |||||
| endif () | |||||
| if (${pkg_name}_CXXFLAGS) | |||||
| set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}") | |||||
| endif () | |||||
| if (${pkg_name}_LDFLAGS) | |||||
| if (${pkg_name}_USE_STATIC_LIBS) | |||||
| #set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") | |||||
| else() | |||||
| set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") | |||||
| endif () | |||||
| endif () | |||||
| __exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR} | |||||
| ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} | |||||
| -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} .. | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) | |||||
| __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j8 | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) | |||||
| else() | |||||
| if (${pkg_name}_CFLAGS) | |||||
| set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}") | |||||
| endif () | |||||
| if (${pkg_name}_CXXFLAGS) | |||||
| set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}") | |||||
| endif () | |||||
| if (${pkg_name}_LDFLAGS) | |||||
| set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}") | |||||
| endif () | |||||
| # in configure && make | |||||
| if (PKG_PRE_CONFIGURE_COMMAND) | |||||
| __exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND} | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) | |||||
| endif () | |||||
| if (PKG_CONFIGURE_COMMAND) | |||||
| __exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND} | |||||
| ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS} | |||||
| --prefix=${${pkg_name}_BASE_DIR} | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) | |||||
| endif () | |||||
| set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION}) | |||||
| if (NOT PKG_CONFIGURE_COMMAND) | |||||
| set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION} | |||||
| ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}) | |||||
| endif () | |||||
| # build | |||||
| __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j8 | |||||
| WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) | |||||
| if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS) | |||||
| file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) | |||||
| file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) | |||||
| file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) | |||||
| file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) | |||||
| else() | |||||
| __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) | |||||
| endif () | |||||
| endif () | |||||
| endif() | |||||
| if (PKG_LIBS) | |||||
| __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) | |||||
| if(NOT ${pkg_name}_LIBS) | |||||
| message(FATAL_ERROR "Can not find pkg: ${pkg_name}") | |||||
| endif() | |||||
| else() | |||||
| find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET) | |||||
| if (${__FIND_PKG_NAME}_FOUND) | |||||
| set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) | |||||
| message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}") | |||||
| return() | |||||
| endif () | |||||
| endif () | |||||
| endfunction() | |||||
| @@ -1,52 +0,0 @@ | |||||
| add_library(intf_pub INTERFACE) | |||||
| target_compile_options(intf_pub INTERFACE | |||||
| -Wall | |||||
| -fPIC | |||||
| -fstack-protector-strong | |||||
| ) | |||||
| target_compile_definitions(intf_pub INTERFACE | |||||
| $<$<STREQUAL:${PRODUCT_SIDE},host>:_GLIBCXX_USE_CXX11_ABI=0> | |||||
| $<$<CONFIG:Release>:CFG_BUILD_NDEBUG> | |||||
| $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | |||||
| WIN64=1 | |||||
| LINUX=0 | |||||
| ) | |||||
| target_link_options(intf_pub INTERFACE | |||||
| -Wl,-z,relro | |||||
| -Wl,-z,now | |||||
| -Wl,-z,noexecstack | |||||
| $<$<CONFIG:Release>:-Wl,--build-id=none> | |||||
| ) | |||||
| target_link_directories(intf_pub INTERFACE | |||||
| ) | |||||
| add_library(intf_ccec INTERFACE) | |||||
| target_compile_options(intf_ccec INTERFACE | |||||
| -mcpu=cortex-a73 | |||||
| --target=aarch64-linux-android29 | |||||
| --sysroot=${HCC_PATH}/../sysroot | |||||
| -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x | |||||
| -Wall | |||||
| -fPIC | |||||
| -fstack-protector-strong | |||||
| ) | |||||
| target_compile_definitions(intf_ccec INTERFACE | |||||
| $<$<STREQUAL:${PRODUCT_SIDE},host>:_GLIBCXX_USE_CXX11_ABI=0> | |||||
| $<$<CONFIG:Release>:CFG_BUILD_NDEBUG> | |||||
| $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | |||||
| ) | |||||
| target_link_options(intf_ccec INTERFACE | |||||
| -mcpu=cortex-a73 | |||||
| --target=aarch64-linux-android29 | |||||
| --sysroot=${HCC_PATH}/../sysroot | |||||
| -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x | |||||
| -Wl,-cce-host-android | |||||
| -Wl,-z,relro | |||||
| -Wl,-z,now | |||||
| -Wl,-z,noexecstack | |||||
| $<$<CONFIG:Release>:-Wl,--build-id=none> | |||||
| ) | |||||
| @@ -1,34 +0,0 @@ | |||||
| if (HAVE_PUB) | |||||
| return() | |||||
| endif() | |||||
| add_library(intf_pub INTERFACE) | |||||
| target_compile_options(intf_pub INTERFACE | |||||
| -Wall | |||||
| -fPIC | |||||
| $<IF:$<STREQUAL:${CMAKE_SYSTEM_NAME},centos>,-fstack-protector-all,-fstack-protector-strong> | |||||
| $<$<COMPILE_LANGUAGE:CXX>:-std=c++11> | |||||
| ) | |||||
| target_compile_definitions(intf_pub INTERFACE | |||||
| _GLIBCXX_USE_CXX11_ABI=0 | |||||
| $<$<CONFIG:Release>:CFG_BUILD_NDEBUG> | |||||
| $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | |||||
| WIN64=1 | |||||
| LINUX=0 | |||||
| LOG_CPP | |||||
| ) | |||||
| target_link_options(intf_pub INTERFACE | |||||
| -Wl,-z,relro | |||||
| -Wl,-z,now | |||||
| -Wl,-z,noexecstack | |||||
| $<$<CONFIG:Release>:-Wl,--build-id=none> | |||||
| ) | |||||
| target_link_directories(intf_pub INTERFACE | |||||
| ) | |||||
| target_link_libraries(intf_pub INTERFACE | |||||
| -lpthread | |||||
| ) | |||||
| #set(HAVE_PUB TRUE CACHE BOOL "pub add") | |||||
| set(HAVE_PUB TRUE) | |||||
| @@ -1,24 +0,0 @@ | |||||
| add_library(intf_pub INTERFACE) | |||||
| target_compile_options(intf_pub INTERFACE | |||||
| -Wall | |||||
| -fPIC | |||||
| $<IF:$<STREQUAL:${OS_TYPE},centos>,-fstack-protector-all,-fstack-protector-strong> | |||||
| $<$<COMPILE_LANGUAGE:CXX>:-std=c++11> | |||||
| ) | |||||
| target_compile_definitions(intf_pub INTERFACE | |||||
| $<$<STREQUAL:${PRODUCT_SIDE},host>:_GLIBCXX_USE_CXX11_ABI=0> | |||||
| OS_TYPE=WIN64 | |||||
| WIN64=1 | |||||
| LINUX=0 | |||||
| $<$<CONFIG:Release>:CFG_BUILD_NDEBUG> | |||||
| $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | |||||
| ) | |||||
| target_link_options(intf_pub INTERFACE | |||||
| $<$<CONFIG:Release>:-Wl,--build-id=none> | |||||
| ) | |||||
| target_link_directories(intf_pub INTERFACE | |||||
| ) | |||||
| target_link_libraries(intf_pub INTERFACE | |||||
| ) | |||||
| @@ -1,340 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "analyzer.h" | |||||
| #include <cstdlib> | |||||
| #include <cstdio> | |||||
| #include <iostream> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| using json = nlohmann::json; | |||||
| using Status = ge::Status; | |||||
| using ComputeGraph = ge::ComputeGraph; | |||||
| using namespace analyzer; | |||||
| namespace { | |||||
| constexpr int kFileAuthority = 0640; | |||||
| constexpr int kJsonDumpLevel = 4; | |||||
| const std::string kFilePath = "./"; | |||||
| const std::string kAnalyzeFile = "ge_check_op.json"; | |||||
| const std::string kUnknownShape = "unknownshape"; | |||||
| const std::string kUnsupport = "unsupport"; | |||||
| const std::string kSessionId = "session_id"; | |||||
| const std::string kGraphId = "graph_id"; | |||||
| const std::string kOpInfo = "op_info"; | |||||
| const std::string kErrorType = "error_type"; | |||||
| const std::string kOpName = "name"; | |||||
| const std::string kOpType = "type"; | |||||
| const std::string kReason = "reason"; | |||||
| const std::string kInput = "input"; | |||||
| const std::string kOutput = "output"; | |||||
| const std::string kShape = "shape"; | |||||
| const std::string kDataType = "data_type"; | |||||
| const std::string kLayout = "layout"; | |||||
| const std::string kResult = "result"; | |||||
| const std::string kOp = "op"; | |||||
| std::map<analyzer::AnalyzeType, std::string> errors_map { | |||||
| {PARSER, "paser_error"}, | |||||
| {INFER_SHAPE, "infer_shape_error"}, | |||||
| {CHECKSUPPORT, "check_support_error"}, | |||||
| {GRAPH_OPTIMIZE, "graph_optimize_error"}, | |||||
| {GRAPH_PARTION, "graph_partion_error"}, | |||||
| {GRAPH_BUILDER, "graph_builder_error"} | |||||
| }; | |||||
| } | |||||
| Analyzer* Analyzer::GetInstance() { | |||||
| static Analyzer instance; | |||||
| return &instance; | |||||
| } | |||||
| Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| GELOGD("Start to build map. SessionId:%lu GraphId:%lu", session_id, graph_id); | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| std::shared_ptr<GraphInfo> graph_info(new(std::nothrow) GraphInfo()); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| std::map<uint64_t, std::shared_ptr<GraphInfo>> graph_map; | |||||
| graph_map[graph_id] = graph_info; | |||||
| graph_info->session_id = session_id; | |||||
| graph_info->graph_id = graph_id; | |||||
| graph_infos_.insert({session_id, graph_map}); | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| std::shared_ptr<GraphInfo> graph_info(new(std::nothrow) GraphInfo()); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| graph_info->session_id = session_id; | |||||
| graph_info->graph_id = graph_id; | |||||
| (iter->second).insert({graph_id, graph_info}); | |||||
| } else { | |||||
| GELOGI("session_id:%lu graph_id:%lu already existed json object", session_id, graph_id); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::Initialize() { | |||||
| // Initialize file | |||||
| string real_path = RealPath(kFilePath.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(FAILED, "[Check][AnalyzeFilePath]File path is empty, Path invalid."); | |||||
| REPORT_CALL_ERROR("E19999", "Analyze file path check invalid, it is empty"); | |||||
| return FAILED; | |||||
| } | |||||
| json_file_name_ = real_path + "/" + kAnalyzeFile; | |||||
| return SUCCESS; | |||||
| } | |||||
| void Analyzer::Finalize() { | |||||
| GELOGD("Analyzer start to finalize!"); | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| for (auto &session_resource : graph_infos_) { | |||||
| session_resource.second.clear(); | |||||
| } | |||||
| graph_infos_.clear(); | |||||
| std::lock_guard<std::mutex> lk(file_mutex_); | |||||
| if (json_file_.is_open()) { | |||||
| json_file_.close(); | |||||
| } | |||||
| } | |||||
| void Analyzer::DestroySessionJsonObject(uint64_t session_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGW("can not find the stored object by session_id[%lu].Do nothing", session_id); | |||||
| } else { | |||||
| graph_infos_.erase(iter); | |||||
| } | |||||
| } | |||||
| void Analyzer::DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGW("can not find the stored object by session_id[%lu].Do nothing", session_id); | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| GELOGW("Can not find the graph json object by session_id[%lu] and graph_id[%lu]. Do nothing.", session_id, | |||||
| graph_id); | |||||
| return; | |||||
| } | |||||
| (iter->second).erase(iter1); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<GraphInfo> Analyzer::GetJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGE(PARAM_INVALID, "[Check][SessionId]session_id:%lu does not exist! " | |||||
| "graph_id:%lu", session_id, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", "Sessin_id %lu does not exist, graph_id %lu", | |||||
| session_id, graph_id); | |||||
| return nullptr; | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| GELOGE(PARAM_INVALID, "[Check][GraphId]graph_id:%lu does not exist! " | |||||
| "session_id:%lu.", graph_id, session_id); | |||||
| REPORT_INNER_ERROR("E19999", "Graph_id %lu does not exist, session_id %lu", | |||||
| graph_id, session_id); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGI("GetJsonObject Success!session_id:%lu graph_id:%lu", session_id, graph_id); | |||||
| return iter1->second; | |||||
| } | |||||
| } | |||||
| void Analyzer::ClearHistoryFile() { | |||||
| GELOGD("Analyzer start to clear history file!"); | |||||
| // Remove history files | |||||
| int res = remove(json_file_name_.c_str()); | |||||
| GELOGD("remove file %s, result:%d", json_file_name_.c_str(), res); | |||||
| } | |||||
| ge::Status Analyzer::CreateAnalyzerFile() { | |||||
| if (is_json_file_create_) { | |||||
| GELOGD("analyzer file has been created!No necessary to create again!"); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("start to create analyzer file!"); | |||||
| std::lock_guard<std::mutex> lg(file_mutex_); | |||||
| int fd = open(json_file_name_.c_str(), O_WRONLY | O_CREAT | O_TRUNC, kFileAuthority); | |||||
| if (fd < 0) { | |||||
| GELOGE(INTERNAL_ERROR, "[FileOpen][AnalyzeFile]Fail to open the analyze file: %s.", | |||||
| json_file_name_.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to open analyze file %s", json_file_name_.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(INTERNAL_ERROR, "[FileClose][AnalyzeFile]Fail to close the analyze file: %s.", | |||||
| json_file_name_.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to clsoe analyze file %s", json_file_name_.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| is_json_file_create_ = true; | |||||
| GELOGD("success to create analyzer file[%s]!", json_file_name_.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id) { | |||||
| GELOGD("start to save analyze file"); | |||||
| auto graph_info = GetJsonObject(session_id, graph_id); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| if (graph_info->op_info.size() == 0) { | |||||
| GELOGD("session_id:%lu graph_id:%lu does not owner op info, break it!", session_id, graph_id); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::lock_guard<std::mutex> lg(file_mutex_); | |||||
| json_file_.open(json_file_name_, std::ios::app); | |||||
| if (!json_file_.is_open()) { | |||||
| GELOGE(FAILED, "[Check][AnalyzeFile]analyze file does not exist[%s]", | |||||
| json_file_name_.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Analyze file %s dose not exist", json_file_name_.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| json jsn; | |||||
| GraphInfoToJson(jsn, *graph_info); | |||||
| bool ret_failed = false; | |||||
| try { | |||||
| json_file_ << jsn.dump(kJsonDumpLevel) << std::endl; | |||||
| } catch (nlohmann::detail::type_error &e) { | |||||
| GELOGE(FAILED, | |||||
| "[Json.dump][GraphInfo]Dump analyze file [%s] failed because [%s]," | |||||
| "session_id:%lu, graph_id:%lu", | |||||
| json_file_name_.c_str(), e.what(), session_id, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", "Dump analyze file %s failed because %s, " | |||||
| "session_id %lu, graph_id %lu", | |||||
| json_file_name_.c_str(), e.what(), session_id, graph_id); | |||||
| ret_failed = true; | |||||
| } | |||||
| json_file_.close(); | |||||
| return ret_failed ? FAILED : SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::DoAnalyze(DataInfo &data_info) { | |||||
| GELOGD("start to do analyzer process"); | |||||
| auto pnode = data_info.node_ptr; | |||||
| GE_CHECK_NOTNULL(pnode); | |||||
| auto desc = pnode->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(desc); | |||||
| // buff analyze data | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto graph_info = GetJsonObject(data_info.session_id, data_info.graph_id); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| auto status = SaveOpInfo(desc, data_info, graph_info); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, | |||||
| "[Check][SaveOpInfo]save op info: desc_name [%s] desc_type [%s] failed!", | |||||
| desc->GetName().c_str(), desc->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Save op info: desc_name %s, desc_type %s failed", | |||||
| desc->GetName().c_str(), desc->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // create json file | |||||
| return CreateAnalyzerFile(); | |||||
| } | |||||
| ge::Status Analyzer::SaveOpInfo(ge::OpDescPtr desc, DataInfo &data_info, | |||||
| std::shared_ptr<analyzer::GraphInfo> graph_info) { | |||||
| auto iter = errors_map.find(data_info.analyze_type); | |||||
| if (iter == errors_map.end()) { | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| OpInfo op_info; | |||||
| op_info.error_type = iter->second; | |||||
| op_info.op_name = desc->GetName(); | |||||
| op_info.op_type = desc->GetType(); | |||||
| op_info.reason = data_info.reason; | |||||
| for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||||
| TensorInfo tensor_info; | |||||
| tensor_info.shape = ptr->GetShape().GetDims(); | |||||
| tensor_info.d_type = ge::TypeUtils::DataTypeToSerialString(ptr->GetDataType()); | |||||
| tensor_info.layout = ge::TypeUtils::FormatToSerialString(ptr->GetFormat()); | |||||
| op_info.input_info.emplace_back(tensor_info); | |||||
| } | |||||
| for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||||
| TensorInfo tensor_info; | |||||
| tensor_info.shape = ptr->GetShape().GetDims(); | |||||
| tensor_info.d_type = ge::TypeUtils::DataTypeToSerialString(ptr->GetDataType()); | |||||
| tensor_info.layout = ge::TypeUtils::FormatToSerialString(ptr->GetFormat()); | |||||
| op_info.output_info.emplace_back(tensor_info); | |||||
| } | |||||
| graph_info->op_info.emplace_back(op_info); | |||||
| return SUCCESS; | |||||
| } | |||||
| void Analyzer::TensorInfoToJson(json& j, const TensorInfo &tensor_info) { | |||||
| j[kShape] = tensor_info.shape; | |||||
| j[kDataType] = tensor_info.d_type; | |||||
| j[kLayout] = tensor_info.layout; | |||||
| } | |||||
| void Analyzer::OpInfoToJson(json& j, const OpInfo &op_info) { | |||||
| j[kErrorType] = op_info.error_type; | |||||
| j[kOpName] = op_info.op_name; | |||||
| j[kOpType] = op_info.op_type; | |||||
| j[kReason] = op_info.reason; | |||||
| for (size_t i = 0; i < op_info.input_info.size(); i++) { | |||||
| json json_tensor_info; | |||||
| TensorInfoToJson(json_tensor_info, op_info.input_info.at(i)); | |||||
| j[kInput + std::to_string(i)] = json_tensor_info; | |||||
| } | |||||
| for (size_t i = 0; i < op_info.output_info.size(); i++) { | |||||
| json json_tensor_info; | |||||
| TensorInfoToJson(json_tensor_info, op_info.output_info.at(i)); | |||||
| j[kOutput + std::to_string(i)] = json_tensor_info; | |||||
| } | |||||
| } | |||||
| void Analyzer::GraphInfoToJson(json& j, const GraphInfo &graph_info) { | |||||
| GELOGD("start to buff graph info!"); | |||||
| j[kSessionId] = graph_info.session_id; | |||||
| j[kGraphId] = graph_info.graph_id; | |||||
| std::vector<json> json_op_infos; | |||||
| for (size_t i = 0; i < graph_info.op_info.size(); i++) { | |||||
| json json_op_info; | |||||
| OpInfoToJson(json_op_info, graph_info.op_info.at(i)); | |||||
| json_op_infos.emplace_back(json_op_info); | |||||
| } | |||||
| j[kOp] = json_op_infos; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,195 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_ANALYZER_ANANLYZER_H_ | |||||
| #define DOMI_ANALYZER_ANANLYZER_H_ | |||||
| #include "nlohmann/json.hpp" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <mutex> | |||||
| #include <memory> | |||||
| #include <fstream> | |||||
| #include <atomic> | |||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace analyzer { | |||||
| enum AnalyzeType { | |||||
| PARSER = 0, | |||||
| INFER_SHAPE = 1, | |||||
| CHECKSUPPORT = 2, | |||||
| GRAPH_OPTIMIZE = 3, | |||||
| GRAPH_PARTION = 4, | |||||
| GRAPH_BUILDER = 5, | |||||
| }; | |||||
| struct TensorInfo { | |||||
| vector<int64_t> shape; | |||||
| string d_type; | |||||
| string layout; | |||||
| }; | |||||
| struct OpInfo { | |||||
| string error_type; | |||||
| string op_name; | |||||
| string op_type; | |||||
| std::vector<TensorInfo> input_info; | |||||
| std::vector<TensorInfo> output_info; | |||||
| string reason; | |||||
| }; | |||||
| struct GraphInfo { | |||||
| uint64_t session_id = 0; | |||||
| uint64_t graph_id = 0; | |||||
| std::vector<OpInfo> op_info; | |||||
| }; | |||||
| struct DataInfo { | |||||
| DataInfo() = default; | |||||
| ~DataInfo() = default; | |||||
| DataInfo(uint64_t sess, uint64_t graph, AnalyzeType type, | |||||
| ge::NodePtr node, std::string error_info) { | |||||
| session_id = sess; | |||||
| graph_id = graph; | |||||
| analyze_type = type; | |||||
| node_ptr = node; | |||||
| reason = error_info; | |||||
| } | |||||
| uint64_t session_id; | |||||
| uint64_t graph_id; | |||||
| AnalyzeType analyze_type; | |||||
| ge::NodePtr node_ptr{nullptr}; | |||||
| std::string reason; | |||||
| }; | |||||
| } | |||||
| class Analyzer { | |||||
| public: | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: get analyzer instance. | |||||
| * @param [in]: None | |||||
| * @return: Analyzer instance ptr | |||||
| */ | |||||
| static Analyzer *GetInstance(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: check whether env var ENABLE_NETWORK_ANALYSIS_DEBUG is enabled. | |||||
| * When enable env, it will keep adaptor sink geop graph even though fail. | |||||
| * @param [in]: None | |||||
| * @return: true: enable env false : disable env | |||||
| */ | |||||
| bool IsEnableNetAnalyzeDebug() { return std::getenv("ENABLE_NETWORK_ANALYSIS_DEBUG") != nullptr; } | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: build buff object by sess id and graph id . | |||||
| * @param [in]: session id & graph id | |||||
| * @return: 0: success other: failed | |||||
| */ | |||||
| ge::Status BuildJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: get buff object by sess id and graph id . | |||||
| * @param [in]: session id & graph id | |||||
| * @return: nullptr if failed | |||||
| */ | |||||
| std::shared_ptr<analyzer::GraphInfo> GetJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: analyzer globle init method. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| ge::Status Initialize(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Release all used resource of analyzer. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void Finalize(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Only release resource about session id. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void DestroySessionJsonObject(uint64_t session_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Only release resource about session id and graph id. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: main process method. Buff analyzed data and output to json file | |||||
| * @param [in]: DataInfo Object | |||||
| * @return: 0: SUCCESS other: FAILED | |||||
| */ | |||||
| ge::Status DoAnalyze(analyzer::DataInfo &data_info); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: Buff analyzed data and output to json file | |||||
| * @param [in]: session id , graph id | |||||
| * @return: 0: SUCCESS other: FAILED | |||||
| */ | |||||
| ge::Status SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id); | |||||
| Analyzer(const Analyzer &) = delete; | |||||
| Analyzer& operator=(const Analyzer&) = delete; | |||||
| Analyzer(Analyzer &&) = delete; | |||||
| Analyzer& operator=(Analyzer &&) = delete; | |||||
| private: | |||||
| void TensorInfoToJson(nlohmann::json& j, const analyzer::TensorInfo &tensor_info); | |||||
| void OpInfoToJson(nlohmann::json& j, const analyzer::OpInfo &op_info); | |||||
| void GraphInfoToJson(nlohmann::json& j, const analyzer::GraphInfo &graph_info); | |||||
| ge::Status SaveOpInfo(ge::OpDescPtr desc, analyzer::DataInfo &data_info, | |||||
| std::shared_ptr<analyzer::GraphInfo> graph_info); | |||||
| void ClearHistoryFile(); | |||||
| ge::Status CreateAnalyzerFile(); | |||||
| explicit Analyzer() {}; | |||||
| ~Analyzer() = default; | |||||
| private: | |||||
| std::map<uint64_t, std::map<uint64_t, std::shared_ptr<analyzer::GraphInfo>>> graph_infos_; | |||||
| std::recursive_mutex mutex_; // protect graph_infos_ | |||||
| std::mutex file_mutex_; // protect json_file_ | |||||
| std::ofstream json_file_; | |||||
| std::string json_file_name_; | |||||
| std::atomic_bool is_json_file_create_{false}; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_ANALYZER_ANANLYZER_H_ | |||||
| @@ -1,797 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge/ge_api.h" | |||||
| #include <iostream> | |||||
| #include <malloc.h> | |||||
| #include "common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/ge/datatype_util.h" | |||||
| #include "proto/ge_api.pb.h" | |||||
| #include "graph/model_serialize.h" | |||||
| #include "graph/detail/model_serialize_imp.h" | |||||
| #include "graph/utils/tensor_adapter.h" | |||||
| #include "init/gelib.h" | |||||
| #include "session/session_manager.h" | |||||
| #include "graph/opsproto_manager.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "graph/manager/util/rt_context_util.h" | |||||
| #include "graph/common/ge_call_wrapper.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "common/ge/tbe_plugin_manager.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "toolchain/plog.h" | |||||
| using domi::OpRegistry; | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace { | |||||
| const int32_t kMaxStrLen = 128; | |||||
| } // namespace | |||||
| static bool g_ge_initialized = false; | |||||
| static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use | |||||
| namespace ge { | |||||
| void GetOpsProtoPath(std::string &opsproto_path) { | |||||
| GELOGI("Enter get ops proto path schedule"); | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| std::string path = path_env; | |||||
| opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
| GELOGI("Get opsproto so path from env: %s", path.c_str()); | |||||
| return; | |||||
| } | |||||
| std::string path_base = PluginManager::GetPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
| } | |||||
| Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
| // check job_id is valid | |||||
| auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | |||||
| if (job_id_iter != options.end()) { | |||||
| if (job_id_iter->second.length() > kMaxStrLen) { | |||||
| GELOGE(PARAM_INVALID,"[Check][JobId]Failed," | |||||
| "the job_id [%s] string length: %zu > max string length: %d", | |||||
| job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); | |||||
| REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id","length"}), | |||||
| std::vector<std::string>({job_id_iter->second, | |||||
| std::to_string(kMaxStrLen)})); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Initialize GE, prepare for execution, call GELib::Initialize | |||||
| Status GEInitializeImpl(const std::map<string, string> &options) { | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| GELOGT(TRACE_INIT, "GEInitialize start"); | |||||
| std::string path_base = ge::GELib::GetPath(); | |||||
| auto ret = ErrorManager::GetInstance().Init(path_base); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(GE_CLI_INIT_FAILED, | |||||
| "[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Init failed when pass param path_base:%s", path_base.c_str()); | |||||
| return ret; | |||||
| } | |||||
| // 0.check init status | |||||
| if (g_ge_initialized) { | |||||
| GELOGW("GEInitialize is called more than once"); | |||||
| return SUCCESS; | |||||
| } | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsProtoInit); | |||||
| // Load OpsProto lib plugin | |||||
| std::string opsproto_path; | |||||
| GetOpsProtoPath(opsproto_path); | |||||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||||
| std::map<string, string> option_tmp; | |||||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
| GE_TIMESTAMP_START(GEInitialize); | |||||
| bool is_proto_init = manager->Initialize(option_tmp); | |||||
| GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); | |||||
| if (!is_proto_init) { | |||||
| GELOGE(GE_CLI_INIT_FAILED, | |||||
| "[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.", | |||||
| opsproto_path.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid", | |||||
| opsproto_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| // check options is valid | |||||
| GE_TIMESTAMP_START(CheckOptionsValid); | |||||
| if (CheckOptionsValid(options) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsProtoInit); | |||||
| GE_TIMESTAMP_START(InitPreparation); | |||||
| TBEPluginManager::Instance().InitPreparation(options); | |||||
| GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||||
| // call Initialize | |||||
| GELOGT(TRACE_RUNNING, "Initializing environment"); | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| GE_TIMESTAMP_START(GELibInitialize); | |||||
| ret = ge::GELib::Initialize(options); | |||||
| GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize"); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(GE_CLI_INIT_FAILED, "[Init][GELib]Failed, error code = %u", ret); | |||||
| return FAILED; | |||||
| } | |||||
| // 7.check return status, return | |||||
| if (!g_ge_initialized) { | |||||
| // Initialize success, first time calling initialize | |||||
| g_ge_initialized = true; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "GEInitialize finished"); | |||||
| return ret; | |||||
| } | |||||
| // Initialize GE, prepare for execution, call GELib::Initialize | |||||
| Status GEInitialize(const std::map<string, string> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| if (DlogReportInitialize() != SUCCESS) { | |||||
| GELOGW("Dlog report device log initialize failed."); | |||||
| } | |||||
| return GEInitializeImpl(options); | |||||
| } | |||||
| Status GEInitialize(const std::map<AscendString, AscendString> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "[Check][Param]Options invalid, first or second option is nullptr."); | |||||
| REPORT_INNER_ERROR("E19999", "Check parameter's options invalid," | |||||
| "the first or second option is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| if (DlogReportInitialize() != SUCCESS) { | |||||
| GELOGW("Dlog report device log initialize failed."); | |||||
| } | |||||
| return GEInitializeImpl(str_options); | |||||
| } | |||||
| // GE finalize, releasing all resources | |||||
| Status GEFinalize() { | |||||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||||
| // check init status | |||||
| if (!g_ge_initialized) { | |||||
| GELOGW("[FINAL][FINAL]GEFinalize is called before GEInitialize"); | |||||
| return SUCCESS; | |||||
| } | |||||
| ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| GELOGT(TRACE_INIT, "GEFinalize start"); | |||||
| // call Finalize | |||||
| Status ret = SUCCESS; | |||||
| Status middle_ret; | |||||
| GELOGT(TRACE_RUNNING, "Finalizing environment"); | |||||
| std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||||
| if (instancePtr == nullptr || !instancePtr->InitFlag()) { | |||||
| GELOGW("GEFinalize Failed: GE not initialized."); | |||||
| ret = GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| if (ret != GE_CLI_GE_NOT_INITIALIZED) { | |||||
| middle_ret = instancePtr->Finalize(); | |||||
| GELOGI("GEFinalize finalize gelib ret=%u", middle_ret); | |||||
| if (middle_ret != SUCCESS) { | |||||
| ret = middle_ret; | |||||
| } | |||||
| } | |||||
| middle_ret = TBEPluginManager::Instance().Finalize(); | |||||
| if (middle_ret != SUCCESS) { | |||||
| ret = middle_ret; | |||||
| } | |||||
| if (g_ge_initialized && ret == SUCCESS) { | |||||
| // Unified destruct rt_context | |||||
| RtContextUtil::GetInstance().DestroyAllRtContexts(); | |||||
| g_ge_initialized = false; | |||||
| } | |||||
| // to avoid memory fragment, use malloc_trim to back free stack to system | |||||
| malloc_trim(0); | |||||
| if (DlogReportFinalize() != SUCCESS) { | |||||
| GELOGW("Dlog report device log finalize failed."); | |||||
| } | |||||
| GELOGT(TRACE_STOP, "GEFinalize finished"); | |||||
| return ret; | |||||
| } | |||||
| std::string GEGetErrorMsg() { | |||||
| return ErrorManager::GetInstance().GetErrorMessage(); | |||||
| } | |||||
| std::string GEGetWarningMsg() { | |||||
| return ErrorManager::GetInstance().GetWarningMessage(); | |||||
| } | |||||
| // Initialize session,which calls innerSession | |||||
| Session::Session(const std::map<string, string> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| // check init status | |||||
| sessionId_ = 0; | |||||
| if (!g_ge_initialized) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Construct][Session]Failed because lack GEInitialize call before."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Creating session failed because lack GEInitialize call before."); | |||||
| return; | |||||
| } | |||||
| // call Initialize | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Construct][Session]Failed, GELib instance is nullptr or it is not InitFlag"); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Creating session"); | |||||
| uint64_t session_id = 0; | |||||
| Status ret = instance_ptr->SessionManagerObj().CreateSession(options, session_id); | |||||
| GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||||
| // check return status, return, update session id if success | |||||
| if (ret == SUCCESS) { | |||||
| sessionId_ = session_id; | |||||
| } else { | |||||
| GELOGE(ret, "[Construct][Session]Failed, error code:%u.", ret); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | |||||
| } | |||||
| Session::Session(const std::map<AscendString, AscendString> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| // check init status | |||||
| sessionId_ = 0; | |||||
| if (!g_ge_initialized) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Construct][Session]Failed because lack GEInitialize call before."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Creating session failed because lack GEInitialize call before."); | |||||
| return; | |||||
| } | |||||
| // call Initialize | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Construct][Session]Failed, the GELib instance is nullptr or is not InitFlag"); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Creating session"); | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "[Construct][Session]Failed, the first or second option is nullptr."); | |||||
| REPORT_INNER_ERROR("E19999", "Creating session's options invalid," | |||||
| "the first or second option is nullptr."); | |||||
| return; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| uint64_t session_id = 0; | |||||
| Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); | |||||
| GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||||
| // check return status, return, update session id if success | |||||
| if (ret == SUCCESS) { | |||||
| sessionId_ = session_id; | |||||
| } else { | |||||
| GELOGE(ret, "[Construct][Session]Failed, error code:%u.", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Construct session failed, error code:%u.", ret); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session Constructor finished"); | |||||
| } | |||||
| // session destructor | |||||
| Session::~Session() { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); | |||||
| GELOGT(TRACE_INIT, "Session Destructor start"); | |||||
| // 0.check init status | |||||
| if (!g_ge_initialized) { | |||||
| GELOGW("GE is not yet initialized or is finalized."); | |||||
| return; | |||||
| } | |||||
| Status ret = FAILED; | |||||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||||
| try { | |||||
| uint64_t session_id = sessionId_; | |||||
| // call DestroySession | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGW("GE is not yet initialized or is finalized."); | |||||
| return; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||||
| GELOGT(TRACE_RUNNING, "Destroying session"); | |||||
| ret = instance_ptr->SessionManagerObj().DestroySession(session_id); | |||||
| } catch (google::protobuf::FatalException &e) { | |||||
| GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed " | |||||
| "because get fatalException."); | |||||
| REPORT_CALL_ERROR("E19999", "Destruct session failed, get fatal exception"); | |||||
| } | |||||
| // check return status, return, update session id if success | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Destruct][Session]Failed, error code:%u.", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Destruct session failed, error code:%u.", ret); | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session Destructor finished"); | |||||
| } | |||||
| // Add Graph | |||||
| Status Session::AddGraph(uint32_t graph_id, const Graph &graph) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| std::map<std::string, std::string> options; | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| return AddGraph(graph_id, graph, options); | |||||
| } | |||||
| // Add Graph | |||||
| Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Add][Graph]Failed because GELib instance is nullptr or it is not InitFlag."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "AddGraph Failed, GELib instance is nullptr or it is not InitFlag."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Adding graph to session"); | |||||
| Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, options); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("AddGraph finished in Session."); | |||||
| return ret; | |||||
| } | |||||
| //Add Graph | |||||
| Status Session::AddGraph(uint32_t graph_id, const Graph &graph, | |||||
| const std::map<AscendString, AscendString> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "AddGraph Failed, GELib instance is nullptr or it is not InitFlag."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Adding graph to session"); | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "[Add][Graph]Failed, the first or second option is nullptr."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Add Graph Failed, the first or second option is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string key = option.first.GetString(); | |||||
| std::string val = option.second.GetString(); | |||||
| str_options[key] = val; | |||||
| } | |||||
| Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("AddGraph finished in Session."); | |||||
| return ret; | |||||
| } | |||||
| Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::map<AscendString, AscendString> options; | |||||
| return AddGraphWithCopy(graph_id, graph, options); | |||||
| } | |||||
| // Add Graph With Copy | |||||
| Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, | |||||
| const std::map<AscendString, AscendString> &options) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "AddGraph Failed, GELib instance is nullptr or is not InitFlag."); | |||||
| return FAILED; | |||||
| } | |||||
| std::map<std::string, std::string> str_options; | |||||
| for (auto it = options.begin(); it != options.end(); ++it) { | |||||
| str_options.insert({it->first.GetString(), it->second.GetString()}); | |||||
| } | |||||
| GELOGD("Adding graph to session"); | |||||
| Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("AddGraph finished in Session."); | |||||
| return ret; | |||||
| } | |||||
| // Remove Graph | |||||
| Status Session::RemoveGraph(uint32_t graph_id) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Session RemoveGraph start"); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| // call RemoveGraph | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (!instance_ptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Remove][Graph]Failed, GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "RemoveGraph Failed, GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Removing Graph from session"); | |||||
| Status ret = instance_ptr->SessionManagerObj().RemoveGraph(sessionId_, graph_id); | |||||
| // check return status, return | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, " | |||||
| "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session RemoveGraph finished"); | |||||
| return ret; | |||||
| } | |||||
| // Print Output Result | |||||
| void PrintOutputResult(std::vector<Tensor> &outputs) { | |||||
| if (outputs.empty() || outputs[0].GetData() == nullptr) { | |||||
| GELOGW("outputs is empty or data is nullptr."); | |||||
| return; | |||||
| } | |||||
| size_t out_buf_size = outputs[0].GetSize(); | |||||
| TensorDesc desc(outputs[0].GetTensorDesc()); | |||||
| DataType data_type = desc.GetDataType(); | |||||
| auto iter = CONST_OPDATA_TYPE_SIZE_MAP.find(data_type); | |||||
| if (iter == CONST_OPDATA_TYPE_SIZE_MAP.end()) { | |||||
| GELOGI("DataType %s has not defined size", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return; | |||||
| } | |||||
| size_t length = CONST_OPDATA_TYPE_SIZE_MAP[data_type]; | |||||
| for (size_t i = 0; i < 10 && i < (out_buf_size / length); ++i) { // take first 10 at most | |||||
| switch (data_type) { | |||||
| case DT_BOOL: | |||||
| case DT_INT8: | |||||
| case DT_UINT8: | |||||
| GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int8_t *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| case DT_INT16: | |||||
| case DT_UINT16: | |||||
| GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int16_t *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| case DT_INT32: | |||||
| case DT_UINT32: | |||||
| GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int32_t *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| case DT_INT64: | |||||
| case DT_UINT64: | |||||
| GELOGI("output data[%zu]=%ld", i, *(reinterpret_cast<int64_t *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| case DT_FLOAT: | |||||
| GELOGI("output data[%zu]=%f", i, *(reinterpret_cast<float *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| case DT_DOUBLE: | |||||
| GELOGI("output data[%zu]=%lf", i, *(reinterpret_cast<double *>(outputs[0].GetData()) + i)); | |||||
| break; | |||||
| default: | |||||
| GELOGI("Output datatype %s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Run Graph | |||||
| Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| GELOGT(TRACE_INIT, "Session RunGraph start"); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::vector<Tensor> graph_inputs = inputs; | |||||
| // call RunGraph | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Run][Graph]Failed, GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "RunGraph Failed, GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Running Graph"); | |||||
| Status ret = instance_ptr->SessionManagerObj().RunGraph(sessionId_, graph_id, graph_inputs, outputs); | |||||
| // check return status | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, " | |||||
| "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| // print output | |||||
| if (outputs.size() > 0) { | |||||
| PrintOutputResult(outputs); | |||||
| } | |||||
| // return | |||||
| GELOGT(TRACE_STOP, "Session RunGraph finished"); | |||||
| return ret; | |||||
| } | |||||
| // Run Graph with stream Asynchronously | |||||
| Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<Tensor> &inputs, | |||||
| std::vector<Tensor> &outputs) { | |||||
| ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); | |||||
| GELOGT(TRACE_INIT, "Session run graph with stream async start"); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Run][Graph]Run graph with stream asyn failed, the GELib instance is nullptr," | |||||
| "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Run graph with stream asyn failed, the GELib instance is nullptr" | |||||
| "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||||
| return FAILED; | |||||
| } | |||||
| if (!instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Run][Graph]Run graph with stream asyn failed, the GELib instance is not init," | |||||
| "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Run graph with stream asyn failed, the GELib instance is not init," | |||||
| "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Run Graph Run graph with stream asyn."); | |||||
| Status ret = instance_ptr->SessionManagerObj().RunGraphWithStreamAsync(sessionId_, graph_id, stream, inputs, | |||||
| outputs); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Run][Graph]Run graph with stream asyn Failed," | |||||
| "error code = %u, session id = %lu, graph id = %u, stream = %p.", ret, sessionId_, graph_id, stream); | |||||
| REPORT_CALL_ERROR("E19999", "[Run][Graph]Run graph with stream asyn failed, error code = %u, session id = %lu," | |||||
| "graph id = %u, stream = %p.", ret, sessionId_, graph_id, stream); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "Session run graph with stream async finished"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Register Call Back | |||||
| Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||||
| } | |||||
| Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| std::string str_key; | |||||
| if (key != nullptr) { | |||||
| str_key = key; | |||||
| } | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); | |||||
| } | |||||
| // Build Graph | |||||
| Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Build graph failed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Building Graph"); | |||||
| Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, " | |||||
| "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Build Graph | |||||
| Status Session::BuildGraph(uint32_t graph_id, const std::vector<ge::Tensor> &inputs) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Build graph failed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Building Graph"); | |||||
| Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, " | |||||
| "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Run Graph Asynchronously | |||||
| Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<ge::Tensor> &inputs, | |||||
| RunAsyncCallback callback) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Run][Graph]RunGraphAsyncFailed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "RunGraphAsync Failed, the GELib instance is nullptr or is not InitFlag, " | |||||
| "session_id %lu, graph_id %u", sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Run Graph Asynchronously"); | |||||
| GELOGW( | |||||
| "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Run][Graph]RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u.", | |||||
| ret, sessionId_, graph_id); | |||||
| REPORT_CALL_ERROR("E19999", "RunGraphAsync Failed, error code:%u, session_id:%lu, " | |||||
| "graph_id:%u", ret, sessionId_, graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Get Variables | |||||
| Status Session::GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| auto instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "GetVariables failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Get Variables"); | |||||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, var_names, var_values); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Get Variables | |||||
| Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| auto instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, | |||||
| "[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "GetVariables failed, the GELib instance is nullptr or is not InitFlag."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGT(TRACE_RUNNING, "Get Variables"); | |||||
| std::vector<ge::string> str_var_names; | |||||
| for (auto &var_name : var_names) { | |||||
| if (var_name.GetString() == nullptr) { | |||||
| GELOGE(FAILED, "[Get][Variable]Failed, variables' names are nullptr."); | |||||
| REPORT_INNER_ERROR("E19999", "GetVariables failed, variables' names are nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| str_var_names.emplace_back(var_name.GetString()); | |||||
| } | |||||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); | |||||
| REPORT_CALL_ERROR("E19999", "Get variables failed, error code:%u, session_id:%lu.", | |||||
| ret, sessionId_); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,116 +0,0 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| COMMON_LOCAL_SRC_FILES := \ | |||||
| proto/ge_api.proto \ | |||||
| ge_api.cc \ | |||||
| COMMON_LOCAL_C_INCLUDES := \ | |||||
| proto/ge_ir.proto \ | |||||
| proto/task.proto \ | |||||
| proto/om.proto \ | |||||
| proto/insert_op.proto \ | |||||
| $(LOCAL_PATH) ./ \ | |||||
| $(LOCAL_PATH)/../ \ | |||||
| $(LOCAL_PATH)/../../ \ | |||||
| $(TOPDIR)inc \ | |||||
| $(TOPDIR)inc/external \ | |||||
| $(TOPDIR)inc/external/graph \ | |||||
| $(TOPDIR)inc/common \ | |||||
| $(TOPDIR)inc/framework \ | |||||
| $(TOPDIR)inc/graph \ | |||||
| $(TOPDIR)libc_sec/include \ | |||||
| $(TOPDIR)ops/built-in/op_proto/inc \ | |||||
| third_party/json/include \ | |||||
| third_party/protobuf/include \ | |||||
| third_party/opencv/include \ | |||||
| DEVICE_LOCAL_C_INCLUDES := \ | |||||
| proto/ge_ir.proto \ | |||||
| proto/task.proto \ | |||||
| proto/om.proto \ | |||||
| proto/insert_op.proto \ | |||||
| $(LOCAL_PATH) ./ \ | |||||
| $(LOCAL_PATH)/../ \ | |||||
| $(LOCAL_PATH)/../../ \ | |||||
| $(TOPDIR)inc \ | |||||
| $(TOPDIR)inc/external \ | |||||
| $(TOPDIR)inc/external/graph \ | |||||
| $(TOPDIR)inc/framework \ | |||||
| $(TOPDIR)inc/common \ | |||||
| $(TOPDIR)inc/graph \ | |||||
| $(TOPDIR)libc_sec/include \ | |||||
| $(TOPDIR)ops/built-in/op_proto/inc \ | |||||
| third_party/json/include \ | |||||
| third_party/protobuf/include \ | |||||
| third_party/opencv/include \ | |||||
| #compiler for host infer | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libge_client | |||||
| LOCAL_CFLAGS += -Werror | |||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -Dgoogle=ascend_private | |||||
| ifeq ($(DEBUG), 1) | |||||
| LOCAL_CFLAGS += -g -O0 | |||||
| endif | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libascend_protobuf \ | |||||
| libslog \ | |||||
| libmmpa \ | |||||
| libgraph \ | |||||
| libregister \ | |||||
| libge_compiler \ | |||||
| libge_common | |||||
| LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_SHARED_LIBRARIES += \ | |||||
| libruntime \ | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| #compiler for device | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libge_client | |||||
| LOCAL_CFLAGS += -Werror | |||||
| LOCAL_CFLAGS += -DGOOGLE_PROTOBUF_NO_RTTI -DDEV_VISIBILITY | |||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| LOCAL_CFLAGS += -DOMG_DEVICE_VERSION -DREUSE_MEMORY=1 -Dgoogle=ascend_private | |||||
| LOCAL_MODULE_CLASS := SHARED_LIBRARIES | |||||
| LOCAL_C_INCLUDES := $(DEVICE_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libascend_protobuf \ | |||||
| libslog \ | |||||
| libmmpa \ | |||||
| libgraph \ | |||||
| libregister \ | |||||
| libruntime \ | |||||
| libge_compiler \ | |||||
| libge_common | |||||
| LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_CFLAGS += \ | |||||
| -Wall | |||||
| include $(BUILD_SHARED_LIBRARY) | |||||
| @@ -1 +0,0 @@ | |||||
| ../../proto/ge_api.proto | |||||
| @@ -1,193 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| DT_VARIANT = 26; // variant type | |||||
| DT_BF16 = 27; // bf16 type | |||||
| DT_INT4 = 28; // int4 type | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -1,140 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| float padding_value = 72; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -1,396 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -1,246 +0,0 @@ | |||||
| set(PROTO_LIST | |||||
| "${METADEF_DIR}/proto/om.proto" | |||||
| "${METADEF_DIR}/proto/ge_ir.proto" | |||||
| "${METADEF_DIR}/proto/insert_op.proto" | |||||
| "${METADEF_DIR}/proto/task.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/attr_value.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/function.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/graph.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/node_def.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/op_def.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/resource_handle.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/tensor.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/tensor_shape.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/types.proto" | |||||
| "${METADEF_DIR}/proto/tensorflow/versions.proto" | |||||
| ) | |||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | |||||
| "context/ctx.cc" | |||||
| "model_saver.cc" | |||||
| "ge/datatype_util.cc" | |||||
| "helper/om_file_helper.cc" | |||||
| "helper/model_helper.cc" | |||||
| "../model/ge_model.cc" | |||||
| "../model/ge_root_model.cc" | |||||
| "auth/file_saver.cc" | |||||
| "fp16_t.cc" | |||||
| "math/fp16_math.cc" | |||||
| "debug/memory_dumper.cc" | |||||
| "formats/utils/formats_trans_utils.cc" | |||||
| "dump/dump_properties.cc" | |||||
| "formats/format_transfers/datatype_transfer.cc" | |||||
| "formats/format_transfers/format_transfer_transpose.cc" | |||||
| "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | |||||
| "formats/format_transfers/format_transfer_fractal_z.cc" | |||||
| "formats/format_transfers/format_transfer_fractal_nz.cc" | |||||
| "formats/format_transfers/format_transfer_fractal_zz.cc" | |||||
| "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | |||||
| "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | |||||
| "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | |||||
| "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | |||||
| "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | |||||
| "formats/format_transfers/format_transfer_fracz_nchw.cc" | |||||
| "formats/format_transfers/format_transfer_fracz_nhwc.cc" | |||||
| "formats/format_transfers/format_transfer_fracz_hwcn.cc" | |||||
| "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" | |||||
| "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" | |||||
| "formats/format_transfers/format_transfer_nchw_fz_c04.cc" | |||||
| "formats/formats.cc" | |||||
| "ge_format_util.cc" | |||||
| "fmk_error_codes.cc" | |||||
| "util.cc" | |||||
| "properties_manager.cc" | |||||
| "types.cc" | |||||
| "model_parser/model_parser.cc" | |||||
| "kernel_store.cc" | |||||
| "tbe_kernel_store.cc" | |||||
| "cust_aicpu_kernel_store.cc" | |||||
| "op/attr_value_util.cc" | |||||
| "op/ge_op_utils.cc" | |||||
| "thread_pool.cc" | |||||
| "ge/tbe_plugin_manager.cc" | |||||
| ) | |||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
| ############ libge_common.so ############ | |||||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_common PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| HOST_VISIBILITY | |||||
| FMK_SUPPORT_DUMP | |||||
| OS_CENTOS | |||||
| google=ascend_private | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_compile_options(ge_common PRIVATE | |||||
| -fvisibility=hidden | |||||
| -O2 | |||||
| -Werror | |||||
| -Wno-deprecated-declarations | |||||
| -fno-common | |||||
| ) | |||||
| target_include_directories(ge_common PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/common | |||||
| ${GE_CODE_DIR}/ge/common/op | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_DEPEND_DIR}/inc | |||||
| ${GE_DEPEND_DIR}/inc/cce | |||||
| #### blue zone #### | |||||
| #${GE_DEPEND_DIR}/include | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_options(ge_common PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(ge_common PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| static_mmpa | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| ascend_protobuf | |||||
| register | |||||
| c_sec | |||||
| error_manager | |||||
| slog | |||||
| -Wl,--as-needed | |||||
| json | |||||
| $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> | |||||
| -ldl | |||||
| ) | |||||
| ############ libge_common.a ############ | |||||
| add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS}) | |||||
| target_compile_definitions(ge_common_static PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| HOST_VISIBILITY | |||||
| FMK_SUPPORT_DUMP | |||||
| OS_CENTOS | |||||
| google=ascend_private | |||||
| $<IF:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,OS_TYPE=WIN,OS_TYPE=0> | |||||
| $<$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> | |||||
| LOG_CPP | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_compile_options(ge_common_static PRIVATE | |||||
| $<$<OR:$<STREQUAL:${TARGET_SYSTEM_NAME},Linux>,$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common> | |||||
| $<$<AND:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,$<STREQUAL:${CMAKE_CONFIGURATION_TYPES},Debug>>:/MTd> | |||||
| $<$<AND:$<STREQUAL:${TARGET_SYSTEM_NAME},Windows>,$<STREQUAL:${CMAKE_CONFIGURATION_TYPES},Release>>:/MT> | |||||
| ) | |||||
| target_include_directories(ge_common_static PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/common | |||||
| ${GE_CODE_DIR}/ge/common/op | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_static | |||||
| #### yellow zone #### | |||||
| ${GE_DEPEND_DIR}/inc | |||||
| ${GE_DEPEND_DIR}/inc/cce | |||||
| #### blue zone #### | |||||
| #${GE_DEPEND_DIR}/include | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_libraries(ge_common_static PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ascend_protobuf_static | |||||
| json | |||||
| c_sec | |||||
| $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> | |||||
| -ldl | |||||
| ) | |||||
| else () | |||||
| ############ libge_common.so w/static protobuf ############ | |||||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_common PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| HOST_VISIBILITY | |||||
| FMK_SUPPORT_DUMP | |||||
| OS_CENTOS | |||||
| google=ascend_private | |||||
| LOG_CPP | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_compile_options(ge_common PRIVATE | |||||
| -fvisibility=hidden | |||||
| -O2 | |||||
| -Werror | |||||
| -Wno-deprecated-declarations | |||||
| -fno-common | |||||
| ) | |||||
| target_include_directories(ge_common PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/common | |||||
| ${GE_CODE_DIR}/ge/common/op | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_options(ge_common PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(ge_common PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ascend_protobuf_static | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| register | |||||
| c_sec | |||||
| error_manager | |||||
| slog | |||||
| static_mmpa | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| endif () | |||||
| ############ install ############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS ge_common OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| @@ -1,408 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/auth/file_saver.h" | |||||
| #include <securec.h> | |||||
| #include <cstdlib> | |||||
| #include <fstream> | |||||
| #include <vector> | |||||
| #include "common/math/math_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/util.h" | |||||
| namespace { | |||||
| const int kFileOpSuccess = 0; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) { | |||||
| if (CheckPath(file_path) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Check][FilePath]Check output file failed, file_path:%s.", | |||||
| file_path.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Check output file failed, file_path:%s.", | |||||
| file_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| char real_path[MMPA_MAX_PATH] = {0}; | |||||
| GE_IF_BOOL_EXEC(mmRealPath(file_path.c_str(), real_path, MMPA_MAX_PATH) != EN_OK, | |||||
| GELOGI("File %s is not exist, it will be created.", file_path.c_str())); | |||||
| // Open file | |||||
| mmMode_t mode = M_IRUSR | M_IWUSR; | |||||
| fd = mmOpen2(real_path, M_RDWR | M_CREAT | O_TRUNC, mode); | |||||
| if (fd == EN_INVALID_PARAM || fd == EN_ERROR) { | |||||
| // -1: Failed to open file; - 2: Illegal parameter | |||||
| GELOGE(FAILED, "[Open][File]Failed. errno:%d, errmsg:%s", fd, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Open file failed, errno:%d, errmsg:%s.", | |||||
| fd, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | |||||
| mmSsize_t write_count; | |||||
| uint32_t size_2g = 2147483648; // 0x1 << 31 | |||||
| uint32_t size_1g = 1073741824; // 0x1 << 30 | |||||
| // Write data | |||||
| if (size > size_2g) { | |||||
| auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); | |||||
| while (size > size_1g) { | |||||
| write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size_1g); | |||||
| if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | |||||
| GELOGE(FAILED, "[Write][Data]Failed, errno:%ld, errmsg:%s", | |||||
| write_count, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Write data failed, errno:%ld, errmsg:%s.", | |||||
| write_count, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| size -= size_1g; | |||||
| seek += size_1g; | |||||
| } | |||||
| write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size); | |||||
| } else { | |||||
| write_count = mmWrite(fd, const_cast<void *>(data), size); | |||||
| } | |||||
| // -1: Failed to write to file; - 2: Illegal parameter | |||||
| if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | |||||
| GELOGE(FAILED, "[Write][Data]Failed. mmpa_errorno = %ld, error:%s", | |||||
| write_count, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Write data failed, mmpa_errorno = %ld, error:%s.", | |||||
| write_count, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, const void *data, | |||||
| int len) { | |||||
| if (data == nullptr || len <= 0) { | |||||
| GELOGE(FAILED, "[Check][Param]Failed, model_data is null or the " | |||||
| "length[%d] is less than 1.", len); | |||||
| REPORT_INNER_ERROR("E19999", "Save file failed, model_data is null or the " | |||||
| "length:%d is less than 1.", len); | |||||
| return FAILED; | |||||
| } | |||||
| // Open file | |||||
| int32_t fd = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED, "OpenFile FAILED"); | |||||
| Status ret = SUCCESS; | |||||
| do { | |||||
| // Write file header | |||||
| GE_CHK_BOOL_EXEC(WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) == SUCCESS, | |||||
| ret = FAILED; | |||||
| break, "WriteData FAILED"); | |||||
| // write data | |||||
| GE_CHK_BOOL_EXEC(WriteData(data, static_cast<uint32_t>(len), fd) == SUCCESS, ret = FAILED, "WriteData FAILED"); | |||||
| } while (0); | |||||
| // Close file | |||||
| if (mmClose(fd) != 0) { // mmClose 0: success | |||||
| GELOGE(FAILED, "[Close][File]Failed, error_code:%u errmsg:%s", ret, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Close file failed, error_code:%u errmsg:%s", | |||||
| ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas) { | |||||
| GE_CHK_BOOL_RET_STATUS(!partition_datas.empty() && model_partition_table.num != 0 | |||||
| && model_partition_table.num == partition_datas.size(), FAILED, | |||||
| "Invalid param:partition data size is (%u), model_partition_table.num is (%zu).", | |||||
| model_partition_table.num, partition_datas.size()); | |||||
| // Open file | |||||
| int32_t fd = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); | |||||
| Status ret = SUCCESS; | |||||
| do { | |||||
| // Write file header | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| // Write model partition table | |||||
| uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(model_partition_table)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break); | |||||
| // Write partition data | |||||
| for (const auto &partitionData : partition_datas) { | |||||
| GELOGI("GC:size[%u]", partitionData.size); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(partitionData.data), partitionData.size, fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| } | |||||
| } while (0); | |||||
| // Close file | |||||
| if (mmClose(fd) != EN_OK) { | |||||
| GELOGE(FAILED, "[Close][File]Failed, error_code:%u errmsg:%s", ret, strerror(errno)); | |||||
| REPORT_CALL_ERROR("E19999", "Close file failed, error_code:%u errmsg:%s", | |||||
| ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
| ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas, | |||||
| ge::ModelBufferData &model) { | |||||
| const vector<ModelPartitionTable *> model_partition_tables = { &model_partition_table }; | |||||
| const std::vector<std::vector<ModelPartition>> all_partition_datas = { partition_datas }; | |||||
| return SaveToBuffWithFileHeader(file_header, model_partition_tables, all_partition_datas, model); | |||||
| } | |||||
| Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
| const vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const std::vector<std::vector<ModelPartition>> &all_partition_datas, | |||||
| ge::ModelBufferData &model) { | |||||
| GE_CHK_BOOL_RET_STATUS(model_partition_tables.size() == all_partition_datas.size(), PARAM_INVALID, | |||||
| "Model table size %zu does not match partition size %zu.", | |||||
| model_partition_tables.size(), all_partition_datas.size()); | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| auto &cur_partiton_data = all_partition_datas[index]; | |||||
| auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
| GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0 | |||||
| && cur_model_partition_table.num == cur_partiton_data.size(), FAILED, | |||||
| "Invalid param: partition data size is (%zu), model_partition_table.num is (%u).", | |||||
| cur_partiton_data.size(), cur_model_partition_table.num); | |||||
| } | |||||
| uint64_t model_header_size = sizeof(ModelFileHeader); | |||||
| uint64_t total_size = model_header_size; | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
| total_size += static_cast<uint64_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_model_partition_table)); | |||||
| auto &cur_partition_data = all_partition_datas[index]; | |||||
| for (const auto &partition_data : cur_partition_data) { | |||||
| auto ret = ge::CheckUint64AddOverflow(total_size, partition_data.size); | |||||
| GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "Add uint64 overflow!"); | |||||
| total_size += partition_data.size; | |||||
| } | |||||
| } | |||||
| // save to buff | |||||
| auto buff = reinterpret_cast<uint8_t *>(malloc(total_size)); | |||||
| GE_CHK_BOOL_RET_STATUS(buff != nullptr, FAILED, "Malloc failed!"); | |||||
| GE_PRINT_DYNAMIC_MEMORY(malloc, "File buffer.", total_size) | |||||
| model.data.reset(buff, [](uint8_t *buff) { | |||||
| GELOGD("Free online model memory."); | |||||
| free(buff); | |||||
| buff = nullptr; | |||||
| }); | |||||
| model.length = total_size; | |||||
| uint64_t left_space = total_size; | |||||
| auto ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<ModelFileHeader *>(&file_header)), | |||||
| model_header_size); | |||||
| GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
| buff += model_header_size; | |||||
| left_space -= model_header_size; | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| auto &cur_tabel = *model_partition_tables[index]; | |||||
| uint64_t table_size = static_cast<uint64_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel)); | |||||
| ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(&cur_tabel), table_size); | |||||
| GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
| buff += table_size; | |||||
| left_space -= table_size; | |||||
| auto &cur_partition_data = all_partition_datas[index]; | |||||
| for (const auto &partition_data : cur_partition_data) { | |||||
| ret_mem = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<uint8_t *>(partition_data.data)), | |||||
| partition_data.size); | |||||
| GE_CHK_BOOL_RET_STATUS(ret_mem == EOK, FAILED, "Memcpy_s failed!"); | |||||
| buff += partition_data.size; | |||||
| left_space -= partition_data.size; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(const std::string &file_path) { | |||||
| // Determine file path length | |||||
| if (file_path.size() >= MMPA_MAX_PATH) { | |||||
| GELOGE(FAILED, "[Check][FilePath]Failed, file path's length:%zu > mmpa_max_path:%d", | |||||
| file_path.size(), MMPA_MAX_PATH); | |||||
| REPORT_INNER_ERROR("E19999", "Check file path failed, file path's length:%zu > " | |||||
| "mmpa_max_path:%d", file_path.size(), MMPA_MAX_PATH); | |||||
| return FAILED; | |||||
| } | |||||
| // Find the last separator | |||||
| int path_split_pos = static_cast<int>(file_path.size() - 1); | |||||
| for (; path_split_pos >= 0; path_split_pos--) { | |||||
| if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (path_split_pos == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // If there is a path before the file name, create the path | |||||
| if (path_split_pos != -1) { | |||||
| if (CreateDirectory(std::string(file_path).substr(0, static_cast<size_t>(path_split_pos))) != kFileOpSuccess) { | |||||
| GELOGE(FAILED, "[Create][Directory]Failed, file path:%s.", file_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||||
| FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, const ModelFileHeader *model_file_header) { | |||||
| if (file_path.empty() || model.model_data == nullptr || model.model_len == 0) { | |||||
| GELOGE(FAILED, "[Save][File]Incorrect input param, " | |||||
| "file_path is empty or model_data is nullptr or model_len is 0"); | |||||
| REPORT_INNER_ERROR("E19999", "Save file failed, at least one of the " | |||||
| "input parameters(file_path, model_data, model_len) is incorrect"); | |||||
| return FAILED; | |||||
| } | |||||
| ModelFileHeader file_header; | |||||
| int32_t copy_header_ret = 0; | |||||
| GE_IF_BOOL_EXEC(model_file_header != nullptr, copy_header_ret = memcpy_s(&file_header, sizeof(ModelFileHeader), | |||||
| model_file_header, sizeof(ModelFileHeader))); | |||||
| GE_CHK_BOOL_RET_STATUS(copy_header_ret == 0, FAILED, "Copy ModelFileHeader failed, memcpy_s return: %d", | |||||
| copy_header_ret); | |||||
| file_header.length = model.model_len; | |||||
| file_header.is_encrypt = ModelEncryptType::UNENCRYPTED; | |||||
| const Status ret = SaveWithFileHeader(file_path, file_header, model.model_data, file_header.length); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "[Save][File]Failed, file_path:%s, file_header_len:%u, error_code:%u.", | |||||
| file_path.c_str(), file_header.length, ret); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||||
| FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas) { | |||||
| file_header.is_encrypt = ModelEncryptType::UNENCRYPTED; | |||||
| const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_table, partition_datas); | |||||
| GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", | |||||
| file_path.c_str(), file_header.length); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||||
| FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas) { | |||||
| file_header.is_encrypt = ModelEncryptType::UNENCRYPTED; | |||||
| const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas); | |||||
| GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", | |||||
| file_path.c_str(), file_header.length); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas) { | |||||
| GE_CHK_BOOL_EXEC(model_partition_tables.size() == all_partition_datas.size(), | |||||
| return PARAM_INVALID, | |||||
| "model table size %zu does not match partition size %zu", | |||||
| model_partition_tables.size(), all_partition_datas.size()) | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| auto &cur_partiton_data = all_partition_datas[index]; | |||||
| auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
| GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0 | |||||
| && cur_model_partition_table.num == cur_partiton_data.size(), FAILED, | |||||
| "Invalid param:partition data size is (%u), model_partition_table.num is (%zu).", | |||||
| cur_model_partition_table.num, cur_partiton_data.size()); | |||||
| } | |||||
| // Open file | |||||
| int32_t fd = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); | |||||
| Status ret = SUCCESS; | |||||
| do { | |||||
| // Write file header | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| // Write model partition table | |||||
| auto &cur_tabel = *model_partition_tables[index]; | |||||
| uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&cur_tabel), table_size, fd) != SUCCESS, ret = FAILED; break); | |||||
| // Write partition data | |||||
| auto &cur_partition_datas = all_partition_datas[index]; | |||||
| for (const auto &partition_data : cur_partition_datas) { | |||||
| GELOGI("part_size[%u]", partition_data.size); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| } | |||||
| } | |||||
| } while (0); | |||||
| // Close file | |||||
| if (mmClose(fd) != 0) { // mmClose 0: success | |||||
| GELOGE(FAILED, "[Close][File]Failed, error_code:%u errmsg:%s", ret, strerror(errno)); | |||||
| REPORT_CALL_ERROR("E19999", "Close file failed, error_code:%u errmsg:%s", | |||||
| ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, | |||||
| int len) { | |||||
| if (data == nullptr || len <= 0) { | |||||
| GELOGE(FAILED, "[Check][Param]Failed, model_data is null or the " | |||||
| "length[%d] is less than 1.", len); | |||||
| REPORT_INNER_ERROR("E19999", "Save file failed, the model_data is null or " | |||||
| "its length:%d is less than 1.", len); | |||||
| return FAILED; | |||||
| } | |||||
| // Open file | |||||
| int32_t fd = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED, "OpenFile FAILED"); | |||||
| Status ret = SUCCESS; | |||||
| // write data | |||||
| GE_CHK_BOOL_EXEC(SUCCESS == WriteData(data, (uint32_t)len, fd), ret = FAILED, "WriteData FAILED"); | |||||
| // Close file | |||||
| if (mmClose(fd) != 0) { // mmClose 0: success | |||||
| GELOGE(FAILED, "[Close][File]Failed, error_code:%u errmsg:%s", ret, strerror(errno)); | |||||
| REPORT_CALL_ERROR("E19999", "Close file failed, error_code:%u errmsg:%s", | |||||
| ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,125 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_AUTH_FILE_SAVER_H_ | |||||
| #define GE_COMMON_AUTH_FILE_SAVER_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/common/helper/om_file_helper.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "external/ge/ge_ir_build.h" | |||||
| #include "graph/buffer.h" | |||||
| #include "mmpa/mmpa_api.h" | |||||
| struct PROC_PARAM { | |||||
| uint8_t *model_name; | |||||
| // ISV Ek buffer | |||||
| uint8_t *model_key; | |||||
| uint32_t model_key_len; | |||||
| // ISV root certificate buffer | |||||
| uint8_t *root_cert; | |||||
| uint32_t root_cert_len; | |||||
| // ISV private key buffer | |||||
| uint8_t *pri_key; | |||||
| uint32_t pri_key_len; | |||||
| // Raw AI Module Image buffer | |||||
| uint8_t *ai_image; | |||||
| uint32_t ai_image_len; | |||||
| // ISV HW key buffer | |||||
| uint8_t *hw_key; | |||||
| uint32_t hw_key_len; | |||||
| }; | |||||
| struct ProcOut { | |||||
| uint8_t *passcode; | |||||
| uint32_t passcode_len; | |||||
| uint8_t *encrypted_img; | |||||
| uint32_t encrypted_img_len; | |||||
| }; | |||||
| namespace ge { | |||||
| using std::string; | |||||
| class FileSaver { | |||||
| public: | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief save model, no encryption | |||||
| /// @return Status result | |||||
| /// | |||||
| static Status SaveToFile(const string &file_path, const ge::ModelData &model, | |||||
| const ModelFileHeader *model_file_header = nullptr); | |||||
| static Status SaveToFile(const string &file_path, ModelFileHeader &model_file_header, | |||||
| ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas); | |||||
| static Status SaveToFile(const string &file_path, ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas); | |||||
| static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
| ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas, | |||||
| ge::ModelBufferData& model); | |||||
| static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | |||||
| const std::vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const std::vector<std::vector<ModelPartition>> &all_partition_datas, | |||||
| ge::ModelBufferData &model); | |||||
| static Status SaveToFile(const string &file_path, const void *data, int len); | |||||
| protected: | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Check validity of the file path | |||||
| /// @return Status result | |||||
| /// | |||||
| static Status CheckPath(const string &file_path); | |||||
| static Status WriteData(const void *data, uint32_t size, int32_t fd); | |||||
| static Status OpenFile(int32_t &fd, const std::string &file_path); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief save model to file | |||||
| /// @param [in] file_path file output path | |||||
| /// @param [in] file_header file header info | |||||
| /// @param [in] data model data | |||||
| /// @param [in] len model length | |||||
| /// @return Status result | |||||
| /// | |||||
| static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | |||||
| int len); | |||||
| static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| ModelPartitionTable &model_partition_table, | |||||
| const std::vector<ModelPartition> &partition_datas); | |||||
| static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| std::vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const std::vector<std::vector<ModelPartition>> &all_partition_datas); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_AUTH_FILE_SAVER_H_ | |||||
| @@ -1,131 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_BASE64_H_ | |||||
| #define GE_COMMON_BASE64_H_ | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include "debug/ge_log.h" | |||||
| #include "ge_error_codes.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const char *kBase64Chars = | |||||
| "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |||||
| "abcdefghijklmnopqrstuvwxyz" | |||||
| "0123456789+/"; | |||||
| const char kEqualSymbol = '='; | |||||
| const size_t kBase64CharsNum = 64; | |||||
| const size_t kThreeByteOneGroup = 3; | |||||
| const size_t kFourByteOneGroup = 4; | |||||
| const size_t kThreeByteOneGroupIndex0 = 0; | |||||
| const size_t kThreeByteOneGroupIndex1 = 1; | |||||
| const size_t kThreeByteOneGroupIndex2 = 2; | |||||
| const size_t kFourByteOneGroupIndex0 = 0; | |||||
| const size_t kFourByteOneGroupIndex1 = 1; | |||||
| const size_t kFourByteOneGroupIndex2 = 2; | |||||
| const size_t kFourByteOneGroupIndex3 = 3; | |||||
| } // namespace | |||||
| namespace base64 { | |||||
| static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); } | |||||
| static std::string EncodeToBase64(const std::string &raw_data) { | |||||
| size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; | |||||
| encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; | |||||
| size_t raw_data_index = 0; | |||||
| size_t encode_data_index = 0; | |||||
| std::string encode_data; | |||||
| encode_data.resize(encode_length); | |||||
| for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { | |||||
| auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | |||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex1]); | |||||
| auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex2]); | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_3 & 0x3f]; | |||||
| } | |||||
| if (raw_data_index < raw_data.size()) { | |||||
| auto tail = raw_data.size() - raw_data_index; | |||||
| auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | |||||
| if (tail == 1) { | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[(char_1 << 4u) & 0x30]; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| } else { | |||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]); | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[(char_2 << 2u) & 0x3c]; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| } | |||||
| } | |||||
| return encode_data; | |||||
| } | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-function" | |||||
| static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { | |||||
| if (base64_data.size() % kFourByteOneGroup != 0) { | |||||
| GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| decode_data.clear(); | |||||
| size_t base64_data_len = base64_data.size(); | |||||
| uint8_t byte_4[kFourByteOneGroup]; | |||||
| auto FindCharInBase64Chars = [&](const char &raw_char) -> uint8_t { | |||||
| auto char_pos = std::find(kBase64Chars, kBase64Chars + kBase64CharsNum, raw_char); | |||||
| return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; | |||||
| }; | |||||
| for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += kFourByteOneGroup) { | |||||
| for (size_t i = 0; i < kFourByteOneGroup; ++i) { | |||||
| if (base64_data[input_data_index + i] == kEqualSymbol && | |||||
| input_data_index >= base64_data_len - kFourByteOneGroup && i > 1) { | |||||
| byte_4[i] = kBase64CharsNum; | |||||
| } else if (IsBase64Char(base64_data[input_data_index + i])) { | |||||
| byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "given base64 data is illegal"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| decode_data += | |||||
| static_cast<char>((byte_4[kFourByteOneGroupIndex0] << 2u) + ((byte_4[kFourByteOneGroupIndex1] & 0x30) >> 4u)); | |||||
| if (byte_4[kFourByteOneGroupIndex2] >= kBase64CharsNum) { | |||||
| break; | |||||
| } else if (byte_4[kFourByteOneGroupIndex3] >= kBase64CharsNum) { | |||||
| decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + | |||||
| ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); | |||||
| break; | |||||
| } | |||||
| decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + | |||||
| ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); | |||||
| decode_data += | |||||
| static_cast<char>(((byte_4[kFourByteOneGroupIndex2] & 0x03) << 6u) + byte_4[kFourByteOneGroupIndex3]); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| #pragma GCC diagnostic pop | |||||
| } // namespace base64 | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_BASE64_H_ | |||||
| @@ -1,25 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| using ge::OmgContext; | |||||
| namespace domi { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | |||||
| static OmgContext context; | |||||
| return context; | |||||
| } | |||||
| } // namespace domi | |||||
| @@ -1,39 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/cust_aicpu_kernel_store.h" | |||||
| namespace ge { | |||||
| CustAICPUKernelStore::CustAICPUKernelStore() {} | |||||
| void CustAICPUKernelStore::AddCustAICPUKernel(const CustAICPUKernelPtr &kernel) { | |||||
| AddKernel(kernel); | |||||
| } | |||||
| void CustAICPUKernelStore::LoadCustAICPUKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const { | |||||
| GELOGD("LoadCustAICPUKernelBinToOpDesc in."); | |||||
| if (op_desc != nullptr) { | |||||
| auto kernel_bin = FindKernel(op_desc->GetName()); | |||||
| if (kernel_bin != nullptr) { | |||||
| GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(ge::OP_EXTATTR_CUSTAICPU_KERNEL, kernel_bin), | |||||
| GELOGW("LoadKernelCustAICPUBinToOpDesc: SetExtAttr for kernel_bin failed");) | |||||
| GELOGI("Load cust aicpu kernel:%s, %zu", kernel_bin->GetName().c_str(), kernel_bin->GetBinDataSize()); | |||||
| } | |||||
| } | |||||
| GELOGD("LoadCustAICPUKernelBinToOpDesc success."); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| #define GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| #include "common/kernel_store.h" | |||||
| namespace ge { | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY CustAICPUKernelStore : public KernelStore { | |||||
| public: | |||||
| CustAICPUKernelStore(); | |||||
| ~CustAICPUKernelStore() {} | |||||
| void AddCustAICPUKernel(const CustAICPUKernelPtr &kernel); | |||||
| void LoadCustAICPUKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| @@ -1,175 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/debug/memory_dumper.h" | |||||
| #include <string> | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| using std::string; | |||||
| namespace { | |||||
| const int kInvalidFd = (-1); | |||||
| } // namespace | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::~MemoryDumper() { Close(); } | |||||
| // Dump the data to the file | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile(const char *filename, void *data, | |||||
| int64_t len) { | |||||
| #ifdef FMK_SUPPORT_DUMP | |||||
| GE_CHECK_NOTNULL(filename); | |||||
| GE_CHECK_NOTNULL(data); | |||||
| if (len == 0) { | |||||
| GELOGE(FAILED, "[Check][Param]Failed, data length is 0."); | |||||
| REPORT_INNER_ERROR("E19999", "Check param failed, data length is 0."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| // Open the file | |||||
| int fd = OpenFile(filename); | |||||
| if (fd == kInvalidFd) { | |||||
| GELOGE(FAILED, "[Open][File]Failed, filename:%s.", filename); | |||||
| REPORT_INNER_ERROR("E19999", "Opne file failed, filename:%s.", filename); | |||||
| return FAILED; | |||||
| } | |||||
| // Write the data to the file | |||||
| Status ret = SUCCESS; | |||||
| int32_t mmpa_ret = mmWrite(fd, data, len); | |||||
| // mmWrite return -1:Failed to write data to file;return -2:Invalid parameter | |||||
| if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | |||||
| GELOGE(FAILED, "[Write][Data]Failed, errno:%d, errmsg:%s", mmpa_ret, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Write data failed, errno:%d, errmsg:%s.", | |||||
| mmpa_ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| // Close the file | |||||
| if (mmClose(fd) != EN_OK) { // mmClose return 0: success | |||||
| GELOGE(FAILED, "[Close][File]Failed, error_code:%u, filename:%s errmsg:%s.", ret, filename, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Close file failed, error_code:%u, filename:%s errmsg:%s.", | |||||
| ret, filename, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| #else | |||||
| GELOGW("need to define FMK_SUPPORT_DUMP for dump op input and output."); | |||||
| return SUCCESS; | |||||
| #endif | |||||
| } | |||||
| // Open file | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Open(const char *filename) { | |||||
| GE_CHK_BOOL_RET_STATUS(filename != nullptr, FAILED, "Incorrect parameter. filename is nullptr"); | |||||
| // Try to remove file first for reduce the close time by overwriting way | |||||
| // (The process of file closing will be about 100~200ms slower per file when written by overwriting way) | |||||
| // If remove file failed, then try to open it with overwriting way | |||||
| int ret = remove(filename); | |||||
| // If remove file failed, print the warning log | |||||
| if (ret != 0) { | |||||
| GELOGW("Remove file failed."); | |||||
| } | |||||
| fd_ = OpenFile(filename); | |||||
| if (fd_ == kInvalidFd) { | |||||
| GELOGE(FAILED, "[Open][File]Failed, filename:%s.", filename); | |||||
| REPORT_INNER_ERROR("E19999", "Open file:%s failed.", filename); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Dump the data to file | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Dump(void *data, uint32_t len) const { | |||||
| GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "Incorrect parameter. data is nullptr"); | |||||
| #ifdef FMK_SUPPORT_DUMP | |||||
| int32_t mmpa_ret = mmWrite(fd_, data, len); | |||||
| // mmWrite return -1:failed to write data to file;return -2:invalid parameter | |||||
| if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | |||||
| GELOGE(FAILED, "[Write][Data]Failed, errno:%d, errmsg:%s", mmpa_ret, strerror(errno)); | |||||
| REPORT_INNER_ERROR("E19999", "Write data to file failed, errno:%d, errmsg:%s.", | |||||
| mmpa_ret, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| #else | |||||
| GELOGW("need to define FMK_SUPPORT_DUMP for dump op input and output."); | |||||
| return SUCCESS; | |||||
| #endif | |||||
| } | |||||
| // Close file | |||||
| void MemoryDumper::Close() noexcept { | |||||
| // Close file | |||||
| if (fd_ != kInvalidFd && mmClose(fd_) != EN_OK) { | |||||
| GELOGW("Close file failed, errmsg:%s.", strerror(errno)); | |||||
| } | |||||
| fd_ = kInvalidFd; | |||||
| } | |||||
| // Open file | |||||
| int MemoryDumper::OpenFile(const char *filename) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(filename == nullptr, return kInvalidFd, "Incorrect parameter. filename is nullptr"); | |||||
| // Find the last separator | |||||
| int path_split_pos = static_cast<int>(strlen(filename) - 1); | |||||
| for (; path_split_pos >= 0; path_split_pos--) { | |||||
| GE_IF_BOOL_EXEC(filename[path_split_pos] == '\\' || filename[path_split_pos] == '/', break;) | |||||
| } | |||||
| // Get the absolute path | |||||
| string real_path; | |||||
| char tmp_path[MMPA_MAX_PATH] = {0}; | |||||
| GE_IF_BOOL_EXEC( | |||||
| -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | |||||
| string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, | |||||
| return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, | |||||
| "Dir %s does not exit, errmsg:%s.", prefix_path.c_str(), strerror(errno)); | |||||
| real_path = std::string(tmp_path) + last_path;) | |||||
| GE_IF_BOOL_EXEC( | |||||
| path_split_pos == -1 || path_split_pos == 0, | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(filename) >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_IF_BOOL_EXEC(mmRealPath(filename, tmp_path, MMPA_MAX_PATH) != EN_OK, | |||||
| GELOGI("File %s does not exit, it will be created.", filename)); | |||||
| real_path = std::string(tmp_path);) | |||||
| // Open file, only the current user can read and write, to avoid malicious application access | |||||
| // Using the O_EXCL, if the file already exists,return failed to avoid privilege escalation vulnerability. | |||||
| mmMode_t mode = M_IRUSR | M_IWUSR; | |||||
| int32_t fd = mmOpen2(real_path.c_str(), M_RDWR | M_CREAT | M_APPEND, mode); | |||||
| if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | |||||
| GELOGE(kInvalidFd, "[Open][File]Failed. errno:%d, errmsg:%s, filename:%s.", | |||||
| fd, strerror(errno), filename); | |||||
| return kInvalidFd; | |||||
| } | |||||
| return fd; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,90 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DEBUG_MEMORY_DUMPER_H_ | |||||
| #define GE_COMMON_DEBUG_MEMORY_DUMPER_H_ | |||||
| #include <stdint.h> | |||||
| #include "framework/common/types.h" | |||||
| #include "mmpa/mmpa_api.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| namespace ge { | |||||
| // MemoryDumper:dump memory data for internal test | |||||
| // Output in one time: using DumpToFile | |||||
| // Open file at one time and output multiple times: create MemoryDumper object first, and using Open/Dump/Close | |||||
| class MemoryDumper { | |||||
| public: | |||||
| MemoryDumper(); | |||||
| ~MemoryDumper(); | |||||
| // Assignment/copy is not allowed to avoid repeated release | |||||
| MemoryDumper &operator=(const MemoryDumper &dumper) = delete; | |||||
| MemoryDumper(const MemoryDumper &dumper) = delete; | |||||
| /** @ingroup domi_common | |||||
| * @brief write memory data to file, if the filename is not exist, create it first | |||||
| * @param [in] filename the output file path, specific to filename | |||||
| * @param [in] data the memory data | |||||
| * @param [in] len length of data | |||||
| * @return SUCCESS output success | |||||
| * @return FAILED output failed | |||||
| * @author | |||||
| */ | |||||
| static Status DumpToFile(const char *filename, void *data, int64_t len); | |||||
| /** @ingroup domi_common | |||||
| * @brief open the dump file | |||||
| * @param [in] filename the output file path, specific to filename | |||||
| * @return SUCCESS open file success | |||||
| * @return FAILED open file failed | |||||
| * @author | |||||
| */ | |||||
| Status Open(const char *filename); | |||||
| /** @ingroup domi_common | |||||
| * @brief write the Memory data to file | |||||
| * @param [in] data the memory data | |||||
| * @param [in] len length of data | |||||
| * @return SUCCESS success | |||||
| * @return FAILED failed | |||||
| * @author | |||||
| */ | |||||
| Status Dump(void *data, uint32_t len) const; | |||||
| /** @ingroup domi_common | |||||
| * @brief close the Dump file | |||||
| * @return SUCCESS success | |||||
| * @return FAILED failed | |||||
| * @author | |||||
| */ | |||||
| void Close() noexcept; | |||||
| private: | |||||
| /** @ingroup domi_common | |||||
| * @brief open the dump file | |||||
| * @param [in] filename the output file path, specific to filename | |||||
| * @return int the file handle after file open, -1 means open file failed | |||||
| * @author | |||||
| */ | |||||
| static int OpenFile(const char *filename); | |||||
| int fd_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_DEBUG_MEMORY_DUMPER_H_ | |||||
| @@ -1,159 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/dump/dump_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace { | |||||
| const char *const kDumpOFF = "OFF"; | |||||
| const char *const kDumpoff = "off"; | |||||
| const char *const kDumpOn = "on"; | |||||
| const uint64_t kInferSessionId = 0; | |||||
| const uint32_t kAllOverflow = 3; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetInstance() { | |||||
| static DumpManager instance; | |||||
| return instance; | |||||
| } | |||||
| bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||||
| if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { | |||||
| dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||||
| GELOGI("Dump does not open"); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Dump status is %s, dump debug is %s.", dump_config.dump_status.c_str(), dump_config.dump_debug.c_str()); | |||||
| if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && | |||||
| dump_config.dump_debug == kDumpoff) { | |||||
| dump_properties.ClearDumpPropertyValue(); | |||||
| dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||||
| return false; | |||||
| } | |||||
| if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { | |||||
| GELOGW("Not support coexistence of dump debug and dump status."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void DumpManager::SetDumpDebugConf(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||||
| if (dump_config.dump_debug == kDumpOn) { | |||||
| GELOGI("Only do overflow detection, dump debug is %s.", dump_config.dump_debug.c_str()); | |||||
| dump_properties.InitInferOpDebug(); | |||||
| dump_properties.SetOpDebugMode(kAllOverflow); | |||||
| } | |||||
| } | |||||
| void DumpManager::SetDumpList(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||||
| for (const auto &model_dump : dump_config.dump_list) { | |||||
| std::string model_name = model_dump.model_name; | |||||
| GELOGI("Dump model is %s", model_name.c_str()); | |||||
| std::set<std::string> dump_layers; | |||||
| for (const auto &layer : model_dump.layers) { | |||||
| GELOGI("Dump layer is %s in model", layer.c_str()); | |||||
| dump_layers.insert(layer); | |||||
| } | |||||
| dump_properties.AddPropertyValue(model_name, dump_layers); | |||||
| } | |||||
| } | |||||
| Status DumpManager::SetNormalDumpConf(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||||
| if (dump_config.dump_status == kDumpOn) { | |||||
| GELOGI("Only do normal dump process, dump status is %s.", dump_config.dump_status.c_str()); | |||||
| dump_properties.SetDumpStatus(dump_config.dump_status); | |||||
| std::string dump_op_switch = dump_config.dump_op_switch; | |||||
| dump_properties.SetDumpOpSwitch(dump_op_switch); | |||||
| if (dump_op_switch == kDumpoff && dump_config.dump_list.empty()) { | |||||
| dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||||
| GELOGE(PARAM_INVALID, "[Check][DumpList]Invalid, dump_op_switch is %s", dump_op_switch.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Dump list check invalid, dump_op_switch is %s", dump_op_switch.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (!dump_config.dump_list.empty()) { | |||||
| if (dump_op_switch == kDumpOn) { | |||||
| GELOGI("Start to dump model and single op, dump op switch is %s", dump_op_switch.c_str()); | |||||
| } else { | |||||
| GELOGI("Only dump model, dump op switch is %s", dump_op_switch.c_str()); | |||||
| } | |||||
| SetDumpList(dump_config, dump_properties); | |||||
| } else { | |||||
| GELOGI("Only dump single op, dump op switch is %s", dump_op_switch.c_str()); | |||||
| } | |||||
| GELOGI("Dump mode is %s", dump_config.dump_mode.c_str()); | |||||
| dump_properties.SetDumpMode(dump_config.dump_mode); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DumpManager::SetDumpPath(const DumpConfig &dump_config, DumpProperties &dump_properties) { | |||||
| std::string dump_path = dump_config.dump_path; | |||||
| if (dump_path.empty()) { | |||||
| GELOGE(PARAM_INVALID, "[Check][DumpPath]It is empty"); | |||||
| REPORT_INNER_ERROR("E19999", "Dump path check is empty"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (dump_path[dump_path.size() - 1] != '/') { | |||||
| dump_path = dump_path + "/"; | |||||
| } | |||||
| dump_path = dump_path + CurrentTimeInStr() + "/"; | |||||
| GELOGI("Dump path is %s", dump_path.c_str()); | |||||
| dump_properties.SetDumpPath(dump_path); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { | |||||
| DumpProperties dump_properties; | |||||
| if (!NeedDoDump(dump_config, dump_properties)) { | |||||
| GELOGD("No need do dump process."); | |||||
| return SUCCESS; | |||||
| } | |||||
| SetDumpDebugConf(dump_config, dump_properties); | |||||
| GE_CHK_STATUS_RET(SetNormalDumpConf(dump_config, dump_properties), "[Init][DumpConf] failed when dump status is on."); | |||||
| GE_CHK_STATUS_RET(SetDumpPath(dump_config, dump_properties), "[Init][DumpPath] failed."); | |||||
| dump_properties_map_[kInferSessionId] = dump_properties; | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManager::GetDumpProperties( | |||||
| uint64_t session_id) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| auto iter = dump_properties_map_.find(session_id); | |||||
| if (iter != dump_properties_map_.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| static DumpProperties default_properties; | |||||
| return default_properties; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::AddDumpProperties( | |||||
| uint64_t session_id, const DumpProperties &dump_properties) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| dump_properties_map_.emplace(session_id, dump_properties); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::RemoveDumpProperties(uint64_t session_id) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| auto iter = dump_properties_map_.find(session_id); | |||||
| if (iter != dump_properties_map_.end()) { | |||||
| dump_properties_map_.erase(iter); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DUMP_DUMP_MANAGER_H_ | |||||
| #define GE_COMMON_DUMP_DUMP_MANAGER_H_ | |||||
| #include <mutex> | |||||
| #include "common/dump/dump_properties.h" | |||||
| #include "common/ge_types.h" | |||||
| namespace ge { | |||||
| class DumpManager { | |||||
| public: | |||||
| static DumpManager &GetInstance(); | |||||
| Status SetDumpConf(const DumpConfig &dump_config); | |||||
| const DumpProperties &GetDumpProperties(uint64_t session_id); | |||||
| const std::map<uint64_t, DumpProperties> &GetDumpPropertiesMap() { return dump_properties_map_; } | |||||
| void AddDumpProperties(uint64_t session_id, const DumpProperties &dump_properties); | |||||
| void RemoveDumpProperties(uint64_t session_id); | |||||
| private: | |||||
| bool NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties); | |||||
| void SetDumpDebugConf(const DumpConfig &dump_config, DumpProperties &dump_properties); | |||||
| Status SetDumpPath(const DumpConfig &dump_config, DumpProperties &dump_properties); | |||||
| Status SetNormalDumpConf(const DumpConfig &dump_config, DumpProperties &dump_properties); | |||||
| void SetDumpList(const DumpConfig &dump_config, DumpProperties &dump_properties); | |||||
| std::mutex mutex_; | |||||
| std::map<uint64_t, DumpProperties> dump_properties_map_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_DUMP_DUMP_MANAGER_H_ | |||||
| @@ -1,332 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/dump/dump_op.h" | |||||
| #include "common/dump/dump_manager.h" | |||||
| #include "common/ge/datatype_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "proto/ge_ir.pb.h" | |||||
| #include "proto/op_mapping.pb.h" | |||||
| #include "runtime/mem.h" | |||||
| #include "aicpu/common/aicpu_task_struct.h" | |||||
| namespace { | |||||
| const uint32_t kAicpuLoadFlag = 1; | |||||
| const char *const kDumpOutput = "output"; | |||||
| const char *const kDumpInput = "input"; | |||||
| const char *const kDumpAll = "all"; | |||||
| const char *const kDumpKernelsDumpOp = "DumpDataInfo"; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| DumpOp::~DumpOp() { | |||||
| if (proto_dev_mem_ != nullptr) { | |||||
| (void)rtFree(proto_dev_mem_); | |||||
| } | |||||
| if (proto_size_dev_mem_ != nullptr) { | |||||
| (void)rtFree(proto_size_dev_mem_); | |||||
| } | |||||
| proto_dev_mem_ = nullptr; | |||||
| proto_size_dev_mem_ = nullptr; | |||||
| } | |||||
| void DumpOp::SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond) { | |||||
| global_step_ = reinterpret_cast<uintptr_t>(global_step); | |||||
| loop_per_iter_ = reinterpret_cast<uintptr_t>(loop_per_iter); | |||||
| loop_cond_ = reinterpret_cast<uintptr_t>(loop_cond); | |||||
| } | |||||
| void DumpOp::SetDynamicModelInfo(const string &dynamic_model_name, const string &dynamic_om_name, | |||||
| uint32_t dynamic_model_id) { | |||||
| dynamic_model_name_ = dynamic_model_name; | |||||
| dynamic_om_name_ = dynamic_om_name; | |||||
| dynamic_model_id_ = dynamic_model_id; | |||||
| } | |||||
| static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uintptr_t loop_cond, | |||||
| toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | |||||
| if (step_id != 0) { | |||||
| GELOGI("step_id exists."); | |||||
| op_mapping_info.set_step_id_addr(static_cast<uint64_t>(step_id)); | |||||
| } else { | |||||
| GELOGI("step_id is null."); | |||||
| } | |||||
| if (loop_per_iter != 0) { | |||||
| GELOGI("loop_per_iter exists."); | |||||
| op_mapping_info.set_iterations_per_loop_addr(static_cast<uint64_t>(loop_per_iter)); | |||||
| } else { | |||||
| GELOGI("loop_per_iter is null."); | |||||
| } | |||||
| if (loop_cond != 0) { | |||||
| GELOGI("loop_cond exists."); | |||||
| op_mapping_info.set_loop_cond_addr(static_cast<uint64_t>(loop_cond)); | |||||
| } else { | |||||
| GELOGI("loop_cond is null."); | |||||
| } | |||||
| } | |||||
| Status DumpOp::DumpOutput(toolkit::aicpu::dump::Task &task) { | |||||
| GELOGI("Start dump output in Launch dump op"); | |||||
| const auto &output_descs = op_desc_->GetAllOutputsDesc(); | |||||
| for (size_t i = 0; i < output_descs.size(); ++i) { | |||||
| toolkit::aicpu::dump::Output output; | |||||
| output.set_data_type(static_cast<int32_t>(DataTypeUtil::GetIrDataType(output_descs.at(i).GetDataType()))); | |||||
| output.set_format(static_cast<int32_t>(output_descs.at(i).GetFormat())); | |||||
| for (auto dim : output_descs.at(i).GetShape().GetDims()) { | |||||
| output.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| for (auto dim : output_descs.at(i).GetOriginShape().GetDims()) { | |||||
| output.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t output_size = 0; | |||||
| if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][TensorSize]Failed, output %zu, node %s(%s),", | |||||
| i, op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Get output %zu tensor size of node %s(%s) failed", | |||||
| i, op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Get output size in lanch dump op is %ld", output_size); | |||||
| output.set_size(output_size); | |||||
| output.set_address(static_cast<uint64_t>(output_addrs_[i])); | |||||
| task.mutable_output()->Add(std::move(output)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DumpOp::DumpInput(toolkit::aicpu::dump::Task &task) { | |||||
| GELOGI("Start dump input in Launch dump op"); | |||||
| const auto &input_descs = op_desc_->GetAllInputsDesc(); | |||||
| for (size_t i = 0; i < input_descs.size(); ++i) { | |||||
| toolkit::aicpu::dump::Input input; | |||||
| input.set_data_type(static_cast<int32_t>(DataTypeUtil::GetIrDataType(input_descs.at(i).GetDataType()))); | |||||
| input.set_format(static_cast<int32_t>(input_descs.at(i).GetFormat())); | |||||
| for (auto dim : input_descs.at(i).GetShape().GetDims()) { | |||||
| input.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| for (auto dim : input_descs.at(i).GetOriginShape().GetDims()) { | |||||
| input.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t input_size = 0; | |||||
| if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][TensorSize]Failed, input %zu, node %s(%s)", | |||||
| i, op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Get input %zu tensor size of node %s(%s) failed", | |||||
| i, op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Get input size in lanch dump op is %ld", input_size); | |||||
| input.set_size(input_size); | |||||
| input.set_address(static_cast<uint64_t>(input_addrs_[i])); | |||||
| task.mutable_input()->Add(std::move(input)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void DumpOp::SetDumpInfo(const DumpProperties &dump_properties, const OpDescPtr &op_desc, vector<uintptr_t> input_addrs, | |||||
| vector<uintptr_t> output_addrs, rtStream_t stream) { | |||||
| dump_properties_ = dump_properties; | |||||
| op_desc_ = op_desc; | |||||
| input_addrs_ = input_addrs; | |||||
| output_addrs_ = output_addrs; | |||||
| stream_ = stream; | |||||
| } | |||||
| Status DumpOp::ExecutorDumpOp(toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | |||||
| std::string proto_msg; | |||||
| size_t proto_size = op_mapping_info.ByteSizeLong(); | |||||
| bool ret = op_mapping_info.SerializeToString(&proto_msg); | |||||
| if (!ret || proto_size == 0) { | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Serialize][Protobuf]Failed, proto_size is %zu", | |||||
| proto_size); | |||||
| REPORT_CALL_ERROR("E19999", "[Serialize][Protobuf]Failed, proto_size is %zu", proto_size); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| rtError_t rt_ret = rtMalloc(&proto_dev_mem_, proto_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtMalloc]Failed, ret: 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| rt_ret = rtMemcpy(proto_dev_mem_, proto_size, proto_msg.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtMemcpy]Failed, ret: 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| rt_ret = rtMalloc(&proto_size_dev_mem_, sizeof(size_t), RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtMalloc]Failed, ret: 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| rt_ret = rtMemcpy(proto_size_dev_mem_, sizeof(size_t), &proto_size, sizeof(size_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtMemcpy]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| constexpr int32_t io_addr_num = 2; | |||||
| constexpr uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addr_num * sizeof(uint64_t); | |||||
| char args[args_size] = {0}; | |||||
| auto param_head = reinterpret_cast<aicpu::AicpuParamHead *>(args); | |||||
| param_head->length = args_size; | |||||
| param_head->ioAddrNum = io_addr_num; | |||||
| auto io_addr = reinterpret_cast<uint64_t *>(args + sizeof(aicpu::AicpuParamHead)); | |||||
| io_addr[0] = reinterpret_cast<uintptr_t>(proto_dev_mem_); | |||||
| io_addr[1] = reinterpret_cast<uintptr_t>(proto_size_dev_mem_); | |||||
| rt_ret = rtCpuKernelLaunch(nullptr, kDumpKernelsDumpOp, | |||||
| 1, // blockDim default 1 | |||||
| args, args_size, | |||||
| nullptr, // no need smDesc | |||||
| stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtCpuKernelLaunch]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtCpuKernelLaunch failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| GELOGI("Kernel launch dump op success"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DumpOp::SetDumpModelName(toolkit::aicpu::dump::OpMappingInfo &op_mapping_info) { | |||||
| if (dynamic_model_name_.empty() && dynamic_om_name_.empty()) { | |||||
| GELOGI("Single op dump, no need set model name"); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::set<std::string> model_list = dump_properties_.GetAllDumpModel(); | |||||
| bool not_find_by_omname = model_list.find(dynamic_om_name_) == model_list.end(); | |||||
| bool not_find_by_modelname = model_list.find(dynamic_model_name_) == model_list.end(); | |||||
| std::string dump_model_name = not_find_by_omname ? dynamic_model_name_ : dynamic_om_name_; | |||||
| if (model_list.find(DUMP_ALL_MODEL) == model_list.end()) { | |||||
| if (not_find_by_omname && not_find_by_modelname) { | |||||
| std::string model_list_str; | |||||
| for (auto &model : model_list) { | |||||
| model_list_str += "[" + model + "]."; | |||||
| } | |||||
| GELOGW("Model %s will not be set to dump, dump list: %s", dump_model_name.c_str(), model_list_str.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| if (!dump_model_name.empty() && dump_properties_.IsDumpOpen()) { | |||||
| GELOGI("Dump model name is %s", dump_model_name.c_str()); | |||||
| op_mapping_info.set_model_name(dump_model_name); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DumpOp::LaunchDumpOp() { | |||||
| GELOGI("Start to launch dump op %s", op_desc_->GetName().c_str()); | |||||
| int32_t device_id = 0; | |||||
| rtError_t rt_ret = rtGetDevice(&device_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "[Call][rtGetDevice]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "[Call][rtGetDevice]Failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| if (device_id < 0) { | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][DeviceId]Failed, device_id %d", device_id); | |||||
| REPORT_INNER_ERROR("E19999","Check device_id %d failed", device_id); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| toolkit::aicpu::dump::OpMappingInfo op_mapping_info; | |||||
| auto dump_path = dump_properties_.GetDumpPath() + std::to_string(device_id) + "/"; | |||||
| op_mapping_info.set_dump_path(dump_path); | |||||
| op_mapping_info.set_flag(kAicpuLoadFlag); | |||||
| op_mapping_info.set_dump_step(dump_properties_.GetDumpStep()); | |||||
| op_mapping_info.set_model_id(dynamic_model_id_); | |||||
| if (SetDumpModelName(op_mapping_info) != SUCCESS) { | |||||
| return SUCCESS; | |||||
| } | |||||
| SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); | |||||
| GELOGI("Dump step is %s ,dump path is %s in Launch dump op", dump_properties_.GetDumpStep().c_str(), | |||||
| dump_path.c_str()); | |||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("call rtGetTaskIdAndStreamID failed, ret = 0x%X", rt_ret); | |||||
| } | |||||
| toolkit::aicpu::dump::Task task; | |||||
| task.set_task_id(task_id); | |||||
| task.set_stream_id(stream_id); | |||||
| task.mutable_op()->set_op_name(op_desc_->GetName()); | |||||
| task.mutable_op()->set_op_type(op_desc_->GetType()); | |||||
| if (dump_properties_.GetDumpMode() == kDumpOutput) { | |||||
| auto ret = DumpOutput(task); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Dump][Output]Failed, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| REPORT_CALL_ERROR("E19999", "Dump Output failed, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| op_mapping_info.mutable_task()->Add(std::move(task)); | |||||
| } | |||||
| if (dump_properties_.GetDumpMode() == kDumpInput) { | |||||
| auto ret = DumpInput(task); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Dump][Input]Failed, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| REPORT_CALL_ERROR("E19999", "Dump Input failed, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| op_mapping_info.mutable_task()->Add(std::move(task)); | |||||
| } | |||||
| if (dump_properties_.GetDumpMode() == kDumpAll || dump_properties_.IsOpDebugOpen()) { | |||||
| auto ret = DumpOutput(task); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Dump][Output]Failed when in dumping all, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| REPORT_CALL_ERROR("E19999", "Dump Output failed when in dumping all, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| ret = DumpInput(task); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Dump][Input]Failed when in dumping all, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| REPORT_CALL_ERROR("E19999", "Dump Input failed when in dumping all, node %s(%s), ret 0x%X", | |||||
| op_desc_->GetName().c_str(), op_desc_->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| op_mapping_info.mutable_task()->Add(std::move(task)); | |||||
| } | |||||
| auto ret = ExecutorDumpOp(op_mapping_info); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Dump][Op]Failed, ret 0x%X", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Executor dump op failed, ret 0x%X", ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,63 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DUMP_DUMP_OP_H_ | |||||
| #define GE_COMMON_DUMP_DUMP_OP_H_ | |||||
| #include <string> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/properties_manager.h" | |||||
| #include "proto/op_mapping.pb.h" | |||||
| #include "runtime/stream.h" | |||||
| namespace ge { | |||||
| class DumpOp { | |||||
| public: | |||||
| DumpOp() = default; | |||||
| ~DumpOp(); | |||||
| void SetDumpInfo(const DumpProperties &dump_properties, const OpDescPtr &op_desc, vector<uintptr_t> input_addrs, | |||||
| vector<uintptr_t> output_addrs, rtStream_t stream); | |||||
| Status LaunchDumpOp(); | |||||
| void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); | |||||
| void SetDynamicModelInfo(const string &dynamic_model_name, const string &dynamic_om_name, uint32_t dynamic_model_id); | |||||
| private: | |||||
| Status ExecutorDumpOp(toolkit::aicpu::dump::OpMappingInfo &op_mapping_info); | |||||
| Status DumpOutput(toolkit::aicpu::dump::Task &task); | |||||
| Status DumpInput(toolkit::aicpu::dump::Task &task); | |||||
| Status SetDumpModelName(toolkit::aicpu::dump::OpMappingInfo &op_mapping_info); | |||||
| DumpProperties dump_properties_; | |||||
| OpDescPtr op_desc_; | |||||
| std::vector<uintptr_t> input_addrs_; | |||||
| std::vector<uintptr_t> output_addrs_; | |||||
| void *proto_dev_mem_ = nullptr; | |||||
| void *proto_size_dev_mem_ = nullptr; | |||||
| rtStream_t stream_; | |||||
| uintptr_t global_step_; | |||||
| uintptr_t loop_per_iter_; | |||||
| uintptr_t loop_cond_; | |||||
| std::string dynamic_model_name_; | |||||
| std::string dynamic_om_name_; | |||||
| std::uint32_t dynamic_model_id_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_DUMP_DUMP_OP_H_ | |||||
| @@ -1,285 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/dump/dump_properties.h" | |||||
| #include <cstdio> | |||||
| #include <string> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/ge_context.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| namespace { | |||||
| const std::string kEnableFlag = "1"; | |||||
| const std::string kDumpStatusOpen = "on"; | |||||
| const uint32_t kAicoreOverflow = (0x1 << 0); | |||||
| const uint32_t kAtomicOverflow = (0x1 << 1); | |||||
| const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); | |||||
| } // namespace | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { | |||||
| CopyFrom(other); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( | |||||
| const DumpProperties &other) { | |||||
| CopyFrom(other); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { | |||||
| enable_dump_.clear(); | |||||
| enable_dump_debug_.clear(); | |||||
| dump_path_.clear(); | |||||
| dump_step_.clear(); | |||||
| dump_mode_.clear(); | |||||
| is_train_op_debug_ = false; | |||||
| is_infer_op_debug_ = false; | |||||
| op_debug_mode_ = 0; | |||||
| std::string enable_dump; | |||||
| (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); | |||||
| enable_dump_ = enable_dump; | |||||
| std::string enable_dump_debug; | |||||
| (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); | |||||
| enable_dump_debug_ = enable_dump_debug; | |||||
| if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { | |||||
| std::string dump_path; | |||||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { | |||||
| if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { | |||||
| dump_path = dump_path + "/"; | |||||
| } | |||||
| dump_path = dump_path + CurrentTimeInStr() + "/"; | |||||
| GELOGI("Get dump path %s successfully", dump_path.c_str()); | |||||
| SetDumpPath(dump_path); | |||||
| } else { | |||||
| GELOGW("Dump path is not set"); | |||||
| } | |||||
| } | |||||
| if (enable_dump_ == kEnableFlag) { | |||||
| std::string dump_step; | |||||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { | |||||
| GELOGI("Get dump step %s successfully", dump_step.c_str()); | |||||
| SetDumpStep(dump_step); | |||||
| } | |||||
| string dump_mode; | |||||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { | |||||
| GELOGI("Get dump mode %s successfully", dump_mode.c_str()); | |||||
| SetDumpMode(dump_mode); | |||||
| } | |||||
| AddPropertyValue(DUMP_ALL_MODEL, {}); | |||||
| } | |||||
| SetDumpDebugOptions(); | |||||
| } | |||||
| // The following is the new dump scenario of the fusion operator | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( | |||||
| const std::string &model, const std::set<std::string> &layers) { | |||||
| for (const std::string &layer : layers) { | |||||
| GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); | |||||
| } | |||||
| model_dump_properties_map_[model] = layers; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { | |||||
| auto iter = model_dump_properties_map_.find(model); | |||||
| if (iter != model_dump_properties_map_.end()) { | |||||
| model_dump_properties_map_.erase(iter); | |||||
| } | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpPropertyValue() { | |||||
| model_dump_properties_map_.clear(); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpInfo() { | |||||
| enable_dump_.clear(); | |||||
| enable_dump_debug_.clear(); | |||||
| dump_path_.clear(); | |||||
| dump_step_.clear(); | |||||
| dump_mode_.clear(); | |||||
| dump_op_switch_.clear(); | |||||
| dump_status_.clear(); | |||||
| is_train_op_debug_ = false; | |||||
| is_infer_op_debug_ = false; | |||||
| op_debug_mode_ = 0; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetAllDumpModel() const { | |||||
| std::set<std::string> model_list; | |||||
| for (auto &iter : model_dump_properties_map_) { | |||||
| model_list.insert(iter.first); | |||||
| } | |||||
| return model_list; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetPropertyValue( | |||||
| const std::string &model) const { | |||||
| auto iter = model_dump_properties_map_.find(model); | |||||
| if (iter != model_dump_properties_map_.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( | |||||
| const std::string &model, const std::string &om_name, const std::string &op_name) const { | |||||
| // if dump all | |||||
| GELOGD("model name is %s om name is %s op is %s in layer need dump", model.c_str(), om_name.c_str(), op_name.c_str()); | |||||
| if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { | |||||
| return true; | |||||
| } | |||||
| // if this model need dump | |||||
| auto om_name_iter = model_dump_properties_map_.find(om_name); | |||||
| auto model_name_iter = model_dump_properties_map_.find(model); | |||||
| if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { | |||||
| // if no dump layer info, dump all layer in this model | |||||
| auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; | |||||
| if (model_iter->second.empty()) { | |||||
| return true; | |||||
| } | |||||
| return model_iter->second.find(op_name) != model_iter->second.end(); | |||||
| } | |||||
| GELOGD("Model %s is not seated to be dump.", model.c_str()); | |||||
| return false; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { | |||||
| dump_path_ = path; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpPath() const { | |||||
| return dump_path_; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { | |||||
| dump_step_ = step; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStep() const { | |||||
| return dump_step_; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { | |||||
| dump_mode_ = mode; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpMode() const { | |||||
| return dump_mode_; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStatus(const std::string &status) { | |||||
| dump_status_ = status; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStatus() const { | |||||
| return dump_status_; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitInferOpDebug() { | |||||
| is_infer_op_debug_ = true; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetOpDebugMode(const uint32_t &op_debug_mode) { | |||||
| op_debug_mode_ = op_debug_mode; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch( | |||||
| const std::string &dump_op_switch) { | |||||
| dump_op_switch_ = dump_op_switch; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpOpSwitch() const { | |||||
| return dump_op_switch_; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsSingleOpNeedDump() const { | |||||
| if (dump_op_switch_ == kDumpStatusOpen) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsDumpOpen() const { | |||||
| if (enable_dump_ == kEnableFlag || dump_status_ == kDumpStatusOpen) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void DumpProperties::CopyFrom(const DumpProperties &other) { | |||||
| if (&other != this) { | |||||
| enable_dump_ = other.enable_dump_; | |||||
| enable_dump_debug_ = other.enable_dump_debug_; | |||||
| dump_path_ = other.dump_path_; | |||||
| dump_step_ = other.dump_step_; | |||||
| dump_mode_ = other.dump_mode_; | |||||
| dump_status_ = other.dump_status_; | |||||
| dump_op_switch_ = other.dump_op_switch_; | |||||
| model_dump_properties_map_ = other.model_dump_properties_map_; | |||||
| is_train_op_debug_ = other.is_train_op_debug_; | |||||
| is_infer_op_debug_ = other.is_infer_op_debug_; | |||||
| op_debug_mode_ = other.op_debug_mode_; | |||||
| } | |||||
| } | |||||
| void DumpProperties::SetDumpDebugOptions() { | |||||
| if (enable_dump_debug_ == kEnableFlag) { | |||||
| std::string dump_debug_mode; | |||||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { | |||||
| GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); | |||||
| } else { | |||||
| GELOGW("Dump debug mode is not set."); | |||||
| return; | |||||
| } | |||||
| if (dump_debug_mode == OP_DEBUG_AICORE) { | |||||
| GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); | |||||
| is_train_op_debug_ = true; | |||||
| op_debug_mode_ = kAicoreOverflow; | |||||
| } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { | |||||
| GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); | |||||
| is_train_op_debug_ = true; | |||||
| op_debug_mode_ = kAtomicOverflow; | |||||
| } else if (dump_debug_mode == OP_DEBUG_ALL) { | |||||
| GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); | |||||
| is_train_op_debug_ = true; | |||||
| op_debug_mode_ = kAllOverflow; | |||||
| } else { | |||||
| GELOGW("ge.exec.dumpDebugMode is invalid."); | |||||
| } | |||||
| } else { | |||||
| GELOGI("ge.exec.enableDumpDebug is false or is not set."); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,115 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DUMP_DUMP_PROPERTIES_H_ | |||||
| #define GE_COMMON_DUMP_DUMP_PROPERTIES_H_ | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| class DumpProperties { | |||||
| public: | |||||
| DumpProperties() = default; | |||||
| ~DumpProperties() = default; | |||||
| DumpProperties(const DumpProperties &dump); | |||||
| DumpProperties &operator=(const DumpProperties &dump); | |||||
| void InitByOptions(); | |||||
| void AddPropertyValue(const std::string &model, const std::set<std::string> &layers); | |||||
| void DeletePropertyValue(const std::string &model); | |||||
| void ClearDumpPropertyValue(); | |||||
| void ClearDumpInfo(); | |||||
| std::set<std::string> GetAllDumpModel() const; | |||||
| std::set<std::string> GetPropertyValue(const std::string &model) const; | |||||
| bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; | |||||
| void SetDumpPath(const std::string &path); | |||||
| const std::string &GetDumpPath() const; | |||||
| void SetDumpStep(const std::string &step); | |||||
| const std::string &GetDumpStep() const; | |||||
| void SetDumpMode(const std::string &mode); | |||||
| const std::string &GetDumpMode() const; | |||||
| void SetDumpStatus(const std::string &status); | |||||
| const std::string &GetDumpStatus() const; | |||||
| void InitInferOpDebug(); | |||||
| bool IsInferOpDebug() const { | |||||
| return is_infer_op_debug_; | |||||
| } | |||||
| void SetDumpOpSwitch(const std::string &dump_op_switch); | |||||
| const std::string &GetDumpOpSwitch() const; | |||||
| bool IsOpDebugOpen() const { | |||||
| return is_train_op_debug_ || is_infer_op_debug_; | |||||
| } | |||||
| bool IsDumpOpen() const; | |||||
| bool IsSingleOpNeedDump() const; | |||||
| void SetOpDebugMode(const uint32_t &op_debug_mode); | |||||
| uint32_t GetOpDebugMode() const { return op_debug_mode_; } | |||||
| const std::string &GetEnableDump() const {return enable_dump_;} | |||||
| const std::string &GetEnableDumpDebug() const {return enable_dump_debug_;} | |||||
| private: | |||||
| void CopyFrom(const DumpProperties &other); | |||||
| void SetDumpDebugOptions(); | |||||
| std::string enable_dump_; | |||||
| std::string enable_dump_debug_; | |||||
| std::string dump_path_; | |||||
| std::string dump_step_; | |||||
| std::string dump_mode_; | |||||
| std::string dump_status_; | |||||
| std::string dump_op_switch_; | |||||
| std::map<std::string, std::set<std::string>> model_dump_properties_map_; | |||||
| bool is_train_op_debug_ = false; | |||||
| bool is_infer_op_debug_ = false; | |||||
| uint32_t op_debug_mode_ = 0; | |||||
| }; | |||||
| } | |||||
| #endif //GE_COMMON_DUMP_DUMP_PROPERTIES_H_ | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "adx_datadump_server.h" | |||||
| int AdxDataDumpServerUnInit() { return 0; } | |||||
| int AdxDataDumpServerInit() { return 0; } | |||||
| @@ -1,241 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/dump/exception_dumper.h" | |||||
| #include "common/ge/datatype_util.h" | |||||
| #include "common/debug/memory_dumper.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/manager/util/debug.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/load/model_manager/model_utils.h" | |||||
| #include "proto/dump_task.pb.h" | |||||
| namespace { | |||||
| static uint64_t GetNowTime() { | |||||
| uint64_t ret = 0; | |||||
| mmTimeval tv; | |||||
| if (mmGetTimeOfDay(&tv, nullptr) == 0) { | |||||
| ret = tv.tv_sec * 1000000ULL + tv.tv_usec; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| static void ReplaceStringElem(std::string &str) { | |||||
| for_each(str.begin(), str.end(), [](char &ch) { | |||||
| if ((ch == ' ') || (ch == '.') || (ch == '/') || (ch == '\\')) { | |||||
| ch = '_'; | |||||
| } | |||||
| }); | |||||
| } | |||||
| static void SetDumpData(const ge::OpDescInfo &op_desc_info, toolkit::dump::DumpData &dump_data) { | |||||
| dump_data.set_version("2.0"); | |||||
| dump_data.set_dump_time(GetNowTime()); | |||||
| dump_data.set_op_name(op_desc_info.op_name); | |||||
| for (size_t i = 0; i < op_desc_info.input_format.size(); ++i) { | |||||
| toolkit::dump::OpInput input; | |||||
| input.set_data_type(toolkit::dump::OutputDataType( | |||||
| ge::DataTypeUtil::GetIrDataType(op_desc_info.input_data_type[i]))); | |||||
| input.set_format(toolkit::dump::OutputFormat(op_desc_info.input_format[i])); | |||||
| for (auto dim : op_desc_info.input_shape[i]) { | |||||
| input.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| input.set_size(op_desc_info.input_size[i]); | |||||
| GELOGI("[Set][DumpData] The input size int exception is %ld", op_desc_info.input_size[i]); | |||||
| dump_data.mutable_input()->Add(std::move(input)); | |||||
| } | |||||
| for (size_t j = 0; j < op_desc_info.output_format.size(); ++j) { | |||||
| toolkit::dump::OpOutput output; | |||||
| output.set_data_type(toolkit::dump::OutputDataType( | |||||
| ge::DataTypeUtil::GetIrDataType(op_desc_info.output_data_type[j]))); | |||||
| output.set_format(toolkit::dump::OutputFormat(op_desc_info.output_format[j])); | |||||
| for (auto dim : op_desc_info.output_shape[j]) { | |||||
| output.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| output.set_size(op_desc_info.output_size[j]); | |||||
| GELOGI("[Set][DumpData] The output size int exception is %ld", op_desc_info.output_size[j]); | |||||
| dump_data.mutable_output()->Add(std::move(output)); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| namespace ge { | |||||
| ExceptionDumper::~ExceptionDumper() {} | |||||
| void ExceptionDumper::SaveDumpOpInfo(const OpDescPtr &op, uint32_t task_id, uint32_t stream_id, | |||||
| vector<void *> &input_addrs, vector<void *> &output_addrs) { | |||||
| OpDescInfo op_desc_info; | |||||
| SaveOpDescInfo(op, task_id, stream_id, op_desc_info); | |||||
| op_desc_info.input_addrs = input_addrs; | |||||
| op_desc_info.output_addrs = output_addrs; | |||||
| op_desc_info_.emplace_back(std::move(op_desc_info)); | |||||
| } | |||||
| void ExceptionDumper::SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, | |||||
| uint32_t task_id, uint32_t stream_id) { | |||||
| OpDescInfo op_desc_info; | |||||
| SaveOpDescInfo(op, task_id, stream_id, op_desc_info); | |||||
| op_desc_info.input_addrs = ModelUtils::GetInputDataAddrs(model_param, op); | |||||
| op_desc_info.output_addrs = ModelUtils::GetOutputDataAddrs(model_param, op); | |||||
| op_desc_info_.emplace_back(std::move(op_desc_info)); | |||||
| } | |||||
| void ExceptionDumper::SaveOpDescInfo(const OpDescPtr &op, uint32_t task_id, uint32_t stream_id, | |||||
| OpDescInfo &op_desc_info) { | |||||
| if (op == nullptr) { | |||||
| GELOGW("[Save][OpExceptionInfo] op desc ptr is null."); | |||||
| return; | |||||
| } | |||||
| GELOGD("[Save][OpExceptionInfo] Start to save dump op [%s] info of task_id: %u, stream_id: %u", | |||||
| op->GetName().c_str(), task_id, stream_id); | |||||
| op_desc_info.op_name = op->GetName(); | |||||
| op_desc_info.op_type = op->GetType(); | |||||
| op_desc_info.task_id = task_id; | |||||
| op_desc_info.stream_id = stream_id; | |||||
| for (size_t i = 0; i < op->GetAllInputsSize(); ++i) { | |||||
| GeTensorDescPtr input_tensor_desc = op->MutableInputDesc(i); | |||||
| if (input_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| op_desc_info.input_format.emplace_back(input_tensor_desc->GetFormat()); | |||||
| op_desc_info.input_shape.emplace_back(input_tensor_desc->GetShape().GetDims()); | |||||
| op_desc_info.input_data_type.emplace_back(input_tensor_desc->GetDataType()); | |||||
| int64_t input_size = 0; | |||||
| if (TensorUtils::GetTensorSizeInBytes(*input_tensor_desc, input_size) != SUCCESS) { | |||||
| GELOGW("[Save][OpExceptionInfo] Op [%s] get input size failed.", op->GetName().c_str()); | |||||
| return; | |||||
| } | |||||
| GELOGD("[Save][OpExceptionInfo] Save dump op info, the input size is %ld", input_size); | |||||
| op_desc_info.input_size.emplace_back(input_size); | |||||
| } | |||||
| for (size_t j = 0; j < op->GetOutputsSize(); ++j) { | |||||
| GeTensorDescPtr output_tensor_desc = op->MutableOutputDesc(j); | |||||
| if (output_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| op_desc_info.output_format.emplace_back(output_tensor_desc->GetFormat()); | |||||
| op_desc_info.output_shape.emplace_back(output_tensor_desc->GetShape().GetDims()); | |||||
| op_desc_info.output_data_type.emplace_back(output_tensor_desc->GetDataType()); | |||||
| int64_t output_size = 0; | |||||
| if (TensorUtils::GetTensorSizeInBytes(*output_tensor_desc, output_size) != SUCCESS) { | |||||
| GELOGW("[Save][OpExceptionInfo] Op [%s] get output size failed.", op->GetName().c_str()); | |||||
| return; | |||||
| } | |||||
| GELOGD("[Save][OpExceptionInfo] Save dump op info, the output size is %ld.", output_size); | |||||
| op_desc_info.output_size.emplace_back(output_size); | |||||
| } | |||||
| } | |||||
| Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &exception_infos) const { | |||||
| GELOGI("[Dump][Exception] Start to dump exception info"); | |||||
| for (const rtExceptionInfo &iter : exception_infos) { | |||||
| OpDescInfo op_desc_info; | |||||
| if (GetOpDescInfo(iter.streamid, iter.taskid, op_desc_info)) { | |||||
| toolkit::dump::DumpData dump_data; | |||||
| SetDumpData(op_desc_info, dump_data); | |||||
| uint64_t now_time = GetNowTime(); | |||||
| std::string op_name = op_desc_info.op_name; | |||||
| std::string op_type = op_desc_info.op_type; | |||||
| ReplaceStringElem(op_name); | |||||
| ReplaceStringElem(op_type); | |||||
| string dump_file_path = | |||||
| "./" + op_type + "." + op_name + "." + std::to_string(op_desc_info.task_id) + "." + std::to_string(now_time); | |||||
| GELOGI("[Dump][Exception] The exception dump file path is %s", dump_file_path.c_str()); | |||||
| uint64_t proto_size = dump_data.ByteSizeLong(); | |||||
| std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | |||||
| bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | |||||
| if (!ret || proto_size == 0) { | |||||
| REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | |||||
| GELOGE(PARAM_INVALID, "[Dump][Exception] Dump data proto serialize failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_STATUS_RET(MemoryDumper::DumpToFile(dump_file_path.c_str(), &proto_size, sizeof(uint64_t)), | |||||
| "Failed to dump proto size"); | |||||
| GE_CHK_STATUS_RET(MemoryDumper::DumpToFile(dump_file_path.c_str(), proto_msg.get(), proto_size), | |||||
| "Failed to dump proto msg"); | |||||
| if (DumpExceptionInput(op_desc_info, dump_file_path) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Dump][Exception] Dump exception input failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (DumpExceptionOutput(op_desc_info, dump_file_path) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Dump][Exception] Dump exception output failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GELOGI("[Dump][Exception] Dump exception info SUCCESS"); | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "[Dump][Exception] Get op desc info failed,task id:%u,stream id:%u", | |||||
| iter.taskid, iter.streamid); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool ExceptionDumper::GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const { | |||||
| GELOGI("[Get][OpDescInfo] There are %zu op need to dump.", op_desc_info_.size()); | |||||
| for (size_t index = 0; index < op_desc_info_.size(); ++index) { | |||||
| OpDescInfo dump_op_info = op_desc_info_.at(index); | |||||
| if (dump_op_info.task_id == task_id && dump_op_info.stream_id == stream_id) { | |||||
| GELOGI("[Get][OpDescInfo] Find exception op [%s] of task_id: %u, stream_id: %u.", | |||||
| dump_op_info.op_name.c_str(), task_id, stream_id); | |||||
| op_desc_info = dump_op_info; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status ExceptionDumper::DumpExceptionInput(const OpDescInfo &op_desc_info, const string &dump_file) const { | |||||
| GELOGI("[Dump][ExceptionInput] Start to dump exception input"); | |||||
| for (size_t i = 0; i < op_desc_info.input_addrs.size(); i++) { | |||||
| if (Debug::DumpDevMem(dump_file.data(), op_desc_info.input_addrs.at(i), op_desc_info.input_size.at(i)) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Dump][ExceptionInput] Dump the %zu input data of op [%s] failed", | |||||
| i, op_desc_info.op_name.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ExceptionDumper::DumpExceptionOutput(const OpDescInfo &op_desc_info, const string &dump_file) const { | |||||
| GELOGI("[Dump][ExceptionOutput] Start to dump exception output"); | |||||
| for (size_t i = 0; i < op_desc_info.output_addrs.size(); i++) { | |||||
| if (Debug::DumpDevMem(dump_file.data(), op_desc_info.output_addrs.at(i), op_desc_info.output_size.at(i)) != | |||||
| SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Dump][ExceptionInput] Dump the %zu input data of op [%s] failed", | |||||
| i, op_desc_info.op_name.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| OpDescInfo *ExceptionDumper::MutableOpDescInfo(uint32_t task_id, uint32_t stream_id) { | |||||
| for (OpDescInfo &op_desc_info : op_desc_info_) { | |||||
| if (op_desc_info.task_id == task_id && op_desc_info.stream_id == stream_id) { | |||||
| return &op_desc_info; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DUMP_EXCEPTION_DUMPER_H_ | |||||
| #define GE_COMMON_DUMP_EXCEPTION_DUMPER_H_ | |||||
| #include <vector> | |||||
| #include "graph/op_desc.h" | |||||
| #include "framework/common/ge_types.h" | |||||
| #include "graph/load/model_manager/task_info/task_info.h" | |||||
| namespace ge { | |||||
| class ExceptionDumper { | |||||
| public: | |||||
| ExceptionDumper() = default; | |||||
| ~ExceptionDumper(); | |||||
| void SaveDumpOpInfo(const OpDescPtr &op, uint32_t task_id, uint32_t stream_id, | |||||
| std::vector<void *> &input_addrs, std::vector<void *> &output_addrs); | |||||
| void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id); | |||||
| Status DumpExceptionInfo(const std::vector<rtExceptionInfo> &exception_infos) const; | |||||
| bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; | |||||
| OpDescInfo *MutableOpDescInfo(uint32_t task_id, uint32_t stream_id); | |||||
| private: | |||||
| void SaveOpDescInfo(const OpDescPtr &op, uint32_t task_id, uint32_t stream_id, OpDescInfo &op_desc_info); | |||||
| Status DumpExceptionInput(const OpDescInfo &op_desc_info, const std::string &dump_file) const; | |||||
| Status DumpExceptionOutput(const OpDescInfo &op_desc_info, const std::string &dump_file) const; | |||||
| std::vector<OpDescInfo> op_desc_info_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_DUMP_EXCEPTION_DUMPER_H_ | |||||
| @@ -1,152 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "opdebug_register.h" | |||||
| namespace { | |||||
| const size_t kOpDebugMemorySize = 2048UL; | |||||
| const size_t kDebugP2pSize = 8UL; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| OpdebugRegister::~OpdebugRegister() {} | |||||
| Status OpdebugRegister::RegisterDebugForModel(rtModel_t model_handle, uint32_t op_debug_mode, DataDumper &data_dumper) { | |||||
| GELOGD("Start to register debug for model in overflow"); | |||||
| auto ret = MallocMemForOpdebug(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Malloc][MemForOpdebug]Failed when debug for model overflow, ret:0x%X", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Malloc memory for opdebug failed when debug " | |||||
| "for model in overflow, ret 0x%X", ret); | |||||
| return ret; | |||||
| } | |||||
| uint32_t debug_stream_id = 0; | |||||
| uint32_t debug_task_id = 0; | |||||
| auto rt_ret = rtDebugRegister(model_handle, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtDebugRegister error, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| GELOGD("debug_task_id:%u, debug_stream_id:%u in model overflow", debug_task_id, debug_stream_id); | |||||
| data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true); | |||||
| return SUCCESS; | |||||
| } | |||||
| void OpdebugRegister::UnregisterDebugForModel(rtModel_t model_handle) { | |||||
| rtError_t rt_ret = RT_ERROR_NONE; | |||||
| if (model_handle != nullptr) { | |||||
| GELOGD("start to call rtDebugUnRegister in model overflow."); | |||||
| rt_ret = rtDebugUnRegister(model_handle); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtDebugUnRegister failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| } | |||||
| if (op_debug_addr_ != nullptr) { | |||||
| rt_ret = rtFree(op_debug_addr_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtFree failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| op_debug_addr_ = nullptr; | |||||
| } | |||||
| if (p2p_debug_addr_ != nullptr) { | |||||
| rt_ret = rtFree(p2p_debug_addr_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtFree failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| p2p_debug_addr_ = nullptr; | |||||
| } | |||||
| return; | |||||
| } | |||||
| Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_debug_mode, DataDumper &data_dumper) { | |||||
| GELOGD("Start to register debug for stream in stream overflow"); | |||||
| auto ret = MallocMemForOpdebug(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Malloc][MemForOpdebug]Failed when debug for stream in overflow, ret:0x%X", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Malloc memory for opdebug failed when debug " | |||||
| "for stream in overflow, ret:0x%X", ret); | |||||
| return ret; | |||||
| } | |||||
| uint32_t debug_stream_id = 0; | |||||
| uint32_t debug_task_id = 0; | |||||
| auto rt_ret = rtDebugRegisterForStream(stream, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "[Call][rtDebugRegisterForStream]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDebugRegisterForStream failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| GELOGD("debug_task_id:%u, debug_stream_id:%u in stream overflow.", debug_task_id, debug_stream_id); | |||||
| data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true); | |||||
| return SUCCESS; | |||||
| } | |||||
| void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { | |||||
| rtError_t rt_ret = RT_ERROR_NONE; | |||||
| if (stream != nullptr) { | |||||
| GELOGD("start call rtDebugUnRegisterForStream in unknown shape over flow."); | |||||
| rt_ret = rtDebugUnRegisterForStream(stream); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtDebugUnRegisterForStream failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| } | |||||
| if (op_debug_addr_ != nullptr) { | |||||
| rt_ret = rtFree(op_debug_addr_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtFree failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| op_debug_addr_ = nullptr; | |||||
| } | |||||
| if (p2p_debug_addr_ != nullptr) { | |||||
| rt_ret = rtFree(p2p_debug_addr_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("rtFree failed, ret: 0x%X", rt_ret); | |||||
| } | |||||
| p2p_debug_addr_ = nullptr; | |||||
| } | |||||
| return; | |||||
| } | |||||
| Status OpdebugRegister::MallocMemForOpdebug() { | |||||
| rtError_t rt_ret = rtMalloc(&op_debug_addr_, kOpDebugMemorySize, RT_MEMORY_DDR); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "[Call][rtMalloc]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| uint64_t debug_addrs_tmp = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(op_debug_addr_)); | |||||
| // For data dump, aicpu needs the pointer to pointer that save the real debug address. | |||||
| rt_ret = rtMalloc(&p2p_debug_addr_, kDebugP2pSize, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "[Call][rtMalloc]Failed, ret 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, ret 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| rt_ret = rtMemcpy(p2p_debug_addr_, sizeof(uint64_t), &debug_addrs_tmp, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "[Call][rtMemcpy]To p2p_addr error 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy to p2p_addr error 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_DUMP_OPDEBUG_REGISTER_H_ | |||||
| #define GE_COMMON_DUMP_OPDEBUG_REGISTER_H_ | |||||
| #include <map> | |||||
| #include "common/debug/ge_log.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "graph/load/model_manager/data_dumper.h" | |||||
| namespace ge { | |||||
| class OpdebugRegister { | |||||
| public: | |||||
| OpdebugRegister() = default; | |||||
| ~OpdebugRegister(); | |||||
| Status RegisterDebugForModel(rtModel_t model_handle, uint32_t op_debug_mode, DataDumper &data_dumper); | |||||
| void UnregisterDebugForModel(rtModel_t model_handle); | |||||
| Status RegisterDebugForStream(rtStream_t stream, uint32_t op_debug_mode, DataDumper &data_dumper); | |||||
| void UnregisterDebugForStream(rtStream_t stream); | |||||
| private: | |||||
| Status MallocMemForOpdebug(); | |||||
| void *op_debug_addr_ = nullptr; | |||||
| void *p2p_debug_addr_ = nullptr; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_DUMP_OPDEBUG_REGISTER_H_ | |||||
| @@ -1,64 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/common/fmk_error_codes.h" | |||||
| namespace domi { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY StatusFactory *StatusFactory::Instance() { | |||||
| static StatusFactory instance; | |||||
| return &instance; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void StatusFactory::RegisterErrorNo(uint32_t err, | |||||
| const std::string &desc) { | |||||
| if (err_desc_.find(err) != err_desc_.end()) { | |||||
| return; | |||||
| } | |||||
| err_desc_[err] = desc; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string StatusFactory::GetErrDesc(uint32_t err) { | |||||
| auto iter_find = err_desc_.find(err); | |||||
| if (iter_find == err_desc_.end()) { | |||||
| return ""; | |||||
| } | |||||
| return iter_find->second; | |||||
| } | |||||
| // General error code | |||||
| DEF_ERRORNO(SUCCESS, "Success"); | |||||
| DEF_ERRORNO(FAILED, "Failed"); | |||||
| // Common errocode | |||||
| DEF_ERRORNO(MEMALLOC_FAILED, "Failed to allocate memory!"); // 50331648 | |||||
| DEF_ERRORNO(PARAM_INVALID, "Parameter's invalid!"); // 50331649 | |||||
| DEF_ERRORNO(CCE_FAILED, "Failed to call CCE API!"); // 50331650 | |||||
| DEF_ERRORNO(RT_FAILED, "Failed to call runtime API!"); // 50331651 | |||||
| DEF_ERRORNO(INTERNAL_ERROR, "Internal errors"); // 50331652 | |||||
| DEF_ERRORNO(CSEC_ERROR, "Failed to call libc_sec API!"); // 50331653 | |||||
| DEF_ERRORNO(TEE_ERROR, "Failed to call tee API!"); // 50331653 | |||||
| DEF_ERRORNO(UNSUPPORTED, "Parameter's unsupported!"); | |||||
| DEF_ERRORNO(OUT_OF_MEMORY, "Out of memory!"); | |||||
| // errorcode | |||||
| DEF_ERRORNO(PARSE_MODEL_FAILED, "Failed to parse the model!"); | |||||
| DEF_ERRORNO(PARSE_WEIGHTS_FAILED, "Failed to parse the weights!"); | |||||
| DEF_ERRORNO(NOT_INITIALIZED, "It hasn't been initialized!"); | |||||
| DEF_ERRORNO(TIMEOUT, "Running time out!"); | |||||
| // errorcode | |||||
| DEF_ERRORNO(MODEL_NOT_READY, "The model is not ready yet!"); | |||||
| DEF_ERRORNO(PUSH_DATA_FAILED, "Failed to push data!"); | |||||
| DEF_ERRORNO(DATA_QUEUE_ISFULL, "Data queue is full!"); | |||||
| } // namespace domi | |||||
| @@ -1,190 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/datatype_transfer.h" | |||||
| #include <cstdint> | |||||
| #include <map> | |||||
| #include <utility> | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "common/fp16_t.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "securec.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| enum DataTypeTransMode { | |||||
| kTransferWithDatatypeFloatToFloat16, | |||||
| kTransferWithDatatypeFloatToInt32, | |||||
| kTransferWithDatatypeFloat16ToFloat, | |||||
| kTransferWithDatatypeFloat16ToInt32, | |||||
| kTransferWithDatatypeInt32ToFloat, | |||||
| kTransferWithDatatypeInt32ToFloat16, | |||||
| kTransferWithDatatypeInt32ToUint8, | |||||
| kTransferWithDatatypeInt32ToInt8, | |||||
| kTransferWithDatatypeUint8ToFloat, | |||||
| kTransferWithDatatypeUint8ToInt32, | |||||
| kTransferWithDatatypeInt8ToFloat, | |||||
| kTransferWithDatatypeInt8ToInt32, | |||||
| kTransferWithDatatypeInt64ToInt32, | |||||
| kTransferWithDatatypeInt32ToInt64, | |||||
| kTransferWithDatatypeInt32ToDouble, | |||||
| kTransferWithDatatypeDoubleToInt32, | |||||
| }; | |||||
| std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | |||||
| {std::pair<DataType, DataType>(DT_FLOAT, DT_FLOAT16), kTransferWithDatatypeFloatToFloat16}, | |||||
| {std::pair<DataType, DataType>(DT_FLOAT, DT_INT32), kTransferWithDatatypeFloatToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_FLOAT16, DT_FLOAT), kTransferWithDatatypeFloat16ToFloat}, | |||||
| {std::pair<DataType, DataType>(DT_FLOAT16, DT_INT32), kTransferWithDatatypeFloat16ToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_FLOAT), kTransferWithDatatypeInt32ToFloat}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_FLOAT16), kTransferWithDatatypeInt32ToFloat16}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_UINT8), kTransferWithDatatypeInt32ToUint8}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_INT8), kTransferWithDatatypeInt32ToInt8}, | |||||
| {std::pair<DataType, DataType>(DT_UINT8, DT_FLOAT), kTransferWithDatatypeUint8ToFloat}, | |||||
| {std::pair<DataType, DataType>(DT_UINT8, DT_INT32), kTransferWithDatatypeUint8ToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_INT8, DT_FLOAT), kTransferWithDatatypeInt8ToFloat}, | |||||
| {std::pair<DataType, DataType>(DT_INT8, DT_INT32), kTransferWithDatatypeInt8ToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_INT64, DT_INT32), kTransferWithDatatypeInt64ToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_INT64), kTransferWithDatatypeInt32ToInt64}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_DOUBLE), kTransferWithDatatypeInt32ToDouble}, | |||||
| {std::pair<DataType, DataType>(DT_DOUBLE, DT_INT32), kTransferWithDatatypeDoubleToInt32}, | |||||
| }; | |||||
| template <typename SrcT, typename DstT> | |||||
| Status TransDataSrc2Dst(const CastArgs &args, uint8_t *dst, const size_t data_size) { | |||||
| SrcT src_data; | |||||
| for (size_t idx = 0; idx != data_size; idx++) { | |||||
| src_data = reinterpret_cast<const SrcT *>(args.data)[idx]; | |||||
| reinterpret_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename SrcT> | |||||
| Status TransDataSrc2Fp16(const CastArgs &args, uint8_t *dst, const size_t data_size) { | |||||
| fp16_t src_data; | |||||
| for (size_t idx = 0; idx != data_size; idx++) { | |||||
| src_data = reinterpret_cast<const SrcT *>(args.data)[idx]; | |||||
| reinterpret_cast<uint16_t *>(dst)[idx] = src_data.val; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CastKernel(const CastArgs &args, uint8_t *dst, const size_t data_size, const DataTypeTransMode trans_mode) { | |||||
| static std::map<DataTypeTransMode, std::function<Status(const CastArgs &, uint8_t *, const size_t)>> | |||||
| transfer_handle = { | |||||
| {kTransferWithDatatypeFloatToFloat16, TransDataSrc2Fp16<float>}, | |||||
| {kTransferWithDatatypeFloatToInt32, TransDataSrc2Dst<float, int32_t>}, | |||||
| {kTransferWithDatatypeFloat16ToFloat, TransDataSrc2Dst<fp16_t, float>}, | |||||
| {kTransferWithDatatypeFloat16ToInt32, TransDataSrc2Dst<fp16_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToFloat, TransDataSrc2Dst<int32_t, float>}, | |||||
| {kTransferWithDatatypeInt32ToFloat16, TransDataSrc2Fp16<int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToUint8, TransDataSrc2Dst<int32_t, uint8_t>}, | |||||
| {kTransferWithDatatypeInt32ToInt8, TransDataSrc2Dst<int32_t, int8_t>}, | |||||
| {kTransferWithDatatypeUint8ToFloat, TransDataSrc2Dst<uint8_t, float>}, | |||||
| {kTransferWithDatatypeUint8ToInt32, TransDataSrc2Dst<uint8_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt8ToFloat, TransDataSrc2Dst<int8_t, float>}, | |||||
| {kTransferWithDatatypeInt8ToInt32, TransDataSrc2Dst<int8_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt64ToInt32, TransDataSrc2Dst<int64_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToInt64, TransDataSrc2Dst<int32_t, int64_t>}, | |||||
| {kTransferWithDatatypeInt32ToDouble, TransDataSrc2Dst<int32_t, double>}, | |||||
| {kTransferWithDatatypeDoubleToInt32, TransDataSrc2Dst<double, int32_t>}, | |||||
| }; | |||||
| auto it = transfer_handle.find(trans_mode); | |||||
| if (it == transfer_handle.end()) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } else { | |||||
| return (it->second)(args, dst, data_size); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result) { | |||||
| GELOGD("Begin trans data from %s to %s, data size %zu", TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str(), args.src_data_size); | |||||
| std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | |||||
| auto iter = trans_mode_map.find(trans_info); | |||||
| if (iter == trans_mode_map.end()) { | |||||
| std::string error = "Failed to trans data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| auto trans_mode = iter->second; | |||||
| int size = GetSizeByDataType(args.dst_data_type); | |||||
| if (size <= 0) { | |||||
| std::string error = "Failed to calc size from data type" + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | |||||
| std::string error = "args.src_data_size" + FmtToStr(args.src_data_size) + | |||||
| " or data type size" + FmtToStr(size) + " is too big"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| size_t total_size = static_cast<size_t>(args.src_data_size * size); | |||||
| result.length = total_size; | |||||
| if (total_size == 0) { | |||||
| GELOGI("In TransDataType, total_size is zero, has no data."); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allocate][DSTMemory]Failed, memory for dst buf %zu, data size %zu", | |||||
| total_size, args.src_data_size); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allocate memory for dst buf %zu, data size %zu", | |||||
| total_size, args.src_data_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | |||||
| std::string error = "Failed to cast data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | |||||
| FmtToStr(std::to_string(args.src_data_size)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_INTERNAL_ERROR, error.c_str()); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| result.data = dst; | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<DataTypeTransfer> BuildDataTypeTransfer(const CastArgs &args) { | |||||
| if (!DataTypeTransferExists(args)) { | |||||
| return nullptr; | |||||
| } | |||||
| return ge::MakeShared<DataTypeTransfer>(); | |||||
| } | |||||
| bool DataTypeTransferExists(const CastArgs &args) { | |||||
| std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | |||||
| auto iter = trans_mode_map.find(trans_info); | |||||
| return iter != trans_mode_map.end(); | |||||
| } | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_DATATYPE_TRANSFER_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_DATATYPE_TRANSFER_H_ | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| struct CastArgs { | |||||
| const uint8_t *data; | |||||
| size_t src_data_size; | |||||
| DataType src_data_type; | |||||
| DataType dst_data_type; | |||||
| }; | |||||
| class DataTypeTransfer { | |||||
| public: | |||||
| Status TransDataType(const CastArgs &args, TransResult &result); | |||||
| }; | |||||
| std::shared_ptr<DataTypeTransfer> BuildDataTypeTransfer(const CastArgs &args); | |||||
| bool DataTypeTransferExists(const CastArgs &args); | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_DATATYPE_TRANSFER_H_ | |||||
| @@ -1,202 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { | |||||
| return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||||
| } | |||||
| Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][SrcShape]Failed, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][DSTShape]Failed, dst shape %s.", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||||
| src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | |||||
| src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | |||||
| src_shape.at(kC1hwncoc0C0) != cube_size) { | |||||
| std::string error = "Failed to check relationship between src and dst shape, src shape" + | |||||
| FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allocate][DSTMemory]Failed to allcoate memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto h = args.src_shape.at(kC1hwncoc0H); | |||||
| auto w = args.src_shape.at(kC1hwncoc0W); | |||||
| auto n = args.src_shape.at(kC1hwncoc0N); | |||||
| auto c0 = args.src_shape.at(kC1hwncoc0C0); | |||||
| auto co = args.src_shape.at(kC1hwncoc0Co); | |||||
| auto c = args.dst_shape.at(kHwcnC); | |||||
| auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t cn = c * n; | |||||
| int64_t wcn = w * cn; | |||||
| int64_t coc0 = co * c0; | |||||
| int64_t ncoc0 = n * coc0; | |||||
| int64_t wncoc0 = w * ncoc0; | |||||
| int64_t hwncoc0 = h * wncoc0; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = h_idx * wcn; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * cn; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t c_head_addr = w_head_addr + c_idx * n; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t dst_idx = c_head_addr + n_idx; | |||||
| int64_t c1_idx = c_idx / cube_size; | |||||
| int64_t c0_idx = c_idx % cube_size; | |||||
| int64_t co_idx = c0_idx; | |||||
| int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| // The memcpy_s/memset_s argument `dstMax` must be less than 2G | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from " | |||||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld to " | |||||
| "HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, src_offset, | |||||
| h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from " | |||||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld to " | |||||
| "HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, src_offset, | |||||
| h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForC1hwncoc0ToHwcn(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][Shape]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s.", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Get shape faield, total size %ld from dst shape %s, src shape %s.", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from C1HWNCoC0 to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld.", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed when after trans, src shape %s, data type %s, dst shape %s, " | |||||
| "memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferC1hwncoc0Hwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from C1HWNCoC0 to HWCN is not unique. Trans shape in this direction is not supported."); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferC1hwncoc0Hwcn, FORMAT_C1HWNCoC0, FORMAT_HWCN) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_C1HWNCOC0_HWCN_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_C1HWNCOC0_HWCN_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferC1hwncoc0Hwcn : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_C1HWNCOC0_HWCN_H_ | |||||
| @@ -1,185 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| Status CheckDataTypeSupport(DataType dtype) { return GetSizeByDataType(dtype) > 0 ? SUCCESS : UNSUPPORTED; } | |||||
| Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| auto c1 = Ceil(c, c0); | |||||
| auto no = Ceil(n, static_cast<int64_t>(kNiSize)); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(d * c1 * h * w); | |||||
| dst_shape.push_back(no); | |||||
| dst_shape.push_back(kNiSize); | |||||
| dst_shape.push_back(c0); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeDhwckToFz3D(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kDhwcnDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto d = src_shape.at(kDhwcnD); | |||||
| auto h = src_shape.at(kDhwcnH); | |||||
| auto w = src_shape.at(kDhwcnW); | |||||
| auto c = src_shape.at(kDhwcnC); | |||||
| auto n = src_shape.at(kDhwcnN); | |||||
| return TransShapeToFz(d, n, c, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
| if (!CheckShapeValid(args.src_shape, kDhwcnDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t d = args.src_shape[kDhwcnD]; | |||||
| int64_t h = args.src_shape[kDhwcnH]; | |||||
| int64_t w = args.src_shape[kDhwcnW]; | |||||
| int64_t c = args.src_shape[kDhwcnC]; | |||||
| int64_t n = args.src_shape[kDhwcnN]; | |||||
| int64_t n1n0 = Ceil(n, static_cast<int64_t>(kNiSize)) * kNiSize; | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t c1 = Ceil(c, c0); | |||||
| auto cn = c * n; | |||||
| auto wcn = w * cn; | |||||
| auto hwcn = h * wcn; | |||||
| auto n1n0c0 = n1n0 * c0; | |||||
| auto wn1n0c0 = w * n1n0c0; | |||||
| auto hwn1n0c0 = h * wn1n0c0; | |||||
| auto c1hwn1n0c0 = c1 * hwn1n0c0; | |||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = 1; | |||||
| for (auto dim : args.dst_shape) { | |||||
| dst_size *= dim; | |||||
| } | |||||
| dst_size *= data_size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| for (int64_t di = 0; di < d; di++) { | |||||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||||
| for (int64_t hi = 0; hi < h; hi++) { | |||||
| for (int64_t wi = 0; wi < w; wi++) { | |||||
| for (int64_t n1n0i = 0; n1n0i < n1n0; n1n0i++) { | |||||
| for (int64_t c0i = 0; c0i < c0; c0i++) { | |||||
| int64_t dst_idx = di * c1hwn1n0c0 + c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | |||||
| int64_t dst_offset = dst_idx * data_size; | |||||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| errno_t ret; | |||||
| if (pad_zero) { | |||||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||||
| static_cast<size_t>(data_size)); | |||||
| } else { | |||||
| int64_t src_idx = di * hwcn + hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||||
| } | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at " | |||||
| "offset %ld, error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, " | |||||
| "error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = dst_size; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); | |||||
| std::vector<int64_t> expect_shape; | |||||
| auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (args.src_format == FORMAT_DHWCN && args.dst_format == FORMAT_FRACTAL_Z_3D) { | |||||
| return TransFormatDhwckToFz3D(args, result); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_format == FORMAT_DHWCN && dst_format == FORMAT_FRACTAL_Z_3D) { | |||||
| return TransShapeDhwckToFz3D(src_shape, data_type, dst_shape); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,33 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferDhwcnFractalZ3D : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | |||||
| @@ -1,186 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| Status CheckDataTypeSupport(DataType dtype) { return GetSizeByDataType(dtype) > 0 ? SUCCESS : UNSUPPORTED; } | |||||
| Status TransShapeToFz(int64_t d, int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| auto c1 = Ceil(c, c0); | |||||
| auto no = Ceil(n, static_cast<int64_t>(kNiSize)); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(d * c1 * h * w); | |||||
| dst_shape.push_back(no); | |||||
| dst_shape.push_back(kNiSize); | |||||
| dst_shape.push_back(c0); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeDhwncToFz3DTranspose(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kDhwncDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto d = src_shape.at(kDhwncD); | |||||
| auto h = src_shape.at(kDhwncH); | |||||
| auto w = src_shape.at(kDhwncW); | |||||
| auto n = src_shape.at(kDhwncN); | |||||
| auto c = src_shape.at(kDhwncC); | |||||
| // exchange n c, normalize process with dhwcn to fraz3D | |||||
| return TransShapeToFz(d, c, n, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &result) { | |||||
| if (!CheckShapeValid(args.src_shape, kDhwncDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t d = args.src_shape[kDhwncD]; | |||||
| int64_t h = args.src_shape[kDhwncH]; | |||||
| int64_t w = args.src_shape[kDhwncW]; | |||||
| // exchange nc ,for normalize process with dhwcn to Fz3D | |||||
| int64_t c = args.src_shape[kDhwncN]; | |||||
| int64_t n = args.src_shape[kDhwncC]; | |||||
| int64_t n1n0 = Ceil(n, static_cast<int64_t>(kNiSize)) * kNiSize; | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t c1 = Ceil(c, c0); | |||||
| auto cn = c * n; | |||||
| auto wcn = w * cn; | |||||
| auto hwcn = h * wcn; | |||||
| auto n1n0c0 = n1n0 * c0; | |||||
| auto wn1n0c0 = w * n1n0c0; | |||||
| auto hwn1n0c0 = h * wn1n0c0; | |||||
| auto c1hwn1n0c0 = c1 * hwn1n0c0; | |||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = 1; | |||||
| for (auto dim : args.dst_shape) { | |||||
| dst_size *= dim; | |||||
| } | |||||
| dst_size *= data_size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| for (int64_t di = 0; di < d; di++) { | |||||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||||
| for (int64_t hi = 0; hi < h; hi++) { | |||||
| for (int64_t wi = 0; wi < w; wi++) { | |||||
| for (int64_t n1n0i = 0; n1n0i < n1n0; n1n0i++) { | |||||
| for (int64_t c0i = 0; c0i < c0; c0i++) { | |||||
| int64_t dst_idx = di * c1hwn1n0c0 + c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | |||||
| int64_t dst_offset = dst_idx * data_size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||||
| errno_t ret; | |||||
| if (pad_zero) { | |||||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||||
| static_cast<size_t>(data_size)); | |||||
| } else { | |||||
| int64_t src_idx = di * hwcn + hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||||
| } | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at " | |||||
| "offset %ld, error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, " | |||||
| "error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = dst_size; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); | |||||
| std::vector<int64_t> expect_shape; | |||||
| auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (args.src_format == ge::FORMAT_DHWNC && args.dst_format == ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | |||||
| return TransFormatDhwncToFz3DTranspose(args, result); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_format == FORMAT_DHWNC && dst_format == FORMAT_FRACTAL_Z_3D_TRANSPOSE) { | |||||
| return TransShapeDhwncToFz3DTranspose(src_shape, data_type, dst_shape); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,33 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferDhwncFractalZ3DTranspose : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | |||||
| @@ -1,463 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fractal_nz.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| const int kDimSize4D = 4; | |||||
| const size_t kSingleDim = 1; | |||||
| const size_t kNdDimIndexN = 0; | |||||
| const size_t kNdDimIndexH = 1; | |||||
| const size_t kNdDimIndexW = 2; | |||||
| const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz | |||||
| const size_t kNdDimCountBackwardsW = 1; | |||||
| const size_t kNdDimCountBackwardsWH = 2; | |||||
| const size_t kFNzDimCountBackwardsW0 = 1; | |||||
| const size_t kFNzDimCountBackwardsW0H0 = 2; | |||||
| const size_t kFNzDimCountBackwardsW0H0H1 = 3; | |||||
| const size_t kFNzDimCountBackwardsW0H0H1W1 = 4; | |||||
| bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| using ShapeVector = std::vector<int64_t>; | |||||
| bool CheckShape(Format format, const ShapeVector &shape) { | |||||
| switch (format) { | |||||
| case FORMAT_ND: | |||||
| return IsShapeValid(shape); | |||||
| case FORMAT_NCHW: | |||||
| case FORMAT_NHWC: | |||||
| return CheckShapeValid(shape, kDimSize4D); | |||||
| default: | |||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_NZ is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| /** | |||||
| * After the conversion to two-dimensional matrix, the memory arrangement is small z and large N. | |||||
| * @src_shape: N*H*W | |||||
| * @dst_shape: N*W1*H1*H0*w0 | |||||
| * @return | |||||
| */ | |||||
| Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape, | |||||
| ShapeVector &hw_shape) { | |||||
| dst_shape.clear(); | |||||
| hw_shape.clear(); | |||||
| auto w0 = GetCubeSizeByDataType(data_type); | |||||
| int64_t h0 = kCubeSize; | |||||
| switch (src_shape.size()) { | |||||
| case kSingleDim: | |||||
| dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); | |||||
| dst_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| dst_shape.push_back(h0); | |||||
| dst_shape.push_back(w0); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][DSTShape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| default: | |||||
| auto size = src_shape.size(); | |||||
| int64_t times = 1; | |||||
| for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) { | |||||
| dst_shape.push_back(src_shape[i]); | |||||
| times *= src_shape[i]; | |||||
| } | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); | |||||
| dst_shape.push_back(h0); | |||||
| dst_shape.push_back(w0); | |||||
| hw_shape.push_back(times); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][DSTShape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| ShapeVector expect_src_shape; | |||||
| auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Transfer][ShapeToFracNz]Failed, shape from %s to %s, shape %s to %s, " | |||||
| "data type %s, error_code:%u", TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allocate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allocate memory for dst buf %ld " | |||||
| "trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||||
| auto times = hw_shape.at(kNdDimIndexN); | |||||
| auto h = hw_shape.at(kNdDimIndexH); | |||||
| auto w = hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | |||||
| auto shape_size = args.dst_shape.size(); | |||||
| auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; | |||||
| auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; | |||||
| auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0]; | |||||
| auto h1h0 = h1 * h0; | |||||
| auto h1h0w0 = h1h0 * w0; | |||||
| auto w1h1h0w0 = w1 * h1h0w0; | |||||
| auto num_w1 = w / w0; | |||||
| for (int64_t times_idx = 0; times_idx < times; times_idx++) { | |||||
| auto times_head = times_idx * w1h1h0w0; | |||||
| auto src_times_head = times_idx * hw; | |||||
| for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) { | |||||
| auto h1h0_head = times_head + h1h0_idx * w0; | |||||
| auto src_h_head = src_times_head + h1h0_idx * w; | |||||
| for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { | |||||
| auto dst_offset = (h1h0_head + w1_idx * h1h0w0) * size; | |||||
| auto src_offset = (src_h_head + w1_idx * w0) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size * w0)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,"[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| auto w1_head = num_w1 * w0; | |||||
| for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) { | |||||
| auto src_w_idx = w1_head + w0_idx; | |||||
| auto dst_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | |||||
| auto src_offset = (src_h_head + src_w_idx) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,"[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to trans format " | |||||
| "from %s to %s, memory for dst buf %ld", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to trans format from %s to %s and allocate memory " | |||||
| "for dst buf %ld", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto times = dst_hw_shape.at(kNdDimIndexN); | |||||
| auto h = dst_hw_shape.at(kNdDimIndexH); | |||||
| auto w = dst_hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | |||||
| auto shape_size = args.src_shape.size(); | |||||
| auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; | |||||
| auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; | |||||
| auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0]; | |||||
| auto h1h0 = h1 * h0; | |||||
| auto h1h0w0 = h1h0 * w0; | |||||
| auto w1h1h0w0 = w1 * h1h0w0; | |||||
| auto num_w1 = w / w0; | |||||
| errno_t ret; | |||||
| for (int64_t times_idx = 0; times_idx < times; times_idx++) { | |||||
| auto times_head = times_idx * w1h1h0w0; | |||||
| auto dst_times_head = times_idx * hw; | |||||
| for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) { | |||||
| auto h1h0_head = times_head + h1h0_idx * w0; | |||||
| auto dst_h_head = dst_times_head + h1h0_idx * w; | |||||
| for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { | |||||
| auto src_offset = (h1h0_head + w1_idx * h1h0w0) * size; | |||||
| auto dst_offset = (dst_h_head + w1_idx * w0) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size * w0)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| auto w1_head = num_w1 * w0; | |||||
| for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) { | |||||
| auto dst_w_idx = w1_head + w0_idx; | |||||
| auto src_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | |||||
| auto dst_offset = (dst_h_head + dst_w_idx) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, trans format from %s to %s, src shape %s, dst shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype failed, trans format from %s to %s, src shape %s, " | |||||
| "dst shape %s, data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed, trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check shape failed, trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| ShapeVector expect_shape; | |||||
| ShapeVector hw_shape; | |||||
| auto ret = TransShapeToFracNz(args.src_shape, args.src_data_type, expect_shape, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return TransFormatFromNdToFracNz(args, result, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||||
| Format dst_format, ShapeVector &dst_shape) { | |||||
| if (!IsDataTypeSupport(data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, trans format from %s to %s, src shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype failed, trans format from %s to %s, src shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShape(src_format, src_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed, trans format from %s to %s, src shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check shape failed, trans format from %s to %s, src shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| ShapeVector hw_shape; | |||||
| return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, trans format from %s to %s, src shape %s, dst shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype failed, trans format from %s to %s, src shape %s, " | |||||
| "dst shape %s, data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed, trans format from %s to %s, src shape %s, dst shape %s, " | |||||
| "data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check shape failed, trans format from %s to %s, src shape %s, " | |||||
| "dst shape %s, data type %s is not supported", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| ShapeVector hw_shape; | |||||
| Status ret = CheckShapeRelation(args, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| return TransFormatFromFracNzToNd(args, result, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalNzND::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||||
| Format dst_format, ShapeVector &dst_shape) { | |||||
| GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_ND, FORMAT_FRACTAL_NZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_NCHW, FORMAT_FRACTAL_NZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_NHWC, FORMAT_FRACTAL_NZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_ND) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_NCHW) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_NHWC) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| // transfer from nd to nz | |||||
| class FormatTransferFractalNz : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| // transfer nz to nd | |||||
| class FormatTransferFractalNzND : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_ | |||||
| @@ -1,572 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fractal_z.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| constexpr int64_t kDim = 1; | |||||
| static int64_t Measure(int64_t x, int64_t y) { | |||||
| int64_t z = y; | |||||
| while (x % y != 0) { | |||||
| z = x % y; | |||||
| x = y; | |||||
| y = z; | |||||
| } | |||||
| return z; | |||||
| } | |||||
| // least common multiple | |||||
| static int64_t Lcm(int64_t a, int64_t b) { | |||||
| if (b == 0) { | |||||
| return -1; | |||||
| } | |||||
| int64_t temp = (a * b) / (Measure(a, b)); | |||||
| return temp; | |||||
| } | |||||
| Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | |||||
| /** | |||||
| * FZ represents the weight of convolution,. | |||||
| * After the conversion to two-dimensional matrix, the memory arrangement is small n and large Z. | |||||
| * If 4D(eg.NCHW) is used to represent convolution kernel, N is width, HWC is height. | |||||
| * | |||||
| * frac_z axises: (C1*H*W, No, Ni, C0), which Ni = 16, C0 = 16/32, No = Ceil(N/Ni), C1 = Ceil(C/C0) | |||||
| * @return | |||||
| */ | |||||
| Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| auto c1 = Ceil(c, c0); | |||||
| auto no = Ceil(n, static_cast<int64_t>(kNiSize)); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(h * w * c1); | |||||
| dst_shape.push_back(no); | |||||
| dst_shape.push_back(kNiSize); | |||||
| dst_shape.push_back(c0); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape, | |||||
| int64_t groups) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t cin_ori = c; | |||||
| int64_t cout_ori = n / groups; | |||||
| int64_t cube_k = GetCubeSizeByDataType(data_type); | |||||
| int64_t e_mult = std::min( | |||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), | |||||
| groups); | |||||
| int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; | |||||
| int64_t c1_dim = cin_opt / cube_k; | |||||
| int64_t g_dim = Ceil(groups, e_mult); | |||||
| auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize)); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(g_dim * c1_dim * h * w); | |||||
| dst_shape.push_back(n1); | |||||
| dst_shape.push_back(16); | |||||
| dst_shape.push_back(cube_k); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto n = src_shape.at(kNchwN); | |||||
| auto c = src_shape.at(kNchwC); | |||||
| auto h = src_shape.at(kNchwH); | |||||
| auto w = src_shape.at(kNchwW); | |||||
| return TransShapeToFz(n, c, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto h = src_shape.at(kHwcnH); | |||||
| auto w = src_shape.at(kHwcnW); | |||||
| auto c = src_shape.at(kHwcnC); | |||||
| auto n = src_shape.at(kHwcnN); | |||||
| return TransShapeToFz(n, c, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape | |||||
| , int64_t groups){ | |||||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto h = src_shape.at(kHwcnH); | |||||
| auto w = src_shape.at(kHwcnW); | |||||
| auto c = src_shape.at(kHwcnC); | |||||
| auto n = src_shape.at(kHwcnN); | |||||
| return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups); | |||||
| } | |||||
| Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto n = src_shape.at(kNhwcN); | |||||
| auto h = src_shape.at(kNhwcH); | |||||
| auto w = src_shape.at(kNhwcW); | |||||
| auto c = src_shape.at(kNhwcC); | |||||
| return TransShapeToFz(n, c, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
| int64_t n = args.src_shape.at(kNchwN); | |||||
| int64_t c = args.src_shape.at(kNchwC); | |||||
| int64_t h = args.src_shape.at(kNchwH); | |||||
| int64_t w = args.src_shape.at(kNchwW); | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t c1 = Ceil(c, c0); | |||||
| int64_t hw = h * w; | |||||
| int64_t chw = c * hw; | |||||
| int64_t nchw = n * chw; | |||||
| int64_t hwc0 = hw * c0; | |||||
| // horizontal fractal matrix count (N) | |||||
| int64_t hf_cnt = Ceil(n, static_cast<int64_t>(kNiSize)); | |||||
| // vertical fractal matrix count (C1HWC0) | |||||
| int64_t vf_cnt = c1 * hw; | |||||
| // elements count in one fractal | |||||
| int64_t fractal_ele_cnt = c0 * kNiSize; | |||||
| int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = total_ele_cnt * size; | |||||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| dst == nullptr, | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
| for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { | |||||
| // vertical fractal matrix base index | |||||
| auto vf_base_i = vfi * hf_cnt; | |||||
| for (int64_t hfi = 0; hfi < hf_cnt; hfi++) { | |||||
| // global fractal matrix index | |||||
| auto gfi = vf_base_i + hfi; | |||||
| auto src_n_offset = hfi * chw * kNiSize; | |||||
| auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0; | |||||
| for (int64_t row = 0; row < c0; row++) { | |||||
| auto src_ci = vfi / hw * c0 + row; | |||||
| auto src_row_offset = src_f_offset + row * hw; | |||||
| for (int col = 0; col < kNiSize; col++) { | |||||
| auto src_ni = hfi * kNiSize + col; | |||||
| auto src_offset = src_row_offset + chw * col; | |||||
| // pad 0 | |||||
| // 1. src_ni grater than n | |||||
| // 2. src_ci grater than c | |||||
| // 3. source address grater than original array size | |||||
| auto need_pad_zero = src_ni >= n || src_offset >= nchw || src_ci >= c; | |||||
| auto idx = gfi * fractal_ele_cnt + col * c0 + row; | |||||
| auto offset = idx * size; | |||||
| auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| errno_t ret = EOK; | |||||
| if (need_pad_zero) { | |||||
| ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
| } else { | |||||
| if (protected_size < size) { | |||||
| std::string error = "Failed to operate the dst memory, protected_size is " + | |||||
| FmtToStr(protected_size) + " and size is " + FmtToStr(size); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | |||||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size); | |||||
| for (int64_t index = 0; index < size; index++) { | |||||
| *dst_data++ = *src_data++; | |||||
| } | |||||
| } | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED,"[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d pad mode %d", | |||||
| offset, ret, need_pad_zero); | |||||
| REPORT_CALL_ERROR("E19999","Failed to operate dst memory at offset %ld, " | |||||
| "error-code %d pad mode %d", | |||||
| offset, ret, need_pad_zero); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result, int64_t groups){ | |||||
| int64_t h_dim = args.src_shape[kHwcnH]; | |||||
| int64_t w_dim = args.src_shape[kHwcnW]; | |||||
| int64_t c_dim = args.src_shape[kHwcnC]; | |||||
| int64_t n_dim = args.src_shape[kHwcnN]; | |||||
| int64_t cin_ori = c_dim; | |||||
| int64_t cout_ori = n_dim / groups; | |||||
| if (cin_ori == 0 || cout_ori == 0) { | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param]Failed, cin_ori, cout_ori must not be equal 0, " | |||||
| "and current cin_ori, cout_ori, groups are %ld %ld %ld", cin_ori, cout_ori, groups); | |||||
| REPORT_CALL_ERROR("E19999", "Check graph param failed, cin_ori, cout_ori must not be equal 0," | |||||
| "and current cin_ori, cout_ori, groups are %ld %ld %ld", | |||||
| cin_ori, cout_ori, groups); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t e_mult = std::min( | |||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), | |||||
| groups); | |||||
| int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; | |||||
| int64_t cout_opt = Ceil(e_mult * cout_ori, static_cast<int64_t>(kCubeSize)) * static_cast<int64_t>(kCubeSize); | |||||
| int64_t c1_dim = cin_opt / cube_k; | |||||
| int64_t g_dim = Ceil(groups, e_mult); | |||||
| int64_t dim_cin = cin_opt / cube_k; | |||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t size_output_data = g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size; | |||||
| if (size_output_data == 0) { | |||||
| result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS; | |||||
| } | |||||
| errno_t ret = EOK; | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| size_output_data, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| size_output_data, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| ret = memset_s(dst.get(), static_cast<size_t>(size_output_data), 0, static_cast<size_t>(size_output_data)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed, ret is %d", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory, ret is %d", ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| for (int64_t g = 0; g < groups; g++) { | |||||
| for (int64_t d = 0; d < kDim; d++) { | |||||
| for (int64_t c = 0; c < c_dim; c++) { | |||||
| for (int64_t h = 0; h < h_dim; h++) { | |||||
| for (int64_t w = 0; w < w_dim; w++) { | |||||
| for (int64_t n = 0; n < cout_ori; n++) { | |||||
| int64_t e_val = g % e_mult; | |||||
| int64_t dst_ci = e_val * cin_ori + c; | |||||
| int64_t dst_co = e_val * cout_ori + n; | |||||
| int64_t src_co = g * cout_ori + n; | |||||
| int64_t tempory = dst_ci % cube_k; | |||||
| int64_t srx_inx = 0; | |||||
| int64_t dst_inx = (g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt * cube_k + | |||||
| d * c1_dim * h_dim * w_dim * cout_opt * cube_k + | |||||
| (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + | |||||
| h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k + | |||||
| dst_co * cube_k + tempory; | |||||
| srx_inx = d * h_dim * w_dim * c_dim * n_dim + h * w_dim * c_dim * n_dim + | |||||
| w * c_dim * n_dim + c * n_dim + src_co; | |||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size); | |||||
| const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size); | |||||
| for (int64_t index = 0; index < data_size; index++) { | |||||
| *dst_data++ = *src_data++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(size_output_data); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
| int64_t h = args.src_shape[kHwcnH]; | |||||
| int64_t w = args.src_shape[kHwcnW]; | |||||
| int64_t c = args.src_shape[kHwcnC]; | |||||
| int64_t n = args.src_shape[kHwcnN]; | |||||
| int64_t n1n0 = Ceil(n, static_cast<int64_t>(kNiSize)) * kNiSize; | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t c1 = Ceil(c, c0); | |||||
| auto cn = c * n; | |||||
| auto wcn = w * cn; | |||||
| auto n1n0c0 = n1n0 * c0; | |||||
| auto wn1n0c0 = w * n1n0c0; | |||||
| auto hwn1n0c0 = h * wn1n0c0; | |||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = 1; | |||||
| for (auto dim : args.dst_shape) { | |||||
| dst_size *= dim; | |||||
| } | |||||
| dst_size *= data_size; | |||||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| dst == nullptr, | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||||
| for (int64_t hi = 0; hi < h; hi++) { | |||||
| for (int64_t wi = 0; wi < w; wi++) { | |||||
| for (int64_t n1n0i = 0; n1n0i < n1n0; n1n0i++) { | |||||
| for (int64_t c0i = 0; c0i < c0; c0i++) { | |||||
| int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | |||||
| int64_t dst_offset = dst_idx * data_size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||||
| errno_t ret = EOK; | |||||
| if (pad_zero) { | |||||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||||
| static_cast<size_t>(data_size)); | |||||
| } else { | |||||
| if (protected_size < data_size) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID,"[Operate][DSTMemory]Failed, protected_size " | |||||
| "is %ld and size is %ld", | |||||
| protected_size, data_size); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||||
| for (int64_t index = 0; index < data_size; index++) { | |||||
| *dst_data++ = *src_data++; | |||||
| } | |||||
| } | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed, " | |||||
| "at offset %ld, error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memoery at offset %ld, " | |||||
| "error-code %d, pad mode %d", | |||||
| dst_offset, ret, pad_zero); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
| int64_t n = args.src_shape[kNhwcN]; | |||||
| int64_t h = args.src_shape[kNhwcH]; | |||||
| int64_t w = args.src_shape[kNhwcW]; | |||||
| int64_t c = args.src_shape[kNhwcC]; | |||||
| auto wc = w * c; | |||||
| auto hwc = h * w * c; | |||||
| int64_t n1n0 = Ceil(n, static_cast<int64_t>(kNiSize)) * kNiSize; | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int64_t c1 = Ceil(c, c0); | |||||
| auto n1n0c0 = n1n0 * c0; | |||||
| auto wn1n0c0 = w * n1n0c0; | |||||
| auto hwn1n0c0 = h * wn1n0c0; | |||||
| int64_t data_size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = 1; | |||||
| for (auto dim : args.dst_shape) { | |||||
| dst_size *= dim; | |||||
| } | |||||
| dst_size *= data_size; | |||||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| dst == nullptr, | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION;); | |||||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||||
| for (int64_t hi = 0; hi < h; hi++) { | |||||
| for (int64_t wi = 0; wi < w; wi++) { | |||||
| for (int64_t n1n0i = 0; n1n0i < n1n0; n1n0i++) { | |||||
| for (int64_t c0i = 0; c0i < c0; c0i++) { | |||||
| int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | |||||
| int64_t dst_offset = dst_idx * data_size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||||
| errno_t ret = EOK; | |||||
| if (pad_zero) { | |||||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||||
| static_cast<size_t>(data_size)); | |||||
| } else { | |||||
| if (protected_size < data_size) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Operate][DSTMemory]Failed, protected_size " | |||||
| "is %ld and size is %ld", | |||||
| protected_size, data_size); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); | |||||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||||
| for (int64_t index = 0; index < data_size; index++) { | |||||
| *dst_data++ = *src_data++; | |||||
| } | |||||
| } | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld," | |||||
| " error-code %d, pad mode %d", dst_offset, ret, pad_zero); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, " | |||||
| "error-code %d, pad mode %d", | |||||
| dst_offset, ret, pad_zero); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); | |||||
| std::vector<int64_t> expect_shape; | |||||
| auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | |||||
| return TransFormatNhwcToFz(args, result); | |||||
| } | |||||
| if ((args.src_format == FORMAT_HWCN) && (GetPrimaryFormat(args.dst_format) == FORMAT_FRACTAL_Z)) { | |||||
| if (GetSubFormat(args.dst_format) > 1) { | |||||
| return TransFormatHwcnToFzWithGroups(args, result, GetSubFormat(args.dst_format)); | |||||
| } | |||||
| return TransFormatHwcnToFz(args, result); | |||||
| } | |||||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||||
| return TransFormatFromNchwToFz(args, result); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | |||||
| } | |||||
| if ((src_format == FORMAT_HWCN) && (GetPrimaryFormat(dst_format) == FORMAT_FRACTAL_Z)) { | |||||
| if (GetSubFormat(dst_format) > 1) { | |||||
| return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, GetSubFormat(dst_format)); | |||||
| } | |||||
| return TransShapeHwcnToFz(src_shape, data_type, dst_shape); | |||||
| } | |||||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | |||||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferFractalZ : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ | |||||
| @@ -1,482 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fractal_zz.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| const int kDimSize4D = 4; | |||||
| const size_t kSingleDim = 1; | |||||
| const size_t kNdDimIndexN = 0; | |||||
| const size_t kNdDimIndexH = 1; | |||||
| const size_t kNdDimIndexW = 2; | |||||
| const size_t kDimDValueBNdFZz = 2; // dim d-value between Nd and FractalZz | |||||
| const size_t kNdDimCountBackwardsW = 1; | |||||
| const size_t kNdDimCountBackwardsWH = 2; | |||||
| const size_t kFZzDimCountBackwardsW0 = 1; | |||||
| const size_t kFZzDimCountBackwardsW0H0 = 2; | |||||
| const size_t kFZzDimCountBackwardsW0H0W1 = 3; | |||||
| const size_t kFZzDimCountBackwardsW0H0W1H1 = 4; | |||||
| bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } | |||||
| using ShapeVector = std::vector<int64_t>; | |||||
| bool CheckShape(Format format, const ShapeVector &shape) { | |||||
| switch (format) { | |||||
| case FORMAT_ND: | |||||
| return IsShapeValid(shape); | |||||
| case FORMAT_NCHW: | |||||
| case FORMAT_NHWC: | |||||
| return CheckShapeValid(shape, kDimSize4D); | |||||
| default: | |||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_ZZ is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| /** | |||||
| * After the conversion to two-dimensional matrix, the memory arrangement is small z and large Z. | |||||
| * @src_shape: N*H*W | |||||
| * @dst_shape: N*H1*W1*H0*w0 | |||||
| * @return | |||||
| */ | |||||
| Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape, | |||||
| ShapeVector &hw_shape) { | |||||
| dst_shape.clear(); | |||||
| hw_shape.clear(); | |||||
| auto w0 = GetCubeSizeByDataType(data_type); | |||||
| auto h0 = GetCubeSizeByDataType(data_type); | |||||
| switch (src_shape.size()) { | |||||
| case kSingleDim: | |||||
| dst_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); | |||||
| dst_shape.push_back(h0); | |||||
| dst_shape.push_back(w0); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][DSTShape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| default: | |||||
| auto size = src_shape.size(); | |||||
| int64_t times = 1; | |||||
| for (size_t i = 0; i != size - kDimDValueBNdFZz; i++) { | |||||
| dst_shape.push_back(src_shape[i]); | |||||
| times *= src_shape[i]; | |||||
| } | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); | |||||
| dst_shape.push_back(h0); | |||||
| dst_shape.push_back(w0); | |||||
| hw_shape.push_back(times); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][DSTShape]Failed, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| ShapeVector expect_src_shape; | |||||
| auto ret = TransShapeToFracZz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Trans][ShapeToFracZz] Failed from %s to %s, shape %s to %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to trans shape from %s to %s, shape %s to %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||||
| auto times = hw_shape.at(kNdDimIndexN); | |||||
| auto h = hw_shape.at(kNdDimIndexH); | |||||
| auto w = hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | |||||
| auto shape_size = args.dst_shape.size(); | |||||
| auto h1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; | |||||
| auto w1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; | |||||
| auto h0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0]; | |||||
| auto h0w0 = h0 * w0; | |||||
| auto w1h0w0 = w1 * h0w0; | |||||
| auto h1w1h0w0 = h1 * w1h0w0; | |||||
| auto num_w1 = w / w0; | |||||
| for (int64_t times_idx = 0; times_idx < times; times_idx++) { | |||||
| auto times_head = times_idx * h1w1h0w0; | |||||
| auto src_times_head = times_idx * hw; | |||||
| for (int64_t h1_idx = 0; h1_idx < h1; h1_idx++) { | |||||
| auto h1_head = times_head + h1_idx * w1h0w0; | |||||
| auto src_h1_head = h1_idx * h0; | |||||
| for (int64_t h0_idx = 0; h0_idx < h0 && h0_idx + src_h1_head < h; h0_idx++) { | |||||
| auto h0_head = h1_head + h0_idx * w0; | |||||
| auto src_h_head = src_times_head + (src_h1_head + h0_idx) * w; | |||||
| for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { | |||||
| auto src_offset = (src_h_head + w1_idx * w0) * size; | |||||
| auto dst_offset = (h0_head + w1_idx * h0w0) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size * w0)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| auto w1_head = num_w1 * w0; | |||||
| auto w0_head = h0_head + num_w1 * h0w0; | |||||
| for (int64_t w0_idx = 0; w0_idx + w1_head < w; w0_idx++) { | |||||
| auto src_w_idx = w1_head + w0_idx; | |||||
| auto src_offset = (src_h_head + src_w_idx) * size; | |||||
| auto dst_offset = (w0_head + w0_idx) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed to allcoate memory " | |||||
| "for dst buf %ld when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to allcoate memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | |||||
| auto times = dst_hw_shape.at(kNdDimIndexN); | |||||
| auto h = dst_hw_shape.at(kNdDimIndexH); | |||||
| auto w = dst_hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | |||||
| auto shape_size = args.src_shape.size(); | |||||
| auto h1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; | |||||
| auto w1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; | |||||
| auto h0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0]; | |||||
| auto h0w0 = h0 * w0; | |||||
| auto w1h0w0 = w1 * h0w0; | |||||
| auto h1w1h0w0 = h1 * w1h0w0; | |||||
| auto num_w1 = w / w0; | |||||
| for (int64_t times_idx = 0; times_idx < times; times_idx++) { | |||||
| auto times_head = times_idx * h1w1h0w0; | |||||
| auto dst_times_head = times_idx * hw; | |||||
| for (int64_t h1_idx = 0; h1_idx < h1; h1_idx++) { | |||||
| auto h1_head = times_head + h1_idx * w1h0w0; | |||||
| auto dst_h1_head = h1_idx * h0; | |||||
| for (int64_t h0_idx = 0; h0_idx < h0 && h0_idx + dst_h1_head < h; h0_idx++) { | |||||
| auto h0_head = h1_head + h0_idx * w0; | |||||
| auto dst_h_head = dst_times_head + (dst_h1_head + h0_idx) * w; | |||||
| for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { | |||||
| auto src_offset = (h0_head + w1_idx * h0w0) * size; | |||||
| auto dst_offset = (dst_h_head + w1_idx * w0) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size * w0)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| auto w1_head = num_w1 * w0; | |||||
| auto w0_head = h0_head + num_w1 * h0w0; | |||||
| for (int64_t w0_idx = 0; w0_idx + w1_head < w; w0_idx++) { | |||||
| auto src_offset = (w0_head + w0_idx) * size; | |||||
| auto dst_w_idx = w1_head + w0_idx; | |||||
| auto dst_offset = (dst_h_head + dst_w_idx) * size; | |||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Operate][DSTMemory]Failed at offset %ld, " | |||||
| "error-code %d", | |||||
| dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to operate dst memory at offset %ld, error-code %d", | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype failed, not support trans format " | |||||
| "from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| ShapeVector expect_shape; | |||||
| ShapeVector hw_shape; | |||||
| auto ret = TransShapeToFracZz(args.src_shape, args.src_data_type, expect_shape, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return TransFormatFromNdToFracZz(args, result, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalZz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||||
| Format dst_format, ShapeVector &dst_shape) { | |||||
| if (!IsDataTypeSupport(data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype failed, not support trans format from %s to %s, " | |||||
| "src shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShape(src_format, src_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, not support trans format from %s to %s, " | |||||
| "src shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
| ShapeToString(src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| ShapeVector hw_shape; | |||||
| return TransShapeToFracZz(src_shape, data_type, dst_shape, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalZzND::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (!IsDataTypeSupport(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, | |||||
| "[Check][Datatype]Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Check datatype Failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, not support trans format " | |||||
| "from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, not support trans format from %s to %s, " | |||||
| "src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, dst shape %s, data type %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| ShapeVector hw_shape; | |||||
| Status ret = CheckShapeRelation(args, hw_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| return TransFormatFromFracZzToNd(args, result, hw_shape); | |||||
| } | |||||
| Status FormatTransferFractalZzND::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type, | |||||
| Format dst_format, ShapeVector &dst_shape) { | |||||
| GELOGD("The shape derivation from %s to %s is not unique. Trans shape is not supported", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_ND, FORMAT_FRACTAL_ZZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_NCHW, FORMAT_FRACTAL_ZZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZz, FORMAT_NHWC, FORMAT_FRACTAL_ZZ) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZzND, FORMAT_FRACTAL_ZZ, FORMAT_ND) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZzND, FORMAT_FRACTAL_ZZ, FORMAT_NCHW) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZzND, FORMAT_FRACTAL_ZZ, FORMAT_NHWC) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_ZZ_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_ZZ_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| // Transfer from nd to zz | |||||
| class FormatTransferFractalZz : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| // Transfer zz to nd | |||||
| class FormatTransferFractalZzND : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_ZZ_H_ | |||||
| @@ -1,202 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | |||||
| "shape from FORMAT_FRACTAL_Z to HWCN, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to HWCN, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t c1 = Ceil(dst_shape.at(kHwcnC), c0); | |||||
| int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | |||||
| if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || | |||||
| src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | |||||
| std::string error = "Failed to check relationship between src shape" + | |||||
| FmtToStr(ShapeToString(src_shape)) + " and dst shape" + | |||||
| FmtToStr(ShapeToString(dst_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allocate][DSTMemory]Failed, memory for dst buf %ld, shape %s " | |||||
| "when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, shape %s " | |||||
| "when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto n0 = args.src_shape.at(kFracZN0); | |||||
| auto ni = args.src_shape.at(kFracZNi); | |||||
| auto c0 = args.src_shape.at(kFracZC0); | |||||
| auto h = args.dst_shape.at(kHwcnH); | |||||
| auto w = args.dst_shape.at(kHwcnW); | |||||
| auto c = args.dst_shape.at(kHwcnC); | |||||
| auto n = args.dst_shape.at(kHwcnN); | |||||
| int64_t nc = ni * n0; | |||||
| int64_t ncc0 = nc * c0; | |||||
| int64_t wncc0 = w * ncc0; | |||||
| int64_t hwncc0 = h * wncc0; | |||||
| int64_t cn = c * n; | |||||
| int64_t wcn = w * cn; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = h_idx * wcn; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * cn; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t c_head_addr = w_head_addr + c_idx * n; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t dst_idx = c_head_addr + n_idx; | |||||
| int64_t c1_idx = c_idx / c0; | |||||
| int64_t c0_idx = c_idx % c0; | |||||
| int64_t nc_idx = n_idx; | |||||
| int64_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| total_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from FracZ offset %ld to " | |||||
| "HWCN[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from FracZ offset %ld to " | |||||
| "HWCN[%ld, %ld, %ld, %ld], offset %ld, err-code %d", | |||||
| src_offset, h_idx, w_idx, c_idx, n_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForFracZToHwcn(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][ShapeSize]Failed, " | |||||
| "total size %ld from dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from " | |||||
| "dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from FracZ to HWCN, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed after trans, src shape %s, " | |||||
| "data type %s, dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, " | |||||
| "data type %s, dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferFracZHwcn::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from FracZ to HWCN is not unique. Trans shape in this direction is not supported"); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFracZHwcn, FORMAT_FRACTAL_Z, FORMAT_HWCN) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_HWCN_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_HWCN_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferFracZHwcn : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_HWCN_H_ | |||||
| @@ -1,207 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fracz_nchw.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NCHW) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | |||||
| "shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(dst_shape, kNchwDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t c1 = Ceil(dst_shape.at(kNchwC), c0); | |||||
| int64_t n0 = Ceil(dst_shape.at(kNchwN), static_cast<int64_t>(kNiSize)); | |||||
| if (src_shape.at(kFracZHWC1) != dst_shape.at(kNchwH) * dst_shape.at(kNchwW) * c1 || | |||||
| src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed to check relationship between src and dst shape, " | |||||
| "src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to check relationship between src and dst shape, " | |||||
| "src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allocate][DSTMemory]Failed, memory for dst buf %ld, shape %s " | |||||
| "when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, shape %s " | |||||
| "when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto n0 = args.src_shape.at(kFracZN0); | |||||
| auto ni = args.src_shape.at(kFracZNi); | |||||
| auto c0 = args.src_shape.at(kFracZC0); | |||||
| auto h = args.dst_shape.at(kNchwH); | |||||
| auto w = args.dst_shape.at(kNchwW); | |||||
| auto c = args.dst_shape.at(kNchwC); | |||||
| auto n = args.dst_shape.at(kNchwN); | |||||
| int64_t nc = ni * n0; | |||||
| int64_t ncc0 = nc * c0; | |||||
| int64_t wncc0 = w * ncc0; | |||||
| int64_t hwncc0 = h * wncc0; | |||||
| int64_t hw = h * w; | |||||
| int64_t chw = c * hw; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * chw; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t c_head_addr = n_head_addr + c_idx * hw; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = c_head_addr + h_idx * w; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t dst_idx = h_head_addr + w_idx; | |||||
| int64_t c1_idx = c_idx / c0; | |||||
| int64_t c0_idx = c_idx % c0; | |||||
| int64_t nc_idx = n_idx; | |||||
| int64_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| total_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from FracZ offset %ld to " | |||||
| "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||||
| "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret ); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForFracZToNchw(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][ShapeSize]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, " | |||||
| "data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferFracZNchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from FracZ to NCHW is not unique. Trans shape in this direction is not supported"); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFracZNchw, FORMAT_FRACTAL_Z, FORMAT_NCHW) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NCHW_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NCHW_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferFracZNchw : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NCHW_H_ | |||||
| @@ -1,206 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_fracz_nhwc.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status CheckArgsForFracZToNhwc(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NHWC) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, " | |||||
| "shape from FORMAT_FRACTAL_Z to NCHW, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from FORMAT_FRACTAL_Z to NCHW, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(src_shape, kFracZDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(dst_shape, kNhwcDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t c1 = Ceil(dst_shape.at(kNhwcC), c0); | |||||
| int64_t n0 = Ceil(dst_shape.at(kNhwcN), static_cast<int64_t>(kNiSize)); | |||||
| if (src_shape.at(kFracZHWC1) != dst_shape.at(kNhwcH) * dst_shape.at(kNhwcW) * c1 || | |||||
| src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Check][Shape]Failed to check relationship between src and dst shape, " | |||||
| "src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to check relationship between src and dst shape, " | |||||
| "src shape %s, dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size, int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allocate][DSTMemory]Failed, memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto n0 = args.src_shape.at(kFracZN0); | |||||
| auto ni = args.src_shape.at(kFracZNi); | |||||
| auto c0 = args.src_shape.at(kFracZC0); | |||||
| auto h = args.dst_shape.at(kNhwcH); | |||||
| auto w = args.dst_shape.at(kNhwcW); | |||||
| auto c = args.dst_shape.at(kNhwcC); | |||||
| auto n = args.dst_shape.at(kNhwcN); | |||||
| int64_t nc = ni * n0; | |||||
| int64_t ncc0 = nc * c0; | |||||
| int64_t wncc0 = w * ncc0; | |||||
| int64_t hwncc0 = h * wncc0; | |||||
| int64_t wc = w * c; | |||||
| int64_t hwc = h * wc; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * hwc; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = n_head_addr + h_idx * wc; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * c; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t dst_idx = w_head_addr + c_idx; | |||||
| int64_t c1_idx = c_idx / c0; | |||||
| int64_t c0_idx = c_idx % c0; | |||||
| int64_t nc_idx = n_idx; | |||||
| int64_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? | |||||
| total_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from FracZ offset %ld to " | |||||
| "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999","Failed to copy data from FracZ offset %ld to " | |||||
| "NCHW[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| src_offset, n_idx, c_idx, h_idx, w_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForFracZToNhwc(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Get][ShapeSize]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from FracZ to NHWC, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999","Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferFracZNhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from FracZ to NHWC is not unique. Trans shape in this direction is not supported"); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferFracZNhwc, FORMAT_FRACTAL_Z, FORMAT_NHWC) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NHWC_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NHWC_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferFracZNhwc : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACZ_NHWC_H_ | |||||
| @@ -1,264 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { | |||||
| return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||||
| } | |||||
| Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<int64_t> &src_shape, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| auto cube_size = GetCubeSizeByDataType(data_type); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); | |||||
| dst_shape.push_back(src_shape.at(kHwcnH)); | |||||
| dst_shape.push_back(src_shape.at(kHwcnW)); | |||||
| dst_shape.push_back(src_shape.at(kHwcnN)); | |||||
| dst_shape.push_back(cube_size); | |||||
| dst_shape.push_back(cube_size); | |||||
| if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||||
| if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_C1HWNCoC0) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Trans][Shape]Failed, " | |||||
| "shape from HWCN to C1HWNCoC0, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from HWCN to C1HWNCoC0, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.src_shape, kHwcnDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.dst_shape, kC1hwncoc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| std::vector<int64_t> expect_dst_shape; | |||||
| auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (args.dst_shape != expect_dst_shape) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Trans][Shape]Failed, src shape %s and dst shape %s are not compatible. " | |||||
| "expect dst shape %s", | |||||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_dst_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans format, src shape %s and dst shape %s " | |||||
| "are not compatible. expect dst shape %s", | |||||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed, " | |||||
| "memory for dst buf %ld, shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto h = args.src_shape.at(kHwcnH); | |||||
| auto w = args.src_shape.at(kHwcnW); | |||||
| auto c = args.src_shape.at(kHwcnC); | |||||
| auto n = args.src_shape.at(kHwcnN); | |||||
| auto c1 = args.dst_shape.at(kC1hwncoc0C1); | |||||
| auto c0 = args.dst_shape.at(kC1hwncoc0C0); | |||||
| auto co = args.dst_shape.at(kC1hwncoc0Co); | |||||
| int64_t coc0 = co * c0; | |||||
| int64_t ncoc0 = n * coc0; | |||||
| int64_t wncoc0 = w * ncoc0; | |||||
| int64_t hwncoc0 = h * wncoc0; | |||||
| int64_t cn = c * n; | |||||
| int64_t wcn = w * cn; | |||||
| for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { | |||||
| int64_t c1_head_addr = c1_idx * hwncoc0; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = c1_head_addr + h_idx * wncoc0; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * ncoc0; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = w_head_addr + n_idx * coc0; | |||||
| for (int64_t co_idx = 0; co_idx < co; co_idx++) { | |||||
| int64_t co_head_addr = n_head_addr + co_idx * c0; | |||||
| for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) { | |||||
| int64_t dst_idx = c0_idx + co_head_addr; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| int64_t c_idx = c0_idx + c1_idx * c0; | |||||
| int64_t src_idx = h_idx * wcn + w_idx * cn + c_idx * n + n_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| if (c_idx < c && c0_idx == co_idx) { | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Copy][Data]Failed, " | |||||
| "data from HWCN[%ld, %ld, %ld, %ld] offset %ld to " | |||||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, | |||||
| n_idx, co_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from " | |||||
| "HWCN[%ld, %ld, %ld, %ld] offset %ld " | |||||
| "to, C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] " | |||||
| "offset %ld, err-code %d", | |||||
| h_idx, w_idx, c_idx, n_idx, src_offset, c1_idx, h_idx, w_idx, | |||||
| n_idx, co_idx, c0_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } else { | |||||
| auto ret = | |||||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to set to 0 to " | |||||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set to 0 to " | |||||
| "C1HWNCoC0[%ld, %ld, %ld, %ld, %ld, %ld] offset %ld, " | |||||
| "err-code %d", | |||||
| c1_idx, h_idx, w_idx, n_idx, co_idx, c0_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForHwcnToC1hwncoc0(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][ShapeSize]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from HWCN to C1HWNCoC0, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { | |||||
| if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); | |||||
| } else if (src_format != FORMAT_HWCN) { | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } else { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferHwcnC1hwncoc0, FORMAT_HWCN, FORMAT_C1HWNCoC0) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferHwcnC1hwncoc0 : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_HWCN_C1HWNCOC0_H_ | |||||
| @@ -1,207 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NCHW) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, shape from NC1HWC0 to NCHW, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from NC1HWC0 to NCHW, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.dst_shape, kNchwDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 <= 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Get][Cube]Failed, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get cube size, the data tyep %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | |||||
| src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | |||||
| src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed to check relationship between " | |||||
| "src shape %s and dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to check relationship between src shape %s " | |||||
| "and dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed, " | |||||
| "memory for dst buf %ld, shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto h = args.src_shape.at(kNc1hwc0H); | |||||
| auto w = args.src_shape.at(kNc1hwc0W); | |||||
| auto n = args.src_shape.at(kNc1hwc0N); | |||||
| auto c1 = args.src_shape.at(kNc1hwc0C1); | |||||
| auto c0 = args.src_shape.at(kNc1hwc0C0); | |||||
| auto c = args.dst_shape.at(kNchwC); | |||||
| int64_t hw = h * w; | |||||
| int64_t chw = c * hw; | |||||
| int64_t wc0 = w * c0; | |||||
| int64_t hwc0 = h * wc0; | |||||
| int64_t c1hwc0 = c1 * hwc0; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * chw; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t c_head_addr = n_head_addr + c_idx * hw; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = c_head_addr + h_idx * w; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t dst_idx = h_head_addr + w_idx; | |||||
| int64_t c1_idx = c_idx / c0; | |||||
| int64_t c0_idx = c_idx % c0; | |||||
| int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Copy][Data]Failed, data from " | |||||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] " | |||||
| "src offset %ld to NCHW[%ld, %ld, %ld, %ld], dst offset %ld, err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, | |||||
| c_idx, h_idx, w_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] " | |||||
| "src offset %ld to NCHW[%ld, %ld, %ld, %ld], dst offset %ld, " | |||||
| "err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, | |||||
| c_idx, h_idx, w_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForNc1hwc0ToNchw(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Get][ShapeSize]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from NC1HWC0 to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferNc1hwc0Nchw::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from NC1HWC0 to NCHW is not unique. Trans shape in this direction is not supported"); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nchw, FORMAT_NC1HWC0, FORMAT_NCHW) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NCHW_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NCHW_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferNc1hwc0Nchw : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NCHW_H_ | |||||
| @@ -1,208 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
| auto src_shape = args.src_shape; | |||||
| auto dst_shape = args.dst_shape; | |||||
| if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NHWC) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed, shape from NC1HWC0 to NHWC, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from NC1HWC0 to NHWC, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.src_shape, kNc1hwc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.dst_shape, kNhwcDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 <= 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Get][Cube]Failed, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get cube size, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | |||||
| src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | |||||
| src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed to check relationship between " | |||||
| "src shape %s and dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to check relationship between src shape %s " | |||||
| "and dst shape %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allocate][DSTMemory]Failed, memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto h = args.src_shape.at(kNc1hwc0H); | |||||
| auto w = args.src_shape.at(kNc1hwc0W); | |||||
| auto n = args.src_shape.at(kNc1hwc0N); | |||||
| auto c1 = args.src_shape.at(kNc1hwc0C1); | |||||
| auto c0 = args.src_shape.at(kNc1hwc0C0); | |||||
| auto c = args.dst_shape.at(kNhwcC); | |||||
| int64_t wc = w * c; | |||||
| int64_t hwc = h * wc; | |||||
| int64_t wc0 = w * c0; | |||||
| int64_t hwc0 = h * wc0; | |||||
| int64_t c1hwc0 = c1 * hwc0; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * hwc; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = n_head_addr + h_idx * wc; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * c; | |||||
| for (int64_t c_idx = 0; c_idx < c; c_idx++) { | |||||
| int64_t dst_idx = w_head_addr + c_idx; | |||||
| int64_t c1_idx = c_idx / c0; | |||||
| int64_t c0_idx = c_idx % c0; | |||||
| int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| auto dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Copy][Data]Failed, data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] " | |||||
| "offset %ld to NHWC[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, | |||||
| h_idx, w_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from NC1HWC0[%ld, %ld, %ld, %ld, %ld] " | |||||
| "offset %ld to NHWC[%ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, src_offset, n_idx, c_idx, | |||||
| h_idx, w_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForNc1hwc0ToNhwc(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][ShapeSize]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD("[Trans][Format]Begin to trans format from NC1HWC0 to NCHW, " | |||||
| "src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "[Get][Data]Failed, after trans, src shape %s, " | |||||
| "data type %s, dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferNc1hwc0Nhwc::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| GELOGD("The shape derivation from NC1HWC0 to NHWC is not unique. Trans shape in this direction is not supported"); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferNc1hwc0Nhwc, FORMAT_NC1HWC0, FORMAT_NHWC) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NHWC_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NHWC_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferNc1hwc0Nhwc : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NC1HWC0_NHWC_H_ | |||||
| @@ -1,367 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_nchw_fz_c04.h" | |||||
| #include "common/formats/format_transfers/format_transfer_transpose.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include <cstdlib> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| /** 【Explain about transfer from nchw to FZ_CO4】 | |||||
| * First Step: Padding in N and C axis. Here C must be less or equal than 4 | |||||
| * After Padding, it will be like (n = ceil(n,16)*16, 4, h, w) | |||||
| * Second Step: transpose. It will be like (n = ceil(n,16)*16, h, w, 4) | |||||
| * Third Step: View the 4D as 2D , first dim is N, second dim is h*w*c. | |||||
| * Padding to (N, ceil(Z/16)*16) | |||||
| * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) | |||||
| */ | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| constexpr int64_t kMaxDimsNumC = 4; | |||||
| Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | |||||
| Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 < 0) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| auto chw = c * h * w; | |||||
| auto first_dim = Ceil(chw, c0); | |||||
| auto no = Ceil(n, static_cast<int64_t>(c0)); | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(first_dim); | |||||
| dst_shape.push_back(no); | |||||
| dst_shape.push_back(c0); | |||||
| dst_shape.push_back(c0); | |||||
| if (!IsShapeValid(dst_shape)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransShapeNchwToFzC04(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto n = src_shape.at(kNchwN); | |||||
| auto c = src_shape.at(kNchwC); | |||||
| auto h = src_shape.at(kNchwH); | |||||
| auto w = src_shape.at(kNchwW); | |||||
| return TransShape(n, c, h, w, data_type, dst_shape); | |||||
| } | |||||
| Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
| int64_t n = args.src_shape.at(kNchwN); | |||||
| int64_t c = args.src_shape.at(kNchwC); | |||||
| int64_t h = args.src_shape.at(kNchwH); | |||||
| int64_t w = args.src_shape.at(kNchwW); | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto data = args.data; | |||||
| TransResult trans_result_1; | |||||
| std::vector<int64_t> perm_arg_1 = {0, 2, 3, 1}; | |||||
| std::vector<int64_t> expect_shape = {n, h, w, c}; | |||||
| auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Trans][Formats]Failed from NCHW to HWCN, src_shape %s, src_data_type %s", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failede to trans formats from NCHW to HWCN, src_shape %s, " | |||||
| "src_data_type %s", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ret; | |||||
| } | |||||
| TransArgs args_tmp = args; | |||||
| args_tmp.src_shape = expect_shape; | |||||
| args_tmp.data = trans_result_1.data.get(); | |||||
| // check size it should be same with original | |||||
| size_t expect_size = n * c * h * w * size; // before has do check about mul | |||||
| if (trans_result_1.length != expect_size) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Shape]size %zu is not match expect size %zu " | |||||
| "after transpose", | |||||
| trans_result_1.length, expect_size); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| // prepare for padding in chw | |||||
| int64_t tmp = h * w * c; | |||||
| int64_t n_o = Ceil(n, static_cast<int64_t>(c0)); | |||||
| int64_t c_o = c0; | |||||
| int64_t h_o = Ceil(tmp, c0); | |||||
| int64_t w_o = c0; | |||||
| std::vector<int64_t> shape_o = {n_o, c_o, h_o, w_o}; | |||||
| // data overflow check totally | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%ld]", h_o, w_o); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%ld]", n_o, c_o); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| auto t1 = h_o * w_o; | |||||
| auto t2 = n_o * c_o; | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%d]", total_ele_cnt, size); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| int64_t dst_size = total_ele_cnt * size; | |||||
| if (dst_size == 0) { | |||||
| result.length = 0; | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to alloc the memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld " | |||||
| "when trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); | |||||
| if (retMem != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Set][Memory]Failed, dst buf %ld, error_code %d", | |||||
| dst_size, retMem); | |||||
| REPORT_CALL_ERROR("E19999", "Set memory failed, dst buf %ld, error_code %d", dst_size, retMem); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| // copy data | |||||
| auto block = c * h * w * size; | |||||
| auto stride = h_o * w_o * size; | |||||
| auto p_s = trans_result_1.data.get(); | |||||
| auto p_d = dst.get(); | |||||
| auto protectSize = dst_size; | |||||
| for (auto k = 0; k < n; k++) { | |||||
| ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Set][Memcpy]Failed, block %zu, stride %zu, " | |||||
| "protect_size %ld, error_code %d", block, stride, protectSize, ret); | |||||
| REPORT_CALL_ERROR("E19999", "[Set][Memcpy]Failed, block %zu, stride %zu, " | |||||
| "protect_size %ld, error_code %d", block, stride, protectSize, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| protectSize = protectSize - block; | |||||
| } | |||||
| // transpose : 2,0,1,3 | |||||
| std::vector<int64_t> perm_arg_2 = {2, 0, 1, 3}; | |||||
| ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Trans][Formats]Failed from NCHW to HWCN, error_code %u", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to trans formats from NCHW to HWCN, error_code %u", ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uint8_t> &dst) { | |||||
| args_tmp = args; | |||||
| auto src_shape = args_tmp.src_shape; | |||||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| auto n = src_shape.at(kNchwN); | |||||
| auto c = src_shape.at(kNchwC); | |||||
| auto h = src_shape.at(kNchwH); | |||||
| auto w = src_shape.at(kNchwW); | |||||
| if (c > kMaxDimsNumC) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Invalid dim c num[%lu]. " | |||||
| "It should be in (0,4]", c); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| auto n_o = Ceil(n, c0) * c0; | |||||
| auto c_o = kMaxDimsNumC; | |||||
| auto h_o = h; | |||||
| auto w_o = w; | |||||
| args_tmp.src_shape.at(kNchwN) = n_o; | |||||
| args_tmp.src_shape.at(kNchwC) = c_o; | |||||
| args_tmp.src_shape.at(kNchwH) = h_o; | |||||
| args_tmp.src_shape.at(kNchwW) = w_o; | |||||
| // data overflow check | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", h_o, w_o); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%ld]", h_o, w_o); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", n_o, c_o); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%ld]", n_o, c_o); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| auto t1 = h_o * w_o; | |||||
| auto t2 = n_o * c_o; | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%ld]", t1, t2); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%ld]", t1, t2); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Shape]Failed, " | |||||
| "int64 mul overflow.A[%ld], B[%d]", total_ele_cnt, size); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, int64 mul overflow.A[%ld], " | |||||
| "B[%d]", total_ele_cnt, size); | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR); | |||||
| int64_t dst_size = total_ele_cnt * size; | |||||
| if (dst_size == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to alloc the memory for dst buf %ld when " | |||||
| "trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld when " | |||||
| "trans format from %s to %s", | |||||
| dst_size, TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto ret = memset_s(dst.get(), dst_size, 0, dst_size); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Set][Memory]Failed, dst buf %ld, error_code %d", | |||||
| dst_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Set memory failed, dst buf %ld, error_code %d", dst_size, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| auto p_s = args.data; | |||||
| auto p_d = dst.get(); | |||||
| auto block = h * w * size; | |||||
| auto protectSize = dst_size; | |||||
| for (int i = 0; i < n; i++) { | |||||
| for (int j = 0; j < c; j++) { | |||||
| ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, | |||||
| p_s + (i * c * h * w + j * h * w) * size, block); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Set][Memcpy]Failed, block %zu, " | |||||
| "protect_size %ld, error_code %d", block, protectSize, ret); | |||||
| REPORT_CALL_ERROR("E19999", "[Set][Memcpy]Failed, block %zu, protect_size %ld, " | |||||
| "error_code %d", block, protectSize, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| protectSize = protectSize - block; | |||||
| } | |||||
| } | |||||
| args_tmp.data = dst.get(); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); | |||||
| TransArgs args_tmp = args; | |||||
| std::shared_ptr<uint8_t> dst = nullptr; | |||||
| auto ret = PaddingNC(args, args_tmp, dst); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Padding][NCAxis]Failed, error_code %u", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Padding in NC axis failed, error_code %u", ret); | |||||
| return ret; | |||||
| } | |||||
| std::vector<int64_t> expect_shape; | |||||
| ret = TransShape(args_tmp.src_format, args_tmp.src_shape, args_tmp.src_data_type, | |||||
| args_tmp.dst_format, expect_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { | |||||
| return TransFormatFromNchwToFzC04(args_tmp, result); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { | |||||
| return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); | |||||
| } | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferNchwToFZC04 : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const ge::formats::TransArgs &args, ge::formats::TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ | |||||
| @@ -1,247 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| int64_t c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 <= 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Get][Cube]Failed, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get cube size, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(src_shape.at(kNchwN)); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); | |||||
| dst_shape.push_back(src_shape.at(kNchwH)); | |||||
| dst_shape.push_back(src_shape.at(kNchwW)); | |||||
| dst_shape.push_back(c0); | |||||
| if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
| if (args.src_format != FORMAT_NCHW || args.dst_format != FORMAT_NC1HWC0) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| std::vector<int64_t> expect_5d_shape; | |||||
| auto ret = TransShapeNchwToNc1hwc0(args.src_shape, args.src_data_type, expect_5d_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (expect_5d_shape != args.dst_shape) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Trans][Format]Failed, the src and dst shape are not compatible. " | |||||
| "data type %s, src shape %s, dst shape %s, expect dst shape %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(expect_5d_shape).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans formats, the src and dst shape are not " | |||||
| "compatible. data type %s, src shape %s, dst shape %s, expect dst shape %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_5d_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "[Allcoate][Memory]Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto n = args.src_shape.at(kNchwN); | |||||
| auto c = args.src_shape.at(kNchwC); | |||||
| auto h = args.src_shape.at(kNchwH); | |||||
| auto w = args.src_shape.at(kNchwW); | |||||
| int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
| if (c0 <= 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][Shape]The c0 is invalid %ld, data_type %s", | |||||
| c0, TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Check shape failed, the c0 is invalid %ld, data_type %s", | |||||
| c0, TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| int64_t c1 = (c - 1) / c0 + 1; | |||||
| int64_t hw = h * w; | |||||
| int64_t chw = c * hw; | |||||
| int64_t hwc0 = hw * c0; | |||||
| int64_t c1hwc0 = c1 * hwc0; | |||||
| int64_t wc0 = w * c0; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * c1hwc0; | |||||
| for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { | |||||
| int64_t c1_head_addr = n_head_addr + c1_idx * hwc0; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = c1_head_addr + h_idx * wc0; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * c0; | |||||
| for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) { | |||||
| int64_t dst_index = c0_idx + w_head_addr; | |||||
| int64_t dst_offset = dst_index * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| int64_t cIdx = c0_idx + c1_idx * c0; | |||||
| int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; | |||||
| auto src_offset = srcIdx * size; | |||||
| if (cIdx < c) { | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | |||||
| static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from NCHW[%ld] offset %ld " | |||||
| "to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| srcIdx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from NCHW[%ld] offset %ld " | |||||
| "to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| srcIdx, src_offset, n_idx, c1_idx, h_idx, w_idx, c0_idx, | |||||
| dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } else { | |||||
| auto ret = | |||||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to set to 0 to NC1HWC0[%ld, %ld, %ld, %ld, %ld] " | |||||
| "offset %ld, err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set to 0 to " | |||||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld, err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForNchwToNc1hwc0(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| // Guarantee the validity of parameters in check function | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Get][Shape]Failed, total size %ld from dst shape %s, " | |||||
| "src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get total size %ld from dst shape %s, src shape %s", | |||||
| total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| GELOGD( | |||||
| "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
| "%s, dst shape %s memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| if (src_format == FORMAT_NCHW) { | |||||
| return TransShapeNchwToNc1hwc0(src_shape, data_type, dst_shape); | |||||
| } else { | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferNchwNc1hwc0, FORMAT_NCHW, FORMAT_NC1HWC0) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NCHW_NC1HWC0_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NCHW_NC1HWC0_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferNchwNc1hwc0 : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NCHW_NC1HWC0_H_ | |||||
| @@ -1,258 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| bool CheckDataTypeSupported(const DataType &data_type) { return GetSizeByDataType(data_type) > 0; } | |||||
| Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| int64_t c0 = GetCubeSizeByDataType(data_type); | |||||
| if (c0 <= 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Get][Cube]Failed, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get cube size, the data type %s is invalid", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| dst_shape.clear(); | |||||
| dst_shape.push_back(src_shape.at(kNhwcN)); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); | |||||
| dst_shape.push_back(src_shape.at(kNhwcH)); | |||||
| dst_shape.push_back(src_shape.at(kNhwcW)); | |||||
| dst_shape.push_back(c0); | |||||
| if (!CheckShapeValid(dst_shape, kNc1hwc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", | |||||
| ShapeToString(dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||||
| if (args.src_format != FORMAT_NHWC || args.dst_format != FORMAT_NC1HWC0) { | |||||
| std::string error = "Dose not support trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| if (!CheckDataTypeSupported(args.src_data_type)) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Check][DataType]Failed from NHWC to NC1HWC0, " | |||||
| "invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Failed to trans shape from NHWC to NC1HWC0, invalid data type %s", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.src_shape, kNhwcDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| if (!CheckShapeValid(args.dst_shape, kNc1hwc0DimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Dst shape %s check valid", | |||||
| ShapeToString(args.dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| std::vector<int64_t> expect_dst_shape; | |||||
| auto ret = TransShapeNhwcToNc1hwc0(args.src_shape, args.src_data_type, expect_dst_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (args.dst_shape != expect_dst_shape) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, | |||||
| "[Trans][Format]Failed , the src shape %s and dst shape %s are not compatible. " | |||||
| "expect dst shape %s", | |||||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_dst_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to trans format, the src shape %s and " | |||||
| "dst shape %s are not compatible. expect dst shape %s", | |||||
| ShapeToString(args.src_shape).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
| ShapeToString(expect_dst_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||||
| if (dst == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Allcoate][Memory]Failed, memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc the memory for dst buf %ld, " | |||||
| "shape %s when trans format from %s to %s", | |||||
| total_size, ShapeToString(args.dst_shape).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| auto n = args.src_shape.at(kNhwcN); | |||||
| auto h = args.src_shape.at(kNhwcH); | |||||
| auto w = args.src_shape.at(kNhwcW); | |||||
| auto c = args.src_shape.at(kNhwcC); | |||||
| auto c1 = args.dst_shape.at(kNc1hwc0C1); | |||||
| auto c0 = args.dst_shape.at(kNc1hwc0C0); | |||||
| int64_t wc = w * c; | |||||
| int64_t hwc = h * wc; | |||||
| int64_t wc0 = w * c0; | |||||
| int64_t hwc0 = h * wc0; | |||||
| int64_t c1hwc0 = c1 * hwc0; | |||||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||||
| int64_t n_head_addr = n_idx * c1hwc0; | |||||
| for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) { | |||||
| int64_t c1_head_addr = n_head_addr + c1_idx * hwc0; | |||||
| for (int64_t h_idx = 0; h_idx < h; h_idx++) { | |||||
| int64_t h_head_addr = c1_head_addr + h_idx * wc0; | |||||
| for (int64_t w_idx = 0; w_idx < w; w_idx++) { | |||||
| int64_t w_head_addr = h_head_addr + w_idx * c0; | |||||
| for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) { | |||||
| int64_t dst_idx = c0_idx + w_head_addr; | |||||
| int64_t dst_offset = dst_idx * size; | |||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| int64_t c_idx = c0_idx + c1_idx * c0; | |||||
| int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; | |||||
| auto src_offset = src_idx * size; | |||||
| if (c_idx < c) { | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, protected_size, args.data + src_offset, size); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to copy data from NHWC[%ld, %ld, %ld, %ld] " | |||||
| "offset %ld to NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld err-code %d", | |||||
| n_idx, h_idx, w_idx, c_idx, src_offset, | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to copy data from NHWC[%ld, %ld, %ld, %ld] " | |||||
| "offset %ld to " | |||||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld err-code %d", | |||||
| n_idx, h_idx, w_idx, c_idx, src_offset, | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } else { | |||||
| auto ret = memset_s(dst.get() + dst_offset, protected_size, 0, size); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to set 0 to " | |||||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld base err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set 0 to " | |||||
| "NC1HWC0[%ld, %ld, %ld, %ld, %ld] offset %ld base err-code %d", | |||||
| n_idx, c1_idx, h_idx, w_idx, c0_idx, dst_offset, ret); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| Status ret = CheckArgsForNhwcToNc1hwc0(args); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Get][ShapeSize]Failed, " | |||||
| "total size %ld from dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "[Get][Shape]Failed, total size %ld from " | |||||
| "dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| GELOGD("Begin to trans format from NHWC to NC1HWC0, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| ret = GetDstDataAfterTrans(args, result, size, total_size); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Get][Data]Failed, after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to get data after trans, src shape %s, data type %s, " | |||||
| "dst shape %s, memory size %ld, error_code %u", | |||||
| ShapeToString(args.src_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size, ret); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferNhwcNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| if (src_format == FORMAT_NHWC && CheckDataTypeSupported(data_type)) { | |||||
| if (!CheckShapeValid(src_shape, kNhwcDimsNum)) { | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", | |||||
| ShapeToString(src_shape).c_str()); | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return TransShapeNhwcToNc1hwc0(src_shape, data_type, dst_shape); | |||||
| } else if (src_format != FORMAT_NHWC) { | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } else { | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferNhwcNc1hwc0, FORMAT_NHWC, FORMAT_NC1HWC0) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,35 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NHWC_NC1HWC0_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NHWC_NC1HWC0_H_ | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| class FormatTransferNhwcNc1hwc0 : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_NHWC_NC1HWC0_H_ | |||||
| @@ -1,273 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/format_transfers/format_transfer_transpose.h" | |||||
| #include <securec.h> | |||||
| #include <memory> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| namespace { | |||||
| std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | |||||
| {FORMAT_NCHW, | |||||
| {{FORMAT_NHWC, std::vector<int64_t>({kNchwN, kNchwH, kNchwW, kNchwC})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kNchwH, kNchwW, kNchwC, kNchwN})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kNchwC, kNchwH, kNchwW, kNchwN})}}}, | |||||
| {FORMAT_NHWC, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kNhwcN, kNhwcC, kNhwcH, kNhwcW})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kNhwcC, kNhwcH, kNhwcW, kNhwcN})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}}, | |||||
| {FORMAT_HWCN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kHwcnN, kHwcnC, kHwcnH, kHwcnW})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({kHwcnN, kHwcnH, kHwcnW, kHwcnC})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}}, | |||||
| {FORMAT_CHWN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kChwnN, kChwnC, kChwnH, kChwnW})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({kChwnN, kChwnH, kChwnW, kChwnC})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}}, | |||||
| }; | |||||
| bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | |||||
| if (src_shape.empty()) { | |||||
| std::string error = "Failed to transpose, empty src shape"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Trans][Shape]Failed, empty src shape"); | |||||
| return false; | |||||
| } | |||||
| for (auto dim : src_shape) { | |||||
| if (dim < 0) { | |||||
| std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (perm_arg.size() != src_shape.size()) { | |||||
| std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + | |||||
| " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| std::vector<int64_t> exists(perm_arg.size()); | |||||
| for (auto perm : perm_arg) { | |||||
| if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | |||||
| std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + | |||||
| ", perm arg " + FmtToStr(JoinToString(perm_arg)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | |||||
| const std::vector<int64_t> &perm_arg) { | |||||
| if (src == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Trans][Param]Failed, the src is null"); | |||||
| return false; | |||||
| } | |||||
| if (GetSizeByDataType(src_data_type) < 0) { | |||||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "[Trans][Param]Failed, the data type %s is not support", | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to transpose, the data type %s is not support", | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| return false; | |||||
| } | |||||
| return IsShapeArgValid(src_shape, perm_arg); | |||||
| } | |||||
| std::vector<int64_t> GenHeads(const std::vector<int64_t> &shape) { | |||||
| std::vector<int64_t> heads(shape.size()); | |||||
| bool first = true; | |||||
| for (auto i = static_cast<int64_t>(shape.size() - 1); i >= 0; --i) { | |||||
| if (first) { | |||||
| heads[i] = 1; | |||||
| first = false; | |||||
| } else { | |||||
| heads[i] = shape[i + 1] * heads[i + 1]; | |||||
| } | |||||
| } | |||||
| return heads; | |||||
| } | |||||
| int64_t GenOffset(const std::vector<int64_t> &offsets, const std::vector<int64_t> &indexes) { | |||||
| int64_t offset = 0; | |||||
| for (size_t i = 0; i < indexes.size(); ++i) { | |||||
| offset += offsets[i] * indexes[i]; | |||||
| } | |||||
| return offset; | |||||
| } | |||||
| void AddOne(const std::vector<int64_t> &shape, std::vector<int64_t> &indexes) { | |||||
| size_t i = indexes.size() - 1; | |||||
| indexes[i]++; | |||||
| while (i > 0) { | |||||
| if (indexes[i] >= shape[i]) { | |||||
| indexes[i] = 0; | |||||
| indexes[i - 1]++; | |||||
| --i; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<int64_t> TransShapeByPerm(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | |||||
| std::vector<int64_t> dst_shape(src_shape.size()); | |||||
| for (size_t i = 0; i < perm_arg.size(); ++i) { | |||||
| dst_shape[i] = src_shape[perm_arg[i]]; | |||||
| } | |||||
| return dst_shape; | |||||
| } | |||||
| } // namespace | |||||
| Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | |||||
| const std::vector<int64_t> &perm_arg, TransResult &result) { | |||||
| if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) { | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| auto dst_shape = TransShapeByPerm(src_shape, perm_arg); | |||||
| auto src_origin_ordered_heads = GenHeads(src_shape); | |||||
| auto src_heads = TransShapeByPerm(src_origin_ordered_heads, perm_arg); | |||||
| int64_t dst_ele_num = GetItemNumByShape(dst_shape); | |||||
| int64_t data_size = GetSizeByDataType(src_data_type); | |||||
| int64_t dst_size = data_size * dst_ele_num; | |||||
| GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | |||||
| JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| if (dst_ele_num == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| int64_t dst_index = 0; | |||||
| std::vector<int64_t> dst_indexes(dst_shape.size()); | |||||
| while (dst_index < dst_ele_num) { | |||||
| auto src_offset = GenOffset(src_heads, dst_indexes) * data_size; | |||||
| auto dst_offset_bytes = dst_index * data_size; | |||||
| auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||||
| ? dst_size - dst_offset_bytes | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | |||||
| static_cast<size_t>(data_size)); | |||||
| if (ret != EOK) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, | |||||
| "[Operate][Memory]Failed to transpose, src shape %s, perm arg %s, dst shape %s, " | |||||
| "failed to write to dst offset %ld, current dim offset %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), | |||||
| dst_offset_bytes, ShapeToString(dst_indexes).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to transpose, src shape %s, perm arg %s, dst shape %s, " | |||||
| "failed to write to dst offset %ld, current dim offset %s", | |||||
| ShapeToString(src_shape).c_str(), ShapeToString(perm_arg).c_str(), | |||||
| ShapeToString(dst_shape).c_str(), | |||||
| dst_offset_bytes, ShapeToString(dst_indexes).c_str()); | |||||
| return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
| } | |||||
| AddOne(dst_shape, dst_indexes); | |||||
| ++dst_index; | |||||
| } | |||||
| result.data = dst; | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> &src_shape, | |||||
| const std::vector<int64_t> &dst_shape, DataType src_data_type, | |||||
| const std::vector<int64_t> &perm_arg, TransResult &result) { | |||||
| if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) { | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | |||||
| if (dst_shape != expected_shape) { | |||||
| std::string error = "Failed to trans axis for perm_arg" + | |||||
| FmtToStr(ShapeToString(perm_arg)) + ", invalid dst shape" + | |||||
| FmtToStr(ShapeToString(dst_shape)) + ", expect" + FmtToStr(ShapeToString(expected_shape)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||||
| } | |||||
| return Transpose(data, src_shape, src_data_type, perm_arg, result); | |||||
| } | |||||
| Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | |||||
| auto dst_iter = perm_args.find(src_format); | |||||
| if (dst_iter == perm_args.end()) { | |||||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| auto iter = dst_iter->second.find(dst_format); | |||||
| if (iter == dst_iter->second.end()) { | |||||
| std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| perm = iter->second; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| std::vector<int64_t> expected_shape; | |||||
| auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!IsTransShapeDstCorrect(args, expected_shape)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| return Transpose(args.data, args.src_shape, args.src_data_type, perm_args[args.src_format][args.dst_format], result); | |||||
| } | |||||
| Status FormatTransferTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
| std::vector<int64_t> perm_arg; | |||||
| GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); | |||||
| if (!IsShapeArgValid(src_shape, perm_arg)) { | |||||
| return ACL_ERROR_GE_SHAPE_INVALID; | |||||
| } | |||||
| dst_shape = TransShapeByPerm(src_shape, perm_arg); | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_NHWC) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_HWCN) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_CHWN) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_NCHW) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_CHWN) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_HWCN) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_NCHW) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_NHWC) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_CHWN) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_NCHW) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_NHWC) | |||||
| REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_HWCN) | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_TRANSPOSE_H_ | |||||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_TRANSPOSE_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "register/register_format_transfer.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | |||||
| const std::vector<int64_t> &perm_arg, TransResult &result); | |||||
| Status TransposeWithShapeCheck(const uint8_t *src, const std::vector<int64_t> &src_shape, | |||||
| const std::vector<int64_t> &dst_shape, DataType src_data_type, | |||||
| const std::vector<int64_t> &perm_arg, TransResult &result); | |||||
| Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm); | |||||
| class FormatTransferTranspose : public FormatTransfer { | |||||
| public: | |||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) override; | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_TRANSPOSE_H_ | |||||
| @@ -1,106 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/formats.h" | |||||
| #include <securec.h> | |||||
| #include <cmath> | |||||
| #include <cstring> | |||||
| #include <functional> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | |||||
| auto transfer = BuildFormatTransfer(args); | |||||
| if (transfer == nullptr) { | |||||
| std::string error = "Failed to trans data from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| auto src_shape_size = GetItemNumByShape(args.src_shape); | |||||
| if (args.data == nullptr && src_shape_size != 0) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Shape]Failed, input data is null " | |||||
| "or shape size not euqal to 0, src_shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| REPORT_CALL_ERROR("E19999","Failed to check shape, input data is null " | |||||
| "or shape size not equal to 0, src_shape %s", | |||||
| ShapeToString(args.src_shape).c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| return transfer->TransFormat(args, result); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_format, | |||||
| const std::vector<int64_t> &src_shape, | |||||
| DataType data_type, | |||||
| Format dst_format, | |||||
| std::vector<int64_t> &dst_shape) { | |||||
| formats::TransArgs args; | |||||
| args.src_format = src_format; | |||||
| args.dst_format = dst_format; | |||||
| auto transfer = BuildFormatTransfer(args); | |||||
| if (transfer == nullptr) { | |||||
| std::string error = "Failed to trans data from format " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_FORMAT_INVALID; | |||||
| } | |||||
| return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | |||||
| auto transfer = BuildDataTypeTransfer(args); | |||||
| if (transfer == nullptr) { | |||||
| std::string error = "Failed to trans data from datatype " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | |||||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_DATATYPE_INVALID, error.c_str()); | |||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | |||||
| } | |||||
| if (args.data == nullptr && args.src_data_size != 0) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param]Failed, input data is null " | |||||
| "or data size not equal to 0, src_data_size %ld", args.src_data_size); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| return transfer->TransDataType(args, result); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransFormatSupport(const TransArgs &args) { | |||||
| return FormatTransferExists(args); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransDataTypeSupport(const CastArgs &args) { | |||||
| return DataTypeTransferExists(args); | |||||
| } | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| @@ -1,49 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_FORMATS_H_ | |||||
| #define GE_COMMON_FORMATS_FORMATS_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "common/formats/format_transfers/datatype_transfer.h" | |||||
| #include "register/register_format_transfer.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| /** | |||||
| * Convert the data format, and put the converted format and length in the result | |||||
| * @param args | |||||
| * @param result | |||||
| * @return | |||||
| */ | |||||
| Status TransFormat(const TransArgs &args, TransResult &result); | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | |||||
| Format dst_format, std::vector<int64_t> &dst_shape); | |||||
| Status TransDataType(const CastArgs &args, TransResult &result); | |||||
| bool IsTransFormatSupport(const TransArgs &args); | |||||
| bool IsTransDataTypeSupport(const CastArgs &args); | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_FORMATS_H_ | |||||
| @@ -1,105 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ | |||||
| #define GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| static const int kCubeSize = 16; | |||||
| static const int kNiSize = 16; | |||||
| static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; | |||||
| enum NchwDimIndex { | |||||
| kNchwN, | |||||
| kNchwC, | |||||
| kNchwH, | |||||
| kNchwW, | |||||
| kNchwDimsNum | |||||
| }; | |||||
| enum NhwcDimIndex { | |||||
| kNhwcN, | |||||
| kNhwcH, | |||||
| kNhwcW, | |||||
| kNhwcC, | |||||
| kNhwcDimsNum | |||||
| }; | |||||
| enum HwcnDimIndex { | |||||
| kHwcnH, | |||||
| kHwcnW, | |||||
| kHwcnC, | |||||
| kHwcnN, | |||||
| kHwcnDimsNum | |||||
| }; | |||||
| enum ChwnDimIndex { | |||||
| kChwnC, | |||||
| kChwnH, | |||||
| kChwnW, | |||||
| kChwnN, | |||||
| kChwnDimsNum | |||||
| }; | |||||
| enum Nc1hwc0DimIndex { | |||||
| kNc1hwc0N, | |||||
| kNc1hwc0C1, | |||||
| kNc1hwc0H, | |||||
| kNc1hwc0W, | |||||
| kNc1hwc0C0, | |||||
| kNc1hwc0DimsNum | |||||
| }; | |||||
| enum C1hwncoc0DimIndex { | |||||
| kC1hwncoc0C1, | |||||
| kC1hwncoc0H, | |||||
| kC1hwncoc0W, | |||||
| kC1hwncoc0N, | |||||
| kC1hwncoc0Co, | |||||
| kC1hwncoc0C0, | |||||
| kC1hwncoc0DimsNum | |||||
| }; | |||||
| enum FracZDimIndex { | |||||
| kFracZHWC1, | |||||
| kFracZN0, | |||||
| kFracZNi, | |||||
| kFracZC0, | |||||
| kFracZDimsNum | |||||
| }; | |||||
| enum DhwcnDimIndex { | |||||
| kDhwcnD, | |||||
| kDhwcnH, | |||||
| kDhwcnW, | |||||
| kDhwcnC, | |||||
| kDhwcnN, | |||||
| kDhwcnDimsNum | |||||
| }; | |||||
| enum DhwncDimIndex { | |||||
| kDhwncD, | |||||
| kDhwncH, | |||||
| kDhwncW, | |||||
| kDhwncN, | |||||
| kDhwncC, | |||||
| kDhwncDimsNum | |||||
| }; | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ | |||||
| @@ -1,130 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include <cstdint> | |||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace formats { | |||||
| int64_t GetCubeSizeByDataType(DataType data_type) { | |||||
| // Current cube does not support 4 bytes and longer data | |||||
| auto size = GetSizeByDataType(data_type); | |||||
| if (size <= 0) { | |||||
| std::string error = "Failed to get cube size, the data type " + | |||||
| FmtToStr(TypeUtils::DataTypeToSerialString(data_type)) + " is invalid"; | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return -1; | |||||
| } else if (size == 1) { | |||||
| return kCubeSize * 2; // 32 bytes cube size | |||||
| } else { | |||||
| return kCubeSize; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const GeShape &shape) { | |||||
| return ShapeToString(shape.GetDims()); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const std::vector<int64_t> &shape) { | |||||
| return JoinToString(shape); | |||||
| } | |||||
| int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | |||||
| int64_t num = 1; | |||||
| for (auto dim : shape) { | |||||
| num *= dim; | |||||
| } | |||||
| return num; | |||||
| } | |||||
| bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | |||||
| if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | |||||
| std::string error = "Invalid shape, dims num " + FmtToStr(shape.size()) + | |||||
| ", expect " + FmtToStr(expect_dims); | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return false; | |||||
| } | |||||
| return IsShapeValid(shape); | |||||
| } | |||||
| bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
| if (shape.empty()) { | |||||
| return false; | |||||
| } | |||||
| int64_t num = 1; | |||||
| for (auto dim : shape) { | |||||
| if (dim < 0) { | |||||
| std::string error = "Invalid negative dims in the shape " + FmtToStr(ShapeToString(shape)); | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return false; | |||||
| } | |||||
| if (dim != 0 && kShapeItemNumMAX / dim < num) { | |||||
| std::string error = "Shape overflow, the total count should be less than " + FmtToStr(kShapeItemNumMAX); | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return false; | |||||
| } | |||||
| num *= dim; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsShapeEqual(const GeShape &src, const GeShape &dst) { | |||||
| if (src.GetDims().size() != dst.GetDims().size()) { | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < src.GetDims().size(); ++i) { | |||||
| if (src.GetDim(i) != dst.GetDim(i)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
| if (args.src_shape != expect_shape) { | |||||
| std::string error = "Failed to trans format from" + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + ", invalid relationship between src shape " + | |||||
| FmtToStr(ShapeToString(args.src_shape)) + " and dst " + | |||||
| FmtToStr(ShapeToString(args.dst_shape)); | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
| if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
| std::string error = "Failed to trans format from " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
| FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + ", the dst shape" + | |||||
| FmtToStr(ShapeToString(args.dst_shape)) + " is invalid, expect" + | |||||
| FmtToStr(ShapeToString(expect_shape)); | |||||
| GE_WARNINGLOG_AND_ERRORMSG(error.c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace formats | |||||
| } // namespace ge | |||||