| @@ -16,8 +16,11 @@ endif() | |||||
| if(DEFINED ENV{D_PKG_SERVER}) | if(DEFINED ENV{D_PKG_SERVER}) | ||||
| set(GE_PB_PKG $ENV{D_PKG_SERVER}) | set(GE_PB_PKG $ENV{D_PKG_SERVER}) | ||||
| message("Download packages from PKG server") | |||||
| endif() | |||||
| message("Download packages from DPKG server") | |||||
| elseif(DEFINED ENV{MSLIBS_SERVER}) | |||||
| set(GE_PB_PKG "http://$ENV{MSLIBS_SERVER}:8081") | |||||
| message("Download packages from MSPKG server") | |||||
| endif () | |||||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | ||||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | ||||
| @@ -37,7 +40,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} | |||||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
| if (ENABLE_OPEN_SRC) | if (ENABLE_OPEN_SRC) | ||||
| set(HI_PYTHON python3.7) | |||||
| set(HI_PYTHON python3) | |||||
| include(cmake/external_libs/protobuf_shared.cmake) | include(cmake/external_libs/protobuf_shared.cmake) | ||||
| include(cmake/external_libs/protobuf_static.cmake) | include(cmake/external_libs/protobuf_static.cmake) | ||||
| @@ -71,7 +74,7 @@ if (ENABLE_OPEN_SRC) | |||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | set(STATIC_ACL_LIB ${GE_LIB_PATH}) | ||||
| find_module(slog libslog.so ${GE_LIB_PATH}) | find_module(slog libslog.so ${GE_LIB_PATH}) | ||||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | ||||
| find_module(msprof libmsprof.so ${GE_LIB_PATH}) | |||||
| find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | |||||
| find_module(hccl libhccl.so ${GE_LIB_PATH}) | find_module(hccl libhccl.so ${GE_LIB_PATH}) | ||||
| find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | ||||
| find_module(runtime libruntime.so ${GE_LIB_PATH}) | find_module(runtime libruntime.so ${GE_LIB_PATH}) | ||||
| @@ -80,20 +83,19 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) | find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) | ||||
| find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | ||||
| find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) | find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) | ||||
| find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
| else() | else() | ||||
| find_module(slog libslog.so ${ASCEND_ATC_DIR} ${ASCEND_DRIVER_COMMON_DIR}) | find_module(slog libslog.so ${ASCEND_ATC_DIR} ${ASCEND_DRIVER_COMMON_DIR}) | ||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) | ||||
| if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
| if(PRODUCT STREQUAL "flr3") | if(PRODUCT STREQUAL "flr3") | ||||
| message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | ||||
| @@ -106,20 +108,17 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | ||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | |||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | |||||
| if(PRODUCT STREQUAL "flr3") | if(PRODUCT STREQUAL "flr3") | ||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_SHARE_DIR}) | |||||
| elseif(PRODUCT STREQUAL "flr1") | elseif(PRODUCT STREQUAL "flr1") | ||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| elseif(PRODUCT STREQUAL "flr2") | elseif(PRODUCT STREQUAL "flr2") | ||||
| # flr2 ascend_hal_stub limsprof ? | # flr2 ascend_hal_stub limsprof ? | ||||
| else() | else() | ||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | ||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) | |||||
| endif() | endif() | ||||
| elseif(PLATFORM STREQUAL "all") | elseif(PLATFORM STREQUAL "all") | ||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | ||||
| @@ -127,14 +126,14 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(resource libresource.so ${ASCEND_ATC_DIR}) | find_module(resource libresource.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | ||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_ACL_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | ||||
| else() | else() | ||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| endif() | endif() | ||||
| if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| add_subdirectory(tests) | add_subdirectory(tests) | ||||
| endif() | endif() | ||||
| @@ -23,6 +23,7 @@ ExternalProject_Add(gflags_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 | #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | ||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| INSTALL_COMMAND $(MAKE) install | INSTALL_COMMAND $(MAKE) install | ||||
| @@ -10,7 +10,10 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/gtest/release-1.8.0.tar.gz") | |||||
| set(MD5 "") | |||||
| elseif (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | ||||
| set(MD5 "") | set(MD5 "") | ||||
| else() | else() | ||||
| @@ -22,8 +25,9 @@ set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack- | |||||
| set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(gtest_build | ExternalProject_Add(gtest_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | ||||
| -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||||
| -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| INSTALL_COMMAND $(MAKE) install | INSTALL_COMMAND $(MAKE) install | ||||
| EXCLUDE_FROM_ALL TRUE | EXCLUDE_FROM_ALL TRUE | ||||
| @@ -5,10 +5,14 @@ endif() | |||||
| include(ExternalProject) | include(ExternalProject) | ||||
| set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | |||||
| set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | |||||
| set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| #elseif (ENABLE_GITEE) | |||||
| # set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | |||||
| # set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | |||||
| #set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | ||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | set(MD5 "0dc903888211db3a0f170304cd9f3a89") | ||||
| @@ -18,6 +22,7 @@ ExternalProject_Add(json_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/cloud_code/pkg/include.zip | #URL /home/txd/workspace/cloud_code/pkg/include.zip | ||||
| SOURCE_DIR ${JSON_SRC_DIR} | SOURCE_DIR ${JSON_SRC_DIR} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | CONFIGURE_COMMAND "" | ||||
| BUILD_COMMAND "" | BUILD_COMMAND "" | ||||
| INSTALL_COMMAND "" | INSTALL_COMMAND "" | ||||
| @@ -6,7 +6,10 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) | |||||
| set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | ||||
| file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | ||||
| if (ENABLE_GITEE) | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz") | |||||
| set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") | |||||
| elseif (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | ||||
| set(MD5 "1bdbcecdd68ea8392630467646776e02") | set(MD5 "1bdbcecdd68ea8392630467646776e02") | ||||
| else() | else() | ||||
| @@ -19,6 +22,7 @@ ExternalProject_Add(onnx | |||||
| #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | ||||
| #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | ||||
| #SOURCE_DIR ${ONNX_SRC_DIR} | #SOURCE_DIR ${ONNX_SRC_DIR} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | CONFIGURE_COMMAND "" | ||||
| BUILD_COMMAND "" | BUILD_COMMAND "" | ||||
| #INSTALL_COMMAND "" | #INSTALL_COMMAND "" | ||||
| @@ -26,6 +26,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(protobuf_build | ExternalProject_Add(protobuf_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -Dprotobuf_WITH_ZLIB=OFF | -Dprotobuf_WITH_ZLIB=OFF | ||||
| -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | ||||
| @@ -27,6 +27,7 @@ ExternalProject_Add(protobuf_static_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | ||||
| @@ -1,115 +1,116 @@ | |||||
| if (HAVE_PROTOC) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| include(GNUInstallDirs) | |||||
| #set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if(GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| else() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| endif() | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| ExternalProject_Add(protoc_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc) | |||||
| set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc) | |||||
| function(protobuf_generate comp c_var h_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${c_var}) | |||||
| set(${h_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc") | |||||
| list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${c_var} ${${c_var}} PARENT_SCOPE) | |||||
| set(${h_var} ${${h_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| function(protobuf_generate_py comp py_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${py_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}_pb2.py" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${py_var} ${${py_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| #set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add") | |||||
| set(HAVE_PROTOC TRUE) | |||||
| if (HAVE_PROTOC) | |||||
| return() | |||||
| endif() | |||||
| include(ExternalProject) | |||||
| include(GNUInstallDirs) | |||||
| #set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) | |||||
| if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
| set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
| endif() | |||||
| if(GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| else() | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| endif() | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| ExternalProject_Add(protoc_build | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | |||||
| BUILD_COMMAND $(MAKE) | |||||
| INSTALL_COMMAND $(MAKE) install | |||||
| EXCLUDE_FROM_ALL TRUE | |||||
| ) | |||||
| set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc) | |||||
| set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc) | |||||
| function(protobuf_generate comp c_var h_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${c_var}) | |||||
| set(${h_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc") | |||||
| list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${c_var} ${${c_var}} PARENT_SCOPE) | |||||
| set(${h_var} ${${h_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| function(protobuf_generate_py comp py_var) | |||||
| if(NOT ARGN) | |||||
| message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files") | |||||
| return() | |||||
| endif() | |||||
| set(${py_var}) | |||||
| foreach(file ${ARGN}) | |||||
| get_filename_component(abs_file ${file} ABSOLUTE) | |||||
| get_filename_component(file_name ${file} NAME_WE) | |||||
| get_filename_component(file_dir ${abs_file} PATH) | |||||
| get_filename_component(parent_subdir ${file_dir} NAME) | |||||
| if("${parent_subdir}" STREQUAL "proto") | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) | |||||
| else() | |||||
| set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) | |||||
| endif() | |||||
| list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py") | |||||
| add_custom_command( | |||||
| OUTPUT "${proto_output_path}/${file_name}_pb2.py" | |||||
| WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" | |||||
| COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file} | |||||
| DEPENDS protoc_build ${abs_file} | |||||
| COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM ) | |||||
| endforeach() | |||||
| set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE) | |||||
| set(${py_var} ${${py_var}} PARENT_SCOPE) | |||||
| endfunction() | |||||
| #set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add") | |||||
| set(HAVE_PROTOC TRUE) | |||||
| @@ -10,11 +10,20 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| ExternalProject_Add(c_sec_build | ExternalProject_Add(c_sec_build | ||||
| URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec | #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec | ||||
| PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch | PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | ||||
| @@ -60,6 +60,8 @@ set(TRAIN_SRC_LIST | |||||
| "common/dump/dump_manager.cc" | "common/dump/dump_manager.cc" | ||||
| "common/dump/dump_properties.cc" | "common/dump/dump_properties.cc" | ||||
| "common/dump/dump_op.cc" | "common/dump/dump_op.cc" | ||||
| "common/profiling/ge_profiling.cc" | |||||
| "common/profiling/ge_runner_profiling.cc" | |||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| "ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
| "generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
| @@ -201,6 +203,7 @@ set(TRAIN_SRC_LIST | |||||
| "host_kernels/sub_kernel.cc" | "host_kernels/sub_kernel.cc" | ||||
| "host_kernels/transdata_kernel.cc" | "host_kernels/transdata_kernel.cc" | ||||
| "host_kernels/unpack_kernel.cc" | "host_kernels/unpack_kernel.cc" | ||||
| "host_kernels/reformat_kernel.cc" | |||||
| "graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
| "graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
| "graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
| @@ -331,7 +334,6 @@ set(TRAIN_SRC_LIST | |||||
| "hybrid/hybrid_davinci_model.cc" | "hybrid/hybrid_davinci_model.cc" | ||||
| "executor/ge_executor.cc" | "executor/ge_executor.cc" | ||||
| "client/ge_api.cc" | "client/ge_api.cc" | ||||
| "client/ge_prof.cc" | |||||
| "analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
| "ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| @@ -487,6 +489,7 @@ set(INFER_SRC_LIST | |||||
| "host_kernels/slice_d_kernel.cc" | "host_kernels/slice_d_kernel.cc" | ||||
| "host_kernels/dynamic_stitch_kernel.cc" | "host_kernels/dynamic_stitch_kernel.cc" | ||||
| "host_kernels/identity_kernel.cc" | "host_kernels/identity_kernel.cc" | ||||
| "host_kernels/reformat_kernel.cc" | |||||
| "graph/passes/stop_gradient_pass.cc" | "graph/passes/stop_gradient_pass.cc" | ||||
| "graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
| "graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
| @@ -602,7 +605,7 @@ set(INFER_SRC_LIST | |||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | ||||
| ############ libge_runner.so ############ | ############ libge_runner.so ############ | ||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} $<TARGET_OBJECTS:msprofiler_fwk>) | |||||
| target_compile_definitions(ge_runner PRIVATE | target_compile_definitions(ge_runner PRIVATE | ||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| @@ -647,7 +650,6 @@ target_link_libraries(ge_runner | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_memory | ge_memory | ||||
| adump_server | adump_server | ||||
| msprofiler | |||||
| static_mmpa | static_mmpa | ||||
| -Wl,--no-as-needed | -Wl,--no-as-needed | ||||
| graph | graph | ||||
| @@ -656,7 +658,6 @@ target_link_libraries(ge_runner | |||||
| register | register | ||||
| c_sec | c_sec | ||||
| slog | slog | ||||
| msprof | |||||
| runtime | runtime | ||||
| resource | resource | ||||
| error_manager | error_manager | ||||
| @@ -781,7 +782,6 @@ target_link_libraries(opensrc_ascendcl PRIVATE | |||||
| c_sec | c_sec | ||||
| runtime | runtime | ||||
| slog | slog | ||||
| msprof | |||||
| ascend_hal_stub | ascend_hal_stub | ||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| -lrt | -lrt | ||||
| @@ -797,12 +797,10 @@ set_target_properties(opensrc_ascendcl PROPERTIES | |||||
| add_custom_command( | add_custom_command( | ||||
| OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc | OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc | ||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc | ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc | ||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc | |||||
| COMMAND echo "Generating stub files." | COMMAND echo "Generating stub files." | ||||
| && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | ||||
| && mv ge_ir_build.cc stub_ge_ir_build.cc | && mv ge_ir_build.cc stub_ge_ir_build.cc | ||||
| && mv ge_api.cc stub_ge_api.cc | && mv ge_api.cc stub_ge_api.cc | ||||
| && mv ge_prof.cc stub_ge_prof.cc | |||||
| && echo "Generating stub files end." | && echo "Generating stub files end." | ||||
| #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} | #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} | ||||
| #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | ||||
| @@ -811,7 +809,6 @@ add_custom_command( | |||||
| add_custom_target(ge_stub | add_custom_target(ge_stub | ||||
| DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc | DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc | ||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc | ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc | ||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc | |||||
| ) | ) | ||||
| ################################################################## | ################################################################## | ||||
| @@ -853,7 +850,6 @@ target_include_directories(atc_stub_ge_compiler PRIVATE | |||||
| ############ stub/libge_runner.so ############ | ############ stub/libge_runner.so ############ | ||||
| add_library(fwk_stub_ge_runner SHARED | add_library(fwk_stub_ge_runner SHARED | ||||
| stub_ge_api.cc | stub_ge_api.cc | ||||
| stub_ge_prof.cc | |||||
| stub_ge_ir_build.cc | stub_ge_ir_build.cc | ||||
| ) | ) | ||||
| @@ -134,7 +134,7 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
| Status GEInitialize(const std::map<AscendString, AscendString> &options) { | Status GEInitialize(const std::map<AscendString, AscendString> &options) { | ||||
| std::map<std::string, std::string> str_options; | std::map<std::string, std::string> str_options; | ||||
| for (auto & option : options) { | |||||
| for (auto &option : options) { | |||||
| if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | ||||
| GELOGE(FAILED, "GEInitialize options is nullptr."); | GELOGE(FAILED, "GEInitialize options is nullptr."); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -1,369 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge/ge_prof.h" | |||||
| #include "ge/ge_api.h" | |||||
| #include "init/gelib.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/profiling/profiling_manager.h" | |||||
| #include "graph/load/graph_loader.h" | |||||
| #include "toolchain/prof_acl_api.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace { | |||||
| const uint32_t kMaxDeviceNum = 64; | |||||
| const uint32_t kDeviceListIndex = 3; | |||||
| const std::string kProfilingInit = "prof_init"; | |||||
| const std::string kProfilingFinalize = "prof_finalize"; | |||||
| const std::string kProfilingStart = "prof_start"; | |||||
| const std::string kProfilingStop = "prof_stop"; | |||||
| const std::string kDeviceNums = "devNums"; | |||||
| const std::string kDeviceIdList = "devIdList"; | |||||
| const std::string kAicoreMetrics = "aicoreMetrics"; | |||||
| const std::map<ge::ProfilingAicoreMetrics, std::string> kProfAicoreMetricsToString = { | |||||
| {ge::kAicoreArithmaticThroughput, "AICORE_ARITHMATIC_THROUGHPUT"}, | |||||
| {ge::kAicorePipeline, "AICORE_PIPELINE"}, | |||||
| {ge::kAicoreSynchronization, "AICORE_SYNCHRONIZATION"}, | |||||
| {ge::kAicoreMemory, "AICORE_MEMORY"}, | |||||
| {ge::kAicoreInternalMemory, "AICORE_INTERNAL_MEMORY"}, | |||||
| {ge::kAicoreStall, "AICORE_STALL"}}; | |||||
| } // namespace | |||||
| static bool g_graph_prof_init_ = false; | |||||
| static std::mutex g_prof_mutex_; | |||||
| namespace ge { | |||||
| struct aclgrphProfConfig { | |||||
| ProfConfig config; | |||||
| }; | |||||
| Status aclgrphProfInit(const char *profiler_path, uint32_t length) { | |||||
| GELOGT(TRACE_INIT, "Graph prof init start"); | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(g_prof_mutex_); | |||||
| if (g_graph_prof_init_) { | |||||
| GELOGW("Multi graph profiling initializations."); | |||||
| return GE_PROF_MULTI_INIT; | |||||
| } | |||||
| Status ret = CheckPath(profiler_path, length); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Profiling config path is invalid."); | |||||
| return ret; | |||||
| } | |||||
| // if command mode is set, just return | |||||
| if (ProfilingManager::Instance().ProfilingOn()) { | |||||
| GELOGW("Graph prof init failed, cause profiling command pattern is running."); | |||||
| return GE_PROF_MODE_CONFLICT; | |||||
| } | |||||
| ret = ProfInit(profiler_path); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "ProfInit init fail"); | |||||
| return ret; | |||||
| } | |||||
| GraphLoader graph_loader; | |||||
| Command command; | |||||
| command.cmd_params.clear(); | |||||
| command.cmd_type = kProfilingInit; | |||||
| command.module_index = PROF_MODEL_LOAD; | |||||
| ret = graph_loader.CommandHandle(command); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Handle profiling command %s failed, config = %s", kProfilingInit.c_str(), profiler_path); | |||||
| return ret; | |||||
| } | |||||
| if (!g_graph_prof_init_) { | |||||
| g_graph_prof_init_ = true; | |||||
| GELOGI("Profiling init successfully."); | |||||
| } | |||||
| GELOGI("Successfully execute GraphProfInit."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status aclgrphProfFinalize() { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(g_prof_mutex_); | |||||
| // if command mode is set, just return | |||||
| if (ProfilingManager::Instance().ProfilingOn()) { | |||||
| GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); | |||||
| return GE_PROF_MODE_CONFLICT; | |||||
| } | |||||
| if (!g_graph_prof_init_) { | |||||
| GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); | |||||
| return GE_PROF_NOT_INIT; | |||||
| } | |||||
| GraphLoader graph_loader; | |||||
| Command command; | |||||
| command.cmd_params.clear(); | |||||
| command.cmd_type = kProfilingFinalize; | |||||
| Status ret = graph_loader.CommandHandle(command); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Handle profiling command %s failed.", kProfilingFinalize.c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret = ProfFinalize(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Finalize profiling failed, result = %d", ret); | |||||
| } | |||||
| if (ret == SUCCESS) { | |||||
| g_graph_prof_init_ = false; | |||||
| GELOGI("Successfully execute GraphProfFinalize."); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| bool TransProfConfigToParam(const aclgrphProfConfig *profiler_config, vector<string> &prof_config_params) { | |||||
| prof_config_params.clear(); | |||||
| prof_config_params.emplace_back(kDeviceNums); | |||||
| prof_config_params.emplace_back(std::to_string(profiler_config->config.devNums)); | |||||
| prof_config_params.emplace_back(kDeviceIdList); | |||||
| std::string devID = ""; | |||||
| if (profiler_config->config.devNums == 0) { | |||||
| GELOGW("The device num is invalid."); | |||||
| return false; | |||||
| } | |||||
| for (uint32_t i = 0; i < profiler_config->config.devNums; i++) { | |||||
| devID.append(std::to_string(profiler_config->config.devIdList[i])); | |||||
| if (i != profiler_config->config.devNums - 1) { | |||||
| devID.append(","); | |||||
| } | |||||
| } | |||||
| prof_config_params.push_back(devID); | |||||
| prof_config_params.push_back(kAicoreMetrics); | |||||
| auto iter = | |||||
| kProfAicoreMetricsToString.find(static_cast<ProfilingAicoreMetrics>(profiler_config->config.aicoreMetrics)); | |||||
| if (iter == kProfAicoreMetricsToString.end()) { | |||||
| GELOGW("The prof aicore metrics is invalid."); | |||||
| return false; | |||||
| } | |||||
| prof_config_params.push_back(iter->second); | |||||
| return true; | |||||
| } | |||||
| bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) { | |||||
| if (deviceid_list == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "deviceIdList is nullptr"); | |||||
| return false; | |||||
| } | |||||
| if (device_nums == 0 || device_nums > kMaxDeviceNum) { | |||||
| GELOGE(PARAM_INVALID, "The device nums is invalid."); | |||||
| return false; | |||||
| } | |||||
| // real device num | |||||
| int32_t dev_count = 0; | |||||
| rtError_t rt_err = rtGetDeviceCount(&dev_count); | |||||
| if (rt_err != RT_ERROR_NONE) { | |||||
| GELOGE(INTERNAL_ERROR, "Get the Device count fail."); | |||||
| return false; | |||||
| } | |||||
| if (device_nums > static_cast<uint32_t>(dev_count)) { | |||||
| GELOGE(PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count); | |||||
| return false; | |||||
| } | |||||
| std::unordered_set<uint32_t> record; | |||||
| for (size_t i = 0; i < device_nums; ++i) { | |||||
| uint32_t dev_id = deviceid_list[i]; | |||||
| if (dev_id >= static_cast<uint32_t>(dev_count)) { | |||||
| GELOGE(PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count); | |||||
| return false; | |||||
| } | |||||
| if (record.count(dev_id) > 0) { | |||||
| GELOGE(PARAM_INVALID, "Device id %u is duplicatedly set", dev_id); | |||||
| return false; | |||||
| } | |||||
| record.insert(dev_id); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| aclgrphProfConfig *aclgrphProfCreateConfig(uint32_t *deviceid_list, uint32_t device_nums, | |||||
| ProfilingAicoreMetrics aicore_metrics, ProfAicoreEvents *aicore_events, | |||||
| uint64_t data_type_config) { | |||||
| if (!isProfConfigValid(deviceid_list, device_nums)) { | |||||
| return nullptr; | |||||
| } | |||||
| aclgrphProfConfig *config = new (std::nothrow) aclgrphProfConfig(); | |||||
| if (config == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "new aclgrphProfConfig fail"); | |||||
| return nullptr; | |||||
| } | |||||
| config->config.devNums = device_nums; | |||||
| if (memcpy_s(config->config.devIdList, sizeof(config->config.devIdList), deviceid_list, | |||||
| device_nums * sizeof(uint32_t)) != EOK) { | |||||
| GELOGE(INTERNAL_ERROR, "copy devID failed. size = %u", device_nums); | |||||
| delete config; | |||||
| return nullptr; | |||||
| } | |||||
| config->config.aicoreMetrics = static_cast<ProfAicoreMetrics>(aicore_metrics); | |||||
| config->config.dataTypeConfig = data_type_config; | |||||
| GELOGI("Successfully create prof config."); | |||||
| return config; | |||||
| } | |||||
| Status aclgrphProfDestroyConfig(aclgrphProfConfig *profiler_config) { | |||||
| if (profiler_config == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "destroy profilerConfig failed, profilerConfig must not be nullptr"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| delete profiler_config; | |||||
| GELOGI("Successfully destroy prof config."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status aclgrphProfStart(aclgrphProfConfig *profiler_config) { | |||||
| if (profiler_config == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(g_prof_mutex_); | |||||
| // if command mode is set, just return | |||||
| if (ProfilingManager::Instance().ProfilingOn()) { | |||||
| GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); | |||||
| return GE_PROF_MODE_CONFLICT; | |||||
| } | |||||
| if (!g_graph_prof_init_) { | |||||
| GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); | |||||
| return GE_PROF_NOT_INIT; | |||||
| } | |||||
| Status ret = ProfStartProfiling(&profiler_config->config); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Start profiling failed, prof result = %d", ret); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<string> prof_params; | |||||
| if (!TransProfConfigToParam(profiler_config, prof_params)) { | |||||
| GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GraphLoader graph_loader; | |||||
| Command command; | |||||
| command.cmd_params.clear(); | |||||
| command.cmd_type = kProfilingStart; | |||||
| command.cmd_params = prof_params; | |||||
| command.module_index = profiler_config->config.dataTypeConfig; | |||||
| GELOGI("Profiling will start, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), | |||||
| prof_params[kDeviceListIndex].c_str(), command.module_index); | |||||
| ret = graph_loader.CommandHandle(command); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Handle profiling command failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Successfully execute GraphProfStartProfiling."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status aclgrphProfStop(aclgrphProfConfig *profiler_config) { | |||||
| if (profiler_config == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(g_prof_mutex_); | |||||
| // if command mode is set, just return | |||||
| if (ProfilingManager::Instance().ProfilingOn()) { | |||||
| GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); | |||||
| return GE_PROF_MODE_CONFLICT; | |||||
| } | |||||
| if (!g_graph_prof_init_) { | |||||
| GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); | |||||
| return GE_PROF_NOT_INIT; | |||||
| } | |||||
| for (uint32_t i = 0; i < profiler_config->config.devNums; i++) { | |||||
| uint64_t data_type_config; | |||||
| Status status = ProfGetDataTypeConfig(profiler_config->config.devIdList[i], data_type_config); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Prof get data type config failed, prof result = %d", status); | |||||
| return status; | |||||
| } | |||||
| if (data_type_config != profiler_config->config.dataTypeConfig) { | |||||
| GELOGE(FAILED, "data type config verify failed"); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| std::vector<string> prof_params; | |||||
| if (!TransProfConfigToParam(profiler_config, prof_params)) { | |||||
| GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GraphLoader graph_loader; | |||||
| Command command; | |||||
| command.cmd_params.clear(); | |||||
| command.cmd_type = kProfilingStop; | |||||
| command.cmd_params = prof_params; | |||||
| command.module_index = profiler_config->config.dataTypeConfig; | |||||
| GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), | |||||
| prof_params[kDeviceListIndex].c_str(), command.module_index); | |||||
| Status ret = graph_loader.CommandHandle(command); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Handle profiling command failed"); | |||||
| return FAILED; | |||||
| } | |||||
| ret = ProfStopProfiling(&profiler_config->config); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Stop profiling failed, prof result = %d", ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("Successfully execute GraphProfStopProfiling."); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -4,7 +4,6 @@ LOCAL_PATH := $(call my-dir) | |||||
| COMMON_LOCAL_SRC_FILES := \ | COMMON_LOCAL_SRC_FILES := \ | ||||
| proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
| ge_api.cc \ | ge_api.cc \ | ||||
| ge_prof.cc \ | |||||
| COMMON_LOCAL_C_INCLUDES := \ | COMMON_LOCAL_C_INCLUDES := \ | ||||
| @@ -69,9 +68,9 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libge_compiler \ | libge_compiler \ | ||||
| libge_common \ | |||||
| libmsprof | |||||
| libge_common | |||||
| LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -104,8 +103,10 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libregister \ | libregister \ | ||||
| libruntime \ | libruntime \ | ||||
| libge_compiler \ | libge_compiler \ | ||||
| libge_common \ | |||||
| libmsprof | |||||
| libge_common | |||||
| LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -24,6 +24,7 @@ set(SRC_LIST | |||||
| "helper/om_file_helper.cc" | "helper/om_file_helper.cc" | ||||
| "helper/model_helper.cc" | "helper/model_helper.cc" | ||||
| "../model/ge_model.cc" | "../model/ge_model.cc" | ||||
| "../model/ge_root_model.cc" | |||||
| "auth/file_saver.cc" | "auth/file_saver.cc" | ||||
| "fp16_t.cc" | "fp16_t.cc" | ||||
| "math/fp16_math.cc" | "math/fp16_math.cc" | ||||
| @@ -54,8 +54,8 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) { | |||||
| Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | ||||
| mmSsize_t write_count; | mmSsize_t write_count; | ||||
| uint32_t size_2g = ((uint32_t) 0x1 << 31); | |||||
| uint32_t size_1g = ((uint32_t) 0x1 << 30); | |||||
| uint32_t size_2g = 2147483648; // 0x1 << 31 | |||||
| uint32_t size_1g = 1073741824; // 0x1 << 30 | |||||
| // Write data | // Write data | ||||
| if (size > size_2g) { | if (size > size_2g) { | ||||
| auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); | auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); | ||||
| @@ -258,6 +258,65 @@ FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, Mod | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status | |||||
| FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas) { | |||||
| file_header.is_encrypt = ModelEncryptType::UNENCRYPTED; | |||||
| const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas); | |||||
| GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", | |||||
| file_path.c_str(), file_header.length); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas) { | |||||
| GE_CHK_BOOL_EXEC(model_partition_tables.size() == all_partition_datas.size(), | |||||
| return PARAM_INVALID, | |||||
| "model table size %zu does not match partition size %zu", | |||||
| model_partition_tables.size(), all_partition_datas.size()) | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| auto &cur_partiton_data = all_partition_datas[index]; | |||||
| auto &cur_model_partition_table = *model_partition_tables[index]; | |||||
| GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0 | |||||
| && cur_model_partition_table.num == cur_partiton_data.size(), FAILED, | |||||
| "Invalid param:partition data size is (%u), model_partition_table.num is (%zu).", | |||||
| cur_model_partition_table.num, cur_partiton_data.size()); | |||||
| } | |||||
| // Open file | |||||
| int32_t fd = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); | |||||
| Status ret = SUCCESS; | |||||
| do { | |||||
| // Write file header | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| for (size_t index = 0; index < model_partition_tables.size(); ++index) { | |||||
| // Write model partition table | |||||
| auto &cur_tabel = *model_partition_tables[index]; | |||||
| uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(&cur_tabel), table_size, fd) != SUCCESS, ret = FAILED; break); | |||||
| // Write partition data | |||||
| auto &cur_partition_datas = all_partition_datas[index]; | |||||
| for (const auto &partition_data : cur_partition_datas) { | |||||
| GELOGI("GC:size[%zu]", partition_data.size); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED; | |||||
| break); | |||||
| } | |||||
| } | |||||
| } while (0); | |||||
| // Close file | |||||
| GE_CHK_BOOL_RET_STATUS(mmClose(fd) == EN_OK, FAILED, "Close file failed."); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, | ||||
| int len) { | int len) { | ||||
| if (data == nullptr || len <= 0) { | if (data == nullptr || len <= 0) { | ||||
| @@ -74,6 +74,10 @@ class FileSaver { | |||||
| ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
| const std::vector<ModelPartition> &partition_datas); | const std::vector<ModelPartition> &partition_datas); | ||||
| static Status SaveToFile(const string &file_path, ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas); | |||||
| static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, | ||||
| ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
| const std::vector<ModelPartition> &partitionDatas, | const std::vector<ModelPartition> &partitionDatas, | ||||
| @@ -108,6 +112,9 @@ class FileSaver { | |||||
| static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | ||||
| ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
| const std::vector<ModelPartition> &partition_datas); | const std::vector<ModelPartition> &partition_datas); | ||||
| static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, | |||||
| vector<ModelPartitionTable *> &model_partition_tables, | |||||
| const vector<vector<ModelPartition>> &all_partition_datas); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_COMMON_AUTH_FILE_SAVER_H_ | #endif // GE_COMMON_AUTH_FILE_SAVER_H_ | ||||
| @@ -25,32 +25,38 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const char* kBase64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |||||
| "abcdefghijklmnopqrstuvwxyz" | |||||
| "0123456789+/"; | |||||
| const char *kBase64Chars = | |||||
| "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |||||
| "abcdefghijklmnopqrstuvwxyz" | |||||
| "0123456789+/"; | |||||
| const char kEqualSymbol = '='; | const char kEqualSymbol = '='; | ||||
| const size_t kBase64CharsNum = 64; | const size_t kBase64CharsNum = 64; | ||||
| const size_t kThreeByteOneGroup = 3; | const size_t kThreeByteOneGroup = 3; | ||||
| const size_t kFourByteOneGroup = 4; | const size_t kFourByteOneGroup = 4; | ||||
| } | |||||
| const size_t kThreeByteOneGroupIndex0 = 0; | |||||
| const size_t kThreeByteOneGroupIndex1 = 1; | |||||
| const size_t kThreeByteOneGroupIndex2 = 2; | |||||
| const size_t kFourByteOneGroupIndex0 = 0; | |||||
| const size_t kFourByteOneGroupIndex1 = 1; | |||||
| const size_t kFourByteOneGroupIndex2 = 2; | |||||
| const size_t kFourByteOneGroupIndex3 = 3; | |||||
| } // namespace | |||||
| namespace base64 { | namespace base64 { | ||||
| static inline bool IsBase64Char(const char &c) { | |||||
| return (isalnum(c) || (c == '+') || (c == '/')); | |||||
| } | |||||
| static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); } | |||||
| static std::string EncodeToBase64(const std::string &raw_data) { | static std::string EncodeToBase64(const std::string &raw_data) { | ||||
| size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; | size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; | ||||
| encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; | encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; | ||||
| size_t raw_data_index = 0 ; | |||||
| size_t raw_data_index = 0; | |||||
| size_t encode_data_index = 0; | size_t encode_data_index = 0; | ||||
| std::string encode_data; | std::string encode_data; | ||||
| encode_data.resize(encode_length); | encode_data.resize(encode_length); | ||||
| for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { | for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { | ||||
| auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | ||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]); | |||||
| auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + 2]); | |||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex1]); | |||||
| auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex2]); | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | ||||
| encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | ||||
| encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; | encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; | ||||
| @@ -80,8 +86,7 @@ static std::string EncodeToBase64(const std::string &raw_data) { | |||||
| #pragma GCC diagnostic ignored "-Wunused-function" | #pragma GCC diagnostic ignored "-Wunused-function" | ||||
| static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { | static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { | ||||
| if (base64_data.size() % kFourByteOneGroup != 0) { | if (base64_data.size() % kFourByteOneGroup != 0) { | ||||
| GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", | |||||
| base64_data.size()); | |||||
| GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| decode_data.clear(); | decode_data.clear(); | ||||
| @@ -92,10 +97,10 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco | |||||
| return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; | return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; | ||||
| }; | }; | ||||
| for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += 4) { | |||||
| for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += kFourByteOneGroup) { | |||||
| for (size_t i = 0; i < kFourByteOneGroup; ++i) { | for (size_t i = 0; i < kFourByteOneGroup; ++i) { | ||||
| if (base64_data[input_data_index + i] == kEqualSymbol && | if (base64_data[input_data_index + i] == kEqualSymbol && | ||||
| input_data_index >= base64_data_len - 4 && i > 1) { | |||||
| input_data_index >= base64_data_len - kFourByteOneGroup && i > 1) { | |||||
| byte_4[i] = kBase64CharsNum; | byte_4[i] = kBase64CharsNum; | ||||
| } else if (IsBase64Char(base64_data[input_data_index + i])) { | } else if (IsBase64Char(base64_data[input_data_index + i])) { | ||||
| byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); | byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); | ||||
| @@ -104,19 +109,23 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| decode_data += static_cast<char>((byte_4[0] << 2u) + ((byte_4[1] & 0x30) >> 4u)); | |||||
| if (byte_4[2] >= kBase64CharsNum){ | |||||
| decode_data += | |||||
| static_cast<char>((byte_4[kFourByteOneGroupIndex0] << 2u) + ((byte_4[kFourByteOneGroupIndex1] & 0x30) >> 4u)); | |||||
| if (byte_4[kFourByteOneGroupIndex2] >= kBase64CharsNum) { | |||||
| break; | break; | ||||
| } else if (byte_4[3] >= kBase64CharsNum) { | |||||
| decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); | |||||
| } else if (byte_4[kFourByteOneGroupIndex3] >= kBase64CharsNum) { | |||||
| decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + | |||||
| ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); | |||||
| break; | break; | ||||
| } | } | ||||
| decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); | |||||
| decode_data += static_cast<char>(((byte_4[2] & 0x03) << 6u) + byte_4[3]); | |||||
| decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + | |||||
| ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); | |||||
| decode_data += | |||||
| static_cast<char>(((byte_4[kFourByteOneGroupIndex2] & 0x03) << 6u) + byte_4[kFourByteOneGroupIndex3]); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| #pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
| } | |||||
| } // namespace base64 | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_COMMON_BASE64_H_ | #endif // GE_COMMON_BASE64_H_ | ||||
| @@ -139,7 +139,8 @@ int MemoryDumper::OpenFile(const char *filename) { | |||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | ||||
| string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, | |||||
| return kInvalidFd, "Prefix path is too long!"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, | ||||
| "Dir %s does not exit.", prefix_path.c_str()); | "Dir %s does not exit.", prefix_path.c_str()); | ||||
| real_path = std::string(tmp_path) + last_path;) | real_path = std::string(tmp_path) + last_path;) | ||||
| @@ -23,12 +23,30 @@ | |||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| const int kDimSize4D = 4; | const int kDimSize4D = 4; | ||||
| const size_t kSingleDim = 1; | |||||
| const size_t kNdDimIndexN = 0; | |||||
| const size_t kNdDimIndexH = 1; | |||||
| const size_t kNdDimIndexW = 2; | |||||
| const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz | |||||
| const size_t kNdDimCountBackwardsW = 1; | |||||
| const size_t kNdDimCountBackwardsWH = 2; | |||||
| const size_t kFNzDimCountBackwardsW0 = 1; | |||||
| const size_t kFNzDimCountBackwardsW0H0 = 2; | |||||
| const size_t kFNzDimCountBackwardsW0H0H1 = 3; | |||||
| const size_t kFNzDimCountBackwardsW0H0H1W1 = 4; | |||||
| bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } | bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } | ||||
| using ShapeVector = std::vector<int64_t>; | using ShapeVector = std::vector<int64_t>; | ||||
| @@ -60,14 +78,14 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| auto w0 = GetCubeSizeByDataType(data_type); | auto w0 = GetCubeSizeByDataType(data_type); | ||||
| int64_t h0 = kCubeSize; | int64_t h0 = kCubeSize; | ||||
| switch (src_shape.size()) { | switch (src_shape.size()) { | ||||
| case 1: | |||||
| dst_shape.push_back(Ceil(src_shape[0], w0)); | |||||
| dst_shape.push_back(1); | |||||
| case kSingleDim: | |||||
| dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); | |||||
| dst_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| dst_shape.push_back(h0); | dst_shape.push_back(h0); | ||||
| dst_shape.push_back(w0); | dst_shape.push_back(w0); | ||||
| hw_shape.push_back(1); | |||||
| hw_shape.push_back(1); | |||||
| hw_shape.push_back(src_shape[0]); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||||
| if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -76,17 +94,17 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| default: | default: | ||||
| auto size = src_shape.size(); | auto size = src_shape.size(); | ||||
| int64_t times = 1; | int64_t times = 1; | ||||
| for (size_t i = 0; i != size - 2; i++) { | |||||
| for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) { | |||||
| dst_shape.push_back(src_shape[i]); | dst_shape.push_back(src_shape[i]); | ||||
| times *= src_shape[i]; | times *= src_shape[i]; | ||||
| } | } | ||||
| dst_shape.push_back(Ceil(src_shape[size - 1], w0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - 2], h0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); | |||||
| dst_shape.push_back(h0); | dst_shape.push_back(h0); | ||||
| dst_shape.push_back(w0); | dst_shape.push_back(w0); | ||||
| hw_shape.push_back(times); | hw_shape.push_back(times); | ||||
| hw_shape.push_back(src_shape[size - 2]); | |||||
| hw_shape.push_back(src_shape[size - 1]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||||
| if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -128,16 +146,16 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
| } | } | ||||
| // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
| auto times = hw_shape.at(0); | |||||
| auto h = hw_shape.at(1); | |||||
| auto w = hw_shape.at(2); | |||||
| auto times = hw_shape.at(kNdDimIndexN); | |||||
| auto h = hw_shape.at(kNdDimIndexH); | |||||
| auto w = hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | auto hw = h * w; | ||||
| auto shape_size = args.dst_shape.size(); | auto shape_size = args.dst_shape.size(); | ||||
| auto w1 = args.dst_shape[shape_size - 4]; | |||||
| auto h1 = args.dst_shape[shape_size - 3]; | |||||
| auto h0 = args.dst_shape[shape_size - 2]; | |||||
| auto w0 = args.dst_shape[shape_size - 1]; | |||||
| auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; | |||||
| auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; | |||||
| auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0]; | |||||
| auto h1h0 = h1 * h0; | auto h1h0 = h1 * h0; | ||||
| auto h1h0w0 = h1h0 * w0; | auto h1h0w0 = h1h0 * w0; | ||||
| auto w1h1h0w0 = w1 * h1h0w0; | auto w1h1h0w0 = w1 * h1h0w0; | ||||
| @@ -198,16 +216,16 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
| return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
| } | } | ||||
| auto times = dst_hw_shape.at(0); | |||||
| auto h = dst_hw_shape.at(1); | |||||
| auto w = dst_hw_shape.at(2); | |||||
| auto times = dst_hw_shape.at(kNdDimIndexN); | |||||
| auto h = dst_hw_shape.at(kNdDimIndexH); | |||||
| auto w = dst_hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | auto hw = h * w; | ||||
| auto shape_size = args.src_shape.size(); | auto shape_size = args.src_shape.size(); | ||||
| auto w1 = args.src_shape[shape_size - 4]; | |||||
| auto h1 = args.src_shape[shape_size - 3]; | |||||
| auto h0 = args.src_shape[shape_size - 2]; | |||||
| auto w0 = args.src_shape[shape_size - 1]; | |||||
| auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; | |||||
| auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; | |||||
| auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0]; | |||||
| auto h1h0 = h1 * h0; | auto h1h0 = h1 * h0; | ||||
| auto h1h0w0 = h1h0 * w0; | auto h1h0w0 = h1h0 * w0; | ||||
| auto w1h1h0w0 = w1 * h1h0w0; | auto w1h1h0w0 = w1 * h1h0w0; | ||||
| @@ -23,12 +23,29 @@ | |||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| const int kDimSize4D = 4; | const int kDimSize4D = 4; | ||||
| const size_t kSingleDim = 1; | |||||
| const size_t kNdDimIndexN = 0; | |||||
| const size_t kNdDimIndexH = 1; | |||||
| const size_t kNdDimIndexW = 2; | |||||
| const size_t kDimDValueBNdFZz = 2; // dim d-value between Nd and FractalZz | |||||
| const size_t kNdDimCountBackwardsW = 1; | |||||
| const size_t kNdDimCountBackwardsWH = 2; | |||||
| const size_t kFZzDimCountBackwardsW0 = 1; | |||||
| const size_t kFZzDimCountBackwardsW0H0 = 2; | |||||
| const size_t kFZzDimCountBackwardsW0H0W1 = 3; | |||||
| const size_t kFZzDimCountBackwardsW0H0W1H1 = 4; | |||||
| bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } | bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } | ||||
| using ShapeVector = std::vector<int64_t>; | using ShapeVector = std::vector<int64_t>; | ||||
| @@ -40,8 +57,8 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
| case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
| return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
| default: | default: | ||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_ZZ is not supported."; | |||||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
| " and FORMAT_FRACTAL_ZZ is not supported."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -60,14 +77,14 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| auto w0 = GetCubeSizeByDataType(data_type); | auto w0 = GetCubeSizeByDataType(data_type); | ||||
| auto h0 = GetCubeSizeByDataType(data_type); | auto h0 = GetCubeSizeByDataType(data_type); | ||||
| switch (src_shape.size()) { | switch (src_shape.size()) { | ||||
| case 1: | |||||
| dst_shape.push_back(1); | |||||
| dst_shape.push_back(Ceil(src_shape[0], w0)); | |||||
| case kSingleDim: | |||||
| dst_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); | |||||
| dst_shape.push_back(h0); | dst_shape.push_back(h0); | ||||
| dst_shape.push_back(w0); | dst_shape.push_back(w0); | ||||
| hw_shape.push_back(1); | |||||
| hw_shape.push_back(1); | |||||
| hw_shape.push_back(src_shape[0]); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(DIM_DEFAULT_VALUE); | |||||
| hw_shape.push_back(src_shape[kNdDimIndexN]); | |||||
| if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -76,17 +93,17 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap | |||||
| default: | default: | ||||
| auto size = src_shape.size(); | auto size = src_shape.size(); | ||||
| int64_t times = 1; | int64_t times = 1; | ||||
| for (size_t i = 0; i != size - 2; i++) { | |||||
| for (size_t i = 0; i != size - kDimDValueBNdFZz; i++) { | |||||
| dst_shape.push_back(src_shape[i]); | dst_shape.push_back(src_shape[i]); | ||||
| times *= src_shape[i]; | times *= src_shape[i]; | ||||
| } | } | ||||
| dst_shape.push_back(Ceil(src_shape[size - 2], h0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - 1], w0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); | |||||
| dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); | |||||
| dst_shape.push_back(h0); | dst_shape.push_back(h0); | ||||
| dst_shape.push_back(w0); | dst_shape.push_back(w0); | ||||
| hw_shape.push_back(times); | hw_shape.push_back(times); | ||||
| hw_shape.push_back(src_shape[size - 2]); | |||||
| hw_shape.push_back(src_shape[size - 1]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); | |||||
| hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); | |||||
| if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -127,16 +144,16 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
| } | } | ||||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
| auto times = hw_shape.at(0); | |||||
| auto h = hw_shape.at(1); | |||||
| auto w = hw_shape.at(2); | |||||
| auto times = hw_shape.at(kNdDimIndexN); | |||||
| auto h = hw_shape.at(kNdDimIndexH); | |||||
| auto w = hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | auto hw = h * w; | ||||
| auto shape_size = args.dst_shape.size(); | auto shape_size = args.dst_shape.size(); | ||||
| auto h1 = args.dst_shape[shape_size - 4]; | |||||
| auto w1 = args.dst_shape[shape_size - 3]; | |||||
| auto h0 = args.dst_shape[shape_size - 2]; | |||||
| auto w0 = args.dst_shape[shape_size - 1]; | |||||
| auto h1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; | |||||
| auto w1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; | |||||
| auto h0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0]; | |||||
| auto h0w0 = h0 * w0; | auto h0w0 = h0 * w0; | ||||
| auto w1h0w0 = w1 * h0w0; | auto w1h0w0 = w1 * h0w0; | ||||
| auto h1w1h0w0 = h1 * w1h0w0; | auto h1w1h0w0 = h1 * w1h0w0; | ||||
| @@ -155,8 +172,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (src_h_head + w1_idx * w0) * size; | auto src_offset = (src_h_head + w1_idx * w0) * size; | ||||
| auto dst_offset = (h0_head + w1_idx * h0w0) * size; | auto dst_offset = (h0_head + w1_idx * h0w0) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -171,8 +188,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (src_h_head + src_w_idx) * size; | auto src_offset = (src_h_head + src_w_idx) * size; | ||||
| auto dst_offset = (w0_head + w0_idx) * size; | auto dst_offset = (w0_head + w0_idx) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -205,16 +222,16 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
| } | } | ||||
| // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D | ||||
| auto times = dst_hw_shape.at(0); | |||||
| auto h = dst_hw_shape.at(1); | |||||
| auto w = dst_hw_shape.at(2); | |||||
| auto times = dst_hw_shape.at(kNdDimIndexN); | |||||
| auto h = dst_hw_shape.at(kNdDimIndexH); | |||||
| auto w = dst_hw_shape.at(kNdDimIndexW); | |||||
| auto hw = h * w; | auto hw = h * w; | ||||
| auto shape_size = args.src_shape.size(); | auto shape_size = args.src_shape.size(); | ||||
| auto h1 = args.src_shape[shape_size - 4]; | |||||
| auto w1 = args.src_shape[shape_size - 3]; | |||||
| auto h0 = args.src_shape[shape_size - 2]; | |||||
| auto w0 = args.src_shape[shape_size - 1]; | |||||
| auto h1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; | |||||
| auto w1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; | |||||
| auto h0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0]; | |||||
| auto w0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0]; | |||||
| auto h0w0 = h0 * w0; | auto h0w0 = h0 * w0; | ||||
| auto w1h0w0 = w1 * h0w0; | auto w1h0w0 = w1 * h0w0; | ||||
| auto h1w1h0w0 = h1 * w1h0w0; | auto h1w1h0w0 = h1 * w1h0w0; | ||||
| @@ -233,8 +250,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (h0_head + w1_idx * h0w0) * size; | auto src_offset = (h0_head + w1_idx * h0w0) * size; | ||||
| auto dst_offset = (dst_h_head + w1_idx * w0) * size; | auto dst_offset = (dst_h_head + w1_idx * w0) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -249,8 +266,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
| auto dst_w_idx = w1_head + w0_idx; | auto dst_w_idx = w1_head + w0_idx; | ||||
| auto dst_offset = (dst_h_head + dst_w_idx) * size; | auto dst_offset = (dst_h_head + dst_w_idx) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -35,7 +35,6 @@ | |||||
| * Padding to (N, ceil(Z/16)*16) | * Padding to (N, ceil(Z/16)*16) | ||||
| * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) | * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) | ||||
| */ | */ | ||||
| namespace ge { | namespace ge { | ||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <securec.h> | #include <securec.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include "common/formats/utils/formats_definitions.h" | |||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| @@ -29,21 +30,21 @@ namespace formats { | |||||
| namespace { | namespace { | ||||
| std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | ||||
| {FORMAT_NCHW, | {FORMAT_NCHW, | ||||
| {{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {{FORMAT_NHWC, std::vector<int64_t>({kNchwN, kNchwH, kNchwW, kNchwC})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kNchwH, kNchwW, kNchwC, kNchwN})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kNchwC, kNchwH, kNchwW, kNchwN})}}}, | |||||
| {FORMAT_NHWC, | {FORMAT_NHWC, | ||||
| {{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kNhwcN, kNhwcC, kNhwcH, kNhwcW})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kNhwcC, kNhwcH, kNhwcW, kNhwcN})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}}, | |||||
| {FORMAT_HWCN, | {FORMAT_HWCN, | ||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}}, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kHwcnN, kHwcnC, kHwcnH, kHwcnW})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({kHwcnN, kHwcnH, kHwcnW, kHwcnC})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}}, | |||||
| {FORMAT_CHWN, | {FORMAT_CHWN, | ||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}}, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({kChwnN, kChwnC, kChwnH, kChwnW})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({kChwnN, kChwnH, kChwnW, kChwnC})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}}, | |||||
| }; | }; | ||||
| bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
| @@ -23,6 +23,7 @@ static const int kCubeSize = 16; | |||||
| static const int kNiSize = 16; | static const int kNiSize = 16; | ||||
| static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; | static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; | ||||
| enum NchwDimIndex { | enum NchwDimIndex { | ||||
| kNchwN, | kNchwN, | ||||
| kNchwC, | kNchwC, | ||||
| @@ -47,6 +48,14 @@ enum HwcnDimIndex { | |||||
| kHwcnDimsNum | kHwcnDimsNum | ||||
| }; | }; | ||||
| enum ChwnDimIndex { | |||||
| kChwnC, | |||||
| kChwnH, | |||||
| kChwnW, | |||||
| kChwnN, | |||||
| kChwnDimsNum | |||||
| }; | |||||
| enum Nc1hwc0DimIndex { | enum Nc1hwc0DimIndex { | ||||
| kNc1hwc0N, | kNc1hwc0N, | ||||
| kNc1hwc0C1, | kNc1hwc0C1, | ||||
| @@ -123,7 +123,10 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec | |||||
| if (handle == nullptr) { | if (handle == nullptr) { | ||||
| const char *error = mmDlerror(); | const char *error = mmDlerror(); | ||||
| GE_IF_BOOL_EXEC(error == nullptr, error = ""); | GE_IF_BOOL_EXEC(error == nullptr, error = ""); | ||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen %s!", error); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"mmDlopen", "shared library path is " + FmtToStr(file_path_dlopen) + ". Errormessage" + FmtToStr(error)}); | |||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen the shared library path[%s]. Errormessage[%s]!", | |||||
| file_path_dlopen.c_str(), error); | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -132,6 +135,9 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec | |||||
| for (const auto &func_name : func_check_list) { | for (const auto &func_name : func_check_list) { | ||||
| auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str())); | auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str())); | ||||
| if (real_fn == nullptr) { | if (real_fn == nullptr) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"mmDlsym", FmtToStr(func_name) + " is skipped since function" + | |||||
| FmtToStr(func_name) + " is not existed!"}); | |||||
| GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), | GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), | ||||
| func_name.c_str()); | func_name.c_str()); | ||||
| is_valid = false; | is_valid = false; | ||||
| @@ -37,6 +37,8 @@ | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| const int kBaseInt = 10; | |||||
| std::map<string, string> TBEPluginManager::options_ = {}; | std::map<string, string> TBEPluginManager::options_ = {}; | ||||
| // Get Singleton Instance | // Get Singleton Instance | ||||
| @@ -155,7 +157,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) { | |||||
| domi::FrameworkType type = domi::TENSORFLOW; | domi::FrameworkType type = domi::TENSORFLOW; | ||||
| auto it = options_.find(FRAMEWORK_TYPE); | auto it = options_.find(FRAMEWORK_TYPE); | ||||
| if (it != options_.end()) { | if (it != options_.end()) { | ||||
| type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | |||||
| type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, kBaseInt)); | |||||
| } | } | ||||
| fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); | fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); | ||||
| GELOGI("Framework type is %s.", fmk_type.c_str()); | GELOGI("Framework type is %s.", fmk_type.c_str()); | ||||
| @@ -7,6 +7,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ | |||||
| helper/om_file_helper.cc \ | helper/om_file_helper.cc \ | ||||
| helper/model_helper.cc \ | helper/model_helper.cc \ | ||||
| ../model/ge_model.cc \ | ../model/ge_model.cc \ | ||||
| ../model/ge_root_model.cc \ | |||||
| auth/file_saver.cc \ | auth/file_saver.cc \ | ||||
| fp16_t.cc \ | fp16_t.cc \ | ||||
| math/fp16_math.cc \ | math/fp16_math.cc \ | ||||
| @@ -32,6 +32,7 @@ using domi::ModelTaskDef; | |||||
| namespace { | namespace { | ||||
| const int64_t kOriginalOmPartitionNum = 1; | const int64_t kOriginalOmPartitionNum = 1; | ||||
| const uint32_t kStatiOmFileModelNum = 1; | |||||
| } | } | ||||
| @@ -39,7 +40,7 @@ namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | ||||
| Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | ||||
| const uint8_t *data, size_t size) { | |||||
| const uint8_t *data, size_t size, size_t model_index) { | |||||
| if (size < 1 || size > UINT32_MAX) { | if (size < 1 || size > UINT32_MAX) { | ||||
| GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); | GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); | ||||
| if (size > UINT32_MAX) { | if (size > UINT32_MAX) { | ||||
| @@ -68,25 +69,16 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
| partition_model.data = const_cast<uint8_t *>(data); | partition_model.data = const_cast<uint8_t *>(data); | ||||
| partition_model.size = static_cast<uint32_t>(size); | partition_model.size = static_cast<uint32_t>(size); | ||||
| partition_model.type = type; | partition_model.type = type; | ||||
| if (om_file_save_helper->AddPartition(partition_model) != SUCCESS) { | |||||
| if (om_file_save_helper->AddPartition(partition_model, model_index) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu", size); | GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu", size); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, | |||||
| const SaveParam &save_param, | |||||
| const std::string &output_file, | |||||
| ModelBufferData& model) { | |||||
| if (output_file.empty()) { | |||||
| GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); | |||||
| return FAILED; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED); | |||||
| std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>(); | |||||
| GE_CHECK_NOTNULL(om_file_save_helper); | |||||
| Status ModelHelper::SaveModelDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { | |||||
| ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); | ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); | ||||
| if (model_tmp == nullptr) { | if (model_tmp == nullptr) { | ||||
| GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); | GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); | ||||
| @@ -96,16 +88,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| model_tmp->SetVersion(ge_model->GetVersion()); | model_tmp->SetVersion(ge_model->GetVersion()); | ||||
| model_tmp->SetAttr(ge_model->MutableAttrMap()); | model_tmp->SetAttr(ge_model->MutableAttrMap()); | ||||
| ge::Buffer model_buffer; | |||||
| (void)model_tmp->Save(model_buffer); | (void)model_tmp->Save(model_buffer); | ||||
| GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); | GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); | ||||
| if (model_buffer.GetSize() > 0) { | if (model_buffer.GetSize() > 0) { | ||||
| if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(), | if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(), | ||||
| model_buffer.GetSize()) != SUCCESS) { | |||||
| model_buffer.GetSize(), model_index) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Add model graph partition failed"); | GELOGE(PARAM_INVALID, "Add model graph partition failed"); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveModelWeights(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, size_t model_index) { | |||||
| auto ge_model_weight = ge_model->GetWeight(); | auto ge_model_weight = ge_model->GetWeight(); | ||||
| GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | ||||
| // weight is not necessary | // weight is not necessary | ||||
| @@ -113,31 +110,43 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, | GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, | ||||
| ModelPartitionType::WEIGHTS_DATA, | ModelPartitionType::WEIGHTS_DATA, | ||||
| ge_model_weight.GetData(), | ge_model_weight.GetData(), | ||||
| ge_model_weight.GetSize()), "Add weight partition failed"); | |||||
| ge_model_weight.GetSize(), model_index), "Add weight partition failed"); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveModelTbeKernel(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, size_t model_index) { | |||||
| TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | ||||
| GELOGD("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); | GELOGD("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); | ||||
| if (tbe_kernel_store.DataSize() > 0) { | if (tbe_kernel_store.DataSize() > 0) { | ||||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, | |||||
| ModelPartitionType::TBE_KERNELS, | |||||
| tbe_kernel_store.Data(), | |||||
| tbe_kernel_store.DataSize()), "Add tbe kernel partition failed"); | |||||
| GE_CHK_STATUS_RET( | |||||
| SaveModelPartition(om_file_save_helper, ModelPartitionType::TBE_KERNELS, | |||||
| ge_model->GetTBEKernelStore().Data(), ge_model->GetTBEKernelStore().DataSize(), | |||||
| model_index), "Add tbe kernel partition failed"); | |||||
| } | } | ||||
| // no need to check value, DATA->NetOutput | // no need to check value, DATA->NetOutput | ||||
| (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); | (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveModelCustAICPU(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, size_t model_index) { | |||||
| CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); | CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); | ||||
| GELOGD("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); | GELOGD("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); | ||||
| if (cust_aicpu_kernel_store.DataSize() > 0) { | if (cust_aicpu_kernel_store.DataSize() > 0) { | ||||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, | GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, | ||||
| ModelPartitionType::CUST_AICPU_KERNELS, | ModelPartitionType::CUST_AICPU_KERNELS, | ||||
| cust_aicpu_kernel_store.Data(), | |||||
| cust_aicpu_kernel_store.DataSize()), | |||||
| ge_model->GetCustAICPUKernelStore().Data(), | |||||
| cust_aicpu_kernel_store.DataSize(), model_index), | |||||
| "Add cust aicpu kernel partition failed"); | "Add cust aicpu kernel partition failed"); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveModelTaskDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, ge::Buffer &task_buffer, size_t model_index) { | |||||
| std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | ||||
| if (model_task_def == nullptr) { | if (model_task_def == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); | ||||
| @@ -146,9 +155,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| size_t partition_task_size = model_task_def->ByteSizeLong(); | size_t partition_task_size = model_task_def->ByteSizeLong(); | ||||
| GE_IF_BOOL_EXEC(partition_task_size == 0 || partition_task_size > INT_MAX, | GE_IF_BOOL_EXEC(partition_task_size == 0 || partition_task_size > INT_MAX, | ||||
| GELOGE(FAILED, "Model_def's byte size (%zu) is invalid!", partition_task_size); | GELOGE(FAILED, "Model_def's byte size (%zu) is invalid!", partition_task_size); | ||||
| return FAILED); | |||||
| return FAILED); | |||||
| ge::Buffer task_buffer(partition_task_size); | |||||
| task_buffer = ge::Buffer(partition_task_size); | |||||
| if (task_buffer.GetSize() == 0) { | if (task_buffer.GetSize() == 0) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc model task def buffer failed"); | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc model task def buffer failed"); | ||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| @@ -159,21 +168,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| GELOGD("TASK_INFO size is %zu", partition_task_size); | GELOGD("TASK_INFO size is %zu", partition_task_size); | ||||
| if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TASK_INFO, task_buffer.GetData(), | if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TASK_INFO, task_buffer.GetData(), | ||||
| partition_task_size) != SUCCESS) { | |||||
| partition_task_size, model_index) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Add model task def partition failed"); | GELOGE(PARAM_INVALID, "Add model task def partition failed"); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveModelHeader(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | |||||
| const GeModelPtr &ge_model, size_t model_num) { | |||||
| // Save target/version to model_header | // Save target/version to model_header | ||||
| ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); | ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); | ||||
| model_header.platform_type = ge_model->GetPlatformType(); | model_header.platform_type = ge_model->GetPlatformType(); | ||||
| model_header.om_ir_version = ge_model->GetVersion(); | model_header.om_ir_version = ge_model->GetVersion(); | ||||
| model_header.model_num = model_num; | |||||
| std::string platform_version = ge_model->GetPlatformVersion(); | std::string platform_version = ge_model->GetPlatformVersion(); | ||||
| errno_t err; | errno_t err; | ||||
| err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), | err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), | ||||
| platform_version.size() + 1); | platform_version.size() + 1); | ||||
| if (err != EOK) { | if (err != EOK) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelHelper SaveModel failed while allocating memory for platform_version."); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||||
| "ModelHelper SaveModel failed while allocating memory for platform_version."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } | } | ||||
| string version = reinterpret_cast<char *>(model_header.platform_version); | string version = reinterpret_cast<char *>(model_header.platform_version); | ||||
| @@ -188,8 +204,142 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| } | } | ||||
| string model_name = reinterpret_cast<char *>(model_header.name); | string model_name = reinterpret_cast<char *>(model_header.name); | ||||
| GELOGD("Model name save:%s", model_name.c_str()); | GELOGD("Model name save:%s", model_name.c_str()); | ||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::SaveAllModelPartiton(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, | |||||
| const GeModelPtr &ge_model, ge::Buffer &model_buffer, | |||||
| ge::Buffer &task_buffer, size_t model_index) { | |||||
| if (SaveModelDef(om_file_save_helper, ge_model, model_buffer, model_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save model def failed"); | |||||
| return FAILED; | |||||
| } | |||||
| if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save model weights failed"); | |||||
| return FAILED; | |||||
| } | |||||
| if (SaveModelTbeKernel(om_file_save_helper, ge_model, model_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save model tbe kernel failed"); | |||||
| return FAILED; | |||||
| } | |||||
| if (SaveModelCustAICPU(om_file_save_helper, ge_model, model_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save model cust ai cpu failed"); | |||||
| return FAILED; | |||||
| } | |||||
| if (SaveModelTaskDef(om_file_save_helper, ge_model, task_buffer, model_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save task def failed"); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, | |||||
| const SaveParam &save_param, | |||||
| const std::string &output_file, | |||||
| ModelBufferData& model) { | |||||
| if (output_file.empty()) { | |||||
| GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); | |||||
| return FAILED; | |||||
| } | |||||
| Status ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_); | |||||
| GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED); | |||||
| std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>(); | |||||
| GE_CHECK_NOTNULL(om_file_save_helper); | |||||
| ge::Buffer model_buffer; | |||||
| ge::Buffer task_buffer; | |||||
| auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffer, task_buffer); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "save all model partition failed"); | |||||
| return ret; | |||||
| } | |||||
| ret = SaveModelHeader(om_file_save_helper, ge_model); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "save model header failed"); | |||||
| return ret; | |||||
| } | |||||
| ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRootModel( | |||||
| const GeRootModelPtr &ge_root_model, | |||||
| const SaveParam &save_param, | |||||
| const std::string &output_file, | |||||
| ModelBufferData& model, | |||||
| bool is_unknown_shape) { | |||||
| GE_CHECK_NOTNULL(ge_root_model); | |||||
| GE_IF_BOOL_EXEC(ge_root_model == nullptr, GELOGE(FAILED, "Ge_root_model is nullptr"); return FAILED); | |||||
| auto &name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GE_IF_BOOL_EXEC(name_to_ge_model.empty(), GELOGE(FAILED, "Ge_root_model has no sub model"); return FAILED); | |||||
| GE_IF_BOOL_EXEC(output_file.empty(), | |||||
| GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); | |||||
| return FAILED); | |||||
| if (!is_unknown_shape) { | |||||
| auto &model_root = name_to_ge_model.begin()->second; | |||||
| return SaveToOmModel(model_root, save_param, output_file, model); | |||||
| } | |||||
| std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>(); | |||||
| GE_CHECK_NOTNULL(om_file_save_helper); | |||||
| auto &first_ge_model = name_to_ge_model.at(ge_root_model->GetRootGraph()->GetName()); | |||||
| // ge root model must be the first to be loaded | |||||
| vector<string> model_names{ge_root_model->GetRootGraph()->GetName()}; | |||||
| for (auto &item : name_to_ge_model) { | |||||
| if (item.first != model_names.front()) { | |||||
| model_names.emplace_back(item.first); | |||||
| } | |||||
| } | |||||
| vector<ge::Buffer> model_buffers(model_names.size()); | |||||
| vector<ge::Buffer> task_buffers(model_names.size()); | |||||
| size_t cur_index = 0; | |||||
| if (model_names.size() > 1) { | |||||
| GELOGD("only save first model MODEL_DEF"); | |||||
| if (SaveModelDef(om_file_save_helper, first_ge_model, model_buffers[cur_index], cur_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "save model def failed"); | |||||
| return FAILED; | |||||
| } | |||||
| ++cur_index; | |||||
| } | |||||
| for (; cur_index < model_names.size(); ++cur_index) { | |||||
| auto model_name = model_names[cur_index]; | |||||
| GELOGD("cur model %s index is %zu", model_name.c_str(), cur_index); | |||||
| const GeModelPtr &ge_model = name_to_ge_model.at(model_name); | |||||
| auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffers[cur_index], | |||||
| task_buffers[cur_index], cur_index); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Save model %s failed", model_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| auto ret = SaveModelHeader(om_file_save_helper, first_ge_model, model_names.size()); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Save model %s header failed", first_ge_model->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| ret = om_file_save_helper->SaveRootModel(save_param, output_file.c_str(), model, is_offline_); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); | GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -288,7 +438,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||||
| } | } | ||||
| file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data); | file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data); | ||||
| OmFileLoadHelper om_load_helper; | OmFileLoadHelper om_load_helper; | ||||
| status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); | status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| @@ -310,7 +459,61 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||||
| GELOGE(status, "GenerateGeModel failed"); | GELOGE(status, "GenerateGeModel failed"); | ||||
| return status; | return status; | ||||
| } | } | ||||
| GELOGD("in ModelHelper::LoadModel, is_assign_model_ is setted to true!"); | |||||
| is_assign_model_ = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | |||||
| if (model_data.model_data == nullptr || model_data.model_len == 0) { | |||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); | |||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | |||||
| if (is_assign_model_) { | |||||
| GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!"); | |||||
| return GE_EXEC_LOAD_MODEL_REPEATED; | |||||
| } | |||||
| if (ReleaseLocalModelData() != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Parse model content failed!"); | |||||
| return status; | |||||
| } | |||||
| file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data); | |||||
| //model verison 1.0 file header does not have model_num member | |||||
| is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION && | |||||
| file_header_->model_num > kStatiOmFileModelNum; | |||||
| GELOGD("cur om model is ge root model or no %d, model version %zu", is_unknown_shape_model_, file_header_->version); | |||||
| OmFileLoadHelper om_load_helper; | |||||
| if (is_unknown_shape_model_) { | |||||
| auto model_num = file_header_->model_num; | |||||
| status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_, model_num); | |||||
| } else { | |||||
| status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); | |||||
| } | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Om_load_helper init failed"); | |||||
| model_addr_tmp_ = nullptr; | |||||
| return status; | |||||
| } | |||||
| // Encrypt model need to del temp model/no encrypt model don't need to del model | |||||
| model_addr_tmp_ = nullptr; | |||||
| status = GenerateGeRootModel(om_load_helper); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "GenerateGeRootModel failed"); | |||||
| return status; | |||||
| } | |||||
| GELOGD("in ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); | |||||
| is_assign_model_ = true; | is_assign_model_ = true; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -341,6 +544,61 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { | |||||
| GELOGD("Begin to generate ge root model"); | |||||
| root_model_ = ge::MakeShared<ge::GeRootModel>(); | |||||
| GE_CHECK_NOTNULL(root_model_); | |||||
| if (!is_unknown_shape_model_) { | |||||
| if (GenerateGeModel(om_load_helper) != SUCCESS) { | |||||
| GELOGE(FAILED, "GenerateGeModel failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(model_); | |||||
| root_model_->SetRootGraph(GraphUtils::GetComputeGraph(model_->GetGraph())); | |||||
| return SUCCESS; | |||||
| } | |||||
| bool is_first_model = true; | |||||
| for (size_t mode_index = 0; mode_index < file_header_->model_num; ++mode_index) { | |||||
| GeModelPtr cur_model = ge::MakeShared<ge::GeModel>(); | |||||
| Status ret = LoadModelData(om_load_helper, cur_model, mode_index); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_MODEL_PARTITION_FAILED; | |||||
| } | |||||
| if (is_first_model) { | |||||
| is_first_model = false; | |||||
| root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | |||||
| root_model_->SetModelId(cur_model->GetModelId()); | |||||
| model_ = cur_model; | |||||
| continue; | |||||
| } | |||||
| ret = LoadWeights(om_load_helper, cur_model, mode_index); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED; | |||||
| } | |||||
| ret = LoadTBEKernelStore(om_load_helper, cur_model, mode_index); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | |||||
| } | |||||
| ret = LoadCustAICPUKernelStore(om_load_helper, cur_model, mode_index); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | |||||
| } | |||||
| ret = LoadTask(om_load_helper, cur_model, mode_index); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_TASK_PARTITION_FAILED; | |||||
| } | |||||
| root_model_->SetSubgraphInstanceNameToModel(cur_model->GetName(), cur_model); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper) { | Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper) { | ||||
| ModelPartition partition_model_def; | ModelPartition partition_model_def; | ||||
| // no need to check value, DATA->NetOutput | // no need to check value, DATA->NetOutput | ||||
| @@ -366,6 +624,28 @@ void ModelHelper::SetModelToGeModel(ge::Model &model) { | |||||
| model_->SetAttr(model.MutableAttrMap()); | model_->SetAttr(model.MutableAttrMap()); | ||||
| } | } | ||||
| Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { | |||||
| ModelPartition partition_model_def; | |||||
| // no need to check value, DATA->NetOutput | |||||
| om_load_helper.GetModelPartition(ModelPartitionType::MODEL_DEF, partition_model_def, mode_index); | |||||
| GELOGD("Model_def partition addr:%p,size:%u", partition_model_def.data, partition_model_def.size); | |||||
| ge::Model model; | |||||
| if (ge::Model::Load(partition_model_def.data, partition_model_def.size, model) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Load model failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| cur_model->SetGraph(model.GetGraph()); | |||||
| cur_model->SetName(model.GetName()); | |||||
| cur_model->SetVersion(model.GetVersion()); | |||||
| cur_model->SetPlatformVersion(model.GetPlatformVersion()); | |||||
| cur_model->SetAttr(model.MutableAttrMap()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { | Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { | ||||
| ModelPartition partition; | ModelPartition partition; | ||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { | if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { | ||||
| @@ -379,6 +659,19 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { | |||||
| ModelPartition partition; | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get weight model partition failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); | |||||
| cur_model->SetWeight(weight); | |||||
| GELOGD("GetWeight size:%u", partition.size); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { | ||||
| ModelPartition task_partition; | ModelPartition task_partition; | ||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { | if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { | ||||
| @@ -398,6 +691,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(Om | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, | |||||
| GeModelPtr &cur_model, | |||||
| size_t mode_index) { | |||||
| ModelPartition task_partition; | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition, mode_index) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get task model partition failed."); | |||||
| return FAILED; | |||||
| } | |||||
| std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>(); | |||||
| GE_CHECK_NOTNULL(task); | |||||
| if (task_partition.size != 0) { | |||||
| if (!ReadProtoFromArray(task_partition.data, task_partition.size, task.get())) { | |||||
| GELOGE(INTERNAL_ERROR, "ReadProtoFromArray failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("TASK_INFO op_size:%zu, stream_num:%u", task->op().size(), task->stream_num()); | |||||
| } | |||||
| cur_model->SetModelTaskDef(task); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { | Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { | ||||
| // Load tbe kernels | // Load tbe kernels | ||||
| ModelPartition partition_kernel_def; | ModelPartition partition_kernel_def; | ||||
| @@ -414,6 +728,23 @@ Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { | |||||
| // Load tbe kernels | |||||
| ModelPartition partition_kernel_def; | |||||
| TBEKernelStore kernel_store; | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::TBE_KERNELS, partition_kernel_def, mode_index) == | |||||
| SUCCESS) { | |||||
| GELOGD("Kernels partition size:%u", partition_kernel_def.size); | |||||
| if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { | |||||
| GELOGD("Load tbe kernels success"); | |||||
| } else { | |||||
| GELOGW("Load tbe kernels failed"); | |||||
| } | |||||
| } | |||||
| cur_model->SetTBEKernelStore(kernel_store); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { | Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { | ||||
| // Load cust aicpu kernels | // Load cust aicpu kernels | ||||
| ModelPartition partition_kernel_def; | ModelPartition partition_kernel_def; | ||||
| @@ -421,19 +752,39 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) { | if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) { | ||||
| GELOGD("Kernels partition size:%u", partition_kernel_def.size); | GELOGD("Kernels partition size:%u", partition_kernel_def.size); | ||||
| if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { | if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { | ||||
| GELOGI("Load cust aicpu kernels success"); | |||||
| GELOGD("Load cust aicpu kernels success"); | |||||
| } else { | |||||
| GELOGW("Load cust aicpu kernels failed"); | |||||
| } | } | ||||
| } | } | ||||
| model_->SetCustAICPUKernelStore(kernel_store); | model_->SetCustAICPUKernelStore(kernel_store); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, | |||||
| GeModelPtr &cur_model, size_t mode_index) { | |||||
| // Load cust aicpu kernels | |||||
| ModelPartition partition_kernel_def; | |||||
| CustAICPUKernelStore kernel_store; | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def, mode_index) | |||||
| == SUCCESS) { | |||||
| GELOGD("Kernels partition size:%u", partition_kernel_def.size); | |||||
| if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { | |||||
| GELOGD("Load cust aicpu kernels success"); | |||||
| } else { | |||||
| GELOGW("Load cust aicpu kernels failed"); | |||||
| } | |||||
| } | |||||
| cur_model->SetCustAICPUKernelStore(kernel_store); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { | ||||
| if (model_ != nullptr) { | if (model_ != nullptr) { | ||||
| return model_; | return model_; | ||||
| } | } | ||||
| GELOGI("Model has not been loaded!"); | |||||
| GELOGD("Model has not been loaded!"); | |||||
| std::shared_ptr<ge::GeModel> out_model = ge::MakeShared<ge::GeModel>(); | std::shared_ptr<ge::GeModel> out_model = ge::MakeShared<ge::GeModel>(); | ||||
| if (out_model == nullptr) { | if (out_model == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -441,6 +792,20 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo | |||||
| return out_model; | return out_model; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::GetGeRootModel() { | |||||
| if (root_model_ != nullptr) { | |||||
| return root_model_; | |||||
| } | |||||
| GELOGD("Model has not been loaded!"); | |||||
| std::shared_ptr<ge::GeRootModel> out_model = ge::MakeShared<ge::GeRootModel>(); | |||||
| if (out_model == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return out_model; | |||||
| } | |||||
| Status ModelHelper::ReleaseLocalModelData() noexcept { | Status ModelHelper::ReleaseLocalModelData() noexcept { | ||||
| Status result = SUCCESS; | Status result = SUCCESS; | ||||
| if (model_addr_tmp_ != nullptr) { | if (model_addr_tmp_ != nullptr) { | ||||
| @@ -52,6 +52,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, | |||||
| uint32_t model_data_size, | |||||
| uint32_t model_num) { | |||||
| Status status = LoadModelPartitionTable(model_data, model_data_size, model_num); | |||||
| if (status != SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| is_inited_ = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| // Use both | // Use both | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, | ||||
| ModelPartition &partition) { | ModelPartition &partition) { | ||||
| @@ -79,6 +90,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, | |||||
| ModelPartition &partition, | |||||
| size_t model_index) { | |||||
| if (!is_inited_) { | |||||
| GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (model_index >= model_contexts_.size()) { | |||||
| GELOGE(PARAM_INVALID, "cur index : %zu, model_contexts size:%zu", model_index, model_contexts_.size()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| auto &cur_ctx = model_contexts_[model_index]; | |||||
| bool found = false; | |||||
| for (ModelPartition &part : cur_ctx.partition_datas_) { | |||||
| if (part.type == type) { | |||||
| partition = part; | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!found) { | |||||
| if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA && | |||||
| type != ModelPartitionType::CUST_AICPU_KERNELS) { | |||||
| GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type)); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { | Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { | ||||
| // Parameter validity check | // Parameter validity check | ||||
| if (model.model_data == nullptr) { | if (model.model_data == nullptr) { | ||||
| @@ -138,7 +180,8 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||||
| context_.partition_datas_.push_back(partition); | context_.partition_datas_.push_back(partition); | ||||
| if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { | if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "The partition size %zu is greater than the model data size %u.", | |||||
| partition.size + mem_offset, model_data_size); | partition.size + mem_offset, model_data_size); | ||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | ||||
| } | } | ||||
| @@ -148,6 +191,61 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) { | |||||
| if (model_data == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Param model_data must not be null!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| uint32_t cur_offset = 0; | |||||
| for (uint32_t index = 0; index < model_num; ++index) { | |||||
| // Init partition table | |||||
| auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_data + cur_offset); | |||||
| size_t partition_table_size = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table); | |||||
| cur_offset += partition_table_size; | |||||
| GELOGD("Cur model index %zu: ModelPartitionTable num :%u, " | |||||
| "ModelFileHeader length :%zu, ModelPartitionTable length :%zu", | |||||
| index, partition_table->num, sizeof(ModelFileHeader), partition_table_size); | |||||
| if (model_data_size <= cur_offset) { | |||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", | |||||
| partition_table->num, model_data_size); | |||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | |||||
| for (uint32_t i = 0; i < partition_table->num; i++) { | |||||
| ModelPartition partition; | |||||
| partition.size = partition_table->partition[i].mem_size; | |||||
| partition.data = model_data + cur_offset; | |||||
| partition.type = partition_table->partition[i].type; | |||||
| if (index >= model_contexts_.size()) { | |||||
| if (index != model_contexts_.size()) { | |||||
| GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", index); | |||||
| return FAILED; | |||||
| } | |||||
| OmFileContext tmp_ctx; | |||||
| tmp_ctx.partition_datas_.push_back(partition); | |||||
| model_contexts_.push_back(tmp_ctx); | |||||
| } else { | |||||
| model_contexts_[index].partition_datas_.push_back(partition); | |||||
| } | |||||
| if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) { | |||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", | |||||
| partition.size + cur_offset, model_data_size); | |||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | |||||
| cur_offset += partition.size; | |||||
| GELOGD("Partition, type:%d, size:%u, model_index:%zu", static_cast<int>(partition.type), partition.size, index); | |||||
| } | |||||
| } | |||||
| if (cur_offset != model_data_size) { | |||||
| GELOGE(FAILED, "do not get the complete model, read end offset:%zu, all size:%zu", cur_offset, model_data_size); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition> | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition> | ||||
| &OmFileSaveHelper::GetModelPartitions() const { | &OmFileSaveHelper::GetModelPartitions() const { | ||||
| return context_.partition_datas_; | return context_.partition_datas_; | ||||
| @@ -172,6 +270,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave | |||||
| return partition_table; | return partition_table; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable( | |||||
| size_t cur_ctx_index) { | |||||
| auto &cur_ctx = model_contexts_[cur_ctx_index]; | |||||
| auto partition_size = static_cast<uint32_t>(cur_ctx.partition_datas_.size()); | |||||
| // Build ModelPartitionTable, flex array | |||||
| cur_ctx.partition_table_.clear(); | |||||
| cur_ctx.partition_table_.resize(sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * partition_size, 0); | |||||
| auto partition_table = reinterpret_cast<ModelPartitionTable *>(cur_ctx.partition_table_.data()); | |||||
| partition_table->num = partition_size; | |||||
| uint32_t mem_offset = 0; | |||||
| for (uint32_t i = 0; i < partition_size; i++) { | |||||
| ModelPartition partition = cur_ctx.partition_datas_[i]; | |||||
| partition_table->partition[i] = {partition.type, mem_offset, partition.size}; | |||||
| mem_offset += partition.size; | |||||
| GELOGD("Partition, type:%d, size:%u", static_cast<int>(partition.type), partition.size); | |||||
| } | |||||
| return partition_table; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { | ||||
| if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { | if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { | ||||
| GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); | GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); | ||||
| @@ -182,6 +302,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPar | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OmFileSaveHelper::AddPartition(ModelPartition &partition, size_t cur_index) { | |||||
| if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { | |||||
| GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); | |||||
| return FAILED; | |||||
| } | |||||
| if (cur_index >= model_contexts_.size()) { | |||||
| if (cur_index != model_contexts_.size()) { | |||||
| GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", cur_index); | |||||
| return FAILED; | |||||
| } | |||||
| OmFileContext tmp_ctx; | |||||
| tmp_ctx.model_data_len_ += partition.size; | |||||
| tmp_ctx.partition_datas_.push_back(partition); | |||||
| model_contexts_.push_back(tmp_ctx); | |||||
| } else { | |||||
| model_contexts_[cur_index].model_data_len_ += partition.size; | |||||
| model_contexts_[cur_index].partition_datas_.push_back(partition); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, | Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, | ||||
| bool is_offline) { | bool is_offline) { | ||||
| (void)save_param.cert_file; | (void)save_param.cert_file; | ||||
| @@ -198,6 +339,10 @@ Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *outp | |||||
| Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) { | Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) { | ||||
| #if !defined(NONSUPPORT_SAVE_TO_FILE) | #if !defined(NONSUPPORT_SAVE_TO_FILE) | ||||
| if (context_.partition_datas_.empty()) { | |||||
| GE_CHK_BOOL_EXEC(!model_contexts_.empty(), return FAILED, "mode contexts empty"); | |||||
| context_ = model_contexts_.front(); | |||||
| } | |||||
| uint32_t model_data_len = context_.model_data_len_; | uint32_t model_data_len = context_.model_data_len_; | ||||
| if (model_data_len == 0) { | if (model_data_len == 0) { | ||||
| GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); | GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); | ||||
| @@ -231,4 +376,53 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat | |||||
| return SUCCESS; | return SUCCESS; | ||||
| #endif | #endif | ||||
| } | } | ||||
| Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, | |||||
| ModelBufferData &model, bool is_offline) { | |||||
| (void)save_param.cert_file; | |||||
| (void)save_param.ek_file; | |||||
| (void)save_param.encode_mode; | |||||
| (void)save_param.hw_key_file; | |||||
| (void)save_param.pri_key_file; | |||||
| #if !defined(NONSUPPORT_SAVE_TO_FILE) | |||||
| vector<ModelPartitionTable *> model_partition_tabels; | |||||
| vector<vector<ModelPartition>> all_model_partitions; | |||||
| for (size_t ctx_index = 0; ctx_index < model_contexts_.size(); ++ctx_index) { | |||||
| auto &cur_ctx = model_contexts_[ctx_index]; | |||||
| uint32_t cur_model_data_len = cur_ctx.model_data_len_; | |||||
| if (cur_model_data_len == 0) { | |||||
| GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); | |||||
| return domi::PARAM_INVALID; | |||||
| } | |||||
| auto tmp_table = GetPartitionTable(ctx_index); | |||||
| if (tmp_table == nullptr) { | |||||
| GELOGE(ge::GE_GRAPH_SAVE_FAILED, "SaveModelToFile execute failed: partition_table is NULL."); | |||||
| return ge::GE_GRAPH_SAVE_FAILED; | |||||
| } | |||||
| uint32_t size_of_table = SIZE_OF_MODEL_PARTITION_TABLE(*tmp_table); | |||||
| FMK_UINT32_ADDCHECK(size_of_table, cur_model_data_len) | |||||
| FMK_UINT32_ADDCHECK(size_of_table + cur_model_data_len, model_header_.length) | |||||
| model_header_.length += size_of_table + cur_model_data_len; | |||||
| model_partition_tabels.push_back(tmp_table); | |||||
| all_model_partitions.push_back(cur_ctx.partition_datas_); | |||||
| GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", | |||||
| size_of_table, cur_model_data_len, ctx_index); | |||||
| } | |||||
| Status ret; | |||||
| if (is_offline) { | |||||
| ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions); | |||||
| } else { | |||||
| GELOGW("do not support save ge root model to buff now"); | |||||
| return FAILED; | |||||
| } | |||||
| if (ret == SUCCESS) { | |||||
| GELOGD("Save model success without encrypt."); | |||||
| } | |||||
| return ret; | |||||
| #else | |||||
| return SUCCESS; | |||||
| #endif | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -357,7 +357,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCH | |||||
| const char *w_data = (const char *)input; | const char *w_data = (const char *)input; | ||||
| int64_t count = h * w * c * k; | int64_t count = h * w * c * k; | ||||
| GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return ); | |||||
| GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return); | |||||
| float *buf = new (std::nothrow) float[count](); | float *buf = new (std::nothrow) float[count](); | ||||
| GE_RT_VOID_CHECK_NOTNULL(buf); | GE_RT_VOID_CHECK_NOTNULL(buf); | ||||
| float *src_buff = nullptr; | float *src_buff = nullptr; | ||||
| @@ -0,0 +1,199 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/profiling/ge_profiling.h" | |||||
| #include "runtime/base.h" | |||||
| #include "common/profiling/profiling_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "graph/load/graph_loader.h" | |||||
| #include "init/gelib.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| namespace { | |||||
| const uint32_t kDeviceListIndex = 3; | |||||
| const std::string kDeviceNums = "devNums"; | |||||
| const std::string kDeviceIdList = "devIdList"; | |||||
| const std::string kProfilingInit = "prof_init"; | |||||
| const std::string kProfilingFinalize = "prof_finalize"; | |||||
| const std::string kProfilingStart = "prof_start"; | |||||
| const std::string kProfilingStop = "prof_stop"; | |||||
| const std::string kProfModelSubscribe = "prof_model_subscribe"; | |||||
| const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
| const std::string kRtSetDeviceRegName = "profiling"; | |||||
| const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = { | |||||
| {kProfCommandhandleInit, kProfilingInit}, | |||||
| {kProfCommandhandleStart, kProfilingStart}, | |||||
| {kProfCommandhandleStop, kProfilingStop}, | |||||
| {kProfCommandhandleFinalize, kProfilingFinalize}, | |||||
| {kProfCommandhandleModelSubscribe, kProfModelSubscribe}, | |||||
| {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; | |||||
| } // namespace | |||||
| bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) { | |||||
| prof_config_params.clear(); | |||||
| prof_config_params.emplace_back(kDeviceNums); | |||||
| prof_config_params.emplace_back(std::to_string(profCommand.devNums)); | |||||
| prof_config_params.emplace_back(kDeviceIdList); | |||||
| std::string devID = ""; | |||||
| if (profCommand.devNums == 0) { | |||||
| GELOGW("The device num is invalid."); | |||||
| return false; | |||||
| } | |||||
| for (uint32_t i = 0; i < profCommand.devNums; i++) { | |||||
| devID.append(std::to_string(profCommand.devIdList[i])); | |||||
| if (i != profCommand.devNums - 1) { | |||||
| devID.append(","); | |||||
| } | |||||
| } | |||||
| prof_config_params.push_back(devID); | |||||
| return true; | |||||
| } | |||||
| bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) { | |||||
| if (deviceid_list == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "deviceIdList is nullptr"); | |||||
| return false; | |||||
| } | |||||
| if (device_nums == 0 || device_nums > MAX_DEV_NUM) { | |||||
| GELOGE(ge::PARAM_INVALID, "The device nums: %u is invalid.", device_nums); | |||||
| return false; | |||||
| } | |||||
| // real device num | |||||
| int32_t dev_count = 0; | |||||
| rtError_t rt_err = rtGetDeviceCount(&dev_count); | |||||
| if (rt_err != RT_ERROR_NONE) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Get the Device count fail."); | |||||
| return false; | |||||
| } | |||||
| if (device_nums > static_cast<uint32_t>(dev_count)) { | |||||
| GELOGE(ge::PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count); | |||||
| return false; | |||||
| } | |||||
| std::unordered_set<uint32_t> record; | |||||
| for (size_t i = 0; i < device_nums; ++i) { | |||||
| uint32_t dev_id = deviceid_list[i]; | |||||
| if (dev_id >= static_cast<uint32_t>(dev_count)) { | |||||
| GELOGE(ge::PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count); | |||||
| return false; | |||||
| } | |||||
| if (record.count(dev_id) > 0) { | |||||
| GELOGE(ge::PARAM_INVALID, "Device id %u is duplicatedly set", dev_id); | |||||
| return false; | |||||
| } | |||||
| record.insert(dev_id); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| ge::Status RegProfCtrlCallback(MsprofCtrlCallback func) { | |||||
| if (func == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "Msprof ctrl callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofCtrlCallback != nullptr) { | |||||
| GELOGW("Msprof ctrl callback is exist, just ignore it."); | |||||
| } else { | |||||
| GELOGI("GE register Msprof ctrl callback."); | |||||
| ge::ProfilingManager::Instance().SetMsprofCtrlCallback(func); | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | |||||
| if (func == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofSetDeviceCallback callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| // Pass MsprofSetDeviceCallback to runtime | |||||
| GELOGI("GE pass setdevice callback to runtime."); | |||||
| ge::Status rt_ret = rtRegDeviceStateCallback(kRtSetDeviceRegName.c_str(), static_cast<rtDeviceStateCallback>(func)); | |||||
| if (rt_ret != ge::SUCCESS) { | |||||
| GELOGE(rt_ret, "Pass MsprofSetDeviceCallback to runtime failed!"); | |||||
| return rt_ret; | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| ge::Status RegProfReporterCallback(MsprofReporterCallback func) { | |||||
| if (func == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofReporterCallback != nullptr) { | |||||
| GELOGW("Msprof reporter callback is exist, just ignore it."); | |||||
| } else { | |||||
| GELOGI("GE register Msprof reporter callback."); | |||||
| ge::ProfilingManager::Instance().SetMsprofReporterCallback(func); | |||||
| // Pass MsprofReporterCallback to runtime | |||||
| ge::Status rt_ret = rtSetMsprofReporterCallback(func); | |||||
| if (rt_ret != ge::SUCCESS) { | |||||
| GELOGE(rt_ret, "Pass MsprofReporterCallback to runtime failed!!"); | |||||
| return rt_ret; | |||||
| } | |||||
| // Pass MsprofReporterCallback to hccl | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | |||||
| if (type != kProfCommandhandleFinalize) { | |||||
| GE_CHECK_NOTNULL(data); | |||||
| } | |||||
| ProfCommandHandleData *prof_config_param = (ProfCommandHandleData *)data; | |||||
| auto iter = kProfCommandTypeMap.find(type); | |||||
| if (iter == kProfCommandTypeMap.end()) { | |||||
| GELOGW("The prof comand type is invalid."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| std::vector<string> prof_params; | |||||
| if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { | |||||
| if (!isProfConfigValid(prof_config_param->devIdList, prof_config_param->devNums)) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| if (!TransProfConfigToParam(*prof_config_param, prof_params)) { | |||||
| GELOGE(ge::PARAM_INVALID, "Transfer profilerConfig to string vector failed"); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| ge::GraphLoader graph_loader; | |||||
| ge::Command command; | |||||
| command.cmd_params.clear(); | |||||
| command.cmd_type = iter->second; | |||||
| command.cmd_params = prof_params; | |||||
| if (type != kProfCommandhandleFinalize) { | |||||
| command.module_index = prof_config_param->profSwitch; | |||||
| } | |||||
| GELOGI("GE commandhandle execute, Command Type: %d, data type config: 0x%llx", type, command.module_index); | |||||
| if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { | |||||
| GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); | |||||
| } | |||||
| ge::Status ret = graph_loader.CommandHandle(command); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Handle profiling command failed"); | |||||
| return ge::FAILED; | |||||
| } | |||||
| GELOGI("Successfully execute profiling command type: %d, command 0x%llx.", type, command.module_index); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/profiling/ge_runner_profiling.h" | |||||
| #include "init/gelib.h" | |||||
| bool IsInitialize() { | |||||
| std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || instance_ptr->InitFlag() == false) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| @@ -24,16 +24,9 @@ | |||||
| #include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
| namespace { | namespace { | ||||
| const char *const kJobID = "jobID"; | |||||
| const char *const kDeviceID = "deviceID"; | |||||
| const char *const kStartCfg = "startCfg"; | |||||
| const char *const kFeatures = "features"; | |||||
| const char *const kConf = "conf"; | |||||
| const char *const kEvents = "events"; | |||||
| const char *const kAiCoreEvents = "ai_core_events"; | |||||
| const char *const kName = "name"; | |||||
| const char *const kTraceID = "traceId"; | |||||
| const char *const kProfDir = "resultPath"; | |||||
| const char *const kTrainingTrace = "training_trace"; | |||||
| const char *const kFpPoint = "fp_point"; | |||||
| const char *const kBpPoint = "bp_point"; | |||||
| const size_t kReportMaxLen = 2048; | const size_t kReportMaxLen = 2048; | ||||
| const int32_t kMaxDeviceNum = 256; | const int32_t kMaxDeviceNum = 256; | ||||
| const std::string kConfigNumsdev = "devNums"; | const std::string kConfigNumsdev = "devNums"; | ||||
| @@ -45,7 +38,13 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| ProfilingManager::ProfilingManager() : subscribe_count_(0) {} | |||||
| ProfilingManager::ProfilingManager() : is_load_profiling_(false), | |||||
| is_execute_profiling_(false), | |||||
| is_training_trace_(false), | |||||
| subscribe_count_(0) { | |||||
| prof_cb_.msprofCtrlCallback = nullptr; | |||||
| prof_cb_.msprofReporterCallback = nullptr; | |||||
| } | |||||
| ProfilingManager::~ProfilingManager() {} | ProfilingManager::~ProfilingManager() {} | ||||
| @@ -58,44 +57,29 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| vector<int32_t>().swap(device_id_); | vector<int32_t>().swap(device_id_); | ||||
| subscribe_count_ = 0; | subscribe_count_ = 0; | ||||
| job_id_ = options.job_id; | |||||
| GELOGI("ProfilingManager::Init job_id:%s", job_id_.c_str()); | |||||
| GELOGI("ProfilingManager::Init job_id:%s", options.job_id.c_str()); | |||||
| Status ret; | |||||
| if (!recv_profiling_config_.empty()) { | |||||
| GELOGI("Profiling json config from acl:%s", recv_profiling_config_.c_str()); | |||||
| ret = InitFromAclCfg(recv_profiling_config_); | |||||
| } else { | |||||
| ret = InitFromOptions(options); | |||||
| if (ret == SUCCESS && is_load_profiling_) { | |||||
| device_id_.push_back(options.device_id); | |||||
| } | |||||
| } | |||||
| struct MsprofGeOptions prof_conf = {{ 0 }}; | |||||
| Status ret = InitFromOptions(options, prof_conf); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to init profiling."); | GELOGE(ret, "Failed to init profiling."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (is_load_profiling_) { | |||||
| // register Framework to profiling | |||||
| int result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); | |||||
| if (result != 0) { | |||||
| GELOGE(FAILED, "Register profiling engine failed."); | |||||
| return FAILED; | |||||
| if (is_execute_profiling_) { | |||||
| if (prof_cb_.msprofCtrlCallback == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | } | ||||
| // profiling startup first time | |||||
| GELOGI("Begin to init profiling, device num %zu", device_id_.size()); | |||||
| for (size_t i = 0; i < device_id_.size(); ++i) { | |||||
| ret = StartProfiling(0, device_id_[i]); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Profiling start failed on device %d.", device_id_[i]); | |||||
| continue; | |||||
| } | |||||
| GELOGI("Profiling init succ on device %d.", device_id_[i]); | |||||
| int32_t cb_ret = prof_cb_.msprofCtrlCallback( | |||||
| static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS), | |||||
| static_cast<void *>(&prof_conf), sizeof(MsprofGeOptions)); | |||||
| if (cb_ret != 0) { | |||||
| GELOGE(FAILED, "Call msprofCtrlCallback failed, type:%u, return:%d", | |||||
| static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS), cb_ret); | |||||
| return FAILED; | |||||
| } | } | ||||
| GELOGI("Profiling init success"); | |||||
| } else { | } else { | ||||
| GELOGI("The profiling is off, skip the initialization"); | GELOGI("The profiling is off, skip the initialization"); | ||||
| } | } | ||||
| @@ -103,288 +87,116 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromAclCfg( | |||||
| const std::string &config) { | |||||
| ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOptions &prof_conf) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| try { | |||||
| is_load_profiling_ = false; | |||||
| is_execute_profiling_ = false; | |||||
| profiling_opts_.clear(); | |||||
| op_trace_conf_.clear(); | |||||
| Json start_prof_conf = Json::parse(config); | |||||
| Json &prof_conf = start_prof_conf[kStartCfg][0]; | |||||
| job_id_ = prof_conf[kJobID]; | |||||
| auto iter = prof_conf.find(kProfDir); | |||||
| if (iter != prof_conf.end()) { | |||||
| prof_dir_ = prof_conf[kProfDir]; | |||||
| } | |||||
| Json &device_id = prof_conf[kDeviceID]; | |||||
| if (device_id.size() != 0) { | |||||
| vector<int32_t>().swap(device_id_); | |||||
| bool is_all = false; | |||||
| for (size_t i = 0; i < device_id.size(); i++) { | |||||
| std::string device_id_str = device_id[i].get<std::string>(); | |||||
| if (device_id_str == "all") { | |||||
| is_all = true; | |||||
| break; | |||||
| } | |||||
| device_id_.push_back(std::stoi(device_id_str)); | |||||
| } | |||||
| if (is_all) { | |||||
| int32_t count = 0; | |||||
| rtError_t rt_err = rtGetDeviceCount(&count); | |||||
| if (rt_err != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "Call rtGetDeviceCount to get device failed."); | |||||
| } | |||||
| vector<int32_t>().swap(device_id_); | |||||
| for (int32_t i = 0; i < count; ++i) { | |||||
| device_id_.push_back(i); | |||||
| } | |||||
| } | |||||
| // enable profiling by env | |||||
| char env_profiling_mode[MMPA_MAX_PATH] = { 0x00 }; | |||||
| is_load_profiling_ = false; // Change in ProfInit | |||||
| is_execute_profiling_ = false; | |||||
| if (options.profiling_mode == "1" && !options.profiling_options.empty()) { | |||||
| // enable profiling by ge option | |||||
| if (memcpy_s(prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX, options.profiling_options.c_str(), | |||||
| options.profiling_options.size()) != EOK) { | |||||
| GELOGE(INTERNAL_ERROR, "copy profiling_options failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | } | ||||
| Json &features = prof_conf[kFeatures]; | |||||
| if (ParseFeaturesFromAclCfg(features) != SUCCESS) { | |||||
| GELOGE(FAILED, "Parse feature from acl cfg failed."); | |||||
| return FAILED; | |||||
| is_execute_profiling_ = true; | |||||
| GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), | |||||
| prof_conf.options, options.profiling_options.c_str()); | |||||
| } else { | |||||
| (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); | |||||
| (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); | |||||
| // The env is invalid | |||||
| if ((strcmp("true", env_profiling_mode) != 0) || (strcmp(prof_conf.options, "\0") == 0)) { | |||||
| return SUCCESS; | |||||
| } | } | ||||
| is_load_profiling_ = true; | |||||
| // enable profiling by env | |||||
| is_execute_profiling_ = true; | is_execute_profiling_ = true; | ||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Json conf is not invalid !"); | |||||
| GELOGI("The profiling in env is %s, %s", env_profiling_mode, prof_conf.options); | |||||
| } | |||||
| if (!is_execute_profiling_) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // Parse json str for bp fp | |||||
| Status ret = ParseOptions(prof_conf.options); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ge::PARAM_INVALID, "Parse training trace param failed."); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| if (memcpy_s(prof_conf.jobId, sizeof(prof_conf.jobId), options.job_id.c_str(), | |||||
| sizeof(options.job_id.c_str())) != EOK) { | |||||
| GELOGE(INTERNAL_ERROR, "copy job_id failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| #endif | #endif | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::ParseFeaturesFromAclCfg( | |||||
| const Json &features) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | |||||
| ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
| if (options.empty()) { | |||||
| GELOGE(ge::PARAM_INVALID, "Profiling options is empty."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| try { | try { | ||||
| for (size_t i = 0; i < features.size(); ++i) { | |||||
| const Json &feature = features[i]; | |||||
| if ((feature.find(kName) == feature.end()) || feature[kName].is_null()) { | |||||
| continue; | |||||
| } | |||||
| const std::string &name = feature[kName]; | |||||
| if (name == "op_trace") { | |||||
| const Json &conf = feature[kConf]; | |||||
| const Json &events = conf[0][kEvents]; | |||||
| const std::string &ai_core_events = events[0][kAiCoreEvents]; | |||||
| GELOGI("Op trace config from acl ai_core_events:%s", ai_core_events.c_str()); | |||||
| is_op_trace_ = true; | |||||
| ProfMgrConf prof_mgr_conf; | |||||
| int result = ProfMgrGetConf(ai_core_events, &prof_mgr_conf); | |||||
| if (result != 0) { | |||||
| GELOGE(FAILED, "ProfMgrGetConf failed."); | |||||
| return FAILED; | |||||
| } | |||||
| op_trace_conf_ = prof_mgr_conf.conf; | |||||
| op_trace_iter_num_ = static_cast<int32_t>(op_trace_conf_.size()); | |||||
| GELOGI("Op trace profiling iter num %d,", op_trace_iter_num_); | |||||
| } else if (name == "task_trace") { | |||||
| is_op_trace_ = false; | |||||
| if (feature.find(kConf) != feature.end()) { | |||||
| const Json &conf = feature[kConf]; | |||||
| std::stringstream task_trace_conf; | |||||
| task_trace_conf << conf; | |||||
| task_trace_conf_ = task_trace_conf.str(); | |||||
| } | |||||
| GELOGI("Task trace config from acl"); | |||||
| } else if (name == "system_trace") { | |||||
| is_op_trace_ = false; | |||||
| const Json &conf = feature[kConf]; | |||||
| std::stringstream system_trace_conf; | |||||
| system_trace_conf << conf; | |||||
| system_trace_conf_ = system_trace_conf.str(); | |||||
| GELOGI("System trace config from acl"); | |||||
| } | |||||
| profiling_opts_.push_back(name); | |||||
| Json prof_options = Json::parse(options); | |||||
| const std::string training_trace = prof_options[kTrainingTrace]; | |||||
| if (training_trace.empty()) { | |||||
| GELOGI("Training trace will not take effect."); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| GELOGI("GE profiling training trace:%s", training_trace.c_str()); | |||||
| if (training_trace != "on") { | |||||
| GELOGE(ge::PARAM_INVALID, "Training trace param:%s is invalid.", training_trace.c_str()); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| if (!fp_point_.empty() && !bp_point_.empty()) { | |||||
| GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | |||||
| } | } | ||||
| } catch (...) { | } catch (...) { | ||||
| GELOGE(ge::PARAM_INVALID, "Json conf feature is not invalid !"); | |||||
| GELOGE(FAILED, "Json prof_conf options is invalid."); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| #endif | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromOptions(const Options &options) { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| // enable profiling support two ways: env and front end | |||||
| char profiling_mode_temp[MMPA_MAX_PATH] = { 0x00 }; | |||||
| char prof_options_temp[MMPA_MAX_PATH] = { 0x00 }; | |||||
| (void)mmGetEnv("PROFILING_MODE", profiling_mode_temp, MMPA_MAX_PATH); | |||||
| (void)mmGetEnv("PROFILING_OPTIONS", prof_options_temp, MMPA_MAX_PATH ); | |||||
| const char *profiling_mode = profiling_mode_temp; | |||||
| const char *prof_options = prof_options_temp; | |||||
| if ((profiling_mode == nullptr) || (strcmp("true", profiling_mode) != 0) || (prof_options == nullptr)) { | |||||
| is_load_profiling_ = false; | |||||
| is_execute_profiling_ = false; | |||||
| } else { | |||||
| std::string prof_options_str = std::string(prof_options); | |||||
| profiling_opts_ = StringUtils::Split(prof_options_str, ':'); | |||||
| is_load_profiling_ = true; | |||||
| is_execute_profiling_ = true; | |||||
| GELOGI("The profiling in env is %s, %s", profiling_mode, prof_options); | |||||
| } | |||||
| if (!is_load_profiling_) { | |||||
| const std::string enable_profiling = "1"; | |||||
| if (options.profiling_mode != enable_profiling || options.profiling_options.empty()) { | |||||
| is_load_profiling_ = false; | |||||
| is_execute_profiling_ = false; | |||||
| return SUCCESS; | |||||
| } else { | |||||
| profiling_opts_ = StringUtils::Split(options.profiling_options, ':'); | |||||
| is_load_profiling_ = true; | |||||
| is_execute_profiling_ = true; | |||||
| GELOGI("The profiling in options is %s, %s", options.profiling_mode.c_str(), options.profiling_options.c_str()); | |||||
| } | |||||
| } | |||||
| // features:'training_trace', 'task_trace' or 'op_trace' etc | |||||
| if (!profiling_opts_.empty()) { | |||||
| if (profiling_opts_[0] == "op_trace") { | |||||
| is_op_trace_ = true; | |||||
| // op trace get conf | |||||
| ProfMgrConf prof_mgr_conf; | |||||
| int result = ProfMgrGetConf("", &prof_mgr_conf); | |||||
| if (result != 0) { | |||||
| GELOGE(FAILED, "ProfMgrGetConf failed."); | |||||
| return FAILED; | |||||
| } | |||||
| op_trace_conf_ = prof_mgr_conf.conf; | |||||
| op_trace_iter_num_ = static_cast<int32_t>(op_trace_conf_.size()); | |||||
| GELOGI("op trace profiling iter num %d,", op_trace_iter_num_); | |||||
| } else { | |||||
| is_op_trace_ = false; | |||||
| op_trace_iter_num_ = 1; | |||||
| uint64_t module = GetProfilingModule(); | |||||
| // The following if case will not be executed in normal case, inc case of ProfStopProfiling is abnormal | |||||
| int32_t device_num = static_cast<int32_t>(device_id_.size()); | |||||
| if (device_num != 0) { | |||||
| auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]); | |||||
| if (device_id_ptr == nullptr) { | |||||
| GELOGE(FAILED, "Stop profiling: device id ptr is null."); | |||||
| return; | |||||
| } | } | ||||
| } | |||||
| #endif | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::StartProfiling(int32_t iter_num, | |||||
| int32_t device_id) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | |||||
| if (!profiling_opts_.empty()) { | |||||
| GELOGI("Start profiling index is %d", iter_num); | |||||
| // current one docker only use one device | |||||
| Json p_device; | |||||
| try { | |||||
| // profiling need physical_device_id | |||||
| p_device[kDeviceID] = std::to_string(device_id); | |||||
| p_device[kJobID] = job_id_; | |||||
| p_device[kTraceID] = std::to_string(GetContext().TraceId()); | |||||
| if (!prof_dir_.empty()) { | |||||
| p_device[kProfDir] = prof_dir_; | |||||
| GELOGI("Prof dir: %s.", prof_dir_.c_str()); | |||||
| } | |||||
| Json features; | |||||
| if (is_op_trace_) { | |||||
| Json f; | |||||
| f[kName] = "op_trace"; | |||||
| Json conf; | |||||
| if (op_trace_conf_.size() <= static_cast<size_t>(iter_num)) { | |||||
| GELOGE(FAILED, "Op trace iter num is invalid!"); | |||||
| return FAILED; | |||||
| } | |||||
| Json events; | |||||
| events[0] = nlohmann::json::parse(op_trace_conf_[iter_num]); | |||||
| conf[0][kEvents] = events; | |||||
| f[kConf] = conf; | |||||
| features[0] = f; | |||||
| if (iter_num == 0) { | |||||
| is_load_ = true; | |||||
| } | |||||
| } else { | |||||
| for (std::vector<std::string>::size_type i = 0; i < profiling_opts_.size(); i++) { | |||||
| Json f; | |||||
| if (profiling_opts_[i] == "system_trace") { | |||||
| f[kConf] = nlohmann::json::parse(system_trace_conf_); | |||||
| } else if (profiling_opts_[i] == "task_trace") { | |||||
| if (!task_trace_conf_.empty()) { | |||||
| f[kConf] = nlohmann::json::parse(task_trace_conf_); | |||||
| } | |||||
| } | |||||
| f[kName] = profiling_opts_[i]; | |||||
| features[i] = f; | |||||
| } | |||||
| is_load_ = true; | |||||
| } | |||||
| p_device[kFeatures] = features; | |||||
| // only one device, but sProfMgrStartUp API require for device list | |||||
| Json devices; | |||||
| devices[0] = p_device; | |||||
| Json start_cfg; | |||||
| start_cfg[kStartCfg] = devices; | |||||
| // convert json to string | |||||
| std::stringstream ss; | |||||
| ss << start_cfg; | |||||
| send_profiling_config_ = ss.str(); | |||||
| GELOGI("Profiling config %s\n", send_profiling_config_.c_str()); | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Op trace json conf is not invalid !"); | |||||
| return FAILED; | |||||
| for (int32_t i = 0; i < device_num; i++) { | |||||
| device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]); | |||||
| } | } | ||||
| // runtime startup for profiling | |||||
| uint64_t module = GetProfilingModule(); | |||||
| int32_t device_num = 1; | |||||
| uint32_t device_id_rt = static_cast<uint32_t>(device_id); | |||||
| GE_CHK_RT_RET(rtProfilerStart(module, device_num, &device_id_rt)); | |||||
| // call profiling startup API | |||||
| ProfMgrCfg prof_cfg = {send_profiling_config_}; | |||||
| void *prof_handle = ProfMgrStartUp(&prof_cfg); | |||||
| if (prof_handle == nullptr) { | |||||
| GELOGW("ProfMgrStartUp failed on device %d ", device_id); | |||||
| return FAILED; | |||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); | |||||
| } | } | ||||
| GELOGD("StartProfiling, prof_handle: %p", prof_handle); | |||||
| prof_handle_vec_.push_back(prof_handle); | |||||
| } | } | ||||
| #endif | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | |||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| if (reporter != nullptr) { | |||||
| int ret = reporter->Flush(); | |||||
| GELOGI("Report data end, ret is %d", ret); | |||||
| // stop profiling | |||||
| if (prof_cb_.msprofCtrlCallback == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr."); | |||||
| return; | |||||
| } | } | ||||
| uint64_t module = GetProfilingModule(); | |||||
| int32_t device_num = static_cast<int32_t>(device_id_.size()); | |||||
| auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]); | |||||
| if (device_id_ptr == nullptr) { | |||||
| GELOGE(FAILED, "Stop profiling: device id ptr is null."); | |||||
| int32_t cb_ret = prof_cb_.msprofCtrlCallback(static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE), | |||||
| nullptr, 0); | |||||
| if (cb_ret != 0) { | |||||
| GELOGW("call msprofCtrlCallback failed, type:%u, return:%d", | |||||
| static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE), cb_ret); | |||||
| return; | return; | ||||
| } | } | ||||
| for (int32_t i = 0; i < device_num; i++) { | |||||
| device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]); | |||||
| } | |||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); | |||||
| } | |||||
| for (size_t i = 0; i < prof_handle_vec_.size(); ++i) { | |||||
| int result = ProfMgrStop(prof_handle_vec_[i]); | |||||
| if (result != 0) { | |||||
| GELOGW("ProfMgr stop return fail:%d, handle:%p", result, prof_handle_vec_[i]); | |||||
| } | |||||
| } | |||||
| vector<void *>().swap(prof_handle_vec_); | |||||
| is_load_ = false; | |||||
| recv_profiling_config_ = ""; | |||||
| GELOGI("Stop Profiling success."); | GELOGI("Stop Profiling success."); | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -392,12 +204,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | ||||
| uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { | uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { | ||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| if (reporter == nullptr) { | |||||
| GELOGI("Profiling report is nullptr!"); | |||||
| return; | |||||
| } | |||||
| std::string data; | std::string data; | ||||
| for (const auto &task : task_desc_info) { | for (const auto &task : task_desc_info) { | ||||
| std::string model_name = task.model_name; | std::string model_name = task.model_name; | ||||
| @@ -412,7 +218,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
| .append(std::to_string(stream_id)).append(" ") | .append(std::to_string(stream_id)).append(" ") | ||||
| .append(std::to_string(model_id)).append("\n")); | .append(std::to_string(model_id)).append("\n")); | ||||
| Msprof::Engine::ReporterData reporter_data{}; | |||||
| ReporterData reporter_data{}; | |||||
| reporter_data.deviceId = device_id; | reporter_data.deviceId = device_id; | ||||
| reporter_data.data = (unsigned char *)data.c_str(); | reporter_data.data = (unsigned char *)data.c_str(); | ||||
| reporter_data.dataLen = data.size(); | reporter_data.dataLen = data.size(); | ||||
| @@ -422,9 +228,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
| return; | return; | ||||
| } | } | ||||
| ret = reporter->Report(&reporter_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Reporter data of task_desc_info fail!"); | |||||
| int32_t cb_ret = CallMsprofReport(reporter_data); | |||||
| if (cb_ret != 0) { | |||||
| GELOGE(cb_ret, "Reporter data of task_desc_info failed, ret:%d", cb_ret); | |||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| @@ -436,9 +242,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | ||||
| uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { | uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { | ||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); | |||||
| std::string data; | std::string data; | ||||
| for (const auto &graph : compute_graph_desc_info) { | for (const auto &graph : compute_graph_desc_info) { | ||||
| data.append("model_name:") | data.append("model_name:") | ||||
| @@ -493,64 +296,52 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
| } | } | ||||
| data.append(" model_id:").append(std::to_string(model_id)); | data.append(" model_id:").append(std::to_string(model_id)); | ||||
| data.append("\n"); | data.append("\n"); | ||||
| Msprof::Engine::ReporterData reporter_data{}; | |||||
| Report(device_id, data, *reporter, reporter_data); | |||||
| GraphDescReport(device_id, data); | |||||
| data.clear(); | data.clear(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( | |||||
| const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | |||||
| Msprof::Engine::ReporterData &reporter_data) { | |||||
| void ProfilingManager::GraphDescReport(const int32_t &device_id, const string &data) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| ReporterData reporter_data{}; | |||||
| int ret = -1; | |||||
| int32_t cb_ret = -1; | |||||
| size_t index = data.size() / kReportMaxLen; | size_t index = data.size() / kReportMaxLen; | ||||
| if (index >= 1) { | if (index >= 1) { | ||||
| reporter_data.deviceId = device_id; | reporter_data.deviceId = device_id; | ||||
| int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | |||||
| ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | |||||
| GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | ||||
| for (size_t i = 0; i < index; ++i) { | for (size_t i = 0; i < index; ++i) { | ||||
| reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * i; | reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * i; | ||||
| reporter_data.dataLen = kReportMaxLen; | reporter_data.dataLen = kReportMaxLen; | ||||
| ret = reporter.Report(&reporter_data); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); | |||||
| cb_ret = CallMsprofReport(reporter_data); | |||||
| GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); | |||||
| } | } | ||||
| reporter_data.dataLen = data.size() - kReportMaxLen * index; | reporter_data.dataLen = data.size() - kReportMaxLen * index; | ||||
| if (reporter_data.dataLen != 0) { | if (reporter_data.dataLen != 0) { | ||||
| reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * index; | reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * index; | ||||
| ret = reporter.Report(&reporter_data); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); | |||||
| cb_ret = CallMsprofReport(reporter_data); | |||||
| GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); | |||||
| } | } | ||||
| } else { | } else { | ||||
| reporter_data.deviceId = device_id; | reporter_data.deviceId = device_id; | ||||
| reporter_data.data = (unsigned char *)data.c_str(); | reporter_data.data = (unsigned char *)data.c_str(); | ||||
| reporter_data.dataLen = data.size(); | reporter_data.dataLen = data.size(); | ||||
| int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | |||||
| ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | |||||
| GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | ||||
| ret = reporter.Report(&reporter_data); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit(const std::string &module) const { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | |||||
| int ret = Msprof::Engine::UnInit(module); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "profiling plugin uninit failed, ret:%d", ret); | |||||
| cb_ret = CallMsprofReport(reporter_data); | |||||
| GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | ||||
| uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | ||||
| const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | |||||
| bool check_device) { | |||||
| const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| int32_t logic_device_id = 0; | int32_t logic_device_id = 0; | ||||
| rtError_t rt_ret = rtGetDevice(&logic_device_id); | rtError_t rt_ret = rtGetDevice(&logic_device_id); | ||||
| @@ -559,13 +350,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr | |||||
| return; | return; | ||||
| } | } | ||||
| GELOGD("current logic_device_id:%d", logic_device_id); | GELOGD("current logic_device_id:%d", logic_device_id); | ||||
| if (check_device) { | |||||
| auto ret = std::find(device_id_.begin(), device_id_.end(), logic_device_id); | |||||
| if (ret == device_id_.end()) { | |||||
| GELOGE(FAILED, "get valid phy_device_id failed, profiling report failed."); | |||||
| return; | |||||
| } | |||||
| } | |||||
| GELOGD("start ProfilingTaskDescInfo."); | GELOGD("start ProfilingTaskDescInfo."); | ||||
| ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id); | ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id); | ||||
| GELOGD("start ProfilingGraphDescInfo."); | GELOGD("start ProfilingGraphDescInfo."); | ||||
| @@ -574,11 +358,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr | |||||
| #endif | #endif | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::SetProfilingConfig( | |||||
| const std::string &profiling_cfg) { | |||||
| recv_profiling_config_ = profiling_cfg; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { | ||||
| uint64_t module = PROF_MODEL_EXECUTE_MASK | | uint64_t module = PROF_MODEL_EXECUTE_MASK | | ||||
| PROF_RUNTIME_API_MASK | | PROF_RUNTIME_API_MASK | | ||||
| @@ -594,9 +373,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetP | |||||
| return module; | return module; | ||||
| } | } | ||||
| void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, | |||||
| uint32_t device_id, | |||||
| uint64_t module) { | |||||
| void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module) { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
| if (prof_type == kProfModelSubscribe) { | if (prof_type == kProfModelSubscribe) { | ||||
| if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { | if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { | ||||
| @@ -608,9 +385,13 @@ void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, | |||||
| subs_dev_module_[device_id] = dev_info; | subs_dev_module_[device_id] = dev_info; | ||||
| } | } | ||||
| } else if (prof_type == kProfModelUnsubscribe) { | } else if (prof_type == kProfModelUnsubscribe) { | ||||
| if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { | |||||
| if (subs_dev_module_[device_id].subscribe_count > 0) { | |||||
| subs_dev_module_[device_id].subscribe_count--; | |||||
| auto iter = subs_dev_module_.find(device_id); | |||||
| if (iter != subs_dev_module_.end()) { | |||||
| if (iter->second.subscribe_count > 0) { | |||||
| iter->second.subscribe_count--; | |||||
| } | |||||
| if (iter->second.subscribe_count == 0) { | |||||
| subs_dev_module_.erase(iter); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -626,10 +407,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo | |||||
| uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; | uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; | ||||
| if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) { | if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) { | ||||
| // register framework to profiling | // register framework to profiling | ||||
| int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); | |||||
| if (result != SUCCESS) { | |||||
| GELOGE(FAILED, "Register profiling engine failed."); | |||||
| return FAILED; | |||||
| // register Framework to profiling | |||||
| int32_t cb_ret = PluginInit(); | |||||
| if (cb_ret != 0) { | |||||
| GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret); | |||||
| return cb_ret; | |||||
| } | } | ||||
| GELOGI("Prof subscribe: model load profiling on."); | GELOGI("Prof subscribe: model load profiling on."); | ||||
| } | } | ||||
| @@ -647,7 +429,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo | |||||
| UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module); | UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module); | ||||
| // Report profiling data | // Report profiling data | ||||
| Status p_ret = davinci_model->ReportProfilingData(false); | |||||
| Status p_ret = davinci_model->ReportProfilingData(); | |||||
| if (p_ret != SUCCESS) { | if (p_ret != SUCCESS) { | ||||
| GELOGE(p_ret, "Report profiling data failed."); | GELOGE(p_ret, "Report profiling data failed."); | ||||
| return p_ret; | return p_ret; | ||||
| @@ -672,6 +454,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo | |||||
| auto iter = subs_dev_module_.find(device[0]); | auto iter = subs_dev_module_.find(device[0]); | ||||
| if (iter != subs_dev_module_.end()) { | if (iter != subs_dev_module_.end()) { | ||||
| if (subs_dev_module_[device[0]].subscribe_count == 1) { | if (subs_dev_module_[device[0]].subscribe_count == 1) { | ||||
| // The same device_id, only stop at last time | |||||
| rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device); | rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(FAILED, "Runtime profiler stop failed."); | GELOGE(FAILED, "Runtime profiler stop failed."); | ||||
| @@ -679,15 +462,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo | |||||
| } | } | ||||
| } | } | ||||
| UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module); | UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module); | ||||
| } else { | |||||
| GELOGE(FAILED, "The device_id:%u has not been subscribed, do not need to cancel.", device[0]); | |||||
| return FAILED; | |||||
| } | } | ||||
| subscribe_count_--; | subscribe_count_--; | ||||
| if (subscribe_count_ == 0) { | if (subscribe_count_ == 0) { | ||||
| int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret); | |||||
| return ret; | |||||
| } | |||||
| // profiling plugin uninit at last subscription | |||||
| PluginUnInit(); | |||||
| } | } | ||||
| #endif | #endif | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -700,11 +483,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn | |||||
| if (model_load_mask == PROF_MODEL_LOAD_MASK) { | if (model_load_mask == PROF_MODEL_LOAD_MASK) { | ||||
| // register Framework to profiling | // register Framework to profiling | ||||
| int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); | |||||
| if (result != SUCCESS) { | |||||
| GELOGE(FAILED, "Register profiling engine failed."); | |||||
| return FAILED; | |||||
| int32_t cb_ret = PluginInit(); | |||||
| if (cb_ret != 0) { | |||||
| GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret); | |||||
| return cb_ret; | |||||
| } | } | ||||
| int32_t device_num = -1; | int32_t device_num = -1; | ||||
| rtError_t rt_ret = rtProfilerStart(model_load_mask, device_num, nullptr); | rtError_t rt_ret = rtProfilerStart(model_load_mask, device_num, nullptr); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -719,7 +503,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn | |||||
| if (training_trace_mask == PROF_TRAINING_TRACE_MASK) { | if (training_trace_mask == PROF_TRAINING_TRACE_MASK) { | ||||
| is_training_trace_ = true; | is_training_trace_ = true; | ||||
| } | } | ||||
| is_acl_api_mode_ = true; | |||||
| GELOGI("Prof init success."); | GELOGI("Prof init success."); | ||||
| #endif | #endif | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -730,19 +513,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi | |||||
| std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
| is_load_profiling_ = false; | is_load_profiling_ = false; | ||||
| is_training_trace_ = false; | is_training_trace_ = false; | ||||
| is_acl_api_mode_ = false; | |||||
| is_execute_profiling_ = false; | |||||
| // profiling plugin uninit | |||||
| PluginUnInit(); | |||||
| int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret); | |||||
| } | |||||
| int32_t dev_num = -1; | int32_t dev_num = -1; | ||||
| rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr); | rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(FAILED, "Runtime profiler stop failed."); | GELOGE(FAILED, "Runtime profiler stop failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (auto device_id_module : device_id_module_map_) { | for (auto device_id_module : device_id_module_map_) { | ||||
| if (device_id_module.second != 0) { | if (device_id_module.second != 0) { | ||||
| uint32_t device_id = static_cast<uint32_t>(device_id_module.first); | uint32_t device_id = static_cast<uint32_t>(device_id_module.first); | ||||
| @@ -792,6 +573,7 @@ Status ProfilingManager::ProfParseDeviceId(const std::map<std::string, std::stri | |||||
| return FAILED; | return FAILED; | ||||
| } catch (std::out_of_range &) { | } catch (std::out_of_range &) { | ||||
| GELOGE(FAILED, "Device id: %s is out of range.", decvice_id[i].c_str()); | GELOGE(FAILED, "Device id: %s is out of range.", decvice_id[i].c_str()); | ||||
| return FAILED; | |||||
| } catch (...) { | } catch (...) { | ||||
| GELOGE(FAILED, "Device id: %s cannot change to int.", decvice_id[i].c_str()); | GELOGE(FAILED, "Device id: %s cannot change to int.", decvice_id[i].c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -818,6 +600,7 @@ Status ProfilingManager::ProfParseParam(const std::map<std::string, std::string> | |||||
| return FAILED; | return FAILED; | ||||
| } catch (std::out_of_range &) { | } catch (std::out_of_range &) { | ||||
| GELOGE(FAILED, "Device num: %s is out of range.", iter->second.c_str()); | GELOGE(FAILED, "Device num: %s is out of range.", iter->second.c_str()); | ||||
| return FAILED; | |||||
| } catch (...) { | } catch (...) { | ||||
| GELOGE(FAILED, "Device num: %s cannot change to int.", iter->second.c_str()); | GELOGE(FAILED, "Device num: %s cannot change to int.", iter->second.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -859,7 +642,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt | |||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | ||||
| } | } | ||||
| GELOGD("Runtime config param: 0x%llx, device num: %d.", module, device_num); | |||||
| GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); | |||||
| rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); | rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -878,7 +661,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt | |||||
| GELOGW("Prof start: load model module is invalid."); | GELOGW("Prof start: load model module is invalid."); | ||||
| } | } | ||||
| UpdateDeviceIdModuleMap(kProfStart, module, device_list); | UpdateDeviceIdModuleMap(kProfStart, module, device_list); | ||||
| GELOGD("Prof start profiling success."); | |||||
| GELOGI("Prof start profiling success."); | |||||
| #endif | #endif | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -901,7 +684,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt | |||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | ||||
| } | } | ||||
| GELOGD("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); | |||||
| GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); | |||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); | GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); | ||||
| @@ -921,7 +704,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt | |||||
| GELOGW("Prof stop: load model module is invalid."); | GELOGW("Prof stop: load model module is invalid."); | ||||
| } | } | ||||
| UpdateDeviceIdModuleMap(kProfStop, module, device_list); | UpdateDeviceIdModuleMap(kProfStop, module, device_list); | ||||
| GELOGD("Prof stop profiling success."); | |||||
| GELOGI("Prof stop profiling success."); | |||||
| #endif | #endif | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -963,47 +746,90 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::Profilin | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(rt_ret, "Runtime get logic_device_id failed, current logic_device_id:%d", logic_device_id); | GELOGE(rt_ret, "Runtime get logic_device_id failed, current logic_device_id:%d", logic_device_id); | ||||
| } | } | ||||
| GELOGD("Current logic_device_id:%d", logic_device_id); | |||||
| GELOGI("Current logic_device_id:%d", logic_device_id); | |||||
| bool execute_model_prof_on = false; | bool execute_model_prof_on = false; | ||||
| auto iter = std::find(device_id_.begin(), device_id_.end(), logic_device_id); | auto iter = std::find(device_id_.begin(), device_id_.end(), logic_device_id); | ||||
| if (iter != device_id_.end()) { | if (iter != device_id_.end()) { | ||||
| execute_model_prof_on = true; | execute_model_prof_on = true; | ||||
| } | } | ||||
| GELOGD("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on); | |||||
| return is_execute_profiling_ || execute_model_prof_on; | |||||
| GELOGI("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on); | |||||
| return execute_model_prof_on; | |||||
| } | } | ||||
| /** | |||||
| * @brief Profiling PluginImpl | |||||
| */ | |||||
| // PluginImpl static variable init | |||||
| Msprof::Engine::Reporter *PluginImpl::reporter_ = nullptr; | |||||
| PluginImpl::PluginImpl(const std::string &module) : module_(module) { GELOGI("Create PluginImpl\n"); } | |||||
| int PluginImpl::Init(const Msprof::Engine::Reporter *reporter) { | |||||
| GELOGI("PluginImpl init"); | |||||
| reporter_ = const_cast<Msprof::Engine::Reporter *>(reporter); | |||||
| return 0; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::PluginInit() const { | |||||
| if (prof_cb_.msprofReporterCallback == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| return prof_cb_.msprofReporterCallback( | |||||
| static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), | |||||
| static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_INIT), | |||||
| nullptr, 0); | |||||
| } | } | ||||
| int PluginImpl::UnInit() { | |||||
| GELOGI("PluginImpl Uninit"); | |||||
| reporter_ = nullptr; | |||||
| return 0; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit() const { | |||||
| #ifdef DAVINCI_SUPPORT_PROFILING | |||||
| if (prof_cb_.msprofReporterCallback == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); | |||||
| return; | |||||
| } | |||||
| int32_t cb_ret = prof_cb_.msprofReporterCallback( | |||||
| static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), | |||||
| static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_UNINIT), | |||||
| nullptr, 0); | |||||
| if (cb_ret != 0) { | |||||
| GELOGW("profiling plugin uninit failed, ret:%d", cb_ret); | |||||
| } | |||||
| #endif | |||||
| } | } | ||||
| Msprof::Engine::PluginIntf *ProfilingEngineImpl::CreatePlugin() { | |||||
| GELOGI(" Create Plugin"); | |||||
| return new (std::nothrow) PluginImpl(GE_PROFILING_MODULE); | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMsprofReport( | |||||
| ReporterData &reporter_data) const { | |||||
| if (prof_cb_.msprofReporterCallback == nullptr) { | |||||
| GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| return prof_cb_.msprofReporterCallback( | |||||
| static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), | |||||
| static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_REPORT), | |||||
| static_cast<void *>(&reporter_data), sizeof(ReporterData)); | |||||
| } | } | ||||
| int ProfilingEngineImpl::ReleasePlugin(Msprof::Engine::PluginIntf *plugin) { | |||||
| if (plugin != nullptr) { | |||||
| delete plugin; | |||||
| plugin = nullptr; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( | |||||
| std::string &fp_point, std::string &bp_point) { | |||||
| // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init | |||||
| if (!fp_point_.empty() && !bp_point_.empty()) { | |||||
| fp_point = fp_point_; | |||||
| bp_point = bp_point_; | |||||
| GELOGI("Bp Fp have been initialized in env or options. bp_point: %s, fp_point: %s", bp_point.c_str(), fp_point.c_str()); | |||||
| return; | |||||
| } | |||||
| // ProfApi mode and training trace is set | |||||
| try { | |||||
| char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = { 0x00 }; | |||||
| INT32 ret = mmGetEnv("PROFILING_OPTIONS", env_profiling_options, MSPROF_OPTIONS_DEF_LEN_MAX); | |||||
| if (ret != EN_OK) { | |||||
| GELOGI("PROFILING_OPTIONS env is not exist."); | |||||
| return; | |||||
| } | |||||
| GELOGI("Parse env PROFILING_OPTIONS:%s.", env_profiling_options); | |||||
| Json prof_options = Json::parse(env_profiling_options); | |||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| fp_point = fp_point_; | |||||
| bp_point = bp_point_; | |||||
| if (!fp_point_.empty() && !bp_point_.empty()) { | |||||
| GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | |||||
| } | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Json prof options is invalid."); | |||||
| return; | |||||
| } | } | ||||
| return 0; | |||||
| return; | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,9 +26,7 @@ | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "external/register/register_types.h" | #include "external/register/register_types.h" | ||||
| #include "toolchain/prof_engine.h" | |||||
| #include "toolchain/prof_mgr_core.h" | |||||
| #include "toolchain/prof_acl_api.h" | |||||
| #include "toolchain/prof_callback.h" | |||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| @@ -37,35 +35,33 @@ using Json = nlohmann::json; | |||||
| namespace { | namespace { | ||||
| const std::string GE_PROFILING_MODULE = "Framework"; | const std::string GE_PROFILING_MODULE = "Framework"; | ||||
| // DataTypeConfig MASK | |||||
| #define PROF_ACL_API_MASK 0x0001 | |||||
| #define PROF_TASK_TIME_MASK 0x0002 | |||||
| #define PROF_AICORE_METRICS_MASK 0x0004 | |||||
| #define PROF_AICPU_TRACE_MASK 0x0008 | |||||
| #define PROF_MODEL_EXECUTE_MASK 0x0010 | |||||
| #define PROF_RUNTIME_API_MASK 0x0020 | |||||
| #define PROF_RUNTIME_TRACE_MASK 0x0040 | |||||
| #define PROF_SCHEDULE_TIMELINE_MASK 0x0080 | |||||
| #define PROF_SCHEDULE_TRACE_MASK 0x0100 | |||||
| #define PROF_AIVECTORCORE_METRICS_MASK 0x0200 | |||||
| #define PROF_SUBTASK_TIME_MASK 0x0400 | |||||
| #define PROF_TRAINING_TRACE_MASK 0x0800 | |||||
| #define PROF_HCCL_TRACE_MASK 0x1000 | |||||
| #define PROF_DATA_PROCESS_MASK 0x2000 | |||||
| #define PROF_MODEL_LOAD_MASK 0x8000000000000000 | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| struct DeviceSubsInfo { | struct DeviceSubsInfo { | ||||
| uint64_t module; | uint64_t module; | ||||
| uint32_t subscribe_count; | uint32_t subscribe_count; | ||||
| }; | }; | ||||
| // register Plugin | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { | |||||
| public: | |||||
| explicit PluginImpl(const std::string &module); | |||||
| ~PluginImpl() {} | |||||
| int Init(const Msprof::Engine::Reporter *reporter); | |||||
| int UnInit(); | |||||
| static Msprof::Engine::Reporter *GetPluginReporter() { return reporter_; } | |||||
| private: | |||||
| static Msprof::Engine::Reporter *reporter_; | |||||
| std::string module_; | |||||
| }; | |||||
| // register Engine | |||||
| class ProfilingEngineImpl : public Msprof::Engine::EngineIntf { | |||||
| public: | |||||
| ProfilingEngineImpl() {} | |||||
| ~ProfilingEngineImpl() {} | |||||
| Msprof::Engine::PluginIntf *CreatePlugin(); | |||||
| int ReleasePlugin(Msprof::Engine::PluginIntf *plugin); | |||||
| struct MsprofCallback { | |||||
| MsprofCtrlCallback msprofCtrlCallback; | |||||
| MsprofReporterCallback msprofReporterCallback; | |||||
| }; | }; | ||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | ||||
| @@ -73,68 +69,54 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
| ProfilingManager(); | ProfilingManager(); | ||||
| virtual ~ProfilingManager(); | virtual ~ProfilingManager(); | ||||
| static ProfilingManager &Instance(); | static ProfilingManager &Instance(); | ||||
| ge::Status Init(const Options &options); | |||||
| ge::Status InitFromOptions(const Options &options); | |||||
| ge::Status InitFromAclCfg(const std::string &config); | |||||
| ge::Status StartProfiling(int32_t iter, int32_t device_id); | |||||
| void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); | |||||
| ge::Status ProfModelSubscribe(uint64_t module, void *model); | |||||
| ge::Status ProfModelUnsubscribe(void *model); | |||||
| ge::Status ProfInit(uint64_t module); | |||||
| ge::Status ProfFinalize(); | |||||
| ge::Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | |||||
| ge::Status ProfStopProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | |||||
| Status Init(const Options &options); | |||||
| Status ProfInit(uint64_t module); | |||||
| Status ProfFinalize(); | |||||
| Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | |||||
| Status ProfStopProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | |||||
| Status ProfModelSubscribe(uint64_t module, void *model); | |||||
| Status ProfModelUnsubscribe(void *model); | |||||
| void StopProfiling(); | void StopProfiling(); | ||||
| bool ProfilingOpTraceOn() const { return is_op_trace_; } | |||||
| bool ProfilingLoadFlag() const { return is_load_; } | |||||
| bool ProfilingTrainingTraceOn() const { return is_training_trace_; } | bool ProfilingTrainingTraceOn() const { return is_training_trace_; } | ||||
| bool ProfilingModelLoadOn() const { return is_load_profiling_; } | bool ProfilingModelLoadOn() const { return is_load_profiling_; } | ||||
| bool ProfilingModelExecuteOn() const; | bool ProfilingModelExecuteOn() const; | ||||
| bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // only used by command pattern | |||||
| bool IsAclApiMode() const { return is_acl_api_mode_; } | |||||
| int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; } | |||||
| bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // is_execute_profiling_ only used by ge option and env | |||||
| void ReportProfilingData(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | void ReportProfilingData(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | ||||
| const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | |||||
| bool check_device); | |||||
| void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | |||||
| Msprof::Engine::ReporterData &reporter_data); | |||||
| const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info); | |||||
| void ProfilingTaskDescInfo(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | void ProfilingTaskDescInfo(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | ||||
| const int32_t &device_id); | const int32_t &device_id); | ||||
| void ProfilingGraphDescInfo(uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | void ProfilingGraphDescInfo(uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | ||||
| const int32_t &device_id); | const int32_t &device_id); | ||||
| void SetProfilingConfig(const string &profiling_cfg); | |||||
| vector<int32_t> GetProfilingDeviceId() const { return device_id_; } | |||||
| void PluginUnInit(const std::string &module) const; | |||||
| Status PluginInit() const; | |||||
| void PluginUnInit() const; | |||||
| Status CallMsprofReport(ReporterData &reporter_data) const; | |||||
| struct MsprofCallback &GetMsprofCallback() { return prof_cb_; } | |||||
| void SetMsprofCtrlCallback(MsprofCtrlCallback func) { prof_cb_.msprofCtrlCallback = func; } | |||||
| void SetMsprofReporterCallback(MsprofReporterCallback func) { prof_cb_.msprofReporterCallback = func; } | |||||
| void GetFpBpPoint(std::string &fp_point, std::string &bp_point); | |||||
| private: | private: | ||||
| ge::Status ParseFeaturesFromAclCfg(const Json &feature); | |||||
| ge::Status ProfParseParam(const std::map<std::string, std::string> &config_para, int32_t &device_num, | |||||
| vector<int32_t> &device_list); | |||||
| ge::Status ProfParseDeviceId(const std::map<std::string, std::string> &config_para, | |||||
| Status InitFromOptions(const Options &options, MsprofGeOptions &prof_conf); | |||||
| Status ParseOptions(const std::string &options); | |||||
| Status ProfParseParam(const std::map<std::string, std::string> &config_para, int32_t &device_num, | |||||
| vector<int32_t> &device_list); | |||||
| Status ProfParseDeviceId(const std::map<std::string, std::string> &config_para, | |||||
| vector<int32_t> &device_list); | vector<int32_t> &device_list); | ||||
| uint64_t GetProfilingModule(); | uint64_t GetProfilingModule(); | ||||
| void GraphDescReport(const int32_t &device_id, const string &data); | |||||
| void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list); | void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list); | ||||
| bool is_load_profiling_ = false; | |||||
| bool is_execute_profiling_ = false; | |||||
| bool is_op_trace_ = false; | |||||
| bool is_load_ = false; | |||||
| bool is_training_trace_ = false; | |||||
| bool is_acl_api_mode_ = false; | |||||
| int32_t op_trace_iter_num_ = 0; | |||||
| string job_id_; | |||||
| string prof_dir_; | |||||
| void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); | |||||
| bool is_load_profiling_; | |||||
| bool is_execute_profiling_; | |||||
| bool is_training_trace_; | |||||
| vector<int32_t> device_id_; | vector<int32_t> device_id_; | ||||
| vector<string> op_trace_conf_; | |||||
| vector<string> profiling_opts_; | |||||
| vector<void *> prof_handle_vec_; | |||||
| string recv_profiling_config_; | |||||
| string send_profiling_config_; | |||||
| string system_trace_conf_; | |||||
| string task_trace_conf_; | |||||
| const ProfilingEngineImpl engine_; | |||||
| map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module | map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module | ||||
| map<uint32_t, DeviceSubsInfo> subs_dev_module_; // key: device_id, value: profiling on module | map<uint32_t, DeviceSubsInfo> subs_dev_module_; // key: device_id, value: profiling on module | ||||
| uint32_t subscribe_count_; | uint32_t subscribe_count_; | ||||
| std::mutex mutex_; | std::mutex mutex_; | ||||
| MsprofCallback prof_cb_; | |||||
| std::string fp_point_; | |||||
| std::string bp_point_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ | #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ | ||||
| @@ -801,7 +801,7 @@ const uint32_t XRGB_CHN_NUM = 4; | |||||
| /// | /// | ||||
| const bool DEFAULT_GLOBAL_POOLING = false; | const bool DEFAULT_GLOBAL_POOLING = false; | ||||
| const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// | |||||
| const uint32_t MODEL_VERSION = 0x20000000; ///< Model version 2.0/// | |||||
| // Eltwise's input size | // Eltwise's input size | ||||
| const int ELTWISE_MIN_INPUT_SIZE = 2; | const int ELTWISE_MIN_INPUT_SIZE = 2; | ||||
| @@ -51,14 +51,15 @@ namespace { | |||||
| * If such an exception is encountered during operation, | * If such an exception is encountered during operation, | ||||
| * the proto file can be divided into several small files or the limit value can be increased. | * the proto file can be divided into several small files or the limit value can be increased. | ||||
| */ | */ | ||||
| const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||||
| const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||||
| const int kFileSizeOutLimitedOrOpenFailed = -1; | |||||
| const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||||
| const int kWarningThreshold = 1073741824; // 536870912 * 2 536870912 represent 512M | |||||
| /// The maximum length of the file. | /// The maximum length of the file. | ||||
| const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now | |||||
| const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now | |||||
| const int kMaxBuffSize = 256; | const int kMaxBuffSize = 256; | ||||
| const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; | const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; | ||||
| constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024; | |||||
| constexpr uint32_t kMaxConfigFileByte = 10485760; // 10 * 1024 * 1024 | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -76,7 +77,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||||
| std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == kFileSizeOutLimitedOrOpenFailed, return false, | |||||
| "file size not valid."); | |||||
| std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | ||||
| if (!fs.is_open()) { | if (!fs.is_open()) { | ||||
| @@ -118,20 +120,20 @@ long GetFileLength(const std::string &input_file) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | ||||
| unsigned long long file_length = 0; | unsigned long long file_length = 0; | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); | |||||
| return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); | |||||
| mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); | |||||
| return kFileSizeOutLimitedOrOpenFailed, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); | ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); | ||||
| return -1, "File[%s] size is 0, not valid.", input_file.c_str()); | return -1, "File[%s] size is 0, not valid.", input_file.c_str()); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19016", {"filepath", "filesize", "maxlen"}, | |||||
| {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
| return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, | |||||
| kMaxFileSizeLimit); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19016", {"filepath", "filesize", "maxlen"}, | |||||
| {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
| return kFileSizeOutLimitedOrOpenFailed, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, | |||||
| kMaxFileSizeLimit); | |||||
| return static_cast<long>(file_length); | return static_cast<long>(file_length); | ||||
| } | } | ||||
| @@ -187,7 +189,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||||
| std::streamsize size = file.tellg(); | std::streamsize size = file.tellg(); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t >(kMaxFileSizeLimit), file.close(); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t>(kMaxFileSizeLimit), file.close(); | |||||
| return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); | return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); | ||||
| file.seekg(0, std::ios::beg); // [no need to check value] | file.seekg(0, std::ios::beg); // [no need to check value] | ||||
| @@ -210,8 +212,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
| GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | ||||
| auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
| if (dir_path_len >= MMPA_MAX_PATH) { | if (dir_path_len >= MMPA_MAX_PATH) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19002", {"filepath", "size"}, {directory_path, std::to_string(MMPA_MAX_PATH)}); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, | |||||
| {directory_path, std::to_string(MMPA_MAX_PATH)}); | |||||
| GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); | GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -224,8 +226,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | ||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", | |||||
| directory_path.c_str()); | |||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| @@ -265,7 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||||
| std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| "E19000", {"path", "errmsg"}, {file, strerror(errno)}); | |||||
| "E19000", {"path", "errmsg"}, {file, strerror(errno)}); | |||||
| return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); | return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | ||||
| @@ -301,13 +302,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | google::protobuf::io::IstreamInputStream input(&fs); | ||||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | bool ret = google::protobuf::TextFormat::Parse(&input, message); | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); | |||||
| !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { | ||||
| mmTimeval tv {}; | |||||
| mmTimeval tv{}; | |||||
| int ret = mmGetTimeOfDay(&tv, nullptr); | int ret = mmGetTimeOfDay(&tv, nullptr); | ||||
| GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); | GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); | ||||
| auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds | auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds | ||||
| @@ -315,7 +316,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { | ||||
| mmTimeval tv {}; | |||||
| mmTimeval tv{}; | |||||
| int ret = mmGetTimeOfDay(&tv, nullptr); | int ret = mmGetTimeOfDay(&tv, nullptr); | ||||
| GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); | GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); | ||||
| auto total_use_time = tv.tv_sec; // seconds | auto total_use_time = tv.tv_sec; // seconds | ||||
| @@ -350,8 +351,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(MMPA_MAX_PATH)}); | |||||
| return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, | |||||
| {path, std::to_string(MMPA_MAX_PATH)}); | |||||
| return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH); | |||||
| // Nullptr is returned when the path does not exist or there is no permission | // Nullptr is returned when the path does not exist or there is no permission | ||||
| // Return absolute path when path is accessible | // Return absolute path when path is accessible | ||||
| @@ -385,16 +387,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
| // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores | // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores | ||||
| // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) | // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) | ||||
| #ifdef __GNUC__ | #ifdef __GNUC__ | ||||
| std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; | |||||
| std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; | |||||
| #else | #else | ||||
| std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; | |||||
| std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; | |||||
| #endif | #endif | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| !ValidateStr(real_path, mode), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, real_path, kPathValidReason}); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | |||||
| !ValidateStr(real_path, mode), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, real_path, kPathValidReason}); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | |||||
| // The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
| if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { | if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { | ||||
| @@ -416,24 +418,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| } | } | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)}); | |||||
| return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), MMPA_MAX_PATH); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)}); | |||||
| return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), | |||||
| MMPA_MAX_PATH); | |||||
| // A regular matching expression to verify the validity of the input file path | // A regular matching expression to verify the validity of the input file path | ||||
| // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores | // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores | ||||
| // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) | // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) | ||||
| #ifdef __GNUC__ | #ifdef __GNUC__ | ||||
| std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; | |||||
| std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; | |||||
| #else | #else | ||||
| std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; | |||||
| std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; | |||||
| #endif | #endif | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| !ValidateStr(file_path, mode), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, kPathValidReason}); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); | |||||
| !ValidateStr(file_path, mode), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, kPathValidReason}); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); | |||||
| std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
| // Can get absolute path (file exists) | // Can get absolute path (file exists) | ||||
| @@ -17,6 +17,7 @@ set(SRC_LIST | |||||
| "../common/dump/dump_properties.cc" | "../common/dump/dump_properties.cc" | ||||
| "../common/dump/dump_manager.cc" | "../common/dump/dump_manager.cc" | ||||
| "../common/dump/dump_op.cc" | "../common/dump/dump_op.cc" | ||||
| "../common/profiling/ge_profiling.cc" | |||||
| "../graph/load/graph_loader.cc" | "../graph/load/graph_loader.cc" | ||||
| "../graph/execute/graph_execute.cc" | "../graph/execute/graph_execute.cc" | ||||
| "../omm/csa_interact.cc" | "../omm/csa_interact.cc" | ||||
| @@ -244,7 +245,6 @@ target_link_libraries(ge_executor_shared PRIVATE | |||||
| mmpa | mmpa | ||||
| graph | graph | ||||
| register | register | ||||
| msprof | |||||
| error_manager | error_manager | ||||
| ascend_hal_stub | ascend_hal_stub | ||||
| ascend_protobuf | ascend_protobuf | ||||
| @@ -283,7 +283,8 @@ Status GeExecutor::Initialize() { | |||||
| // Start profiling | // Start profiling | ||||
| Options profiling_options; | Options profiling_options; | ||||
| profiling_options.device_id = 0; | profiling_options.device_id = 0; | ||||
| profiling_options.job_id = ""; | |||||
| // job id need to be set, the value is meaningless; | |||||
| profiling_options.job_id = "1"; | |||||
| ProfilingManager::Instance().Init(profiling_options); | ProfilingManager::Instance().Init(profiling_options); | ||||
| isInit_ = true; | isInit_ = true; | ||||
| @@ -303,7 +304,7 @@ Status GeExecutor::Finalize() { | |||||
| // Stop profiling | // Stop profiling | ||||
| if (ProfilingManager::Instance().ProfilingOn()) { | if (ProfilingManager::Instance().ProfilingOn()) { | ||||
| ProfilingManager::Instance().StopProfiling(); | ProfilingManager::Instance().StopProfiling(); | ||||
| ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE); | |||||
| ProfilingManager::Instance().PluginUnInit(); | |||||
| } | } | ||||
| GELOGI("Uninit GeExecutor over."); | GELOGI("Uninit GeExecutor over."); | ||||
| @@ -638,7 +639,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
| } | } | ||||
| std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = ModelManager::GetInstance()->GetHybridModel(model_id); | |||||
| std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = | |||||
| ModelManager::GetInstance()->GetHybridModel(model_id); | |||||
| if (hybrid_davinci_model != nullptr) { | if (hybrid_davinci_model != nullptr) { | ||||
| uint64_t session_id = hybrid_davinci_model->GetSessionId(); | uint64_t session_id = hybrid_davinci_model->GetSessionId(); | ||||
| VarManagerPool::Instance().RemoveVarManager(session_id); | VarManagerPool::Instance().RemoveVarManager(session_id); | ||||
| @@ -8,6 +8,7 @@ local_ge_executor_src_files := \ | |||||
| ../common/dump/dump_op.cc \ | ../common/dump/dump_op.cc \ | ||||
| ../common/ge/plugin_manager.cc \ | ../common/ge/plugin_manager.cc \ | ||||
| ../common/ge/op_tiling_manager.cc \ | ../common/ge/op_tiling_manager.cc \ | ||||
| ../common/profiling/ge_profiling.cc \ | |||||
| ../graph/load/graph_loader.cc \ | ../graph/load/graph_loader.cc \ | ||||
| ../graph/execute/graph_execute.cc \ | ../graph/execute/graph_execute.cc \ | ||||
| ../omm/csa_interact.cc \ | ../omm/csa_interact.cc \ | ||||
| @@ -177,7 +178,6 @@ local_ge_executor_shared_library := \ | |||||
| libmmpa \ | libmmpa \ | ||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libmsprof \ | |||||
| liberror_manager \ | liberror_manager \ | ||||
| local_ge_executor_ldflags := -lrt -ldl \ | local_ge_executor_ldflags := -lrt -ldl \ | ||||
| @@ -234,7 +234,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libmmpa \ | libmmpa \ | ||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libmsprof \ | |||||
| liberror_manager \ | liberror_manager \ | ||||
| stub/libascend_hal \ | stub/libascend_hal \ | ||||
| @@ -272,7 +271,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libruntime \ | libruntime \ | ||||
| libslog \ | libslog \ | ||||
| libmmpa \ | libmmpa \ | ||||
| libmsprof \ | |||||
| LOCAL_LDFLAGS += $(local_ge_executor_ldflags) | LOCAL_LDFLAGS += $(local_ge_executor_ldflags) | ||||
| @@ -304,7 +302,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libruntime \ | libruntime \ | ||||
| libslog \ | libslog \ | ||||
| libmmpa \ | libmmpa \ | ||||
| libmsprof \ | |||||
| ifeq ($(device_os),android) | ifeq ($(device_os),android) | ||||
| LOCAL_LDFLAGS += -ldl | LOCAL_LDFLAGS += -ldl | ||||
| @@ -164,6 +164,7 @@ OMG_HOST_SRC_FILES := \ | |||||
| host_kernels/slice_d_kernel.cc \ | host_kernels/slice_d_kernel.cc \ | ||||
| host_kernels/dynamic_stitch_kernel.cc \ | host_kernels/dynamic_stitch_kernel.cc \ | ||||
| host_kernels/identity_kernel.cc \ | host_kernels/identity_kernel.cc \ | ||||
| host_kernels/reformat_kernel.cc \ | |||||
| graph/passes/stop_gradient_pass.cc \ | graph/passes/stop_gradient_pass.cc \ | ||||
| graph/passes/prevent_gradient_pass.cc \ | graph/passes/prevent_gradient_pass.cc \ | ||||
| graph/passes/identity_pass.cc \ | graph/passes/identity_pass.cc \ | ||||
| @@ -14,7 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "host_cpu_engine.h" | #include "host_cpu_engine.h" | ||||
| #include <dlfcn.h> | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| @@ -96,8 +95,8 @@ Status GetDataNumber(const GeTensorDesc &out_desc, uint64_t &data_num) { | |||||
| void HostCpuEngine::CloseSo() { | void HostCpuEngine::CloseSo() { | ||||
| for (auto handle : lib_handles_) { | for (auto handle : lib_handles_) { | ||||
| if (dlclose(handle) != 0) { | |||||
| GELOGW("failed to close handle, message: %s", dlerror()); | |||||
| if (mmDlclose(handle) != 0) { | |||||
| GELOGW("failed to close handle, message: %s", mmDlerror()); | |||||
| } | } | ||||
| } | } | ||||
| lib_handles_.clear(); | lib_handles_.clear(); | ||||
| @@ -323,13 +322,13 @@ Status HostCpuEngine::LoadLibs(std::vector<std::string> &lib_paths) { | |||||
| Status HostCpuEngine::LoadLib(const std::string &lib_path) { | Status HostCpuEngine::LoadLib(const std::string &lib_path) { | ||||
| GELOGI("To invoke dlopen on lib: %s", lib_path.c_str()); | GELOGI("To invoke dlopen on lib: %s", lib_path.c_str()); | ||||
| auto handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
| auto handle = mmDlopen(lib_path.c_str(), MMPA_RTLD_NOW | MMPA_RTLD_GLOBAL); | |||||
| if (handle == nullptr) { | if (handle == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), dlerror()); | |||||
| GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), mmDlerror()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| auto initialize = (Status (*)(const HostCpuContext &))dlsym(handle, "Initialize"); | |||||
| auto initialize = (Status (*)(const HostCpuContext &))mmDlsym(handle, "Initialize"); | |||||
| if (initialize != nullptr) { | if (initialize != nullptr) { | ||||
| GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str()); | GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str()); | ||||
| if (initialize(HostCpuContext()) != SUCCESS) { | if (initialize(HostCpuContext()) != SUCCESS) { | ||||
| @@ -29,6 +29,8 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| common/dump/dump_manager.cc \ | common/dump/dump_manager.cc \ | ||||
| common/dump/dump_properties.cc \ | common/dump/dump_properties.cc \ | ||||
| common/dump/dump_op.cc \ | common/dump/dump_op.cc \ | ||||
| common/profiling/ge_profiling.cc \ | |||||
| common/profiling/ge_runner_profiling.cc \ | |||||
| engine_manager/dnnengine_manager.cc \ | engine_manager/dnnengine_manager.cc \ | ||||
| ge_local_engine/engine/host_cpu_engine.cc \ | ge_local_engine/engine/host_cpu_engine.cc \ | ||||
| generator/ge_generator.cc \ | generator/ge_generator.cc \ | ||||
| @@ -170,6 +172,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| host_kernels/sub_kernel.cc \ | host_kernels/sub_kernel.cc \ | ||||
| host_kernels/transdata_kernel.cc \ | host_kernels/transdata_kernel.cc \ | ||||
| host_kernels/unpack_kernel.cc \ | host_kernels/unpack_kernel.cc \ | ||||
| host_kernels/reformat_kernel.cc \ | |||||
| graph/passes/folding_pass.cc \ | graph/passes/folding_pass.cc \ | ||||
| graph/passes/get_original_format_pass.cc \ | graph/passes/get_original_format_pass.cc \ | ||||
| graph/passes/guarantee_const_pass.cc \ | graph/passes/guarantee_const_pass.cc \ | ||||
| @@ -306,7 +309,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
| proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
| client/ge_api.cc \ | client/ge_api.cc \ | ||||
| client/ge_prof.cc \ | |||||
| RUNNER_LOCAL_C_INCLUDES := \ | RUNNER_LOCAL_C_INCLUDES := \ | ||||
| $(LOCAL_PATH) ./ \ | $(LOCAL_PATH) ./ \ | ||||
| @@ -371,7 +373,7 @@ LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) | |||||
| LOCAL_STATIC_LIBRARIES := libge_memory \ | LOCAL_STATIC_LIBRARIES := libge_memory \ | ||||
| libadump_server \ | libadump_server \ | ||||
| libmsprofiler \ | |||||
| libmsprofiler_fwk \ | |||||
| libmmpa \ | libmmpa \ | ||||
| LOCAL_SHARED_LIBRARIES := \ | LOCAL_SHARED_LIBRARIES := \ | ||||
| @@ -381,7 +383,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libge_common \ | libge_common \ | ||||
| libmsprof \ | |||||
| liberror_manager \ | liberror_manager \ | ||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -408,7 +409,6 @@ endif | |||||
| LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) | ||||
| LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ | LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ | ||||
| ../../out/ge/lib64/stub/ge_prof.cc \ | |||||
| ../../out/ge/lib64/stub/ge_ir_build.cc \ | ../../out/ge/lib64/stub/ge_ir_build.cc \ | ||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -464,7 +464,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | libc_sec \ | ||||
| libslog \ | libslog \ | ||||
| libmmpa \ | libmmpa \ | ||||
| libmsprof \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -497,7 +496,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | libc_sec \ | ||||
| libslog \ | libslog \ | ||||
| libmmpa \ | libmmpa \ | ||||
| libmsprof \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -28,6 +28,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace model_runner { | namespace model_runner { | ||||
| const int kOffsetUnit = 8; | |||||
| RuntimeModel::~RuntimeModel() { | RuntimeModel::~RuntimeModel() { | ||||
| GELOGI("RuntimeModel destructor start"); | GELOGI("RuntimeModel destructor start"); | ||||
| @@ -495,7 +496,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
| return false; | return false; | ||||
| } | } | ||||
| uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | ||||
| int64_t offset = elem_num * 8; | |||||
| int64_t offset = elem_num * kOffsetUnit; | |||||
| uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; | uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; | ||||
| for (int64_t i = elem_num - 1; i >= 0; --i) { | for (int64_t i = elem_num - 1; i >= 0; --i) { | ||||
| buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); | buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); | ||||
| @@ -156,7 +156,12 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen | |||||
| } | } | ||||
| string op_type; | string op_type; | ||||
| if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) { | |||||
| bool is_const = false; | |||||
| (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const); | |||||
| if (is_const) { | |||||
| GELOGD("Get input[%d] is const", index); | |||||
| op_type = CONSTANTOP; | |||||
| } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) { | |||||
| op_type = DATA; | op_type = DATA; | ||||
| } | } | ||||
| @@ -165,6 +170,18 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen | |||||
| if (data_op == nullptr) { | if (data_op == nullptr) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (is_const) { | |||||
| ConstGeTensorPtr tensor_value; | |||||
| if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) { | |||||
| GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) { | |||||
| GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| (void)AttrUtils::SetBool(data_op, "_is_single_op", true); | (void)AttrUtils::SetBool(data_op, "_is_single_op", true); | ||||
| GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail."); | GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail."); | ||||
| @@ -240,6 +257,8 @@ class GeGenerator::Impl { | |||||
| Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); | Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); | ||||
| Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff); | |||||
| Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | ||||
| const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | ||||
| @@ -505,19 +524,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| GE_CHECK_NOTNULL(ge_root_model); | GE_CHECK_NOTNULL(ge_root_model); | ||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | ||||
| ModelHelper model_helper; | |||||
| string model_name = ""; | |||||
| Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name); | |||||
| if (name_ret != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); | |||||
| GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null"); | |||||
| ge_model->SetName(model_name); | |||||
| ret = impl_->SaveModel(file_name_prefix, ge_model, model); | |||||
| ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Save model failed"); | GELOGE(ret, "Save model failed"); | ||||
| if (impl_->graph_manager_.Finalize() != SUCCESS) { | if (impl_->graph_manager_.Finalize() != SUCCESS) { | ||||
| @@ -567,6 +574,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> | |||||
| Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | ||||
| const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | ||||
| bool is_offline) { | bool is_offline) { | ||||
| if (!is_offline) { | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true); | |||||
| } | |||||
| if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { | if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); | GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); | ||||
| @@ -594,40 +604,11 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| // 2. Create ComputeGraph. | // 2. Create ComputeGraph. | ||||
| string name = ge::CurrentTimeInStr() + "_" + model_file_name; | string name = ge::CurrentTimeInStr() + "_" + model_file_name; | ||||
| ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name); | |||||
| GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); | |||||
| // 3. Add Node to ComputeGraph. | |||||
| NodePtr op_node = compute_graph->AddNode(op_desc); | |||||
| GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); | |||||
| // 4. Create InputData node. | |||||
| int32_t arg_index = 0; | |||||
| if (inputs.empty()) { | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); | |||||
| if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); | |||||
| arg_index++; | |||||
| } | |||||
| } else { | |||||
| for (const auto &in_desc : inputs) { | |||||
| GeTensorDesc input_desc = in_desc.GetTensorDesc(); | |||||
| GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); | |||||
| arg_index++; | |||||
| } | |||||
| Graph graph; | |||||
| if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "make graph fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| // 5. Create Output node. | |||||
| if (!outputs.empty()) { | |||||
| GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); | |||||
| } | |||||
| // dump ComputeGraph. | |||||
| compute_graph->Dump(); | |||||
| Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| GELOGI("ATC parser success in single op build."); | GELOGI("ATC parser success in single op build."); | ||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| @@ -683,6 +664,46 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
| return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false); | return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false); | ||||
| } | } | ||||
| Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | |||||
| const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) { | |||||
| ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name); | |||||
| GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); | |||||
| // 1. Add Node to ComputeGraph. | |||||
| NodePtr op_node = compute_graph->AddNode(op_desc); | |||||
| GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); | |||||
| // 2. Create InputData node. | |||||
| int32_t arg_index = 0; | |||||
| if (inputs.empty()) { | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); | |||||
| if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); | |||||
| arg_index++; | |||||
| } | |||||
| } else { | |||||
| for (const auto &in_desc : inputs) { | |||||
| GeTensorDesc input_desc = in_desc.GetTensorDesc(); | |||||
| GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); | |||||
| arg_index++; | |||||
| } | |||||
| } | |||||
| // 3. Create Output node. | |||||
| if (!outputs.empty()) { | |||||
| GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); | |||||
| } | |||||
| // dump ComputeGraph node. | |||||
| compute_graph->Dump(); | |||||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | ||||
| const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) { | const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) { | ||||
| GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); | ||||
| @@ -712,6 +733,44 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr & | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model, | |||||
| ModelBufferData &model_buff) { | |||||
| bool is_unknown_shape = false; | |||||
| auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Check root model is unkonwn shape failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape); | |||||
| GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED, | |||||
| "ge root model has no sub model") | |||||
| GeModelPtr model_root = nullptr; | |||||
| if (is_unknown_shape) { | |||||
| model_root = make_shared<GeModel>(); | |||||
| model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); | |||||
| model_root->SetName(ge_root_model->GetRootGraph()->GetName()); | |||||
| } else { | |||||
| model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; | |||||
| } | |||||
| // set atc version | |||||
| if (!SetAtcVersionInfo(*(model_root.get()))) { | |||||
| GELOGW("SetPackageVersionInfo of atc failed!"); | |||||
| } | |||||
| // set opp version | |||||
| if (!SetOppVersionInfo(*(model_root.get()))) { | |||||
| GELOGW("SetPackageVersionInfo of ops failed!"); | |||||
| } | |||||
| ModelHelper model_helper; | |||||
| model_helper.SetSaveMode(is_offline_); | |||||
| ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Save to om model failed"); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, | Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, | ||||
| GeRootModelPtr &ge_root_model) { | GeRootModelPtr &ge_root_model) { | ||||
| static std::atomic<GraphId> atomic_graph_id(0); | static std::atomic<GraphId> atomic_graph_id(0); | ||||
| @@ -349,7 +349,8 @@ static Status GenerateTaskForConstant(const std::shared_ptr<ComputeGraph> &graph | |||||
| GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | ||||
| std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | ||||
| if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { | if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { | ||||
| GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); | |||||
| GELOGE(FAILED, "Insert memcpy between %s and %s failed.", | |||||
| in_node->GetName().c_str(), node->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -475,7 +476,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr | |||||
| } | } | ||||
| Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { | Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { | ||||
| // set input_desc.size = src_node.output_desc.size | |||||
| // Set the size of input_desc to 'src_node.output_desc.size' | |||||
| if (node_ptr->GetType() == DATA) { | if (node_ptr->GetType() == DATA) { | ||||
| bool is_unknown_shape = false; | bool is_unknown_shape = false; | ||||
| GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), | GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), | ||||
| @@ -498,7 +499,7 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { | |||||
| GE_IF_BOOL_EXEC(src_op == nullptr, continue); | GE_IF_BOOL_EXEC(src_op == nullptr, continue); | ||||
| auto node_op_desc = node_ptr->GetOpDesc(); | auto node_op_desc = node_ptr->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | ||||
| // set dst_node.input_desc = src_node.output_desc | |||||
| // Set the input_desc of dst_node to 'src_node.output_desc' | |||||
| auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); | auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); | ||||
| int64_t size = 0; | int64_t size = 0; | ||||
| GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); | GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); | ||||
| @@ -512,7 +513,6 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { | |||||
| auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | ||||
| GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
| (void) ge::TensorUtils::SetSize(*input_desc, size); | (void) ge::TensorUtils::SetSize(*input_desc, size); | ||||
| GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); | |||||
| GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), | GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), | ||||
| input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), | input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); | ||||
| @@ -21,8 +21,8 @@ | |||||
| namespace { | namespace { | ||||
| const uint32_t kRangeCeilInterval = 2; | const uint32_t kRangeCeilInterval = 2; | ||||
| const uint32_t kLogBase = 2; | const uint32_t kLogBase = 2; | ||||
| const int64_t kLargeBlockSize = 8 * 1024 * 1024; | |||||
| const int64_t kLargeBlockRangeSize = 10; | |||||
| const int64_t kLargeBlockSize = 8388608; // 8 * 1024 * 1024 | |||||
| const int64_t kLargeBlockRangeSize = 2; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -73,15 +73,17 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
| GELOGE(FAILED, "dividend is 0!"); | GELOGE(FAILED, "dividend is 0!"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // Memory size is 512 aligned, so it is not necessary to take less than 512 | |||||
| int64_t min_memory_size = (all_memory_size.back() > MEM_ALIGN_SIZE) ? MEM_ALIGN_SIZE : all_memory_size.front(); | |||||
| auto range_number = static_cast<size_t>( | auto range_number = static_cast<size_t>( | ||||
| ceil(log(all_memory_size.back() / static_cast<double>(all_memory_size.front())) / log(kLogBase))); | |||||
| ceil(log(all_memory_size.back() / static_cast<double>(min_memory_size)) / log(kLogBase))); | |||||
| range_number = (range_number == 0) ? 1 : range_number; | range_number = (range_number == 0) ? 1 : range_number; | ||||
| GELOGD("Range number: %zu", range_number); | GELOGD("Range number: %zu", range_number); | ||||
| vector<vector<int64_t>> ranges(range_number); | vector<vector<int64_t>> ranges(range_number); | ||||
| GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); | GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); | ||||
| size_t range_number_limit = all_memory_size.size() / range_number; | size_t range_number_limit = all_memory_size.size() / range_number; | ||||
| int64_t range_ceil = all_memory_size[0]; | |||||
| int64_t range_ceil = min_memory_size; | |||||
| for (size_t i = 1; i <= range_number; i++) { | for (size_t i = 1; i <= range_number; i++) { | ||||
| GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval), | GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval), | ||||
| GELOGE(FAILED, "Multiply result is out of range."); | GELOGE(FAILED, "Multiply result is out of range."); | ||||
| @@ -114,7 +116,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
| range_ceils.push_back(range.back()); | range_ceils.push_back(range.back()); | ||||
| } | } | ||||
| } | } | ||||
| GELOGD("Range ceils: %s", ToString(range_ceils).c_str()); | |||||
| GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -65,6 +65,98 @@ void AlignMemOffset(size_t &mem_align_size) { | |||||
| mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | ||||
| } | } | ||||
| static bool CompareLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { | |||||
| auto left_node_op_desc = left.node->GetOpDesc(); | |||||
| auto right_node_op_desc = right.node->GetOpDesc(); | |||||
| if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr) | |||||
| && (left_node_op_desc->GetId() < right_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void GetLifeList(const MemoryBlock &block, std::vector<NodeTypeIndex> &life_list, bool child) { | |||||
| for (auto &node : block.NodeTypeIndexList()) { | |||||
| life_list.emplace_back(node); | |||||
| } | |||||
| if (child) { | |||||
| for (auto child_block : block.ChildBlockList()) { | |||||
| if (child_block == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (block.stream_id_ != child_block->stream_id_ || !block.same_stream_ || !child_block->same_stream_) { | |||||
| life_list.clear(); | |||||
| return; | |||||
| } | |||||
| GetLifeList(*child_block, life_list, child); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool CrossLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { | |||||
| if ((left.node == nullptr) || (right.node == nullptr)) { | |||||
| return true; | |||||
| } | |||||
| auto left_node_op_desc = left.node->GetOpDesc(); | |||||
| auto right_node_op_desc = right.node->GetOpDesc(); | |||||
| if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)) { | |||||
| if (left_node_op_desc->GetId() < right_node_op_desc->GetId()) { | |||||
| if (left.life_time_end >= static_cast<size_t>(right_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| } else if (left_node_op_desc->GetId() == right_node_op_desc->GetId()) { | |||||
| return true; | |||||
| } else { | |||||
| if (right.life_time_end >= static_cast<size_t>(left_node_op_desc->GetId())) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// When child block's life time are not cross with parent block, they can be reused(only same stream). | |||||
| /// |-----------------------------parent block---------------------| | |||||
| /// |------child block1--------------||------child block2------| | |||||
| /// |--child block1-1-| | |||||
| /// | |||||
| bool CanIntervalLifeReuse(MemoryBlock &parent_block, MemoryBlock &child_block) { | |||||
| // judge by interval life time, only same stream can be judged by interval life time | |||||
| if (parent_block.stream_id_ != child_block.stream_id_ || !parent_block.same_stream_ || !child_block.same_stream_ | |||||
| || parent_block.NodeTypeIndexList().empty() || child_block.NodeTypeIndexList().empty()) { | |||||
| return false; | |||||
| } | |||||
| // quick judge by front and back node | |||||
| if (CrossLifeTime(parent_block.NodeTypeIndexList().front(), child_block.NodeTypeIndexList().front())) { | |||||
| return false; | |||||
| } | |||||
| if (CrossLifeTime(parent_block.NodeTypeIndexList().back(), child_block.NodeTypeIndexList().back())) { | |||||
| return false; | |||||
| } | |||||
| std::vector<NodeTypeIndex> life_list; | |||||
| GetLifeList(parent_block, life_list, false); | |||||
| GetLifeList(child_block, life_list, true); | |||||
| if (life_list.empty()) { | |||||
| return false; | |||||
| } | |||||
| std::sort(life_list.begin(), life_list.end(), CompareLifeTime); | |||||
| size_t pre_life_end = 0; | |||||
| for (auto &node : life_list) { | |||||
| auto node_op_desc = node.node->GetOpDesc(); | |||||
| if (node_op_desc != nullptr && pre_life_end >= static_cast<size_t>(node_op_desc->GetId())) { | |||||
| // life time cross | |||||
| return false; | |||||
| } | |||||
| pre_life_end = node.life_time_end; | |||||
| } | |||||
| GELOGI("Block size[%zu, %zu] life time are not cross.", parent_block.Size(), child_block.Size()); | |||||
| return true; | |||||
| } | |||||
| void MemoryBlock::SetHeadOffset(size_t offset) { | void MemoryBlock::SetHeadOffset(size_t offset) { | ||||
| head_offset_ = offset; | head_offset_ = offset; | ||||
| size_t child_offset = head_offset_; | size_t child_offset = head_offset_; | ||||
| @@ -125,20 +217,12 @@ size_t MemoryBlock::AlignSize() const { | |||||
| return align_block_size; | return align_block_size; | ||||
| } | } | ||||
| bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { | |||||
| if (node_type_index_list_.empty()) { | |||||
| bool MemoryBlock::IsSameBatchLabel() { | |||||
| // only same batch label can reuse | |||||
| if (batch_label_.empty() || node_type_index_list_.empty()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto node_op_desc = node_type_index_list_[0].node->GetOpDesc(); | |||||
| if (node_op_desc == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label); | |||||
| if (first_batch_label.empty()) { | |||||
| return false; | |||||
| } | |||||
| bool all_same_label = true; | bool all_same_label = true; | ||||
| for (size_t index = 1; index < node_type_index_list_.size(); ++index) { | for (size_t index = 1; index < node_type_index_list_.size(); ++index) { | ||||
| if (node_type_index_list_[index].node == nullptr) { | if (node_type_index_list_[index].node == nullptr) { | ||||
| @@ -147,8 +231,9 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { | |||||
| std::string batch_label; | std::string batch_label; | ||||
| auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); | auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); | ||||
| // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter | |||||
| (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| if (first_batch_label != batch_label) { | |||||
| if (batch_label_ != batch_label) { | |||||
| all_same_label = false; | all_same_label = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -197,7 +282,7 @@ void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLi | |||||
| } | } | ||||
| void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { | void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { | ||||
| if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { | |||||
| if (CanNotLifeReuse(this) || CanNotLifeReuse(block) || (batch_label_ != block->batch_label_)) { | |||||
| return; | return; | ||||
| } | } | ||||
| if (block->continuous_block_) { | if (block->continuous_block_) { | ||||
| @@ -207,16 +292,27 @@ void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_ | |||||
| MemoryBlock *parent = nullptr; | MemoryBlock *parent = nullptr; | ||||
| MemoryBlock *child = nullptr; | MemoryBlock *child = nullptr; | ||||
| // merge small block to large block | // merge small block to large block | ||||
| if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { | |||||
| if ((child_offset_ + block->AlignSize()) <= AlignSize()) { | |||||
| parent = this; | |||||
| child = block; | |||||
| } else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) { | |||||
| parent = block; | |||||
| child = this; | |||||
| // noalign size 802816 + 802816 = 1605632 can reuse | |||||
| // after 32 align size 802848 + 802848 > 1605664 can't reuse | |||||
| // after 512 align size 803328 + 803328 > 1606144 can't reuse | |||||
| // so 803328 + 803328 = 1606144 + 512 can reuse | |||||
| if ((child_offset_ + block->AlignSize()) <= (AlignSize() + MEM_ALIGN_SIZE)) { | |||||
| parent = this; | |||||
| child = block; | |||||
| } else if ((block->child_offset_ + AlignSize()) <= (block->AlignSize() + MEM_ALIGN_SIZE)) { | |||||
| parent = block; | |||||
| child = this; | |||||
| } | |||||
| if ((parent != nullptr) && (child != nullptr)) { | |||||
| // Different streams must use stream dependency to judge the life cycle | |||||
| // In case same stream if it has child block, can judge all the child block's life time in CanIntervalLifeReuse | |||||
| bool can_block_life_reuse = (child->child_blocks_.empty() | |||||
| && (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd())); | |||||
| if (!can_block_life_reuse && !CanIntervalLifeReuse(*parent, *child)) { | |||||
| return; | |||||
| } | } | ||||
| } | |||||
| if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { | |||||
| parent->child_blocks_.emplace_back(child); | parent->child_blocks_.emplace_back(child); | ||||
| parent->child_offset_ += child->AlignSize(); | parent->child_offset_ += child->AlignSize(); | ||||
| child->deleted_block_ = true; | child->deleted_block_ = true; | ||||
| @@ -261,6 +357,7 @@ size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &tota | |||||
| void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, | void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, | ||||
| std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { | std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { | ||||
| GE_CHECK_NOTNULL_EXEC(node, return); | GE_CHECK_NOTNULL_EXEC(node, return); | ||||
| GE_CHECK_NOTNULL_EXEC(org_node, return); | |||||
| auto node_desc = node->GetOpDesc(); | auto node_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL_EXEC(node_desc, return); | GE_CHECK_NOTNULL_EXEC(node_desc, return); | ||||
| auto node_id = node_desc->GetId(); | auto node_id = node_desc->GetId(); | ||||
| @@ -415,12 +512,60 @@ BlockMemAssigner::~BlockMemAssigner() { | |||||
| } | } | ||||
| } | } | ||||
| void GetMaxBatchAllMemorySize(std::map<std::string, vector<int64_t>> &batch_all_memory_size, | |||||
| std::map<std::string, int64_t> batch_total_size, vector<int64_t> &all_memory_size, | |||||
| std::string &max_batch_label) { | |||||
| // use max batch all memory size for reuse range | |||||
| int64_t max_batch_size = 0; | |||||
| for (const auto &it : batch_total_size) { | |||||
| GELOGI("Batch[%s] total memory size[%ld]", it.first.c_str(), it.second); | |||||
| // no batch label | |||||
| if (it.first.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (it.second > max_batch_size) { | |||||
| max_batch_size = it.second; | |||||
| max_batch_label = it.first; | |||||
| } | |||||
| } | |||||
| GELOGI("Max batch[%s] total memory size[%ld]", max_batch_label.c_str(), max_batch_size); | |||||
| for (const auto &it : batch_all_memory_size) { | |||||
| if (it.first.empty() || (it.first == max_batch_label)) { | |||||
| all_memory_size.insert(all_memory_size.end(), it.second.begin(), it.second.end()); | |||||
| } | |||||
| } | |||||
| // all_memory_size can't be empty | |||||
| if (all_memory_size.empty()) { | |||||
| all_memory_size.emplace_back(MEM_ALIGN_SIZE); | |||||
| } | |||||
| sort(all_memory_size.begin(), all_memory_size.end()); | |||||
| GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); | |||||
| for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | |||||
| if (*iter == 0) { | |||||
| iter = all_memory_size.erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | ||||
| vector<int64_t> temp; | vector<int64_t> temp; | ||||
| std::map<std::string, vector<int64_t>> batch_all_memory_size; | |||||
| std::map<std::string, int64_t> batch_total_size; | |||||
| for (const NodePtr &n : compute_graph_->GetAllNodes()) { | for (const NodePtr &n : compute_graph_->GetAllNodes()) { | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | ||||
| if (CheckIsZeroMemNodeType(node_op_desc->GetType())) { | |||||
| continue; | |||||
| } | |||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (node_op_desc->GetType() == ATOMICADDRCLEAN) { | if (node_op_desc->GetType() == ATOMICADDRCLEAN) { | ||||
| atomic_addr_clean_id_ = node_op_desc->GetId(); | atomic_addr_clean_id_ = node_op_desc->GetId(); | ||||
| } | } | ||||
| @@ -434,9 +579,14 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||||
| if (!reuse_input) { | if (!reuse_input) { | ||||
| int64_t size = 0; | int64_t size = 0; | ||||
| GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); | GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); | ||||
| if (anchor_to_symbol_.empty()) { | |||||
| all_memory_size.emplace_back(size); | |||||
| batch_all_memory_size[batch_label].emplace_back(size); | |||||
| if (batch_total_size.find(batch_label) == batch_total_size.end()) { | |||||
| batch_total_size[batch_label] = size; | |||||
| } else { | } else { | ||||
| batch_total_size[batch_label] += size; | |||||
| } | |||||
| if (!anchor_to_symbol_.empty()) { | |||||
| auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); | auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); | ||||
| if (iter1 == anchor_to_symbol_.end()) { | if (iter1 == anchor_to_symbol_.end()) { | ||||
| continue; | continue; | ||||
| @@ -452,23 +602,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||||
| } | } | ||||
| } | } | ||||
| temp.clear(); | temp.clear(); | ||||
| GetNodeWorkSpaceSize(n, temp); | |||||
| all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); | |||||
| } | |||||
| for (const auto &pair : symbol_size_) { | |||||
| all_memory_size.emplace_back(pair.second); | |||||
| } | |||||
| sort(all_memory_size.begin(), all_memory_size.end()); | |||||
| GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); | |||||
| for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | |||||
| if (*iter == 0) { | |||||
| iter = all_memory_size.erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| GetNodeWorkSpaceSize(n, temp, batch_total_size[batch_label]); | |||||
| batch_all_memory_size[batch_label].insert(batch_all_memory_size[batch_label].end(), temp.begin(), temp.end()); | |||||
| } | } | ||||
| GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); | |||||
| GetMaxBatchAllMemorySize(batch_all_memory_size, batch_total_size, all_memory_size, max_batch_label_); | |||||
| InitReuseFlag(); | InitReuseFlag(); | ||||
| PrintSymbolMap(); | PrintSymbolMap(); | ||||
| } | } | ||||
| @@ -529,16 +667,6 @@ bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const Me | |||||
| bool can_reuse = false; | bool can_reuse = false; | ||||
| if (reusable_block.Size() == block_size) { | if (reusable_block.Size() == block_size) { | ||||
| can_reuse = true; | can_reuse = true; | ||||
| } else { | |||||
| string key = std::to_string(reusable_block.Size()); | |||||
| key += "_" + std::to_string(reusable_block.stream_id_); | |||||
| key += "_" + std::to_string(reusable_block.memory_type_); | |||||
| auto it = reusable_block_counts.find(key); | |||||
| GE_IF_BOOL_EXEC((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && | |||||
| (reusable_block.Size() > block_size), | |||||
| can_reuse = true; | |||||
| GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", | |||||
| reusable_block.Size(), block_size);); | |||||
| } | } | ||||
| return can_reuse; | return can_reuse; | ||||
| } | } | ||||
| @@ -860,34 +988,35 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | ||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty() || (batch_label == max_batch_label_)) { | |||||
| size_t align_size = real_size; | |||||
| AlignMemOffset(align_size); | |||||
| theory_memory_size_ += align_size; | |||||
| if (theory_memory_size_ > theory_min_memory_size_) { | |||||
| theory_min_memory_size_ = theory_memory_size_; | |||||
| } | |||||
| } | |||||
| bool is_reuse_memory = false; | bool is_reuse_memory = false; | ||||
| string ge_disable_reuse_mem_env = "0"; | |||||
| (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); | |||||
| if (ge_disable_reuse_mem_env != "1") { | |||||
| if (ge_disable_reuse_mem_env_ != "1") { | |||||
| bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : | bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : | ||||
| !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | ||||
| is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && | is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && | ||||
| !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; | !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; | ||||
| auto stream_id = node_op_desc->GetStreamId(); | |||||
| if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) { | |||||
| for (auto it = reusable_blocks_[memory_type][stream_id].begin(); | |||||
| it != reusable_blocks_[memory_type][stream_id].end(); ++it) { | |||||
| bool do_reuse = is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty(); | |||||
| if (do_reuse) { | |||||
| auto stream_id = node_op_desc->GetStreamId(); | |||||
| for (auto it = reusable_blocks_[memory_type][stream_id].rbegin(); | |||||
| it != reusable_blocks_[memory_type][stream_id].rend(); ++it) { | |||||
| MemoryBlock *reusable_block = *it; | MemoryBlock *reusable_block = *it; | ||||
| if (!IsPostReuse(reusable_block)) { | if (!IsPostReuse(reusable_block)) { | ||||
| reusable_block->reuse_mem_ = false; | reusable_block->reuse_mem_ = false; | ||||
| GELOGI("Unreusable block."); | GELOGI("Unreusable block."); | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::string batch_label; | |||||
| if (reusable_block->IsSameLabel(batch_label)) { | |||||
| std::string op_label; | |||||
| (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label); | |||||
| if (batch_label != op_label) { | |||||
| GELOGI("label diff, op name %s", node_op_desc->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| GE_IF_BOOL_EXEC(reusable_block->batch_label_ != batch_label, continue); | |||||
| // A node can reuse blocks of the same stream and preorder streams | // A node can reuse blocks of the same stream and preorder streams | ||||
| if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | ||||
| @@ -901,7 +1030,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| reusable_block->continuous_block_ = continuous; | reusable_block->continuous_block_ = continuous; | ||||
| reusable_block->ref_count_++; | reusable_block->ref_count_++; | ||||
| ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); | ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); | ||||
| reusable_blocks_[memory_type][stream_id].erase(it); | |||||
| reusable_blocks_[memory_type][stream_id].erase((++it).base()); | |||||
| return reusable_block; | return reusable_block; | ||||
| } | } | ||||
| } | } | ||||
| @@ -914,10 +1043,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
| // Data and netoutput need zero copy block | // Data and netoutput need zero copy block | ||||
| block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); | block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); | ||||
| block->Init(real_size, mem_type, n, out_index, no_align_size); | |||||
| block->Init(real_size, mem_type, n, out_index, no_align_size, node_op_desc->GetStreamId()); | |||||
| block->stream_id_ = node_op_desc->GetStreamId(); | block->stream_id_ = node_op_desc->GetStreamId(); | ||||
| block->ref_count_++; | block->ref_count_++; | ||||
| block->continuous_block_ = continuous; | block->continuous_block_ = continuous; | ||||
| block->batch_label_ = batch_label; | |||||
| if (mem_type == kOutput) { | if (mem_type == kOutput) { | ||||
| auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | ||||
| if (iter != anchor_to_symbol_.end()) { | if (iter != anchor_to_symbol_.end()) { | ||||
| @@ -945,6 +1075,11 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (CheckIsZeroMemNodeType(n->GetType())) { | |||||
| zero_memory_list_.emplace_back(n, kOutput, index); | |||||
| continue; | |||||
| } | |||||
| int64_t size = 0; | int64_t size = 0; | ||||
| if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { | if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { | ||||
| GELOGI("Get size failed"); | GELOGI("Get size failed"); | ||||
| @@ -957,9 +1092,7 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| // only apply total size in first block | // only apply total size in first block | ||||
| if (index != 0) { | if (index != 0) { | ||||
| zero_memory_list_.emplace_back(n, kOutput, index); | zero_memory_list_.emplace_back(n, kOutput, index); | ||||
| } | |||||
| if (index == 0) { | |||||
| } else { | |||||
| NodeIndexIO node_index_io(n, index, kOut); | NodeIndexIO node_index_io(n, index, kOut); | ||||
| auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | ||||
| if (iter != anchor_to_symbol_.end()) { | if (iter != anchor_to_symbol_.end()) { | ||||
| @@ -972,6 +1105,10 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||||
| } | } | ||||
| } | } | ||||
| if (total_size == 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto block_size = GetBlockSize(total_size, ranges); | auto block_size = GetBlockSize(total_size, ranges); | ||||
| GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), | GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), | ||||
| total_size, block_size); | total_size, block_size); | ||||
| @@ -1119,15 +1256,28 @@ bool IsKnownSubgraphData(const NodePtr &node) { | |||||
| return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | ||||
| } | } | ||||
| void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory) { | |||||
| void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, | |||||
| bool same_stream) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); | ||||
| GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); | GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); | ||||
| GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | ||||
| --to_release->ref_count_; | --to_release->ref_count_; | ||||
| if (!same_stream) { | |||||
| to_release->same_stream_ = false; | |||||
| } | |||||
| if (to_release->ref_count_ == 0) { | if (to_release->ref_count_ == 0) { | ||||
| to_release->SetLifeTimeEnd(life_time_); | |||||
| reusable_memory.emplace_back(to_release); | |||||
| AddReusableBlockCount(*to_release, reusable_block_counts_); | |||||
| if (to_release->reuse_mem_ && !to_release->RealSizeList().empty()) { | |||||
| if (to_release->batch_label_.empty() || (to_release->batch_label_ == max_batch_label_)) { | |||||
| size_t align_size = to_release->RealSizeList().back(); | |||||
| AlignMemOffset(align_size); | |||||
| theory_memory_size_ -= align_size; | |||||
| } | |||||
| } | |||||
| if (to_release->same_stream_) { | |||||
| to_release->SetLifeTimeEnd(life_time_); | |||||
| reusable_memory.emplace_back(to_release); | |||||
| AddReusableBlockCount(*to_release, reusable_block_counts_); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1167,10 +1317,9 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_map<string, vec | |||||
| node_type_indexs.back().node->GetName().c_str()); | node_type_indexs.back().node->GetName().c_str()); | ||||
| if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && | if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && | ||||
| (node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx())) && | |||||
| (node->GetOpDesc()->GetStreamId() == block->stream_id_)) { | |||||
| ReleaseMemory(block, reusable_memory); | |||||
| if (block->ref_count_ == 0) { | |||||
| (node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx()))) { | |||||
| ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_)); | |||||
| if (block->ref_count_ == 0 && block->same_stream_) { | |||||
| SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); | SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1267,7 +1416,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector | |||||
| bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); | bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); | ||||
| if (!no_need_assign_memory) { | if (!no_need_assign_memory) { | ||||
| out_node_set_continuous_input = | out_node_set_continuous_input = | ||||
| IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, no_need_assign_memory, reset_zero_copy_flag); | |||||
| IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, | |||||
| no_need_assign_memory, reset_zero_copy_flag); | |||||
| GE_IF_BOOL_EXEC(!no_need_assign_memory, | GE_IF_BOOL_EXEC(!no_need_assign_memory, | ||||
| no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); | no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); | ||||
| } | } | ||||
| @@ -1328,7 +1478,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| iter->second[stream_id].clear(); | iter->second[stream_id].clear(); | ||||
| } | } | ||||
| vector<int64_t> temp; | vector<int64_t> temp; | ||||
| GetNodeWorkSpaceSize(n, temp); | |||||
| int64_t tatal_size = 0; | |||||
| GetNodeWorkSpaceSize(n, temp, tatal_size); | |||||
| vector<int64_t> workspace_bytes; | vector<int64_t> workspace_bytes; | ||||
| vector<int64_t> tvm_workspace_memory_type; | vector<int64_t> tvm_workspace_memory_type; | ||||
| bool has_tvm_workspace_mem_type_attr = | bool has_tvm_workspace_mem_type_attr = | ||||
| @@ -1349,7 +1500,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| bool workspace_skip_flag = false; | bool workspace_skip_flag = false; | ||||
| if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { | if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { | ||||
| GELOGI( | GELOGI( | ||||
| "fusion: node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", | |||||
| "fusion:node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", | |||||
| node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); | node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); | ||||
| workspace_skip_flag = true; | workspace_skip_flag = true; | ||||
| } | } | ||||
| @@ -1380,9 +1531,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| (void)mem_block; // Fix warning | (void)mem_block; // Fix warning | ||||
| } | } | ||||
| bool merge_dynamic_batch = false; | |||||
| GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks()); | |||||
| GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size())); | |||||
| GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), ReuseBlocksByLifeTime(ranges.size())); | |||||
| AssignContinuousBlocks(); | AssignContinuousBlocks(); | ||||
| ResizeMemoryBlocks(); | ResizeMemoryBlocks(); | ||||
| @@ -1402,92 +1551,19 @@ void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_f | |||||
| } | } | ||||
| } | } | ||||
| void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory) { | |||||
| void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory, | |||||
| int64_t &total_size) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); | ||||
| vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); | vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); | ||||
| GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); | GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); | ||||
| for (int64_t byte_size : workspace_byte_nums) { | for (int64_t byte_size : workspace_byte_nums) { | ||||
| workspace_memory.emplace_back(byte_size); | workspace_memory.emplace_back(byte_size); | ||||
| total_size += byte_size; | |||||
| GELOGD("push back size:%ld", byte_size); | GELOGD("push back size:%ld", byte_size); | ||||
| } | } | ||||
| } | } | ||||
| // descending order | |||||
| static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) { | |||||
| if (left == nullptr || right == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end()); | |||||
| if (left_max_size != left->RealSizeList().end()) { | |||||
| auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end()); | |||||
| if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &src) { | |||||
| for (size_t i = 0; i < dest.size(); ++i) { | |||||
| if (i >= src.size()) { | |||||
| return; | |||||
| } | |||||
| if (dest[i] != nullptr && src[i] != nullptr) { | |||||
| if (!dest[i]->reuse_mem_ || !src[i]->reuse_mem_) { | |||||
| GELOGD("Diff batch's workspace can't be reused, i: %zu, dest[i]: %s, stream: %ld, src[i]: %s, stream: %ld.", | |||||
| i, dest[i]->String().c_str(), dest[i]->stream_id_, src[i]->String().c_str(), src[i]->stream_id_); | |||||
| continue; | |||||
| } | |||||
| for (auto &symbol : src[i]->SymbolList()) { | |||||
| dest[i]->AddSymbol(symbol); | |||||
| } | |||||
| for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { | |||||
| dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], | |||||
| src[i]->RealSizeList()[j], | |||||
| src[i]->NoAlignSizeList()[j]); | |||||
| src[i]->deleted_block_ = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool BlockMemAssigner::MergeDynamicBatchBlocks() { | |||||
| bool merged = false; | |||||
| std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks; | |||||
| for (auto block : memory_blocks_) { | |||||
| if (block == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string batch_label; | |||||
| if (block->IsSameLabel(batch_label)) { | |||||
| dynamic_batch_blocks[batch_label].emplace_back(block); | |||||
| } | |||||
| } | |||||
| auto it = dynamic_batch_blocks.begin(); | |||||
| auto it_max = it; | |||||
| // find max block counts | |||||
| for (; it != dynamic_batch_blocks.end(); ++it) { | |||||
| if (it->second.size() > it_max->second.size()) { | |||||
| it_max = it; | |||||
| } | |||||
| std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize); | |||||
| } | |||||
| if (it_max != dynamic_batch_blocks.end()) { | |||||
| GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size()); | |||||
| } | |||||
| for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) { | |||||
| if (it != it_max) { | |||||
| GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str()); | |||||
| MergeBlocks(it_max->second, it->second); | |||||
| merged = true; | |||||
| } | |||||
| } | |||||
| return merged; | |||||
| } | |||||
| // asending order | // asending order | ||||
| static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { | static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { | ||||
| if (left == nullptr || right == nullptr) { | if (left == nullptr || right == nullptr) { | ||||
| @@ -1597,38 +1673,93 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { | |||||
| } | } | ||||
| } | } | ||||
| void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) { | |||||
| if (block.memory_type_ == RT_MEMORY_HBM) { | |||||
| if (block.first_continuous_block_) { | |||||
| mem_offset += MEM_ALIGN_SIZE; | |||||
| } | |||||
| block.Resize(); | |||||
| block.SetHeadOffset(mem_offset); | |||||
| mem_offset += block.Size(); | |||||
| block.SetTailOffset(mem_offset - 1); | |||||
| } else if (block.memory_type_ == RT_MEMORY_P2P_DDR) { | |||||
| if (block.first_continuous_block_) { | |||||
| p2p_mem_offset += MEM_ALIGN_SIZE; | |||||
| } | |||||
| block.Resize(); | |||||
| block.SetHeadOffset(p2p_mem_offset); | |||||
| p2p_mem_offset += block.Size(); | |||||
| block.SetTailOffset(p2p_mem_offset - 1); | |||||
| } | |||||
| } | |||||
| bool DynamicBatchBlockReuse(MemoryBlock &block) { | |||||
| return (block.IsSameBatchLabel() && block.reuse_mem_); | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup domi_omg | /// @ingroup domi_omg | ||||
| /// @brief traverse memory size, resize, calculate offset | |||||
| /// @brief get max batch memory size, others reuse this block memory | |||||
| /// @param [in&out] memory_blocks_ memory block, after calculating offset | /// @param [in&out] memory_blocks_ memory block, after calculating offset | ||||
| /// |-dynamic batch block batch1| | |||||
| /// |-dynamic batch block batch2----| | |||||
| /// |-dynamic batch block batch3--| | |||||
| /// | /// | ||||
| void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| for (auto &memory_block : memory_blocks_) { | |||||
| if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) { | |||||
| void BlockMemAssigner::ResizeDynamicBatchBlocks() { | |||||
| std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks; | |||||
| for (auto block : memory_blocks_) { | |||||
| if (block == nullptr) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (memory_block->memory_type_ == RT_MEMORY_HBM) { | |||||
| if (memory_block->first_continuous_block_) { | |||||
| mem_offset_ += MEM_ALIGN_SIZE; | |||||
| } | |||||
| // when memory is not reuseable, it can't be reused by different branch | |||||
| if (DynamicBatchBlockReuse(*block)) { | |||||
| dynamic_batch_blocks[block->batch_label_].emplace_back(block); | |||||
| } | |||||
| } | |||||
| memory_block->Resize(); | |||||
| memory_block->SetHeadOffset(mem_offset_); | |||||
| mem_offset_ += memory_block->Size(); | |||||
| memory_block->SetTailOffset(mem_offset_ - 1); | |||||
| } else if (memory_block->memory_type_ == RT_MEMORY_P2P_DDR) { | |||||
| if (memory_block->first_continuous_block_) { | |||||
| p2p_mem_offset_ += MEM_ALIGN_SIZE; | |||||
| size_t max_mem_offset = mem_offset_; | |||||
| size_t max_p2p_mem_offset = p2p_mem_offset_; | |||||
| for (auto &batch_blocks : dynamic_batch_blocks) { | |||||
| size_t mem_offset = mem_offset_; | |||||
| size_t p2p_mem_offset = p2p_mem_offset_; | |||||
| for (auto block : batch_blocks.second) { | |||||
| if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) { | |||||
| continue; | |||||
| } | } | ||||
| AddBlockMemOffset(mem_offset, p2p_mem_offset, *block); | |||||
| } | |||||
| if (mem_offset > max_mem_offset) { | |||||
| max_mem_offset = mem_offset; | |||||
| } | |||||
| if (p2p_mem_offset > max_p2p_mem_offset) { | |||||
| max_p2p_mem_offset = p2p_mem_offset; | |||||
| } | |||||
| GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset); | |||||
| } | |||||
| mem_offset_ = max_mem_offset; | |||||
| p2p_mem_offset_ = max_p2p_mem_offset; | |||||
| } | |||||
| memory_block->Resize(); | |||||
| memory_block->SetHeadOffset(p2p_mem_offset_); | |||||
| p2p_mem_offset_ += memory_block->Size(); | |||||
| memory_block->SetTailOffset(p2p_mem_offset_ - 1); | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief traverse memory size, resize, calculate offset | |||||
| /// @param [in&out] memory_blocks_ memory block, after calculating offset | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch1| |-zero copy block-| | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch2----||-zero copy block-| | |||||
| /// |-not dynamic batch block-||-dynamic batch block batch3--| |-zero copy block-| | |||||
| /// | |||||
| void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| for (auto &memory_block : memory_blocks_) { | |||||
| if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_ | |||||
| || DynamicBatchBlockReuse(*memory_block)) { | |||||
| continue; | |||||
| } | } | ||||
| AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block); | |||||
| } | } | ||||
| GELOGD("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu.", | |||||
| mem_offset_, p2p_mem_offset_); | |||||
| ResizeDynamicBatchBlocks(); | |||||
| GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu," | |||||
| "theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -1641,7 +1772,7 @@ void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| /// @return Status result | /// @return Status result | ||||
| /// | /// | ||||
| void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | ||||
| size_t real_size, size_t no_align_size, bool child_block) { | |||||
| size_t real_size, size_t no_align_size, int32_t child_block_level) { | |||||
| ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); | ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | ||||
| string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); | string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); | ||||
| @@ -1689,14 +1820,15 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, | |||||
| } | } | ||||
| op_desc->SetWorkspace(workspace_list); | op_desc->SetWorkspace(workspace_list); | ||||
| } | } | ||||
| GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" | |||||
| " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(), | |||||
| GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu] noalignsize[%zu] " | |||||
| "life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", graph_name.c_str(), | |||||
| op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), | op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), | ||||
| block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, block->reuse_mem_, | |||||
| block->continuous_block_, block->deleted_block_, node_type.ref_input); | |||||
| block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block_level, block->reuse_mem_, | |||||
| block->continuous_block_, block->is_zero_copy_, block->same_stream_, node_type.ref_input, | |||||
| block->batch_label_.c_str()); | |||||
| } | } | ||||
| void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { | |||||
| void SetBlockOpMemOffset(MemoryBlock *block, int32_t child_block_level) { | |||||
| if (block == nullptr) { | if (block == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1709,9 +1841,14 @@ void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { | |||||
| real_size = block->RealSizeList()[index]; | real_size = block->RealSizeList()[index]; | ||||
| no_align_size = block->NoAlignSizeList()[index]; | no_align_size = block->NoAlignSizeList()[index]; | ||||
| } | } | ||||
| SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block); | |||||
| SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block_level); | |||||
| index++; | index++; | ||||
| } | } | ||||
| child_block_level++; | |||||
| for (MemoryBlock *child_block : block->ChildBlockList()) { | |||||
| SetBlockOpMemOffset(child_block, child_block_level); | |||||
| } | |||||
| } | } | ||||
| void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | ||||
| @@ -1724,16 +1861,13 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| SetBlockOpMemOffset(memory_block, false); | |||||
| for (MemoryBlock *child_block : memory_block->ChildBlockList()) { | |||||
| SetBlockOpMemOffset(child_block, true); | |||||
| } | |||||
| SetBlockOpMemOffset(memory_block, 0); | |||||
| } | } | ||||
| if (!is_zero_copy) { | if (!is_zero_copy) { | ||||
| for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | ||||
| MemoryBlock block(0, 0); | MemoryBlock block(0, 0); | ||||
| SetOffsetSize(node_type_index, &block, 0, 0, false); | |||||
| SetOffsetSize(node_type_index, &block, 0, 0, 0); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -65,6 +65,7 @@ class MemoryBlock { | |||||
| stream_id_(stream_id), | stream_id_(stream_id), | ||||
| deleted_block_(false), | deleted_block_(false), | ||||
| reuse_mem_(reuse_mem), | reuse_mem_(reuse_mem), | ||||
| same_stream_(true), | |||||
| input_index_(0), | input_index_(0), | ||||
| continuous_block_(false), | continuous_block_(false), | ||||
| first_continuous_block_(false), | first_continuous_block_(false), | ||||
| @@ -85,10 +86,14 @@ class MemoryBlock { | |||||
| symbol_list_.clear(); | symbol_list_.clear(); | ||||
| } | } | ||||
| void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) { | |||||
| void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size, | |||||
| int64_t stream_id) { | |||||
| real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
| no_align_size_list_.emplace_back(no_align_size); | no_align_size_list_.emplace_back(no_align_size); | ||||
| node_type_index_list_.emplace_back(node, type, out_index, false); | node_type_index_list_.emplace_back(node, type, out_index, false); | ||||
| if (stream_id != stream_id_) { | |||||
| same_stream_ = false; | |||||
| } | |||||
| } | } | ||||
| size_t Size() const { return block_size_; } | size_t Size() const { return block_size_; } | ||||
| @@ -106,6 +111,12 @@ class MemoryBlock { | |||||
| node_type_index_list_.emplace_back(node_type_index); | node_type_index_list_.emplace_back(node_type_index); | ||||
| real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
| no_align_size_list_.emplace_back(no_align_size); | no_align_size_list_.emplace_back(no_align_size); | ||||
| if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) { | |||||
| auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId(); | |||||
| if (stream_id != stream_id_) { | |||||
| same_stream_ = false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void AddSymbol(const std::string &symbol) { | void AddSymbol(const std::string &symbol) { | ||||
| @@ -122,7 +133,7 @@ class MemoryBlock { | |||||
| std::string String(); | std::string String(); | ||||
| bool IsSameLabel(std::string &first_batch_label); | |||||
| bool IsSameBatchLabel(); | |||||
| void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); | void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); | ||||
| @@ -142,6 +153,7 @@ class MemoryBlock { | |||||
| int64_t stream_id_; | int64_t stream_id_; | ||||
| bool deleted_block_; | bool deleted_block_; | ||||
| bool reuse_mem_; | bool reuse_mem_; | ||||
| bool same_stream_; | |||||
| uint32_t input_index_; | uint32_t input_index_; | ||||
| bool continuous_block_; | bool continuous_block_; | ||||
| bool first_continuous_block_; | bool first_continuous_block_; | ||||
| @@ -149,6 +161,7 @@ class MemoryBlock { | |||||
| bool is_zero_copy_; | bool is_zero_copy_; | ||||
| std::map<int64_t, size_t> depend_stream_life_; | std::map<int64_t, size_t> depend_stream_life_; | ||||
| int64_t memory_type_; | int64_t memory_type_; | ||||
| std::string batch_label_; | |||||
| private: | private: | ||||
| size_t block_size_; | size_t block_size_; | ||||
| std::vector<size_t> real_size_list_; | std::vector<size_t> real_size_list_; | ||||
| @@ -209,7 +222,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size); | void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size); | ||||
| void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory); | |||||
| void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory, int64_t &total_size); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -353,7 +366,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// @return void | /// @return void | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory); | |||||
| void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, bool same_stream = true); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -379,11 +392,11 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| /// @brief Merge memory blocks between different batchs | |||||
| /// @brief Resize memory blocks for each batchs | |||||
| /// @return merge or not | /// @return merge or not | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| bool MergeDynamicBatchBlocks(); | |||||
| void ResizeDynamicBatchBlocks(); | |||||
| void AssignContinuousBlocks(); | void AssignContinuousBlocks(); | ||||
| @@ -436,6 +449,17 @@ class BlockMemAssigner : public MemAssigner { | |||||
| int64_t atomic_addr_clean_id_ = 0; | int64_t atomic_addr_clean_id_ = 0; | ||||
| size_t theory_min_memory_size_ = 0; | |||||
| size_t theory_memory_size_ = 0; | |||||
| std::string max_batch_label_; | |||||
| /// | |||||
| /// @ [stream1][nodeid] | |||||
| /// @[nodeid] [stream2][nodeid] | |||||
| /// @ [stream2][nodeid] | |||||
| /// | |||||
| DependStreamLife total_node_depend_stream_life_; | DependStreamLife total_node_depend_stream_life_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -419,7 +419,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), | GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), | ||||
| std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | ||||
| " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | ||||
| " requires continuous output. There may be conflict between the two. This node is not supported now."; | |||||
| " requires continuous output. There may be conflict between the two." + | |||||
| "This node is not supported now."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
| return PARAM_INVALID;); | return PARAM_INVALID;); | ||||
| @@ -429,7 +430,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| GE_IF_BOOL_EXEC(is_peer_reference, | GE_IF_BOOL_EXEC(is_peer_reference, | ||||
| std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | ||||
| " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | ||||
| " requires continuous output. There may be conflict between the two. This node is not supported now."; | |||||
| " requires continuous output. There may be conflict between the two." + | |||||
| "This node is not supported now."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
| return PARAM_INVALID;); | return PARAM_INVALID;); | ||||
| @@ -1646,9 +1648,9 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &node, const ve | |||||
| } | } | ||||
| string atomic_mem_size_str = ss.str(); | string atomic_mem_size_str = ss.str(); | ||||
| GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]", | |||||
| GELOGI("[IMAS]SetAtomicCleanAttr : Set %s atomic_node name[%s] output[0] offset to [%s] streamid[%ld] size[%s]", | |||||
| node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), | node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), | ||||
| atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId()); | |||||
| atomic_mem_start_str.c_str(), node->GetOpDesc()->GetStreamId(), atomic_mem_size_str.c_str()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -282,7 +282,7 @@ Status ModelBuilder::SetInputOutputDesc() { | |||||
| void ModelBuilder::AddNodeInputProperty() { | void ModelBuilder::AddNodeInputProperty() { | ||||
| for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | ||||
| auto node_op_desc = node->GetOpDesc(); | auto node_op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); | |||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return); | |||||
| vector<string> src_name_list; | vector<string> src_name_list; | ||||
| vector<int64_t> src_index_list; | vector<int64_t> src_index_list; | ||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| @@ -309,10 +309,10 @@ void ModelBuilder::AddNodeInputProperty() { | |||||
| for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | ||||
| auto node_op_desc = node->GetOpDesc(); | auto node_op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); | |||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return); | |||||
| GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); | GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); | ||||
| auto out_control_anchor = node->GetOutControlAnchor(); | auto out_control_anchor = node->GetOutControlAnchor(); | ||||
| GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return ); | |||||
| GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return); | |||||
| vector<string> dst_name_list; | vector<string> dst_name_list; | ||||
| vector<int64_t> dst_index_list; | vector<int64_t> dst_index_list; | ||||
| string dst_name_temp; | string dst_name_temp; | ||||
| @@ -330,7 +330,7 @@ void ModelBuilder::AddNodeInputProperty() { | |||||
| dst_name_temp = ""; | dst_name_temp = ""; | ||||
| int64_t dst_index = kWrongIndex; // assign an impossible value to dst_index. | int64_t dst_index = kWrongIndex; // assign an impossible value to dst_index. | ||||
| for (const auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | ||||
| GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return ); | |||||
| GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return); | |||||
| ge::NodePtr dst_node = in_data_anchor->GetOwnerNode(); | ge::NodePtr dst_node = in_data_anchor->GetOwnerNode(); | ||||
| dst_name_temp = dst_name_temp.empty() ? dst_node->GetName() : dst_name_temp + ":" + dst_node->GetName(); | dst_name_temp = dst_name_temp.empty() ? dst_node->GetName() : dst_name_temp + ":" + dst_node->GetName(); | ||||
| dst_index = in_data_anchor->GetIdx(); | dst_index = in_data_anchor->GetIdx(); | ||||
| @@ -49,7 +49,8 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string & | |||||
| } | } | ||||
| bool IsHcclOp(const string &op_type) { | bool IsHcclOp(const string &op_type) { | ||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, | |||||
| ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||||
| return hccl_op_types.find(op_type) != hccl_op_types.end(); | return hccl_op_types.find(op_type) != hccl_op_types.end(); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -38,7 +38,7 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (ge::NodePtr &node : subgraph->GetDirectNode()) { | for (ge::NodePtr &node : subgraph->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); | |||||
| GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return); | |||||
| if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | ||||
| node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | ||||
| node_size++; | node_size++; | ||||
| @@ -49,8 +49,6 @@ const char *const kIsLastNode = "is_last_node"; | |||||
| const char *const kIsInputVar = "INPUT_IS_VAR"; | const char *const kIsInputVar = "INPUT_IS_VAR"; | ||||
| const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | ||||
| const char *const kProfilingMode = "PROFILING_MODE"; | const char *const kProfilingMode = "PROFILING_MODE"; | ||||
| const char *const kProfilingFpPoint = "FP_POINT"; | |||||
| const char *const kProfilingBpPoint = "BP_POINT"; | |||||
| const uint32_t kProfilingArStep = 2; | const uint32_t kProfilingArStep = 2; | ||||
| const uint64_t kProfilingFpStartLogid = 1; | const uint64_t kProfilingFpStartLogid = 1; | ||||
| const uint64_t kProfilingBpEndLogid = 2; | const uint64_t kProfilingBpEndLogid = 2; | ||||
| @@ -810,35 +808,23 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint | |||||
| vector<uint32_t> &all_reduce_nodes, std::string &fp_point_str, | vector<uint32_t> &all_reduce_nodes, std::string &fp_point_str, | ||||
| std::string &bp_point_str) const { | std::string &bp_point_str) const { | ||||
| if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_FPPONIT_OPTIONS, fp_point_str) == SUCCESS && | |||||
| ge::GetContext().GetOption(OPTION_EXEC_PROFILING_BPPONIT_OPTIONS, bp_point_str) == SUCCESS && | |||||
| !fp_point_str.empty() && !bp_point_str.empty()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| ProfilingManager::Instance().GetFpBpPoint(fp_point_str, bp_point_str); | |||||
| Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
| const char *fp_point = std::getenv(kProfilingFpPoint); | |||||
| if (fp_point == nullptr) { | |||||
| if (fp_point_str.empty()) { | |||||
| ret = AutoFindFpOpIndex(graph, profiling_point); | ret = AutoFindFpOpIndex(graph, profiling_point); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); | GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } else { | |||||
| fp_point_str = string(fp_point); | |||||
| GELOGI("Get fp_point_str from env %s", fp_point_str.c_str()); | |||||
| } | } | ||||
| const char *bp_point = std::getenv(kProfilingBpPoint); | |||||
| if (bp_point == nullptr) { | |||||
| if (bp_point_str.empty()) { | |||||
| ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); | ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); | GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } else { | |||||
| bp_point_str = string(bp_point); | |||||
| GELOGI("Get bp_point_str from env %s", bp_point_str.c_str()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -86,7 +86,6 @@ | |||||
| | Node | | | Node | | ||||
| +------------+ | +------------+ | ||||
| *******************************************************************************/ | *******************************************************************************/ | ||||
| namespace ge { | namespace ge { | ||||
| class CaseOpLabelMaker : public LabelMaker { | class CaseOpLabelMaker : public LabelMaker { | ||||
| public: | public: | ||||
| @@ -70,7 +70,6 @@ | |||||
| | Node | | | Node | | ||||
| +------------+ | +------------+ | ||||
| *******************************************************************************/ | *******************************************************************************/ | ||||
| namespace ge { | namespace ge { | ||||
| class IfOpLabelMaker : public LabelMaker { | class IfOpLabelMaker : public LabelMaker { | ||||
| public: | public: | ||||
| @@ -54,7 +54,6 @@ | |||||
| | c | | | c | | ||||
| +---------------+ | +---------------+ | ||||
| *******************************************************************************/ | *******************************************************************************/ | ||||
| namespace ge { | namespace ge { | ||||
| class PartitionedCallLabelMaker : public LabelMaker { | class PartitionedCallLabelMaker : public LabelMaker { | ||||
| public: | public: | ||||
| @@ -70,7 +70,6 @@ | |||||
| | Node | | | Node | | ||||
| +------------+ | +------------+ | ||||
| *******************************************************************************/ | *******************************************************************************/ | ||||
| namespace ge { | namespace ge { | ||||
| class WhileOpLabelMaker : public LabelMaker { | class WhileOpLabelMaker : public LabelMaker { | ||||
| public: | public: | ||||
| @@ -283,7 +283,8 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn | |||||
| std::vector<GeTensorDesc> &output_desc) { | std::vector<GeTensorDesc> &output_desc) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc); | |||||
| Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, | |||||
| input_data, input_desc, output_data, output_desc); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Execute model failed, model_id:%u.", model_id); | GELOGE(ret, "Execute model failed, model_id:%u.", model_id); | ||||
| return ret; | return ret; | ||||
| @@ -919,11 +919,11 @@ Status DataDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> exceptio | |||||
| ReplaceStringElem(op_name); | ReplaceStringElem(op_name); | ||||
| ReplaceStringElem(op_type); | ReplaceStringElem(op_type); | ||||
| string dump_file_path = | string dump_file_path = | ||||
| "./" + op_type + "." + op_name + "." + to_string(op_desc_info.task_id) + "." + to_string(now_time); | |||||
| "./" + op_type + "." + op_name + "." + std::to_string(op_desc_info.task_id) + "." + std::to_string(now_time); | |||||
| GELOGI("The exception dump file path is %s", dump_file_path.c_str()); | GELOGI("The exception dump file path is %s", dump_file_path.c_str()); | ||||
| uint64_t proto_size = dump_data.ByteSizeLong(); | uint64_t proto_size = dump_data.ByteSizeLong(); | ||||
| unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | |||||
| std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | |||||
| bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | ||||
| if (!ret || proto_size == 0) { | if (!ret || proto_size == 0) { | ||||
| GELOGE(PARAM_INVALID, "Dump data proto serialize failed"); | GELOGE(PARAM_INVALID, "Dump data proto serialize failed"); | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
| #include <cce/dnn.h> | |||||
| #include <graph/utils/node_utils.h> | #include <graph/utils/node_utils.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | #include <map> | ||||
| @@ -84,7 +83,7 @@ const uint32_t kAddrLen = sizeof(void *); | |||||
| const int kDecimal = 10; | const int kDecimal = 10; | ||||
| const int kBytes = 8; | const int kBytes = 8; | ||||
| const uint32_t kDataMemAlignSizeCompare = 64; | const uint32_t kDataMemAlignSizeCompare = 64; | ||||
| const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; | |||||
| const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024 | |||||
| const uint32_t kDumpFlagOfL1Fusion = 0; | const uint32_t kDumpFlagOfL1Fusion = 0; | ||||
| const char *const kDefaultBatchLable = "Batch_default"; | const char *const kDefaultBatchLable = "Batch_default"; | ||||
| const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; | ||||
| @@ -331,8 +330,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | ||||
| return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | ||||
| } | } | ||||
| GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| mem_base_, data_size); | |||||
| GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", | |||||
| runtime_param_.graph_id, mem_base_, data_size); | |||||
| if (!is_inner_weight_base_) { | if (!is_inner_weight_base_) { | ||||
| weights_mem_base_ = mem_base_; | weights_mem_base_ = mem_base_; | ||||
| @@ -713,7 +712,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
| // collect profiling for ge | // collect profiling for ge | ||||
| auto &profiling_manager = ProfilingManager::Instance(); | auto &profiling_manager = ProfilingManager::Instance(); | ||||
| if (profiling_manager.ProfilingModelLoadOn()) { | if (profiling_manager.ProfilingModelLoadOn()) { | ||||
| Status p_ret = ReportProfilingData(!profiling_manager.IsAclApiMode()); | |||||
| Status p_ret = ReportProfilingData(); | |||||
| if (p_ret != SUCCESS) { | if (p_ret != SUCCESS) { | ||||
| GELOGE(p_ret, "Report profiling data failed."); | GELOGE(p_ret, "Report profiling data failed."); | ||||
| return p_ret; | return p_ret; | ||||
| @@ -724,14 +723,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status DavinciModel::ReportProfilingData(bool check_device) { | |||||
| Status DavinciModel::ReportProfilingData() { | |||||
| std::vector<ComputeGraphDescInfo> compute_graph_desc_info; | std::vector<ComputeGraphDescInfo> compute_graph_desc_info; | ||||
| Status ret = GetComputeGraphInfo(compute_graph_desc_info); | Status ret = GetComputeGraphInfo(compute_graph_desc_info); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "GetComputeGraphInfo failed."); | GELOGE(ret, "GetComputeGraphInfo failed."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info, check_device); | |||||
| ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info); | |||||
| GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); | GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); | ||||
| op_list_.clear(); | op_list_.clear(); | ||||
| @@ -1544,7 +1543,8 @@ Status DavinciModel::LoadWithQueue() { | |||||
| } | } | ||||
| if (output_queue_ids_.size() != new_output_data_info_.size()) { | if (output_queue_ids_.size() != new_output_data_info_.size()) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, | |||||
| "Output queue ids not match model: output_queue=%zu output_data=%zu", | |||||
| output_queue_ids_.size(), new_output_data_info_.size()); | output_queue_ids_.size(), new_output_data_info_.size()); | ||||
| return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; | ||||
| } | } | ||||
| @@ -2186,8 +2186,9 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data | |||||
| const std::vector<DataBuffer> &blobs = input_data.blobs; | const std::vector<DataBuffer> &blobs = input_data.blobs; | ||||
| for (const auto &data : new_input_data_info_) { | for (const auto &data : new_input_data_info_) { | ||||
| if (data.first >= blobs.size()) { | if (data.first >= blobs.size()) { | ||||
| GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), | |||||
| new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first); | |||||
| GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld, op_name(%s)", blobs.size(), | |||||
| new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first, | |||||
| data.second.GetOpName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -2198,13 +2199,14 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data | |||||
| } | } | ||||
| uint64_t data_size = data.second.GetDataSize(); | uint64_t data_size = data.second.GetDataSize(); | ||||
| GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID, | ||||
| "input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length, | |||||
| data_size); | |||||
| "input data size(%lu) does not match model required size(%lu), op_name(%s) ret failed.", | |||||
| data_buf.length, data_size, data.second.GetOpName().c_str()); | |||||
| void *mem_addr = data.second.GetBasicAddr(); | void *mem_addr = data.second.GetBasicAddr(); | ||||
| void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data)); | void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data)); | ||||
| uint64_t data_buf_length = data_buf.length; | uint64_t data_buf_length = data_buf.length; | ||||
| GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", | |||||
| runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length); | |||||
| GELOGI("CopyPlainData memcpy graph_%u type[F] input[%s] rank[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", | |||||
| runtime_param_.graph_id, data.second.GetOpName().c_str(), data.first, mem_addr, data_buf_addr, data_size, | |||||
| data_buf_length); | |||||
| GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); | GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); | ||||
| } | } | ||||
| @@ -2248,10 +2250,8 @@ inline int64_t SumSize(const vector<int64_t> &size_list) { | |||||
| Status DavinciModel::SinkModelProfile() { | Status DavinciModel::SinkModelProfile() { | ||||
| // profiling plugin must be registered | // profiling plugin must be registered | ||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); | |||||
| Msprof::Engine::ReporterData reporter_data{}; | |||||
| auto &prof_mgr = ProfilingManager::Instance(); | |||||
| ReporterData reporter_data{}; | |||||
| // report model data tag name | // report model data tag name | ||||
| std::string tag_name; | std::string tag_name; | ||||
| tag_name.append("model_load_info_").append(std::to_string(this->Id())); | tag_name.append("model_load_info_").append(std::to_string(this->Id())); | ||||
| @@ -2269,32 +2269,32 @@ Status DavinciModel::SinkModelProfile() { | |||||
| reporter_data.deviceId = device_id_; | reporter_data.deviceId = device_id_; | ||||
| reporter_data.data = (unsigned char *)&name_len; | reporter_data.data = (unsigned char *)&name_len; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)name.c_str(); | reporter_data.data = (unsigned char *)name.c_str(); | ||||
| reporter_data.dataLen = name.size(); | reporter_data.dataLen = name.size(); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| uint32_t model_id = this->Id(); | uint32_t model_id = this->Id(); | ||||
| reporter_data.data = (unsigned char *)&model_id; | reporter_data.data = (unsigned char *)&model_id; | ||||
| reporter_data.dataLen = sizeof(uint32_t); | reporter_data.dataLen = sizeof(uint32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // Load Start/End Time | // Load Start/End Time | ||||
| int64_t start_time = this->GetLoadBeginTime(); | int64_t start_time = this->GetLoadBeginTime(); | ||||
| reporter_data.data = (unsigned char *)&start_time; | reporter_data.data = (unsigned char *)&start_time; | ||||
| reporter_data.dataLen = sizeof(int64_t); | reporter_data.dataLen = sizeof(int64_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| int64_t end_time = this->GetLoadEndTime(); | int64_t end_time = this->GetLoadEndTime(); | ||||
| reporter_data.data = (unsigned char *)&end_time; | reporter_data.data = (unsigned char *)&end_time; | ||||
| reporter_data.dataLen = sizeof(int64_t); | reporter_data.dataLen = sizeof(int64_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| int32_t task_num = task_list_.size(); | int32_t task_num = task_list_.size(); | ||||
| std::multimap<uint32_t, uint32_t> op_id_map; | std::multimap<uint32_t, uint32_t> op_id_map; | ||||
| @@ -2308,6 +2308,7 @@ Status DavinciModel::SinkModelProfile() { | |||||
| uint32_t op_num = fusion_op_info->original_op_names.size(); | uint32_t op_num = fusion_op_info->original_op_names.size(); | ||||
| uint32_t task_id = task->GetTaskID(); | uint32_t task_id = task->GetTaskID(); | ||||
| if (op_num > 0) { | if (op_num > 0) { | ||||
| GELOGI("task.id = %u, opNum = %u", task_id, op_num); | |||||
| op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id)); | op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -2350,39 +2351,39 @@ Status DavinciModel::SinkModelProfile() { | |||||
| int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); | int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); | ||||
| reporter_data.data = (unsigned char *)&fusion_op_name_len; | reporter_data.data = (unsigned char *)&fusion_op_name_len; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)fusion_op_name.c_str(); | reporter_data.data = (unsigned char *)fusion_op_name.c_str(); | ||||
| reporter_data.dataLen = fusion_op_name_len; | reporter_data.dataLen = fusion_op_name_len; | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // original op name before fusion | // original op name before fusion | ||||
| reporter_data.data = (unsigned char *)&op_num; | reporter_data.data = (unsigned char *)&op_num; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| for (uint32_t k = 0; k < op_num; k++) { | for (uint32_t k = 0; k < op_num; k++) { | ||||
| std::string op_name = fusion_op_info->original_op_names[k]; | std::string op_name = fusion_op_info->original_op_names[k]; | ||||
| int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); | int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); | ||||
| reporter_data.data = (unsigned char *)&op_name_len; | reporter_data.data = (unsigned char *)&op_name_len; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)op_name.c_str(); | reporter_data.data = (unsigned char *)op_name.c_str(); | ||||
| reporter_data.dataLen = op_name_len; | reporter_data.dataLen = op_name_len; | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| } | } | ||||
| // stream id info | // stream id info | ||||
| uint32_t streamId = task->GetStreamId(); | uint32_t streamId = task->GetStreamId(); | ||||
| reporter_data.data = (unsigned char *)&streamId; | reporter_data.data = (unsigned char *)&streamId; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // memory info | // memory info | ||||
| struct memoryInfo memory_info; | struct memoryInfo memory_info; | ||||
| @@ -2398,22 +2399,22 @@ Status DavinciModel::SinkModelProfile() { | |||||
| memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size; | memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size; | ||||
| reporter_data.data = (unsigned char *)&memory_info; | reporter_data.data = (unsigned char *)&memory_info; | ||||
| reporter_data.dataLen = sizeof(struct memoryInfo); | reporter_data.dataLen = sizeof(struct memoryInfo); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // task info | // task info | ||||
| reporter_data.data = (unsigned char *)&task_count; | reporter_data.data = (unsigned char *)&task_count; | ||||
| reporter_data.dataLen = sizeof(uint32_t); | reporter_data.dataLen = sizeof(uint32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| Range task_range = op_id_map.equal_range(op_id); | Range task_range = op_id_map.equal_range(op_id); | ||||
| for (CIT idx = task_range.first; idx != task_range.second; ++idx) { | for (CIT idx = task_range.first; idx != task_range.second; ++idx) { | ||||
| uint32_t task_id = idx->second; | uint32_t task_id = idx->second; | ||||
| reporter_data.data = (unsigned char *)&task_id; | reporter_data.data = (unsigned char *)&task_id; | ||||
| reporter_data.dataLen = sizeof(uint32_t); | reporter_data.dataLen = sizeof(uint32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2422,10 +2423,8 @@ Status DavinciModel::SinkModelProfile() { | |||||
| Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { | Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { | ||||
| // profiling plugin must be registered | // profiling plugin must be registered | ||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); | |||||
| Msprof::Engine::ReporterData reporter_data{}; | |||||
| auto &prof_mgr = ProfilingManager::Instance(); | |||||
| ReporterData reporter_data{}; | |||||
| // report model data tag name | // report model data tag name | ||||
| std::string tag_name; | std::string tag_name; | ||||
| tag_name.append("model_time_info_") | tag_name.append("model_time_info_") | ||||
| @@ -2448,33 +2447,33 @@ Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { | |||||
| size_t name_len = name.size(); | size_t name_len = name.size(); | ||||
| reporter_data.data = (unsigned char *)&name_len; | reporter_data.data = (unsigned char *)&name_len; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)name.c_str(); | reporter_data.data = (unsigned char *)name.c_str(); | ||||
| reporter_data.dataLen = name.size(); | reporter_data.dataLen = name.size(); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", | |||||
| this->Id()); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // request id | // request id | ||||
| uint64_t request_id = current_data.request_id; | uint64_t request_id = current_data.request_id; | ||||
| reporter_data.data = (unsigned char *)&request_id; | reporter_data.data = (unsigned char *)&request_id; | ||||
| reporter_data.dataLen = sizeof(uint32_t); | reporter_data.dataLen = sizeof(uint32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | ||||
| // thread id | // thread id | ||||
| int32_t thread_id = GetDataInputTid(); | int32_t thread_id = GetDataInputTid(); | ||||
| reporter_data.data = (unsigned char *)&thread_id; | reporter_data.data = (unsigned char *)&thread_id; | ||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | ||||
| // time info | // time info | ||||
| time_info_.modelId = this->Id(); | time_info_.modelId = this->Id(); | ||||
| reporter_data.data = (unsigned char *)&time_info_; | reporter_data.data = (unsigned char *)&time_info_; | ||||
| reporter_data.dataLen = sizeof(struct timeInfo); | reporter_data.dataLen = sizeof(struct timeInfo); | ||||
| GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -2696,8 +2695,9 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
| is_getnext_sink_dynamic_ = true; | is_getnext_sink_dynamic_ = true; | ||||
| cur_dynamic_dims_.clear(); | cur_dynamic_dims_.clear(); | ||||
| cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); | cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); | ||||
| GE_CHK_RT_RET(rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), | |||||
| netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), | |||||
| netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); | |||||
| GE_CHK_RT_RET(ret); | |||||
| } | } | ||||
| GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str()); | GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str()); | ||||
| if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { | if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { | ||||
| @@ -2801,76 +2801,42 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
| reinterpret_cast<int64_t *>(shape_data_buffer_data) + | reinterpret_cast<int64_t *>(shape_data_buffer_data) + | ||||
| shape_data_buffer_length / sizeof(int64_t)); | shape_data_buffer_length / sizeof(int64_t)); | ||||
| GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); | GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); | ||||
| delete[] (int64_t *)current_data.blobs.back().data; | |||||
| delete[] reinterpret_cast<int64_t *>(current_data.blobs.back().data); | |||||
| current_data.blobs.pop_back(); | current_data.blobs.pop_back(); | ||||
| } | } | ||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); | ||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); | ||||
| if (ProfilingManager::Instance().ProfilingOpTraceOn()) { | |||||
| GELOGI("GetOpTraceIterNum:%d", ProfilingManager::Instance().GetOpTraceIterNum()); | |||||
| for (int32_t i = 0; i < ProfilingManager::Instance().GetOpTraceIterNum(); i++) { | |||||
| if (!ProfilingManager::Instance().ProfilingLoadFlag()) { | |||||
| vector<int32_t> prof_device_id_vec = ProfilingManager::Instance().GetProfilingDeviceId(); | |||||
| for (size_t j = 0; j < prof_device_id_vec.size(); ++j) { | |||||
| // just profiling, no need to check value | |||||
| (void)ProfilingManager::Instance().StartProfiling(i, prof_device_id_vec[j]); | |||||
| } | |||||
| } | |||||
| GELOGI("rtModelExecute start."); | |||||
| rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | |||||
| (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); | |||||
| continue); // [No need to check value] | |||||
| GELOGI("rtModelExecute end"); | |||||
| GELOGI("rtStreamSynchronize start."); | |||||
| rt_ret = rtStreamSynchronize(model->rt_model_stream_); | |||||
| if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { | |||||
| GELOGI("The model with multiple datasets aborts normally."); | |||||
| } else { | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | |||||
| (void)model->ReturnResult(current_data.index, false, seq_end_flag, data_wrapper->GetOutput()); | |||||
| continue); // [No need to check value] | |||||
| } | |||||
| GELOGI("rtStreamSynchronize end."); | |||||
| (void)ProfilingManager::Instance().StopProfiling(); // just profiling, no need to check value | |||||
| } | |||||
| GE_TIMESTAMP_START(rtModelExecute); | |||||
| GELOGI("rtModelExecute start."); | |||||
| rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | |||||
| (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); | |||||
| CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | |||||
| continue); | |||||
| GELOGI("rtModelExecute end"); | |||||
| GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); | |||||
| GE_TIMESTAMP_START(rtStreamSynchronize); | |||||
| GELOGI("rtStreamSynchronize start."); | |||||
| rt_ret = rtStreamSynchronize(model->rt_model_stream_); | |||||
| if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) { | |||||
| seq_end_flag = true; | |||||
| } | |||||
| if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { | |||||
| GELOGI("The model with multiple datasets aborts normally."); | |||||
| } else { | } else { | ||||
| GE_TIMESTAMP_START(rtModelExecute); | |||||
| GELOGI("rtModelExecute start."); | |||||
| rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | |||||
| (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); | |||||
| CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | |||||
| continue); | |||||
| GELOGI("rtModelExecute end"); | |||||
| GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); | |||||
| GE_TIMESTAMP_START(rtStreamSynchronize); | |||||
| GELOGI("rtStreamSynchronize start."); | |||||
| rt_ret = rtStreamSynchronize(model->rt_model_stream_); | |||||
| if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) { | |||||
| seq_end_flag = true; | |||||
| } | |||||
| if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { | |||||
| GELOGI("The model with multiple datasets aborts normally."); | |||||
| } else { | |||||
| GE_IF_BOOL_EXEC( | |||||
| rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag); | |||||
| (void)model->ReturnResult(current_data.index, false, seq_end_flag, | |||||
| data_wrapper->GetOutput()); // [No need to check value] | |||||
| CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | |||||
| continue); | |||||
| } | |||||
| GELOGI("rtStreamSynchronize end."); | |||||
| GE_IF_BOOL_EXEC(model->is_first_execute_, | |||||
| GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END)); | |||||
| GE_IF_BOOL_EXEC( | |||||
| rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag); | |||||
| (void)model->ReturnResult(current_data.index, false, seq_end_flag, | |||||
| data_wrapper->GetOutput()); // [No need to check value] | |||||
| CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | |||||
| continue); | |||||
| } | } | ||||
| GELOGI("rtStreamSynchronize end."); | |||||
| GE_IF_BOOL_EXEC(model->is_first_execute_, | |||||
| GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END)); | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), | ||||
| model->SetProfileTime(MODEL_AFTER_PROC_START)); | model->SetProfileTime(MODEL_AFTER_PROC_START)); | ||||
| GE_TIMESTAMP_START(ReturnResult3); | GE_TIMESTAMP_START(ReturnResult3); | ||||
| @@ -3170,21 +3136,29 @@ Status DavinciModel::DistributeTask() { | |||||
| const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); | const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); | ||||
| for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { | for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { | ||||
| auto &task_def = model_task_def->task(task_index); | |||||
| auto &task = task_list_.at(task_index); | auto &task = task_list_.at(task_index); | ||||
| GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); | GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); | ||||
| // for data dump | // for data dump | ||||
| auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), | |||||
| model_task_def->task(task_index).kernel_ex().op_index()); | |||||
| auto op_index = std::max(task_def.kernel().context().op_index(), | |||||
| task_def.kernel_ex().op_index()); | |||||
| OpDescPtr op = GetOpByIndex(op_index); | OpDescPtr op = GetOpByIndex(op_index); | ||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); | |||||
| if (reinterpret_cast<void *>(task->GetDumpArgs()) != nullptr) { | if (reinterpret_cast<void *>(task->GetDumpArgs()) != nullptr) { | ||||
| bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); | bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); | ||||
| if (call_dump || is_op_debug_reg_) { | if (call_dump || is_op_debug_reg_) { | ||||
| SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); | SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); | ||||
| } | } | ||||
| } | } | ||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | |||||
| bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL) | |||||
| && (task_type != RT_MODEL_TASK_KERNEL_EX) | |||||
| && (task_type != RT_MODEL_TASK_HCCL); | |||||
| GE_IF_BOOL_EXEC(no_need_profiling, continue); | |||||
| SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); | |||||
| // Load task info for profiling | // Load task info for profiling | ||||
| TaskDescInfo task_desc_info; | TaskDescInfo task_desc_info; | ||||
| if (!om_name_.empty()) { | if (!om_name_.empty()) { | ||||
| @@ -3193,7 +3167,7 @@ Status DavinciModel::DistributeTask() { | |||||
| task_desc_info.model_name = name_; | task_desc_info.model_name = name_; | ||||
| } | } | ||||
| task_desc_info.op_name = op->GetName(); | task_desc_info.op_name = op->GetName(); | ||||
| task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); | |||||
| task_desc_info.block_dim = task_def.kernel().block_dim(); | |||||
| task_desc_info.task_id = task->GetTaskID(); | task_desc_info.task_id = task->GetTaskID(); | ||||
| task_desc_info.stream_id = task->GetStreamId(); | task_desc_info.stream_id = task->GetStreamId(); | ||||
| task_desc_info_.emplace_back(task_desc_info); | task_desc_info_.emplace_back(task_desc_info); | ||||
| @@ -3391,14 +3365,14 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 | |||||
| /// | /// | ||||
| Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { | Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { | ||||
| if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { | if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed."); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update input data to model failed."); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != | if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != | ||||
| SUCCESS) { | SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed."); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update output data to model failed."); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| for (ZeroCopyTask &task : zero_copy_tasks_) { | for (ZeroCopyTask &task : zero_copy_tasks_) { | ||||
| @@ -3444,7 +3418,7 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| } | } | ||||
| if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { | if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { | ||||
| GELOGE(FAILED, "Check input size and model size failed"); | |||||
| GELOGE(FAILED, "Check input size and model size failed, op[%s]", data.second.GetOpName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -3861,7 +3835,8 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa | |||||
| if (!is_async_mode_) { | if (!is_async_mode_) { | ||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); | ||||
| ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); | ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ACL_ERROR_GE_INTERNAL_ERROR, | |||||
| "Copy Output data to user failed."); | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); | ||||
| } | } | ||||
| @@ -4061,7 +4036,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { | |||||
| data_dumper_.SetDeviceId(device_id); | data_dumper_.SetDeviceId(device_id); | ||||
| // set loop count addr | // set loop count addr | ||||
| auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void * { | |||||
| auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void *{ | |||||
| if (op != nullptr) { | if (op != nullptr) { | ||||
| auto v_output_size = ModelUtils::GetOutputSize(op); | auto v_output_size = ModelUtils::GetOutputSize(op); | ||||
| auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); | auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); | ||||
| @@ -440,7 +440,7 @@ class DavinciModel { | |||||
| Status SinkTimeProfile(const InputData ¤t_data); | Status SinkTimeProfile(const InputData ¤t_data); | ||||
| Status ReportProfilingData(bool check_device = true); | |||||
| Status ReportProfilingData(); | |||||
| void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | ||||
| data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); | data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); | ||||
| @@ -40,9 +40,7 @@ const int kCmdParSize = 2; | |||||
| const int kDumpCmdPairSize = 2; | const int kDumpCmdPairSize = 2; | ||||
| const std::size_t kProfCmdParaMaxSize = 1000; | const std::size_t kProfCmdParaMaxSize = 1000; | ||||
| const std::size_t kProfStartCmdParaSize = 2; | const std::size_t kProfStartCmdParaSize = 2; | ||||
| const std::string kCmdTypeProfile = "profile"; | |||||
| const std::string kCmdTypeDump = "dump"; | const std::string kCmdTypeDump = "dump"; | ||||
| const std::string kCmdTypeProfiling = "profiling"; | |||||
| const std::string kCmdTypeProfInit = "prof_init"; | const std::string kCmdTypeProfInit = "prof_init"; | ||||
| const std::string kCmdTypeProfFinalize = "prof_finalize"; | const std::string kCmdTypeProfFinalize = "prof_finalize"; | ||||
| const std::string kCmdTypeProfStart = "prof_start"; | const std::string kCmdTypeProfStart = "prof_start"; | ||||
| @@ -51,6 +49,9 @@ const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe"; | |||||
| const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; | const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; | ||||
| const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; | const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; | ||||
| const char *const kDeleteCustOp = "deleteCustOp"; | const char *const kDeleteCustOp = "deleteCustOp"; | ||||
| const int kTimeSpecNano = 1000000000; | |||||
| const int kTimeSpecMiro = 1000000; | |||||
| const int kSessionMaxBias = 100; | |||||
| struct CustAicpuSoBuf { | struct CustAicpuSoBuf { | ||||
| uint64_t kernelSoBuf; | uint64_t kernelSoBuf; | ||||
| uint32_t kernelSoBufLen; | uint32_t kernelSoBufLen; | ||||
| @@ -224,7 +225,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | |||||
| ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | ||||
| GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | ||||
| std::lock_guard<std::mutex> lock(sess_ids_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | ||||
| if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
| Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); | Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); | ||||
| @@ -237,7 +238,7 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_ | |||||
| } | } | ||||
| ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { | ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { | ||||
| std::lock_guard<std::mutex> lock(sess_ids_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| std::vector<uint64_t> v_aicpu_kernel; | std::vector<uint64_t> v_aicpu_kernel; | ||||
| std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | ||||
| if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
| @@ -345,7 +346,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_END); | davinci_model->SetProfileTime(MODEL_LOAD_END); | ||||
| } while (0); | } while (0); | ||||
| @@ -629,8 +630,7 @@ Status ModelManager::Stop(uint32_t model_id) { | |||||
| /// | /// | ||||
| Status ModelManager::HandleCommand(const Command &command) { | Status ModelManager::HandleCommand(const Command &command) { | ||||
| static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { | static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { | ||||
| {kCmdTypeProfile, HandleProfileCommand}, {kCmdTypeDump, HandleDumpCommand}, | |||||
| {kCmdTypeProfiling, HandleAclProfilingCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, | |||||
| {kCmdTypeDump, HandleDumpCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, | |||||
| {kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, | {kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, | ||||
| {kCmdTypeProfStop, HandleProfStopCommand}, | {kCmdTypeProfStop, HandleProfStopCommand}, | ||||
| {kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, | {kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, | ||||
| @@ -645,21 +645,6 @@ Status ModelManager::HandleCommand(const Command &command) { | |||||
| } | } | ||||
| } | } | ||||
| Status ModelManager::HandleAclProfilingCommand(const Command &command) { | |||||
| if (command.cmd_params.size() < kCmdParSize) { | |||||
| GELOGE(PARAM_INVALID, "When the cmd_type is 'profiling', the size of cmd_params must larger than 2."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::string map_key = command.cmd_params[0]; | |||||
| std::string value = command.cmd_params[1]; | |||||
| if (map_key == PROFILE_CONFIG) { | |||||
| ProfilingManager::Instance().SetProfilingConfig(value); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelManager::GetModelByCmd(const Command &command, | Status ModelManager::GetModelByCmd(const Command &command, | ||||
| std::shared_ptr<DavinciModel> &davinci_model) { | std::shared_ptr<DavinciModel> &davinci_model) { | ||||
| if (command.cmd_params.size() < kCmdParSize) { | if (command.cmd_params.size() < kCmdParSize) { | ||||
| @@ -806,29 +791,6 @@ Status ModelManager::HandleProfStopCommand(const Command &command) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelManager::HandleProfileCommand(const Command &command) { | |||||
| if (command.cmd_params.size() < kCmdParSize) { | |||||
| GELOGE(PARAM_INVALID, "When the cmd_type is 'profile', the size of cmd_params must larger than 2."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::string map_key = command.cmd_params[0]; | |||||
| std::string value = command.cmd_params[1]; | |||||
| GELOGI("Profiling mode, Command key:%s , value:%s ", map_key.c_str(), value.c_str()); | |||||
| auto iter = PROFILE_COMPONENT_MAP.find(map_key); | |||||
| if (iter != PROFILE_COMPONENT_MAP.end()) { | |||||
| std::string property_value = (value == "on") ? "1" : "0"; | |||||
| PropertiesManager::Instance().SetPropertyValue(iter->second, property_value); | |||||
| } | |||||
| if ((map_key == PROFILER_JOBCTX || map_key == PROFILER_TARGET_PATH || map_key == RTS_PROFILE_PATH)) { | |||||
| PropertiesManager::Instance().SetPropertyValue(map_key, value); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) { | static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) { | ||||
| auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key); | auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key); | ||||
| if (iter != command.cmd_params.end()) { | if (iter != command.cmd_params.end()) { | ||||
| @@ -1072,12 +1034,12 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get current time."); | GELOGE(INTERNAL_ERROR, "Failed to get current time."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| session_id = static_cast<uint64_t>(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us | |||||
| session_id = static_cast<uint64_t>(tv.tv_sec * kTimeSpecMiro + tv.tv_usec); // 1000000us | |||||
| session_id_bias_++; | session_id_bias_++; | ||||
| // max bais 100. | // max bais 100. | ||||
| session_id_bias_ = session_id_bias_ % 100; | |||||
| session_id = session_id * 100 + session_id_bias_; | |||||
| session_id_bias_ = session_id_bias_ % kSessionMaxBias; | |||||
| session_id = session_id * kSessionMaxBias + session_id_bias_; | |||||
| GELOGD("Generate new session id: %lu.", session_id); | GELOGD("Generate new session id: %lu.", session_id); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -1086,8 +1048,7 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { | |||||
| Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener, | Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener, | ||||
| void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | ||||
| GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, | GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, | ||||
| ACL_ERROR_GE_PARAM_INVALID, | |||||
| "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); | |||||
| ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); | |||||
| GenModelId(&model_id); | GenModelId(&model_id); | ||||
| shared_ptr<DavinciModel> davinci_model = nullptr; | shared_ptr<DavinciModel> davinci_model = nullptr; | ||||
| @@ -1148,7 +1109,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_END); | davinci_model->SetProfileTime(MODEL_LOAD_END); | ||||
| @@ -1252,7 +1213,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
| } | } | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id); | |||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| "Invalid model id %u, check weather model has been loaded or not.", model_id); | |||||
| if (davinci_model->NeedDestroyAicpuKernel()) { | if (davinci_model->NeedDestroyAicpuKernel()) { | ||||
| GELOGI("Start to destroy specified aicpu kernel."); | GELOGI("Start to destroy specified aicpu kernel."); | ||||
| @@ -1289,8 +1251,8 @@ Status ModelManager::CreateAicpuSession(uint64_t session_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name) { | |||||
| GELOGI("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str()); | |||||
| Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded) { | |||||
| GELOGD("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str()); | |||||
| std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
| CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | ||||
| if (aicpu_kernel == nullptr) { | if (aicpu_kernel == nullptr) { | ||||
| @@ -1313,18 +1275,24 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ | |||||
| std::map<string, CustAICPUKernelPtr> new_so_name; | std::map<string, CustAICPUKernelPtr> new_so_name; | ||||
| new_so_name.insert({so_name, aicpu_kernel}); | new_so_name.insert({so_name, aicpu_kernel}); | ||||
| cust_aicpu_so_[resource_id] = new_so_name; | cust_aicpu_so_[resource_id] = new_so_name; | ||||
| GELOGI("LoadCustAicpuSo new aicpu so resource id %lu", resource_id); | |||||
| loaded = false; | |||||
| GELOGD("LoadCustAicpuSo new aicpu so name %s, resource id %lu", so_name.c_str(), resource_id); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto it_so_name = it->second.find(so_name); | auto it_so_name = it->second.find(so_name); | ||||
| if (it_so_name == it->second.end()) { | if (it_so_name == it->second.end()) { | ||||
| it->second.insert({so_name, aicpu_kernel}); | it->second.insert({so_name, aicpu_kernel}); | ||||
| GELOGI("LoadCustAicpuSo add aicpu so resource id %lu", resource_id); | |||||
| loaded = false; | |||||
| GELOGD("LoadCustAicpuSo add aicpu so name %s, resource id %lu", so_name.c_str(), resource_id); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| loaded = true; | |||||
| GELOGD("LoadCustAicpuSo so name %s has been loaded.", so_name.c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | ||||
| GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | |||||
| std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
| if (cust_aicpu_so_.size() == 0) return SUCCESS; | if (cust_aicpu_so_.size() == 0) return SUCCESS; | ||||
| // get current context | // get current context | ||||
| @@ -169,8 +169,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| /// @brief comment handle function | /// @brief comment handle function | ||||
| /// | /// | ||||
| ge::Status HandleCommand(const Command &command); | ge::Status HandleCommand(const Command &command); | ||||
| static ge::Status HandleAclProfilingCommand(const Command &command); | |||||
| static ge::Status HandleProfileCommand(const Command &command); | |||||
| static ge::Status HandleDumpCommand(const Command &command); | static ge::Status HandleDumpCommand(const Command &command); | ||||
| static ge::Status HandleProfModelSubscribeCommand(const Command &command); | static ge::Status HandleProfModelSubscribeCommand(const Command &command); | ||||
| static ge::Status HandleProfModelUnsubscribeCommand(const Command &command); | static ge::Status HandleProfModelUnsubscribeCommand(const Command &command); | ||||
| @@ -289,7 +287,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
| ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name); | |||||
| ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded); | |||||
| ge::Status LaunchCustAicpuSo(); | ge::Status LaunchCustAicpuSo(); | ||||
| @@ -61,7 +61,7 @@ vector<int64_t> ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { | |||||
| GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | ||||
| continue); | continue); | ||||
| GELOGI("[IMAS]GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| GELOGI("GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| v_input_size.push_back(tensor_size); | v_input_size.push_back(tensor_size); | ||||
| } | } | ||||
| @@ -96,7 +96,7 @@ vector<int64_t> ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { | |||||
| GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); | GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); | ||||
| continue); | continue); | ||||
| GELOGI("[IMAS]GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| GELOGI("GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); | |||||
| v_output_size.push_back(tensor_size); | v_output_size.push_back(tensor_size); | ||||
| } | } | ||||
| @@ -279,9 +279,10 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc, | |||||
| output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; | output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; | ||||
| } | } | ||||
| kernel_hccl_infos[i].inputDataAddr = input_data_addr; | kernel_hccl_infos[i].inputDataAddr = input_data_addr; | ||||
| if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER || hccl_type == HCOMREDUCE) { | |||||
| if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { | |||||
| kernel_hccl_infos[i].outputDataAddr = output_data_addr; | kernel_hccl_infos[i].outputDataAddr = output_data_addr; | ||||
| } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { | |||||
| } else if (hccl_type == HCOMALLREDUCE || | |||||
| hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) { | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | ||||
| "davinci_model: GetHcomOperationType fail!"); | "davinci_model: GetHcomOperationType fail!"); | ||||
| kernel_hccl_infos[i].outputDataAddr = output_data_addr; | kernel_hccl_infos[i].outputDataAddr = output_data_addr; | ||||
| @@ -43,6 +43,13 @@ const char *kIsLastNode = "is_last_node"; | |||||
| const char *kIsFirstNode = "is_first_node"; | const char *kIsFirstNode = "is_first_node"; | ||||
| const int64_t kCloseSkt = 100; | const int64_t kCloseSkt = 100; | ||||
| const uint32_t kAddrLen = sizeof(void *); | const uint32_t kAddrLen = sizeof(void *); | ||||
| const int kBaseInt = 10; | |||||
| const int kStrtolFail = 0; | |||||
| const int kArgsInputDesc = 0; | |||||
| const int kArgsInputAddr = 1; | |||||
| const int kArgsOutputDesc = 2; | |||||
| const int kArgsOutputAddr = 3; | |||||
| const int kArgsAttrHandle = 4; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -66,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| // get opcontext stored in model | // get opcontext stored in model | ||||
| const domi::KernelContext &context = kernel_def.context(); | const domi::KernelContext &context = kernel_def.context(); | ||||
| // get kernel_type | // get kernel_type | ||||
| kernel_type_ = static_cast<cce::ccKernelType>(context.kernel_type()); | |||||
| kernel_type_ = static_cast<ccKernelType>(context.kernel_type()); | |||||
| // get opdesc | // get opdesc | ||||
| op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); | op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); | ||||
| GE_CHECK_NOTNULL(op_desc_); | GE_CHECK_NOTNULL(op_desc_); | ||||
| @@ -88,13 +95,13 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| // get bin_file_key | // get bin_file_key | ||||
| const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); | const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); | ||||
| // new aicpu kernel(rtCpuKernelLaunch) no need to check function | // new aicpu kernel(rtCpuKernelLaunch) no need to check function | ||||
| if (kernel_type_ == cce::ccKernelType::CCE_AI_CORE) { | |||||
| if (kernel_type_ == ccKernelType::CCE_AI_CORE) { | |||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_); | rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | ||||
| kernel_def.stub_func().c_str()); | kernel_def.stub_func().c_str()); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);); | return RT_ERROR_TO_GE_STATUS(rt_ret);); | ||||
| } else if (kernel_type_ == cce::ccKernelType::TE) { | |||||
| } else if (kernel_type_ == ccKernelType::TE) { | |||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | ||||
| @@ -111,7 +118,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| ctx_.opIndex2[i] = context.origin_op_index(i); | ctx_.opIndex2[i] = context.origin_op_index(i); | ||||
| } | } | ||||
| ctx_.opCount = context.origin_op_index_size(); | ctx_.opCount = context.origin_op_index_size(); | ||||
| if (kernel_type_ == cce::ccKernelType::TE) { | |||||
| if (kernel_type_ == ccKernelType::TE) { | |||||
| ctx_.opIndex = context.op_index(); | ctx_.opIndex = context.op_index(); | ||||
| uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())); | uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())); | ||||
| if (context.args_offset().size() / sizeof(uint16_t) < 1) { | if (context.args_offset().size() / sizeof(uint16_t) < 1) { | ||||
| @@ -120,9 +127,9 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| } | } | ||||
| ret = InitTVMTask(args_offset_tmp[0], kernel_def); | ret = InitTVMTask(args_offset_tmp[0], kernel_def); | ||||
| } else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) { | |||||
| } else if (kernel_type_ == ccKernelType::CUSTOMIZED) { | |||||
| ret = InitAICPUCustomTask(context.op_index(), kernel_def); | ret = InitAICPUCustomTask(context.op_index(), kernel_def); | ||||
| } else if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| } else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| ret = InitAicpuTask(context.op_index(), kernel_def); | ret = InitAicpuTask(context.op_index(), kernel_def); | ||||
| } else { | } else { | ||||
| if (kernel_def.args().empty() || args_size_ == 0) { | if (kernel_def.args().empty() || args_size_ == 0) { | ||||
| @@ -371,9 +378,9 @@ Status KernelTaskInfo::Distribute() { | |||||
| rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
| char skt_enable_env[MMPA_MAX_PATH] = { 0x00 }; | char skt_enable_env[MMPA_MAX_PATH] = { 0x00 }; | ||||
| INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH); | INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH); | ||||
| int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, 10) : 0; | |||||
| int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail; | |||||
| bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | ||||
| if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); | GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); | ||||
| // blockDim is reserved parameter, set to 1 | // blockDim is reserved parameter, set to 1 | ||||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | ||||
| @@ -749,15 +756,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) = | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputDesc])) = | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) = | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputAddr])) = | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) = | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputDesc])) = | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) = | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputAddr])) = | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) = | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsAttrHandle])) = | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 | ||||
| rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | ||||
| @@ -874,8 +881,10 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_), "launch cust aicpu so failed"); | |||||
| if (kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| bool loaded = false; | |||||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_, loaded), | |||||
| "launch cust aicpu so failed"); | |||||
| } | } | ||||
| // copy args to new host memory | // copy args to new host memory | ||||
| @@ -946,7 +955,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| GELOGI("Op debug is open in aicpu task info"); | GELOGI("Op debug is open in aicpu task info"); | ||||
| dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | ||||
| } | } | ||||
| if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| if (kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | ||||
| } | } | ||||
| @@ -1076,7 +1085,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector<void *> &input_d | |||||
| Status KernelTaskInfo::SetContext(const domi::KernelDef &kernel_def) { | Status KernelTaskInfo::SetContext(const domi::KernelDef &kernel_def) { | ||||
| const domi::KernelContext &context = kernel_def.context(); | const domi::KernelContext &context = kernel_def.context(); | ||||
| ctx_.kernelType = static_cast<cce::ccKernelType>(context.kernel_type()); | |||||
| ctx_.kernelType = static_cast<ccKernelType>(context.kernel_type()); | |||||
| ctx_.opId = context.op_id(); | ctx_.opId = context.op_id(); | ||||
| ctx_.kernelFuncId = context.kernel_func_id(); | ctx_.kernelFuncId = context.kernel_func_id(); | ||||
| ctx_.isFlowtable = context.is_flowtable(); | ctx_.isFlowtable = context.is_flowtable(); | ||||
| @@ -1161,10 +1170,10 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u | |||||
| GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", error); | GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", error); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| cce::ccStatus_t cc_ret; | |||||
| ccStatus_t cc_ret; | |||||
| std::string update_kernel_args = "ccUpdateKernelArgs"; | std::string update_kernel_args = "ccUpdateKernelArgs"; | ||||
| auto cceUpdateKernelArgs = (cce::ccStatus_t(*)(cce::ccOpContext &, uint64_t, uint64_t, uint64_t, void *, uint64_t, | |||||
| void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str())); | |||||
| auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t, | |||||
| uint64_t, void *, uint64_t, void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str())); | |||||
| if (cceUpdateKernelArgs == nullptr) { | if (cceUpdateKernelArgs == nullptr) { | ||||
| GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); | GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); | ||||
| if (mmDlclose(handle) != 0) { | if (mmDlclose(handle) != 0) { | ||||
| @@ -1189,7 +1198,7 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u | |||||
| GELOGW("Failed to close handle %s", error); | GELOGW("Failed to close handle %s", error); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (cc_ret != cce::CC_STATUS_SUCCESS) { | |||||
| if (cc_ret != CC_STATUS_SUCCESS) { | |||||
| GELOGE(CCE_FAILED, "Call cce api failed, ret: 0x%X", cc_ret); | GELOGE(CCE_FAILED, "Call cce api failed, ret: 0x%X", cc_ret); | ||||
| return CCE_FAILED; | return CCE_FAILED; | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
| stream_id_(0), | stream_id_(0), | ||||
| so_name_(""), | so_name_(""), | ||||
| kernel_name_(""), | kernel_name_(""), | ||||
| kernel_type_(cce::ccKernelType::CCE_AI_CORE), | |||||
| kernel_type_(ccKernelType::CCE_AI_CORE), | |||||
| dump_flag_(RT_KERNEL_DEFAULT), | dump_flag_(RT_KERNEL_DEFAULT), | ||||
| dump_args_(nullptr), | dump_args_(nullptr), | ||||
| op_desc_(nullptr), | op_desc_(nullptr), | ||||
| @@ -75,7 +75,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
| Status Release() override; | Status Release() override; | ||||
| cce::ccOpContext *GetCtx() override { return &ctx_; } | |||||
| ccOpContext *GetCtx() override { return &ctx_; } | |||||
| FusionOpInfo *GetFusionOpInfo() override { return &fusion_op_info_; } | FusionOpInfo *GetFusionOpInfo() override { return &fusion_op_info_; } | ||||
| @@ -92,7 +92,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
| bool CallSaveDumpInfo() override { return call_save_dump_; }; | bool CallSaveDumpInfo() override { return call_save_dump_; }; | ||||
| cce::ccOpContext ctx_; | |||||
| ccOpContext ctx_; | |||||
| FusionOpInfo fusion_op_info_; | FusionOpInfo fusion_op_info_; | ||||
| private: | private: | ||||
| @@ -153,7 +153,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
| uint32_t stream_id_; | uint32_t stream_id_; | ||||
| std::string so_name_; | std::string so_name_; | ||||
| std::string kernel_name_; | std::string kernel_name_; | ||||
| cce::ccKernelType kernel_type_; | |||||
| ccKernelType kernel_type_; | |||||
| uint32_t dump_flag_; | uint32_t dump_flag_; | ||||
| void *dump_args_; | void *dump_args_; | ||||
| OpDescPtr op_desc_; | OpDescPtr op_desc_; | ||||
| @@ -41,7 +41,7 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||||
| Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | ||||
| private: | private: | ||||
| void SetInputAndValuePtr(DavinciModel *davinci_model, const vector<void *> &input_data_addrs); | |||||
| void SetInputAndValuePtr(DavinciModel *davinci_model, const std::vector<void *> &input_data_addrs); | |||||
| void *input_ptr_; | void *input_ptr_; | ||||
| rtCondition_t cond_; | rtCondition_t cond_; | ||||
| void *value_ptr_; | void *value_ptr_; | ||||
| @@ -49,7 +49,7 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||||
| uint32_t true_stream_id_; | uint32_t true_stream_id_; | ||||
| rtSwitchDataType_t data_type_; | rtSwitchDataType_t data_type_; | ||||
| static const uint32_t kInputNum = 2; | static const uint32_t kInputNum = 2; | ||||
| vector<int64_t> fixed_addr_offset_; | |||||
| std::vector<int64_t> fixed_addr_offset_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ | ||||
| @@ -25,10 +25,11 @@ Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { | |||||
| const void *args[] = {this->GetNavTablePtr(), | const void *args[] = {this->GetNavTablePtr(), | ||||
| reinterpret_cast<const void *>(static_cast<uintptr_t>(this->GetNavTableSize()))}; | reinterpret_cast<const void *>(static_cast<uintptr_t>(this->GetNavTableSize()))}; | ||||
| rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); return | |||||
| RT_ERROR_TO_GE_STATUS(rt_ret);) | |||||
| rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rtError_t rt_ret = rtMalloc(reinterpret_cast<void **>(&device_args_addr_), sizeof(args), RT_MEMORY_HBM); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);) | |||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(device_args_addr_), sizeof(args), reinterpret_cast<void *>(args), | |||||
| sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
| rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, | rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, | ||||
| @@ -19,6 +19,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace skt { | namespace skt { | ||||
| const size_t kFusedKernelMinimumSize = 2; | |||||
| const size_t kFusedKernelSizeUnit = 2; | |||||
| SuperKernelFactory &SuperKernelFactory::GetInstance() { | SuperKernelFactory &SuperKernelFactory::GetInstance() { | ||||
| static SuperKernelFactory factory; | static SuperKernelFactory factory; | ||||
| return factory; | return factory; | ||||
| @@ -79,17 +81,17 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (super_kernel_size < 2) { | |||||
| if (super_kernel_size < kFusedKernelMinimumSize) { | |||||
| GELOGW( | GELOGW( | ||||
| "SKT: the number of kernels being fused must be greater than or " | "SKT: the number of kernels being fused must be greater than or " | ||||
| "equal to 2"); | "equal to 2"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size()); | GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size()); | ||||
| const size_t nav_table_len = 2 * stub_func_list.size(); | |||||
| const size_t nav_table_len = kFusedKernelSizeUnit * stub_func_list.size(); | |||||
| std::unique_ptr<uint64_t[]> nav_table(new(std::nothrow) uint64_t[nav_table_len]); | std::unique_ptr<uint64_t[]> nav_table(new(std::nothrow) uint64_t[nav_table_len]); | ||||
| GE_CHECK_NOTNULL(nav_table); | GE_CHECK_NOTNULL(nav_table); | ||||
| uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); | |||||
| uint64_t nav_table_size = kFusedKernelSizeUnit * stub_func_list.size() * sizeof(int64_t); | |||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| void *hbm_nav_table_addr = nullptr; | void *hbm_nav_table_addr = nullptr; | ||||
| @@ -101,21 +103,21 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list | |||||
| GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); | GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); | ||||
| // store two uint64_t address | // store two uint64_t address | ||||
| // address divided by 4 because of 32bits encoding, call offset will *4 when calculating | // address divided by 4 because of 32bits encoding, call offset will *4 when calculating | ||||
| nav_table[i * 2] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4; | |||||
| GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); | |||||
| nav_table[i * 2 + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i])); | |||||
| GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); | |||||
| nav_table[i * kFusedKernelSizeUnit] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4; | |||||
| GELOGD("SKT: CALL offet %lu", nav_table[i * kFusedKernelSizeUnit]); | |||||
| nav_table[i * kFusedKernelSizeUnit + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i])); | |||||
| GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * kFusedKernelSizeUnit + 1]); | |||||
| } | } | ||||
| rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | |||||
| rt_ret = rtMalloc(reinterpret_cast<void **>(&hbm_nav_table_addr), nav_table_size, RT_MEMORY_HBM); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
| rt_ret = | |||||
| rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table.get(), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(hbm_nav_table_addr), nav_table_size, | |||||
| reinterpret_cast<void *>(nav_table.get()), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret); | ||||
| GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) | GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
| // Create the necessary metadata for the super kernel | // Create the necessary metadata for the super kernel | ||||
| h = std::unique_ptr<skt::SuperKernel>( | |||||
| new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); | |||||
| h = | |||||
| std::unique_ptr<skt::SuperKernel>(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace skt | } // namespace skt | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "cce/customize.h" | #include "cce/customize.h" | ||||
| #include "cce/taskdown_common.hpp" | |||||
| #include "framework/common/taskdown_common.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/load/new_model_manager/ts_mem_mall.h" | #include "graph/load/new_model_manager/ts_mem_mall.h" | ||||
| #include "graph/load/new_model_manager/task_info/task_info_factory.h" | #include "graph/load/new_model_manager/task_info/task_info_factory.h" | ||||
| @@ -63,8 +63,8 @@ struct RuntimeParam { | |||||
| }; | }; | ||||
| typedef struct FusionOpInfo { | typedef struct FusionOpInfo { | ||||
| vector<string> original_op_names; | |||||
| string op_name; | |||||
| std::vector<std::string> original_op_names; | |||||
| std::string op_name; | |||||
| uint32_t op_index; | uint32_t op_index; | ||||
| uint32_t stream_id; | uint32_t stream_id; | ||||
| } FusionOpInfo; | } FusionOpInfo; | ||||
| @@ -87,7 +87,7 @@ class TaskInfo { | |||||
| virtual Status Release() { return SUCCESS; } | virtual Status Release() { return SUCCESS; } | ||||
| virtual cce::ccOpContext *GetCtx() { return nullptr; } | |||||
| virtual ccOpContext *GetCtx() { return nullptr; } | |||||
| virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } | virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| namespace { | namespace { | ||||
| constexpr uint32_t kMaxTsMemBlock = 2 * 1024 * 1024; // Max block 2M | |||||
| constexpr uint32_t kMaxTsMemBlock = 2097152; // Max block 2M 2 * 1024 * 1024 | |||||
| constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align | constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align | ||||
| constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1; | constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1; | ||||
| } | } | ||||
| @@ -35,6 +35,7 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr | |||||
| GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p", | GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p", | ||||
| op_desc->GetName().c_str(), output_size, virtual_addr); | op_desc->GetName().c_str(), output_size, virtual_addr); | ||||
| basic_addr_ = virtual_addr; | basic_addr_ = virtual_addr; | ||||
| op_name_ = op_desc->GetName(); | |||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | ||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | ||||
| @@ -82,6 +83,7 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector<int64_t> &input_size_list | |||||
| GELOGD("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); | GELOGD("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); | ||||
| basic_addr_ = virtual_addr_list[idx]; | basic_addr_ = virtual_addr_list[idx]; | ||||
| op_name_ = op_desc->GetName(); | |||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | ||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | ||||
| @@ -66,9 +66,12 @@ class ZeroCopyOffset { | |||||
| int64_t GetDataSize() const { return data_size_; } | int64_t GetDataSize() const { return data_size_; } | ||||
| // value of *outside_addrs_ from davinci_model | // value of *outside_addrs_ from davinci_model | ||||
| std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | ||||
| // name of op | |||||
| std::string GetOpName() const { return op_name_; } | |||||
| private: | private: | ||||
| void *basic_addr_ = nullptr; | void *basic_addr_ = nullptr; | ||||
| std::string op_name_; | |||||
| uint32_t data_count_ = 0; | uint32_t data_count_ = 0; | ||||
| std::vector<std::pair<int64_t, void *>> data_info_; | std::vector<std::pair<int64_t, void *>> data_info_; | ||||
| vector<int64_t> relative_offset_; | vector<int64_t> relative_offset_; | ||||
| @@ -80,4 +83,4 @@ class ZeroCopyOffset { | |||||
| std::vector<int64_t> zero_copy_relative_offset_; | std::vector<int64_t> zero_copy_relative_offset_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ | |||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ | |||||
| @@ -131,7 +131,7 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma | |||||
| auto dst_addr = static_cast<uint8_t *>(buffer_addr); | auto dst_addr = static_cast<uint8_t *>(buffer_addr); | ||||
| GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", | GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", | ||||
| name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); | name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); | ||||
| *(uintptr_t *)(args_info + offset) = reinterpret_cast<uintptr_t>(dst_addr); | |||||
| *reinterpret_cast<uintptr_t *>(args_info + offset)= reinterpret_cast<uintptr_t>(dst_addr); | |||||
| is_updated_ = true; | is_updated_ = true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,13 +25,13 @@ | |||||
| namespace ge { | namespace ge { | ||||
| const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | ||||
| 8 * kMByteSize, | |||||
| 32 * kMByteSize, | |||||
| 128 * kMByteSize, | |||||
| kBinSizeUnit8 * kMByteSize, | |||||
| kBinSizeUnit32 * kMByteSize, | |||||
| kBinSizeUnit128 * kMByteSize, | |||||
| kGByteSize, | kGByteSize, | ||||
| 4 * kGByteSize, | |||||
| 16 * kGByteSize, | |||||
| 26 * kGByteSize}; | |||||
| kBinSizeUnit4 * kGByteSize, | |||||
| kBinSizeUnit16 * kGByteSize, | |||||
| kBinSizeUnit26 * kGByteSize}; | |||||
| static bool BlockComparator(const Block *left, const Block *right) { | static bool BlockComparator(const Block *left, const Block *right) { | ||||
| if (left->size != right->size) { | if (left->size != right->size) { | ||||
| @@ -34,10 +34,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | ||||
| constexpr size_t kBinSizeUnit4 = 4; | |||||
| constexpr size_t kBinSizeUnit8 = 8; | |||||
| constexpr size_t kBinSizeUnit16 = 16; | |||||
| constexpr size_t kBinSizeUnit26 = 26; | |||||
| constexpr size_t kBinSizeUnit32 = 32; | |||||
| constexpr size_t kBinSizeUnit128 = 128; | |||||
| constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | ||||
| constexpr size_t kKByteSize = 1024; | constexpr size_t kKByteSize = 1024; | ||||
| constexpr size_t kMByteSize = 1024 * 1024; | |||||
| constexpr size_t kGByteSize = 1024 * 1024 * 1024; | |||||
| constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | |||||
| constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | |||||
| static const uint32_t kNumBins = 8; | static const uint32_t kNumBins = 8; | ||||
| @@ -533,9 +533,8 @@ Status GraphManager::CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_gr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map, | |||||
| uint64_t session_id) { | |||||
| Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| // use default 16 multi thread | // use default 16 multi thread | ||||
| const uint32_t thread_num = 16; | const uint32_t thread_num = 16; | ||||
| @@ -550,14 +549,14 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | ||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext()); | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, | |||||
| GetThreadLocalContext()); | |||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| vector_future.emplace_back(std::move(f)); | vector_future.emplace_back(std::move(f)); | ||||
| } | } | ||||
| for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | ||||
| auto subgraph_list = sub_graph_map[function_graph]; | auto subgraph_list = sub_graph_map[function_graph]; | ||||
| for (const auto &subgraph : subgraph_list) { | for (const auto &subgraph : subgraph_list) { | ||||
| @@ -651,62 +650,13 @@ Status GraphManager::ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_ | |||||
| Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { | Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| auto sub_graph_map = partitioner.GetSubGraphMap(); | auto sub_graph_map = partitioner.GetSubGraphMap(); | ||||
| std::string buffer_optimize; | |||||
| graphStatus graph_status = ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); | |||||
| bool need_lx_fusion = (graph_status == GRAPH_SUCCESS) && (buffer_optimize != kOffOptimize); | |||||
| if (options_.build_mode.empty() && need_lx_fusion) { | |||||
| GELOGI("Enter normal mode with buffer_optimize:%s.", buffer_optimize.c_str()); | |||||
| /// 1. Copy subgraph for buffer optimize while lx fusion failed. | |||||
| /// 2. Set graph with attr "lx_fusion" for fusion optimize. | |||||
| std::unordered_map<std::string, ComputeGraphPtr> copy_graphs; | |||||
| GE_TIMESTAMP_START(CopySubGraphAndMarkFusion); | |||||
| Status ret = CopySubGraphAndMarkFusion(compute_graph, sub_graph_map, copy_graphs); | |||||
| GE_TIMESTAMP_EVENT_END(CopySubGraphAndMarkFusion, "SetSubgraph:CopySubGraphAndMarkFusion"); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "CopySubGraphAndMarkFusion failed."); | |||||
| return ret; | |||||
| } | |||||
| // Multiply optimize subgraph with lx fusion | |||||
| ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx fusion failed."); | |||||
| return ret; | |||||
| } | |||||
| // Check whether all subgraph lx fusion success | |||||
| GE_TIMESTAMP_START(CheckAllFusionOptimizeSuccess); | |||||
| if (CheckAllFusionOptimizeSuccess(compute_graph, sub_graph_map)) { | |||||
| GE_TIMESTAMP_EVENT_END(CheckAllFusionOptimizeSuccess, "SetSubgraph:CheckAllFusionOptimizeSuccess"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Replace subgraph with original graph for lx buffer | |||||
| ret = ReplaceSubgraphWithOriGraph(compute_graph, sub_graph_map, copy_graphs); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Replace subgraph with original graph failed."); | |||||
| return ret; | |||||
| } | |||||
| // Multiply optimize subgraph with lx buffer | |||||
| ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx buffer failed."); | |||||
| return ret; | |||||
| } | |||||
| } else { | |||||
| /// Multiply optimize subgraph: | |||||
| /// 1. run lx buffer while build_mode is normal and buffer_optimize is empty or "off_optimize"; | |||||
| /// 2. run lx fusion or buffer according build_mode and build_step in fe. | |||||
| GELOGD("Directly optimize subgraph with build mode:%s, and step:%s, buffer_optimize:%s.", | |||||
| options_.build_mode.c_str(), | |||||
| options_.build_step.c_str(), | |||||
| buffer_optimize.c_str()); | |||||
| Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx buffer"); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("Directly optimize subgraph with build mode:%s, and step:%s.", | |||||
| options_.build_mode.c_str(), | |||||
| options_.build_step.c_str()); | |||||
| Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph failed"); | |||||
| return ret; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -2515,7 +2465,6 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| GetContext().SetSessionId(session_id); | GetContext().SetSessionId(session_id); | ||||
| GetThreadLocalContext() = ge_context; | GetThreadLocalContext() = ge_context; | ||||
| graph_manager->UpdateLocalOmgContext(root_graph_id); | graph_manager->UpdateLocalOmgContext(root_graph_id); | ||||
| ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ||||
| const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | ||||
| GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | ||||
| @@ -2523,6 +2472,10 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| pthread_self()); | pthread_self()); | ||||
| GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); | GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); | ||||
| GE_CHECK_NOTNULL(compute_graph_tmp); | GE_CHECK_NOTNULL(compute_graph_tmp); | ||||
| if (!AttrUtils::SetInt(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_ID, root_graph_id)) { | |||||
| GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| compute_graph_tmp->SetSessionID(session_id); | compute_graph_tmp->SetSessionID(session_id); | ||||
| Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | ||||
| compute_graph, | compute_graph, | ||||
| @@ -2688,9 +2641,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
| } | } | ||||
| // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. | ||||
| GELOGI("Start for run graph async."); | GELOGI("Start for run graph async."); | ||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| if (graph_manager->IsGraphNeedBuild(graph_node)) { | if (graph_manager->IsGraphNeedBuild(graph_node)) { | ||||
| if (graph_node->GetBuildFlag()) { | if (graph_node->GetBuildFlag()) { | ||||
| @@ -280,9 +280,9 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| uint64_t free_size = total_size_ - var_mem_size_; | uint64_t free_size = total_size_ - var_mem_size_; | ||||
| if (free_size < (size + kSessionMemAlignSize * 2)) { | |||||
| if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) { | |||||
| GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", | GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", | ||||
| size + kSessionMemAlignSize * 2 + var_mem_size_, total_size_); | |||||
| size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -42,6 +42,7 @@ const size_t kGraphMemoryBuffer = 4UL * 1024UL * 1024UL * 1024UL; | |||||
| const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; | const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; | ||||
| const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; | const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; | ||||
| const uint64_t kSessionMemAlignSize = 512; | const uint64_t kSessionMemAlignSize = 512; | ||||
| const size_t kSessionMemAlignUnit = 2; | |||||
| enum MemStatus { | enum MemStatus { | ||||
| NORMAL = 0, | NORMAL = 0, | ||||
| @@ -106,7 +106,7 @@ Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_add | |||||
| GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); | GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address)); | |||||
| base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address)); | |||||
| data_size = var_memory_base_map_[op_name].mem_size; | data_size = var_memory_base_map_[op_name].mem_size; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -32,7 +32,8 @@ Debug::~Debug() = default; | |||||
| void Debug::DumpProto(const Message &proto, const char *file) { | void Debug::DumpProto(const Message &proto, const char *file) { | ||||
| std::string file_path = RealPath(file); | std::string file_path = RealPath(file); | ||||
| int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD | M_UMASK_OTHREAD); | |||||
| int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD | | |||||
| M_UMASK_OTHREAD); | |||||
| if (fd == -1) { | if (fd == -1) { | ||||
| GELOGW("Write %s failed", file_path.c_str()); | GELOGW("Write %s failed", file_path.c_str()); | ||||
| return; | return; | ||||
| @@ -263,7 +263,8 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro | |||||
| Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | ||||
| std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| if (op_desc->GetType() == HCOMBROADCAST || | |||||
| op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| int64_t root_id = 0; | int64_t root_id = 0; | ||||
| Status dmrt = GetHcclRootId(op_desc, root_id); | Status dmrt = GetHcclRootId(op_desc, root_id); | ||||
| @@ -26,6 +26,13 @@ | |||||
| namespace { | namespace { | ||||
| using namespace ge; | using namespace ge; | ||||
| const int kIdentityAnchorIndex = 0; | const int kIdentityAnchorIndex = 0; | ||||
| const size_t kSerialStringVecSize = 4; | |||||
| const int kCaseReadOnly = 0; | |||||
| const int kCaseScopeWriteable = 2; | |||||
| const int kCaseWriteable = 3; | |||||
| const int kCaseInvalidRWType = 5; | |||||
| // rw type of input. | // rw type of input. | ||||
| enum class InputRWType { | enum class InputRWType { | ||||
| kReadOnly, // Normal op input only read | kReadOnly, // Normal op input only read | ||||
| @@ -55,7 +62,7 @@ thread_local map<string, NodeInputOutputRWType> node_rwtype_map_; | |||||
| /// @return rw_type_name | /// @return rw_type_name | ||||
| /// | /// | ||||
| static std::string InputRWTypeToSerialString(InputRWType rw_type) { | static std::string InputRWTypeToSerialString(InputRWType rw_type) { | ||||
| const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"}; | |||||
| const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"}; | |||||
| return names[static_cast<int>(rw_type)]; | return names[static_cast<int>(rw_type)]; | ||||
| } | } | ||||
| @@ -65,7 +72,7 @@ static std::string InputRWTypeToSerialString(InputRWType rw_type) { | |||||
| /// @return rw_type_name | /// @return rw_type_name | ||||
| /// | /// | ||||
| static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { | static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { | ||||
| const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"}; | |||||
| const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"}; | |||||
| return names[static_cast<int>(rw_type)]; | return names[static_cast<int>(rw_type)]; | ||||
| } | } | ||||
| @@ -118,13 +125,13 @@ InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) { | |||||
| } | } | ||||
| switch (total_rw_type) { | switch (total_rw_type) { | ||||
| case 0: | |||||
| case kCaseReadOnly: | |||||
| return InputRWType::kReadOnly; // all input rw type is readonly | return InputRWType::kReadOnly; // all input rw type is readonly | ||||
| case 2: | |||||
| case kCaseScopeWriteable: | |||||
| return InputRWType::kScopeWriteable; // readonly 2 scope_writeable | return InputRWType::kScopeWriteable; // readonly 2 scope_writeable | ||||
| case 3: | |||||
| case kCaseWriteable: | |||||
| return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable | return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable | ||||
| case 5: | |||||
| case kCaseInvalidRWType: | |||||
| return InputRWType::kInvalidRWType; // writeable 2 scope_writeable | return InputRWType::kInvalidRWType; // writeable 2 scope_writeable | ||||
| default: | default: | ||||
| return InputRWType::kInvalidRWType; | return InputRWType::kInvalidRWType; | ||||
| @@ -643,7 +650,7 @@ Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | ||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | ||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | ||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -614,32 +614,32 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
| } | } | ||||
| // flush parent node of subgraph | // flush parent node of subgraph | ||||
| sub_graph->SetParentNode(compute_graph->GetParentNode()); | sub_graph->SetParentNode(compute_graph->GetParentNode()); | ||||
| (void) AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | |||||
| auto sgi = MakeShared<SubGraphInfo>(); | |||||
| if (sgi == nullptr) { | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); | |||||
| return FAILED; | |||||
| } | |||||
| // set engine name | |||||
| sgi->SetEngineName(engine_name); | |||||
| // set stream label | |||||
| string sub_graph_stream; | |||||
| if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) { | |||||
| sgi->SetStreamLabel(sub_graph_stream); | |||||
| } | |||||
| /// for now inputFlag is the same before and after partition. It should | |||||
| /// be changed according to the real partition | |||||
| std::vector<bool> sub_graph_input(graph_info_.input_size_, true); | |||||
| std::vector<bool> sub_graph_output(graph_info_.output_size_, true); | |||||
| sgi->SetSubGraph(sub_graph); | |||||
| sgi->SetOutputFlag(sub_graph_output); | |||||
| sgi->SetInputFlag(sub_graph_input); | |||||
| sgi->SetOutputContext(graph_info_.output_name_); | |||||
| AddEndPldInformationToSubGraphInfo(sgi); | |||||
| GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s", | |||||
| engine_name.c_str(), | |||||
| sub_graph->GetName().c_str(), | |||||
| sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str()); | |||||
| (void)AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | |||||
| GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | |||||
| compute_graph->GetName().c_str()); | |||||
| auto sgi = MakeShared<SubGraphInfo>(); | |||||
| if (sgi == nullptr) { | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); | |||||
| return FAILED; | |||||
| } | |||||
| // set engine name | |||||
| sgi->SetEngineName(engine_name); | |||||
| // set stream label | |||||
| string sub_graph_stream; | |||||
| if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) { | |||||
| sgi->SetStreamLabel(sub_graph_stream); | |||||
| } | |||||
| /// for now inputFlag is the same before and after partition. It should | |||||
| /// be changed according to the real partition | |||||
| std::vector<bool> sub_graph_input(graph_info_.input_size_, true); | |||||
| std::vector<bool> sub_graph_output(graph_info_.output_size_, true); | |||||
| sgi->SetSubGraph(sub_graph); | |||||
| sgi->SetOutputFlag(sub_graph_output); | |||||
| sgi->SetInputFlag(sub_graph_input); | |||||
| sgi->SetOutputContext(graph_info_.output_name_); | |||||
| AddEndPldInformationToSubGraphInfo(sgi); | |||||
| GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s", engine_name.c_str(), | |||||
| sub_graph->GetName().c_str(), sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str()); | |||||
| if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo | if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo | ||||
| output_subgraphs.push_back(sgi); | output_subgraphs.push_back(sgi); | ||||
| } else { | } else { | ||||
| @@ -74,10 +74,88 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input | |||||
| bool AtomicAddrCleanPass::CheckAtomicFromOpsKernel(const NodePtr &node) { | |||||
| // 1.Check if isAtomic attrs exist for HCOM | |||||
| std::shared_ptr<GELib> instance_ptr = GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGW("GELib not initialized, atomic from ops kernel judge false, node_name: %s", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); | |||||
| vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(node->GetType()); | |||||
| for (const auto &op_info : op_info_vec) { | |||||
| if (op_info.isAtomic) { | |||||
| // check peer input is DATA | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
| auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (peer_in_node->GetType() == DATA) { | |||||
| GELOGI("Recognized atomic op %s from %s engine and input is DATA.", node->GetName().c_str(), | |||||
| op_info.engine.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| GELOGI("Recognized atomic op %s from %s engine.", node->GetName().c_str(), op_info.engine.c_str()); | |||||
| hcom_node_vec_.push_back(node); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index) { | |||||
| auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index); | |||||
| if (out_data_anchor == nullptr) { | |||||
| return false; | |||||
| } | |||||
| for (auto input_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto output_node = input_anchor->GetOwnerNode(); | |||||
| // just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input | |||||
| // hccl's attr ATOMIC_ATTR_INPUT_INDEX mark on CalcOpRunningParam, can't be get here | |||||
| if (CheckAtomicFromOpsKernel(output_node)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | |||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | |||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { | |||||
| std::vector<int64_t> atomic_output_index; | |||||
| (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | |||||
| bool is_all_output_peer_also_atomic = true; | |||||
| for (const auto &output_index : atomic_output_index) { | |||||
| if (!IsOutputIndexPeerInputAtomic(node, output_index)) { | |||||
| is_all_output_peer_also_atomic = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (is_all_output_peer_also_atomic) { | |||||
| GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str()); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) { | Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) { | ||||
| // Loop graph , insert clean node follow atomic node | // Loop graph , insert clean node follow atomic node | ||||
| int index = 0; | int index = 0; | ||||
| for (const auto &node : atomic_node_vec) { | for (const auto &node : atomic_node_vec) { | ||||
| if (CheckSkipInsertInLoopGraph(node)) { | |||||
| continue; | |||||
| } | |||||
| // Insert atomic clean op | // Insert atomic clean op | ||||
| NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); | NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); | ||||
| if (clean_addr_node == nullptr) { | if (clean_addr_node == nullptr) { | ||||
| @@ -249,32 +327,10 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // 1.Check if isAtomic attrs exist for HCOM | // 1.Check if isAtomic attrs exist for HCOM | ||||
| std::shared_ptr<GELib> instance_ptr = GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGW("GELib not initialized"); | |||||
| return false; | |||||
| if (CheckAtomicFromOpsKernel(node)) { | |||||
| return true; | |||||
| } | } | ||||
| OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); | |||||
| vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | |||||
| for (const auto &op_info : op_info_vec) { | |||||
| if (op_info.isAtomic) { | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); | |||||
| // check peer input is DATA | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
| auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (peer_in_node->GetType() == DATA) { | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| hcom_node_vec_.push_back(node); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| // 2.Check atomic attr in node | // 2.Check atomic attr in node | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | std::map<string, std::map<int, int>> node_workspace_offset; | ||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| @@ -84,6 +84,11 @@ class AtomicAddrCleanPass : public GraphPass { | |||||
| Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec, | Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec, | ||||
| std::vector<NodePtr> &common_atomic_nodes); | std::vector<NodePtr> &common_atomic_nodes); | ||||
| bool CheckAtomicFromOpsKernel(const NodePtr &node); | |||||
| bool IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index); | |||||
| bool CheckSkipInsertInLoopGraph(const NodePtr &node); | |||||
| vector<NodePtr> hcom_node_vec_; | vector<NodePtr> hcom_node_vec_; | ||||
| bool is_loop_graph_ = false; | bool is_loop_graph_ = false; | ||||
| @@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||||
| FindNodes(graph); | FindNodes(graph); | ||||
| for (const auto &node : need_label_nodes_) { | for (const auto &node : need_label_nodes_) { | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||||
| GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||||
| } | |||||
| GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | ||||
| @@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() { | |||||
| /// | /// | ||||
| void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | ||||
| for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
| const std::string &type = node->GetType(); | |||||
| if (type == STREAMSWITCH) { | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| const std::string &type = op_desc->GetType(); | |||||
| if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { | |||||
| stream_switch_nodes_.emplace_back(node); | stream_switch_nodes_.emplace_back(node); | ||||
| } else if (type == STREAMMERGE) { | |||||
| if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||||
| need_label_nodes_.emplace_back(node); | |||||
| } | |||||
| } else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||||
| need_label_nodes_.emplace_back(node); | |||||
| } else if ((type == ENTER) || (type == REFENTER)) { | } else if ((type == ENTER) || (type == REFENTER)) { | ||||
| enter_nodes_.emplace_back(node); | enter_nodes_.emplace_back(node); | ||||
| } | } | ||||
| @@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | |||||
| /// | /// | ||||
| Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | ||||
| std::string stream_label; | std::string stream_label; | ||||
| if (AttachFlag(node, stream_label) != SUCCESS) { | |||||
| GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::unordered_set<NodePtr> branch_nodes; | std::unordered_set<NodePtr> branch_nodes; | ||||
| std::unordered_set<NodePtr> visited; | std::unordered_set<NodePtr> visited; | ||||
| std::stack<NodePtr> nodes; | std::stack<NodePtr> nodes; | ||||
| nodes.push(node); | nodes.push(node); | ||||
| static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | ||||
| while (!nodes.empty()) { | while (!nodes.empty()) { | ||||
| NodePtr cur_node = nodes.top(); | NodePtr cur_node = nodes.top(); | ||||
| @@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||||
| if (visited.count(cur_node) > 0) { | if (visited.count(cur_node) > 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (AttachFlag(cur_node, stream_label) != SUCCESS) { | |||||
| GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| const std::string &type = cur_node->GetType(); | const std::string &type = cur_node->GetType(); | ||||
| for (const auto &out_node : cur_node->GetOutAllNodes()) { | for (const auto &out_node : cur_node->GetOutAllNodes()) { | ||||
| @@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||||
| visited.insert(cur_node); | visited.insert(cur_node); | ||||
| } | } | ||||
| if (node->GetType() == STREAMSWITCH) { | |||||
| GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||||
| } | |||||
| for (const NodePtr &tmp_node : branch_nodes) { | for (const NodePtr &tmp_node : branch_nodes) { | ||||
| GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | ||||
| @@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | ||||
| "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | ||||
| stream_label += (value ? "_t" : "_f"); | stream_label += (value ? "_t" : "_f"); | ||||
| GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||||
| } else if (type == STREAMMERGE) { | } else if (type == STREAMMERGE) { | ||||
| stream_label = node->GetName(); | stream_label = node->GetName(); | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | ||||
| } else if ((type == EXIT) || (type == REFEXIT)) { | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | ||||
| for (const auto &enter_node : enter_nodes_) { | for (const auto &enter_node : enter_nodes_) { | ||||
| for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | ||||
| if (out_ctrl_node->GetType() == STREAMACTIVE) { | |||||
| if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||||
| enter_active_map[out_ctrl_node] = {enter_node}; | |||||
| } else { | |||||
| enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||||
| } | |||||
| if (out_ctrl_node->GetType() != STREAMACTIVE) { | |||||
| continue; | |||||
| } | |||||
| if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||||
| enter_active_map[out_ctrl_node] = {enter_node}; | |||||
| } else { | |||||
| enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| std::string stream_label; | std::string stream_label; | ||||
| GE_CHECK_NOTNULL(active_node); | GE_CHECK_NOTNULL(active_node); | ||||
| (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | ||||
| if (stream_label.empty()) { | if (stream_label.empty()) { | ||||
| GELOGW("stream_label of enter_active & enter_nodes is empty."); | |||||
| GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { | |||||
| OutDataAnchorPtr cond_out_anchor = nullptr; | OutDataAnchorPtr cond_out_anchor = nullptr; | ||||
| InDataAnchorPtr cond_in_anchor = nullptr; | InDataAnchorPtr cond_in_anchor = nullptr; | ||||
| Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | ||||
| if (ret == NOT_CHANGED) { | |||||
| return SUCCESS; | |||||
| } else if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| int32_t cond_index = 0; | int32_t cond_index = 0; | ||||
| GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | ||||
| bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | ||||
| @@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, | |||||
| std::string type = node->GetType(); | std::string type = node->GetType(); | ||||
| if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | ||||
| if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | ||||
| GELOGE(FAILED, "Get cond_info for if node failed."); | |||||
| GELOGE(FAILED, "Get cond_info for if/case node failed."); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } else { | } else { | ||||
| GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); | |||||
| GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str()); | |||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| @@ -38,7 +38,6 @@ namespace ge { | |||||
| * \ / | * \ / | ||||
| * B | * B | ||||
| */ | */ | ||||
| Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { | Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { | ||||
| GELOGD("CtrlEdgeTransferPass start running"); | GELOGD("CtrlEdgeTransferPass start running"); | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| @@ -21,6 +21,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const int kDataIndexOffset = 2; | |||||
| Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) { | Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) { | ||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| if (node->GetType() != DATA) { | if (node->GetType() != DATA) { | ||||
| @@ -111,7 +112,7 @@ Status ParseSubgraphPostFnWhile(const string &subgraph_name, const ComputeGraphP | |||||
| Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { | Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { | ||||
| return MappingSubgraphIndex(graph, | return MappingSubgraphIndex(graph, | ||||
| [](int data_index) { return (data_index == 0) ? 0 : data_index + 2; }, | |||||
| [](int data_index) { return (data_index == 0) ? 0 : data_index + kDataIndexOffset; }, | |||||
| [](int retval_index) { return retval_index; }); | [](int retval_index) { return retval_index; }); | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "graph/passes/enter_pass.h" | #include "graph/passes/enter_pass.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| @@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| } | } | ||||
| Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | ||||
| auto out_nodes_of_in_node = in_node->GetOutAllNodes(); | |||||
| if (out_nodes_of_in_node.size() != kOutNodesNum) { | |||||
| if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (!node->GetOutControlNodes().empty()) { | |||||
| bool is_constant_flag = true; | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); | |||||
| if (!is_constant_flag) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| for (const auto &out_node : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (out_node->GetType() == MERGE) { | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | ||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | ||||
| auto out_data_anchor = node->GetOutDataAnchor(0); | |||||
| const auto &out_data_anchor = node->GetOutDataAnchor(0); | |||||
| GE_CHECK_NOTNULL(out_data_anchor); | GE_CHECK_NOTNULL(out_data_anchor); | ||||
| for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | ||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | ||||
| } | } | ||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||||
| AddNodeDeleted(node); | |||||
| AddRePassNodesWithInOut(in_node); | AddRePassNodesWithInOut(in_node); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -37,6 +37,7 @@ namespace { | |||||
| const uint32_t kSubgraphLoopVarInputIndex = 0; | const uint32_t kSubgraphLoopVarInputIndex = 0; | ||||
| const uint32_t kSubgraphInputIndex = 1; | const uint32_t kSubgraphInputIndex = 1; | ||||
| const uint32_t kWhileOutputIndex = 5; | const uint32_t kWhileOutputIndex = 5; | ||||
| const size_t kIDiffValue = 2; | |||||
| const std::string kAbs = "Abs"; | const std::string kAbs = "Abs"; | ||||
| } | } | ||||
| @@ -137,7 +138,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n | |||||
| for_info.ctrl_inputs = std::move(ctrl_inputs); | for_info.ctrl_inputs = std::move(ctrl_inputs); | ||||
| for_info.ctrl_outputs = std::move(ctrl_outputs); | for_info.ctrl_outputs = std::move(ctrl_outputs); | ||||
| GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); | |||||
| GELOGI("Build for_info for node %s success.", node->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -159,13 +160,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor == nullptr) { | |||||
| GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); | |||||
| return nullptr; | |||||
| } | |||||
| return peer_out_anchor; | |||||
| return in_data_anchor->GetPeerOutAnchor(); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -186,20 +181,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||||
| uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | ||||
| for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | ||||
| InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | ||||
| if (in_data_anchor == nullptr) { | |||||
| GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); | |||||
| return FAILED; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, | |||||
| GELOGW("Get null input by index %d from node %s ", | |||||
| in_data_anchor->GetIdx(), node->GetName().c_str()); | |||||
| continue); | |||||
| GE_CHECK_NOTNULL(in_data_anchor); | |||||
| data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | ||||
| } | } | ||||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | ||||
| for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| peer_in_data_anchors.emplace_back(peer_in_data_anchor); | peer_in_data_anchors.emplace_back(peer_in_data_anchor); | ||||
| } | } | ||||
| data_outputs.emplace_back(peer_in_data_anchors); | data_outputs.emplace_back(peer_in_data_anchors); | ||||
| @@ -207,13 +195,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||||
| InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | ||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | GE_CHECK_NOTNULL(in_ctrl_anchor); | ||||
| for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||||
| for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||||
| ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | ||||
| } | } | ||||
| OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | ||||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | GE_CHECK_NOTNULL(out_ctrl_anchor); | ||||
| for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | ||||
| } | } | ||||
| @@ -707,7 +695,7 @@ Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) { | |||||
| } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { | } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { | ||||
| continue; | continue; | ||||
| } else { | } else { | ||||
| input_mapping[i] = i - 2; | |||||
| input_mapping[i] = i - kIDiffValue; | |||||
| } | } | ||||
| } | } | ||||
| for_body->UpdateInputMapping(input_mapping); | for_body->UpdateInputMapping(input_mapping); | ||||
| @@ -19,6 +19,8 @@ | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| const size_t kTwoInputNodesSize = 2; | |||||
| Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
| @@ -52,7 +54,7 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
| /// Enter-----------+ | /// Enter-----------+ | ||||
| /// +-> Merge | /// +-> Merge | ||||
| /// NextIteration---+ | /// NextIteration---+ | ||||
| if (input_nodes.size() == 2) { | |||||
| if (input_nodes.size() == kTwoInputNodesSize) { | |||||
| if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { | if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -21,18 +21,16 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| using domi::PARAM_INVALID; | |||||
| using domi::SUCCESS; | |||||
| namespace ge { | namespace ge { | ||||
| const int kValueIndexOutputIndex = 1; | const int kValueIndexOutputIndex = 1; | ||||
| const size_t kCaseNoInput = 0; | |||||
| const size_t kCaseOneInput = 1; | |||||
| Status MergePass::Run(NodePtr &node) { | Status MergePass::Run(NodePtr &node) { | ||||
| GELOGD("MergePass running"); | GELOGD("MergePass running"); | ||||
| @@ -47,15 +45,14 @@ Status MergePass::Run(NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto out_data_anchors = node->GetAllOutDataAnchors(); | |||||
| if (out_data_anchors.empty()) { | |||||
| if (node->GetAllOutDataAnchors().empty()) { | |||||
| GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| auto in_data_nodes = node->GetInDataNodes(); | |||||
| const auto &in_data_nodes = node->GetInDataNodes(); | |||||
| switch (in_data_nodes.size()) { | switch (in_data_nodes.size()) { | ||||
| case 0: { | |||||
| case kCaseNoInput: { | |||||
| /// Case A: input_count = 0, the output of merge node is inactive as well | /// Case A: input_count = 0, the output of merge node is inactive as well | ||||
| /// In which case the output branch can be removed | /// In which case the output branch can be removed | ||||
| /// until another merge node is met | /// until another merge node is met | ||||
| @@ -70,7 +67,7 @@ Status MergePass::Run(NodePtr &node) { | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| case 1: { // Case B: input_count = 1, the merge node can be optimized out | |||||
| case kCaseOneInput: { // Case B: input_count = 1, the merge node can be optimized out | |||||
| std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1}; | std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1}; | ||||
| if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) { | if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) { | ||||
| int index = merge_io_map[0]; | int index = merge_io_map[0]; | ||||
| @@ -22,9 +22,6 @@ | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| Status MultiBatchPass::Run(ComputeGraphPtr graph) { | Status MultiBatchPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("MultiBatchPass Enter"); | GELOGD("MultiBatchPass Enter"); | ||||
| @@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::vector<std::vector<int64_t>> batch_shape; | std::vector<std::vector<int64_t>> batch_shape; | ||||
| vector<vector<int64_t>> combined_batch; | |||||
| std::vector<std::vector<int64_t>> combined_batch; | |||||
| if (!CheckSwitchN(batch_shape, combined_batch)) { | if (!CheckSwitchN(batch_shape, combined_batch)) { | ||||
| GELOGE(FAILED, "CheckSwitchN failed."); | GELOGE(FAILED, "CheckSwitchN failed."); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { | |||||
| /// | /// | ||||
| Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | ||||
| const auto &func_desc = case_node->GetOpDesc(); | const auto &func_desc = case_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(func_desc); | |||||
| if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | ||||
| GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr | |||||
| const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | ||||
| GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
| const string batch_label = "Batch_" + std::to_string(i); | |||||
| const std::string batch_label = "Batch_" + std::to_string(i); | |||||
| for (const auto &node : subgraph->GetDirectNode()) { | for (const auto &node : subgraph->GetDirectNode()) { | ||||
| (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| } | } | ||||
| @@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||||
| continue; | continue; | ||||
| } | } | ||||
| InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
| const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
| if (in_data_anchor == nullptr) { | if (in_data_anchor == nullptr) { | ||||
| GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); | |||||
| const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (pred_input == nullptr) { | if (pred_input == nullptr) { | ||||
| GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status MultiBatchPass::GetDynamicType() { | Status MultiBatchPass::GetDynamicType() { | ||||
| for (const auto &switchn : switch_n_nodes_) { | |||||
| auto switchn_desc = switchn->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(switchn_desc); | |||||
| for (const auto &switch_n : switch_n_nodes_) { | |||||
| int32_t dynamic_type = static_cast<int32_t>(FIXED); | int32_t dynamic_type = static_cast<int32_t>(FIXED); | ||||
| if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||||
| GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); | |||||
| if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||||
| GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (dynamic_type == static_cast<int32_t>(FIXED)) { | if (dynamic_type == static_cast<int32_t>(FIXED)) { | ||||
| @@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | ||||
| GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", | |||||
| GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", | |||||
| dynamic_type, dynamic_type_); | dynamic_type, dynamic_type_); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { | |||||
| Status MultiBatchPass::GetUserDesignateShape() { | Status MultiBatchPass::GetUserDesignateShape() { | ||||
| data_name_order_.clear(); | data_name_order_.clear(); | ||||
| bool first_check = true; | bool first_check = true; | ||||
| for (const auto &switchn : switch_n_nodes_) { | |||||
| auto switchn_desc = switchn->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(switchn_desc); | |||||
| vector<string> cur_switchn_data_name_order; | |||||
| if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { | |||||
| GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); | |||||
| for (const auto &switch_n : switch_n_nodes_) { | |||||
| std::vector<std::string> cur_data_name_order; | |||||
| if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { | |||||
| GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (first_check) { | if (first_check) { | ||||
| data_name_order_ = cur_switchn_data_name_order; | |||||
| data_name_order_ = cur_data_name_order; | |||||
| first_check = false; | first_check = false; | ||||
| } else { | } else { | ||||
| if (data_name_order_ != cur_switchn_data_name_order) { | |||||
| if (data_name_order_ != cur_data_name_order) { | |||||
| GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | ||||
| switchn->GetName().c_str()); | |||||
| switch_n->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { | |||||
| /// @param [out] combined_batch | /// @param [out] combined_batch | ||||
| /// @return bool | /// @return bool | ||||
| /// | /// | ||||
| bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) { | |||||
| bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape, | |||||
| std::vector<std::vector<int64_t>> &combined_batch) { | |||||
| // Check if output_num of different SwitchN is same | // Check if output_num of different SwitchN is same | ||||
| uint32_t batch_num = 0; | uint32_t batch_num = 0; | ||||
| for (const NodePtr &node : switch_n_nodes_) { | for (const NodePtr &node : switch_n_nodes_) { | ||||
| @@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||||
| } | } | ||||
| size_t tmp_combined_dim_num = combined_batch[i].size(); | size_t tmp_combined_dim_num = combined_batch[i].size(); | ||||
| if (combined_dim_num != tmp_combined_dim_num) { | if (combined_dim_num != tmp_combined_dim_num) { | ||||
| GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); | |||||
| GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", | |||||
| combined_dim_num, i, tmp_combined_dim_num); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||||
| /// @param [out] combined_batch | /// @param [out] combined_batch | ||||
| /// @return bool | /// @return bool | ||||
| /// | /// | ||||
| bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape, | |||||
| vector<vector<int64_t>> &combined_batch) { | |||||
| bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape, | |||||
| std::vector<std::vector<int64_t>> &combined_batch) { | |||||
| // Check if output_shape of different SwitchN is same | // Check if output_shape of different SwitchN is same | ||||
| vector<vector<int64_t>> idx_batch_shape; | |||||
| vector<vector<int64_t>> idx_combined_batch; | |||||
| std::vector<std::vector<int64_t>> idx_batch_shape; | |||||
| std::vector<std::vector<int64_t>> idx_combined_batch; | |||||
| for (uint32_t i = 0; i < batch_num; i++) { | for (uint32_t i = 0; i < batch_num; i++) { | ||||
| idx_batch_shape.clear(); | idx_batch_shape.clear(); | ||||
| idx_combined_batch.clear(); | idx_combined_batch.clear(); | ||||
| @@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b | |||||
| GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| vector<int64_t> output_dims; | |||||
| std::vector<int64_t> output_dims; | |||||
| if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | ||||
| GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | ||||
| return false; | return false; | ||||
| @@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | ||||
| const vector<vector<int64_t>> &batch_shape, | |||||
| const vector<vector<int64_t>> &combined_batch) { | |||||
| const std::vector<std::vector<int64_t>> &batch_shape, | |||||
| const std::vector<std::vector<int64_t>> &combined_batch) { | |||||
| NodePtr pred_value_node = pred_value->GetOwnerNode(); | NodePtr pred_value_node = pred_value->GetOwnerNode(); | ||||
| // Create SwitchCase node | // Create SwitchCase node | ||||
| const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | ||||
| @@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t num = output_shape.size(); | |||||
| size_t dim_num = output_shape[0].size(); | |||||
| for (size_t i = 1; i < num; i++) { | |||||
| size_t tmp_dim_num = output_shape[i].size(); | |||||
| if (dim_num != tmp_dim_num) { | |||||
| GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); | |||||
| for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { | |||||
| if (output_shape[0] != *iter) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| if (dim_num == 0) { | |||||
| return true; | |||||
| } | |||||
| for (size_t i = 0; i < dim_num; i++) { | |||||
| int64_t dim_value = output_shape[0][i]; | |||||
| for (size_t j = 1; j < num; j++) { | |||||
| int64_t tmp_dim_value = output_shape[j][i]; | |||||
| if (dim_value != tmp_dim_value) { | |||||
| GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, | |||||
| dim_value, j, tmp_dim_value); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||||
| /// | /// | ||||
| NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
| const OutDataAnchorPtr &pred_value, | const OutDataAnchorPtr &pred_value, | ||||
| const vector<vector<int64_t>> &batch_shape, | |||||
| const vector<vector<int64_t>> &combined_batch) { | |||||
| const std::vector<std::vector<int64_t>> &batch_shape, | |||||
| const std::vector<std::vector<int64_t>> &combined_batch) { | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | ||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | ||||
| @@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const | |||||
| GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||||
| const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||||
| if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | ||||
| GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -37,10 +37,6 @@ | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const uint32_t kShapeDimSize = 1; | |||||
| const uint32_t DIM_SIZE_TWO = 2; | |||||
| } // namespace | |||||
| Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data, | Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data, | ||||
| std::vector<GeTensorPtr> &v_output, const bool scalar_output) { | std::vector<GeTensorPtr> &v_output, const bool scalar_output) { | ||||
| @@ -149,10 +149,10 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node | |||||
| // 5. While->NetOutput in known subgraph | // 5. While->NetOutput in known subgraph | ||||
| std::string op_type; | std::string op_type; | ||||
| bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || | bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || | ||||
| IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | |||||
| ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | |||||
| (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | |||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || | |||||
| ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || | |||||
| (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && | |||||
| (kWhileOpTypes.count(in_node->GetType()) != 0)); | |||||
| if (insert_flag) { | if (insert_flag) { | ||||
| GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); | ||||
| std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; | ||||
| @@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra | |||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | ||||
| for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
| GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | ||||
| if ((type == SWITCH) || (type == REFSWITCH)) { | |||||
| InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
| GE_CHECK_NOTNULL(in_cond_anchor); | |||||
| OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { | |||||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if ((type != SWITCH) && (type != REFSWITCH)) { | |||||
| continue; | |||||
| } | |||||
| InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
| GE_CHECK_NOTNULL(in_cond_anchor); | |||||
| OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { | |||||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||||
| auto iter = cond_switch_map.find(cond_node); | |||||
| if (iter == cond_switch_map.end()) { | |||||
| cond_switch_map[cond_node] = { node }; | |||||
| } else { | |||||
| iter->second.emplace_back(node); | |||||
| } | |||||
| switch_nodes_.emplace_back(node); | |||||
| NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||||
| auto iter = cond_switch_map.find(cond_node); | |||||
| if (iter == cond_switch_map.end()) { | |||||
| cond_switch_map[cond_node] = { node }; | |||||
| } else { | |||||
| iter->second.emplace_back(node); | |||||
| } | } | ||||
| switch_nodes_.emplace_back(node); | |||||
| } | } | ||||
| MarkCycleDependence(cond_switch_map); | MarkCycleDependence(cond_switch_map); | ||||
| @@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||||
| if (idx == SWITCH_DATA_INPUT) { | if (idx == SWITCH_DATA_INPUT) { | ||||
| peer_data_anchor = peer_out_anchor; | peer_data_anchor = peer_out_anchor; | ||||
| } else { | } else { | ||||
| if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { | |||||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| peer_cond_anchor = peer_out_anchor; | peer_cond_anchor = peer_out_anchor; | ||||
| } | } | ||||
| } | } | ||||
| @@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||||
| /// | /// | ||||
| /// @brief Find Switch cond input | /// @brief Find Switch cond input | ||||
| /// @param [in] pass_switch_flag | |||||
| /// @param [out] peer_cond_anchor | /// @param [out] peer_cond_anchor | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { | |||||
| Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { | |||||
| NodePtr tmp_node = nullptr; | NodePtr tmp_node = nullptr; | ||||
| string type; | |||||
| bool need_pass_type = true; | |||||
| while (need_pass_type) { | |||||
| std::string type; | |||||
| bool pass_flag = true; | |||||
| while (pass_flag) { | |||||
| if (tmp_node == nullptr) { | if (tmp_node == nullptr) { | ||||
| tmp_node = peer_cond_anchor->GetOwnerNode(); | tmp_node = peer_cond_anchor->GetOwnerNode(); | ||||
| } else { | } else { | ||||
| @@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | ||||
| need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); | |||||
| pass_flag = ((type == SWITCH) || (type == REFSWITCH)); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||||
| } | } | ||||
| } else { | } else { | ||||
| int64_t switch_group_id = GetGroupId(stream_switch); | int64_t switch_group_id = GetGroupId(stream_switch); | ||||
| map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||||
| std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||||
| std::list<NodePtr> false_node_list; | std::list<NodePtr> false_node_list; | ||||
| std::list<NodePtr> true_node_list; | std::list<NodePtr> true_node_list; | ||||
| std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | ||||
| @@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||||
| /// @return group_id | /// @return group_id | ||||
| /// | /// | ||||
| int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | ||||
| string tailing_optimization_option; | |||||
| std::string tailing_optimization_option; | |||||
| bool is_tailing_optimization = false; | bool is_tailing_optimization = false; | ||||
| if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | ||||
| // "1" means it's True from frontend option | // "1" means it's True from frontend option | ||||
| @@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| string hccl_group_id; | |||||
| std::string hccl_group_id; | |||||
| if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | ||||
| GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | ||||
| return 0; | return 0; | ||||
| @@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | ||||
| OutDataAnchorPtr peer_cond_anchor = iter->first; | OutDataAnchorPtr peer_cond_anchor = iter->first; | ||||
| GE_CHECK_NOTNULL(peer_cond_anchor); | |||||
| NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | ||||
| GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | ||||
| @@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con | |||||
| NodePtr cast_node = graph->AddNode(cast_desc); | NodePtr cast_node = graph->AddNode(cast_desc); | ||||
| GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | ||||
| // Cast node has and only has one input | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | ||||
| return cast_node; | return cast_node; | ||||
| @@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { | |||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||||
| for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { | |||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||||
| "Remove ctl edge failed."); | "Remove ctl edge failed."); | ||||
| GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
| GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
| "Add ctl edge failed."); | "Add ctl edge failed."); | ||||
| }); | }); | ||||
| GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); | |||||
| if (same_cond_switch.count(in_ctl_node) > 0) { | |||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
| GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); | |||||
| if (same_cond_switch.count(in_ctrl_node) > 0) { | |||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
| "Remove ctl edge failed."); | "Remove ctl edge failed."); | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto find_res1 = switch_node_map_.find(in_ctl_node); | |||||
| auto find_res1 = switch_node_map_.find(in_ctrl_node); | |||||
| GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | ||||
| GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| }); | }); | ||||
| auto find_res2 = find_res1->second.find(orig_switch_name); | auto find_res2 = find_res1->second.find(orig_switch_name); | ||||
| @@ -42,9 +42,9 @@ namespace ge { | |||||
| +-----------+ +-----------+ | +-----------+ +-----------+ | ||||
| | Const | | VariableV2| | | Const | | VariableV2| | ||||
| +-----------+ +-----------+ | +-----------+ +-----------+ | ||||
| */ | |||||
| /* Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input | |||||
| Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input | |||||
| +-----------+ | +-----------+ | ||||
| / | task2 | \ | / | task2 | \ | ||||
| @@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { | |||||
| /// | /// | ||||
| /// @brief Find Switch cond input | /// @brief Find Switch cond input | ||||
| /// @param [in] pass_switch_flag | |||||
| /// @param [out] peer_cond_anchor | /// @param [out] peer_cond_anchor | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); | |||||
| Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); | |||||
| /// | /// | ||||
| /// @brief Create StreamSwitch Node | /// @brief Create StreamSwitch Node | ||||
| @@ -70,8 +70,10 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No | |||||
| trans_data_type = true; | trans_data_type = true; | ||||
| trans_format = true; | trans_format = true; | ||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == RESHAPE) { | |||||
| } else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { | |||||
| trans_shape = true; | trans_shape = true; | ||||
| } else if (node->GetType() == REFORMAT) { | |||||
| trans_format = true; | |||||
| } | } | ||||
| id << node->GetType() << '-' << anchor_index; | id << node->GetType() << '-' << anchor_index; | ||||