| @@ -1,8 +1,8 @@ | |||||
| [submodule "parser"] | [submodule "parser"] | ||||
| path = parser | path = parser | ||||
| url = https://gitee.com/ascend/parser.git | url = https://gitee.com/ascend/parser.git | ||||
| branch = development | |||||
| branch = r1.2.0 | |||||
| [submodule "metadef"] | [submodule "metadef"] | ||||
| path = metadef | path = metadef | ||||
| url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
| branch = development | |||||
| branch = r1.2.0 | |||||
| @@ -16,11 +16,8 @@ endif() | |||||
| if(DEFINED ENV{D_PKG_SERVER}) | if(DEFINED ENV{D_PKG_SERVER}) | ||||
| set(GE_PB_PKG $ENV{D_PKG_SERVER}) | set(GE_PB_PKG $ENV{D_PKG_SERVER}) | ||||
| message("Download packages from DPKG server") | |||||
| elseif(DEFINED ENV{MSLIBS_SERVER}) | |||||
| set(GE_PB_PKG "http://$ENV{MSLIBS_SERVER}:8081") | |||||
| message("Download packages from MSPKG server") | |||||
| endif () | |||||
| message("Download packages from PKG server") | |||||
| endif() | |||||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | ||||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | ||||
| @@ -85,8 +82,6 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) | find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) | ||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${GE_LIB_PATH}) | find_module(msprofiler_fwk libmsprofiler_fwk.a ${GE_LIB_PATH}) | ||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
| elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| add_subdirectory(tests) | |||||
| else() | else() | ||||
| find_module(slog libslog.so ${ASCEND_ATC_DIR}) | find_module(slog libslog.so ${ASCEND_ATC_DIR}) | ||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | ||||
| @@ -110,7 +105,7 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | ||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | |||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | |||||
| if(PRODUCT STREQUAL "flr3") | if(PRODUCT STREQUAL "flr3") | ||||
| elseif(PRODUCT STREQUAL "flr1") | elseif(PRODUCT STREQUAL "flr1") | ||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
| @@ -120,7 +115,7 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | ||||
| endif() | endif() | ||||
| elseif(PLATFORM STREQUAL "all") | elseif(PLATFORM STREQUAL "all") | ||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| find_module(msprofiler libmsprofiler.a ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | ||||
| @@ -128,12 +123,17 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(resource libresource.so ${ASCEND_ATC_DIR}) | find_module(resource libresource.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | ||||
| find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) | ||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_ACL_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
| #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | ||||
| else() | else() | ||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| endif() | endif() | ||||
| if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| add_subdirectory(tests) | |||||
| endif() | |||||
| endif() | endif() | ||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | ||||
| @@ -224,14 +224,12 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| # fi | # fi | ||||
| # if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | # if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | ||||
| echo "Generating coverage statistics, please wait..." | |||||
| cd ${BASEPATH} | |||||
| rm -rf ${BASEPATH}/cov | |||||
| mkdir ${BASEPATH}/cov | |||||
| lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info | |||||
| lcov --remove cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' -o cov/coverage.info | |||||
| cd ${BASEPATH}/cov | |||||
| genhtml coverage.info | |||||
| # echo "Generating coverage statistics, please wait..." | |||||
| # cd ${BASEPATH} | |||||
| # rm -rf ${BASEPATH}/cov | |||||
| # mkdir ${BASEPATH}/cov | |||||
| # gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html | |||||
| # fi | |||||
| fi | fi | ||||
| # generate output package in tar form, including ut/st libraries/executables | # generate output package in tar form, including ut/st libraries/executables | ||||
| @@ -21,7 +21,7 @@ function(find_module module name) | |||||
| if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | ||||
| message(FATAL_ERROR "${name} not found in ${path}") | message(FATAL_ERROR "${name} not found in ${path}") | ||||
| endif() | endif() | ||||
| add_library(${module} SHARED IMPORTED) | add_library(${module} SHARED IMPORTED) | ||||
| set_target_properties(${module} PROPERTIES | set_target_properties(${module} PROPERTIES | ||||
| IMPORTED_LOCATION ${${module}_LIBRARY_DIR} | IMPORTED_LOCATION ${${module}_LIBRARY_DIR} | ||||
| @@ -23,7 +23,6 @@ ExternalProject_Add(gflags_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 | #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | ||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| INSTALL_COMMAND $(MAKE) install | INSTALL_COMMAND $(MAKE) install | ||||
| @@ -10,10 +10,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/ge_gtest/release-1.8.0.tar.gz") | |||||
| set(MD5 "") | |||||
| elseif (ENABLE_GITEE) | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | ||||
| set(MD5 "") | set(MD5 "") | ||||
| else() | else() | ||||
| @@ -25,9 +22,8 @@ set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack- | |||||
| set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(gtest_build | ExternalProject_Add(gtest_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | ||||
| -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||||
| -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| INSTALL_COMMAND $(MAKE) install | INSTALL_COMMAND $(MAKE) install | ||||
| EXCLUDE_FROM_ALL TRUE | EXCLUDE_FROM_ALL TRUE | ||||
| @@ -5,24 +5,19 @@ endif() | |||||
| include(ExternalProject) | include(ExternalProject) | ||||
| set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | ||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| #elseif (ENABLE_GITEE) | |||||
| #if (ENABLE_GITEE) | |||||
| # set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | # set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | ||||
| # set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | # set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | ||||
| #set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| endif () | |||||
| # set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| #else() | |||||
| set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| #endif () | |||||
| ExternalProject_Add(json_build | ExternalProject_Add(json_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/cloud_code/pkg/include.zip | #URL /home/txd/workspace/cloud_code/pkg/include.zip | ||||
| SOURCE_DIR ${JSON_SRC_DIR} | SOURCE_DIR ${JSON_SRC_DIR} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | CONFIGURE_COMMAND "" | ||||
| BUILD_COMMAND "" | BUILD_COMMAND "" | ||||
| INSTALL_COMMAND "" | INSTALL_COMMAND "" | ||||
| @@ -6,10 +6,7 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) | |||||
| set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | ||||
| file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | ||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz") | |||||
| set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") | |||||
| elseif (ENABLE_GITEE) | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | ||||
| set(MD5 "1bdbcecdd68ea8392630467646776e02") | set(MD5 "1bdbcecdd68ea8392630467646776e02") | ||||
| else() | else() | ||||
| @@ -22,7 +19,6 @@ ExternalProject_Add(onnx | |||||
| #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | ||||
| #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | ||||
| #SOURCE_DIR ${ONNX_SRC_DIR} | #SOURCE_DIR ${ONNX_SRC_DIR} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND "" | CONFIGURE_COMMAND "" | ||||
| BUILD_COMMAND "" | BUILD_COMMAND "" | ||||
| #INSTALL_COMMAND "" | #INSTALL_COMMAND "" | ||||
| @@ -26,7 +26,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(protobuf_build | ExternalProject_Add(protobuf_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -Dprotobuf_WITH_ZLIB=OFF | -Dprotobuf_WITH_ZLIB=OFF | ||||
| -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | ||||
| @@ -27,7 +27,6 @@ ExternalProject_Add(protobuf_static_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | ||||
| @@ -30,7 +30,6 @@ ExternalProject_Add(protoc_build | |||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | ||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| INSTALL_COMMAND $(MAKE) install | INSTALL_COMMAND $(MAKE) install | ||||
| @@ -10,20 +10,11 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| ExternalProject_Add(c_sec_build | ExternalProject_Add(c_sec_build | ||||
| URL ${REQ_URL} | |||||
| #URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||||
| URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec | #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec | ||||
| PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch | PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch | ||||
| TLS_VERIFY OFF | |||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||||
| -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | ||||
| @@ -157,8 +157,6 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
| "graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
| "graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
| "graph/passes/remove_same_const_pass.cc" | |||||
| "graph/passes/useless_control_out_remove_pass.cc" | |||||
| "graph/passes/control_trigger_pass.cc" | "graph/passes/control_trigger_pass.cc" | ||||
| "graph/passes/dimension_adjust_pass.cc" | "graph/passes/dimension_adjust_pass.cc" | ||||
| "graph/passes/dimension_compute_pass.cc" | "graph/passes/dimension_compute_pass.cc" | ||||
| @@ -524,8 +522,6 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/assign_pass.cc" | "graph/passes/assign_pass.cc" | ||||
| "graph/passes/addn_pass.cc" | "graph/passes/addn_pass.cc" | ||||
| "graph/passes/common_subexpression_elimination_pass.cc" | "graph/passes/common_subexpression_elimination_pass.cc" | ||||
| "graph/passes/remove_same_const_pass.cc" | |||||
| "graph/passes/useless_control_out_remove_pass.cc" | |||||
| "graph/passes/transop_symmetry_elimination_pass.cc" | "graph/passes/transop_symmetry_elimination_pass.cc" | ||||
| "graph/passes/save_pass.cc" | "graph/passes/save_pass.cc" | ||||
| "graph/passes/switch_dead_branch_elimination.cc" | "graph/passes/switch_dead_branch_elimination.cc" | ||||
| @@ -611,7 +607,7 @@ set(INFER_SRC_LIST | |||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | ||||
| ############ libge_runner.so ############ | ############ libge_runner.so ############ | ||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} $<TARGET_OBJECTS:msprofiler_fwk>) | |||||
| target_compile_definitions(ge_runner PRIVATE | target_compile_definitions(ge_runner PRIVATE | ||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| @@ -652,14 +648,11 @@ target_include_directories(ge_runner PRIVATE | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
| ) | ) | ||||
| target_link_libraries(ge_runner PRIVATE | |||||
| target_link_libraries(ge_runner | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_memory | ge_memory | ||||
| adump_server | adump_server | ||||
| static_mmpa | static_mmpa | ||||
| -Wl,--whole-archive | |||||
| msprofiler_fwk | |||||
| -Wl,--no-whole-archive | |||||
| -Wl,--no-as-needed | -Wl,--no-as-needed | ||||
| graph | graph | ||||
| ge_common | ge_common | ||||
| @@ -719,7 +712,7 @@ target_include_directories(ge_compiler PRIVATE | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
| ) | ) | ||||
| target_link_libraries(ge_compiler PRIVATE | |||||
| target_link_libraries(ge_compiler | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_memory | ge_memory | ||||
| static_mmpa | static_mmpa | ||||
| @@ -773,14 +766,7 @@ target_link_options(opensrc_ascendcl PRIVATE | |||||
| -Wl,--allow-multiple-definition | -Wl,--allow-multiple-definition | ||||
| -Wl,-z,muldefs | -Wl,-z,muldefs | ||||
| -Wl,-Bsymbolic | -Wl,-Bsymbolic | ||||
| -Wl,--exclude-libs,libascend_protobuf.a | |||||
| -Wl,--exclude-libs,libge_executor.a | |||||
| -Wl,--exclude-libs,libge_common.a | |||||
| -Wl,--exclude-libs,libgraph.a | |||||
| -Wl,--exclude-libs,libmmpa.a | |||||
| -Wl,--exclude-libs,libregister.a | |||||
| -Wl,--exclude-libs,liberror_manager.a | |||||
| -Wl,--exclude-libs,libadump_server.a | |||||
| -Wl,--exclude-libs,ALL | |||||
| ) | ) | ||||
| target_link_libraries(opensrc_ascendcl PRIVATE | target_link_libraries(opensrc_ascendcl PRIVATE | ||||
| -Wl,--whole-archive | -Wl,--whole-archive | ||||
| @@ -94,9 +94,6 @@ Status DumpOp::DumpOutput(aicpu::dump::Task &task) { | |||||
| for (auto dim : output_descs.at(i).GetShape().GetDims()) { | for (auto dim : output_descs.at(i).GetShape().GetDims()) { | ||||
| output.mutable_shape()->add_dim(dim); | output.mutable_shape()->add_dim(dim); | ||||
| } | } | ||||
| for (auto dim : output_descs.at(i).GetOriginShape().GetDims()) { | |||||
| output.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t output_size = 0; | int64_t output_size = 0; | ||||
| if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { | if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "Get output size filed"); | GELOGE(PARAM_INVALID, "Get output size filed"); | ||||
| @@ -121,9 +118,6 @@ Status DumpOp::DumpInput(aicpu::dump::Task &task) { | |||||
| for (auto dim : input_descs.at(i).GetShape().GetDims()) { | for (auto dim : input_descs.at(i).GetShape().GetDims()) { | ||||
| input.mutable_shape()->add_dim(dim); | input.mutable_shape()->add_dim(dim); | ||||
| } | } | ||||
| for (auto dim : input_descs.at(i).GetOriginShape().GetDims()) { | |||||
| input.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t input_size = 0; | int64_t input_size = 0; | ||||
| if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { | if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "Get output size filed"); | GELOGE(PARAM_INVALID, "Get output size filed"); | ||||
| @@ -220,15 +214,8 @@ Status DumpOp::LaunchDumpOp() { | |||||
| SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); | SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); | ||||
| GELOGI("Dump step is %s ,dump path is %s ,in Launch dump op", dump_properties_.GetDumpStep().c_str(), | GELOGI("Dump step is %s ,dump path is %s ,in Launch dump op", dump_properties_.GetDumpStep().c_str(), | ||||
| dump_path.c_str()); | dump_path.c_str()); | ||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGW("call rtGetTaskIdAndStreamID failed, ret = 0x%X", rt_ret); | |||||
| } | |||||
| aicpu::dump::Task task; | aicpu::dump::Task task; | ||||
| task.set_task_id(task_id); | |||||
| task.set_stream_id(stream_id); | |||||
| task.mutable_op()->set_op_name(op_desc_->GetName()); | task.mutable_op()->set_op_name(op_desc_->GetName()); | ||||
| task.mutable_op()->set_op_type(op_desc_->GetType()); | task.mutable_op()->set_op_type(op_desc_->GetType()); | ||||
| if (dump_properties_.GetDumpMode() == kDumpOutput) { | if (dump_properties_.GetDumpMode() == kDumpOutput) { | ||||
| @@ -184,7 +184,7 @@ void TBEPluginManager::LoadCustomOpLib() { | |||||
| std::string fmk_type = std::to_string(domi::TENSORFLOW); | std::string fmk_type = std::to_string(domi::TENSORFLOW); | ||||
| auto it = options_.find(ge::FRAMEWORK_TYPE); | auto it = options_.find(ge::FRAMEWORK_TYPE); | ||||
| if (it != options_.end()) { | if (it != options_.end()) { | ||||
| fmk_type = it->second; | |||||
| fmk_type = it->second; | |||||
| } | } | ||||
| std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| GELOGI("The size of registration_datas is: %zu", registration_datas.size()); | GELOGI("The size of registration_datas is: %zu", registration_datas.size()); | ||||
| @@ -192,7 +192,7 @@ void TBEPluginManager::LoadCustomOpLib() { | |||||
| if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) { | if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) { | ||||
| GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), | GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), | ||||
| TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); | TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); | ||||
| (void)domi::OpRegistry::Instance()->Register(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -182,7 +182,7 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le | |||||
| command.module_index = prof_config_param->profSwitch; | command.module_index = prof_config_param->profSwitch; | ||||
| } | } | ||||
| GELOGI("GE commandhandle execute, Command Type: %s, data type config: 0x%llx", iter->second.c_str(), | GELOGI("GE commandhandle execute, Command Type: %s, data type config: 0x%llx", iter->second.c_str(), | ||||
| command.module_index); | |||||
| command.module_index); | |||||
| if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { | if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { | ||||
| GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); | GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); | ||||
| } | } | ||||
| @@ -38,8 +38,10 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| ProfilingManager::ProfilingManager() | |||||
| : is_load_profiling_(false), is_execute_profiling_(false), is_training_trace_(false), subscribe_count_(0) { | |||||
| ProfilingManager::ProfilingManager() : is_load_profiling_(false), | |||||
| is_execute_profiling_(false), | |||||
| is_training_trace_(false), | |||||
| subscribe_count_(0) { | |||||
| prof_cb_.msprofCtrlCallback = nullptr; | prof_cb_.msprofCtrlCallback = nullptr; | ||||
| prof_cb_.msprofReporterCallback = nullptr; | prof_cb_.msprofReporterCallback = nullptr; | ||||
| } | } | ||||
| @@ -100,8 +102,8 @@ ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOpt | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| is_execute_profiling_ = true; | is_execute_profiling_ = true; | ||||
| GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), prof_conf.options, | |||||
| options.profiling_options.c_str()); | |||||
| GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), | |||||
| prof_conf.options, options.profiling_options.c_str()); | |||||
| } else { | } else { | ||||
| (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); | (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); | ||||
| (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); | (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); | ||||
| @@ -141,9 +143,6 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
| } | } | ||||
| try { | try { | ||||
| Json prof_options = Json::parse(options); | Json prof_options = Json::parse(options); | ||||
| if (options.find(kTrainingTrace) == std::string::npos) { | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| const std::string training_trace = prof_options[kTrainingTrace]; | const std::string training_trace = prof_options[kTrainingTrace]; | ||||
| if (training_trace.empty()) { | if (training_trace.empty()) { | ||||
| GELOGI("Training trace will not take effect."); | GELOGI("Training trace will not take effect."); | ||||
| @@ -212,16 +211,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
| uint32_t block_dim = task.block_dim; | uint32_t block_dim = task.block_dim; | ||||
| uint32_t task_id = task.task_id; | uint32_t task_id = task.task_id; | ||||
| uint32_t stream_id = task.stream_id; | uint32_t stream_id = task.stream_id; | ||||
| std::string shape_type = task.shape_type; | |||||
| int64_t cur_iter_num = task.cur_iter_num; | |||||
| data = model_name.append(" ") | data = model_name.append(" ") | ||||
| .append(op_name).append(" ") | .append(op_name).append(" ") | ||||
| .append(std::to_string(block_dim)).append(" ") | |||||
| .append(std::to_string(block_dim).append(" ") | |||||
| .append(std::to_string(task_id)).append(" ") | .append(std::to_string(task_id)).append(" ") | ||||
| .append(std::to_string(stream_id)).append(" ") | .append(std::to_string(stream_id)).append(" ") | ||||
| .append(std::to_string(model_id)).append(" ") | |||||
| .append(shape_type).append(" ") | |||||
| .append(std::to_string(cur_iter_num)).append("\n"); | |||||
| .append(std::to_string(model_id)).append("\n")); | |||||
| ReporterData reporter_data{}; | ReporterData reporter_data{}; | ||||
| reporter_data.deviceId = device_id; | reporter_data.deviceId = device_id; | ||||
| @@ -846,7 +841,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP | |||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -15,7 +15,6 @@ message Output { | |||||
| int32 original_output_data_type = 7; | int32 original_output_data_type = 7; | ||||
| int32 original_output_format = 8; | int32 original_output_format = 8; | ||||
| uint64 size = 9; | uint64 size = 9; | ||||
| Shape origin_shape = 10; | |||||
| } | } | ||||
| message Input { | message Input { | ||||
| @@ -24,7 +23,6 @@ message Input { | |||||
| Shape shape = 3; | Shape shape = 3; | ||||
| uint64 address = 4; | uint64 address = 4; | ||||
| uint64 size = 5; | uint64 size = 5; | ||||
| Shape origin_shape = 6; | |||||
| } | } | ||||
| enum BufferType { | enum BufferType { | ||||
| @@ -209,33 +209,19 @@ bool IsDynmaicDimsSizeMatchModel(const vector<uint64_t> cur_dynamic_dims, | |||||
| namespace ge { | namespace ge { | ||||
| bool GeExecutor::isInit_ = false; | bool GeExecutor::isInit_ = false; | ||||
| static void InitOpsProtoManger() { | |||||
| string opsproto_path; | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| string path = path_env; | |||||
| string file_path = RealPath(path.c_str()); | |||||
| if (file_path.empty()) { | |||||
| GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | |||||
| return; | |||||
| class ModelListenerAdapter : public ModelListener { | |||||
| public: | |||||
| domi::Status OnComputeDone(uint32_t model_id, uint32_t dataIndex, uint32_t resultCode, | |||||
| std::vector<ge::OutputTensorInfo> &outputs) { | |||||
| if (listener == nullptr) { | |||||
| GELOGE(ge::FAILED, "listener is null."); | |||||
| return FAILED; | |||||
| } | } | ||||
| opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
| GELOGI("Get opsproto so path from env : %s", path.c_str()); | |||||
| } else { | |||||
| string path_base = PluginManager::GetPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
| } | |||||
| GELOGI("Get opsproto path is %s", opsproto_path.c_str()); | |||||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||||
| map<string, string> option_tmp; | |||||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
| (void)manager->Initialize(option_tmp); | |||||
| } | |||||
| return listener->OnComputeDone(model_id, dataIndex, resultCode, outputs); | |||||
| } | |||||
| std::shared_ptr<ge::ModelListener> listener; | |||||
| }; | |||||
| GeExecutor::GeExecutor() {} | GeExecutor::GeExecutor() {} | ||||
| @@ -246,16 +232,6 @@ Status GeExecutor::Initialize() { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| OpTilingManager::GetInstance().LoadSo(); | |||||
| Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); | |||||
| if (initHostCpuEngineStatus != SUCCESS) { | |||||
| GELOGE(initHostCpuEngineStatus, "Failed to initialize HostCpuEngine"); | |||||
| return initHostCpuEngineStatus; | |||||
| } | |||||
| InitOpsProtoManger(); | |||||
| std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM); | std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM); | ||||
| mem_type.push_back(RT_MEMORY_P2P_DDR); | mem_type.push_back(RT_MEMORY_P2P_DDR); | ||||
| auto ret = MemManager::Instance().Initialize(mem_type); | auto ret = MemManager::Instance().Initialize(mem_type); | ||||
| @@ -560,6 +536,60 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // Load model | |||||
| Status GeExecutor::LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, | |||||
| int32_t priority, std::shared_ptr<ge::ModelListener> listener) { | |||||
| GELOGI("load model offline begin."); | |||||
| if (!isInit_) { | |||||
| GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); | |||||
| return ACL_ERROR_GE_EXEC_NOT_INIT; | |||||
| } | |||||
| string filePath = RealPath(path.c_str()); | |||||
| if (filePath.empty()) { | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, | |||||
| "File path is invalid. please check your text file '%s'.", path.c_str()); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID; | |||||
| } | |||||
| std::shared_ptr<ModelListenerAdapter> listener_adapter = MakeShared<ModelListenerAdapter>(); | |||||
| if (listener_adapter == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelListenerAdapter make shared failed!"); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| listener_adapter->listener = listener; | |||||
| Status ret = GraphLoader::LoadModelFromFile(path, key, priority, listener_adapter, model_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[GeExecutor] LoadModelFromFile failed"); | |||||
| return ACL_ERROR_GE_LOAD_MODEL; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeExecutor::LoadModel(uint32_t &model_id, const ModelData &model_data, | |||||
| std::shared_ptr<ge::ModelListener> listener) { | |||||
| GELOGI("Load model begin."); | |||||
| if (!isInit_) { | |||||
| GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); | |||||
| return ACL_ERROR_GE_EXEC_NOT_INIT; | |||||
| } | |||||
| std::shared_ptr<ModelListenerAdapter> listener_adapter = MakeShared<ModelListenerAdapter>(); | |||||
| if (listener_adapter == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelListenerAdapter make shared failed!"); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| listener_adapter->listener = listener; | |||||
| Status ret = GraphLoader::LoadModel(model_data, listener_adapter, model_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[GeExecutor] LoadModel failed."); | |||||
| return ACL_ERROR_GE_LOAD_MODEL; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| Status GeExecutor::UnloadModel(uint32_t model_id) { | Status GeExecutor::UnloadModel(uint32_t model_id) { | ||||
| GELOGD("unload model %u begin.", model_id); | GELOGD("unload model %u begin.", model_id); | ||||
| if (!isInit_) { | if (!isInit_) { | ||||
| @@ -592,6 +622,21 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeExecutor::RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data) { | |||||
| GELOGI("run model begin."); | |||||
| if (!isInit_) { | |||||
| GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); | |||||
| return ACL_ERROR_GE_EXEC_NOT_INIT; | |||||
| } | |||||
| InputData inputs; | |||||
| GetDomiInputData(input_data, inputs); | |||||
| OutputData outputs; | |||||
| GetDomiOutputData(output_data, outputs); | |||||
| return GraphExecutor::DataInput(inputs, outputs); | |||||
| } | |||||
| // Get input and output descriptor | // Get input and output descriptor | ||||
| Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
| std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) { | std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) { | ||||
| @@ -15,7 +15,6 @@ message Output { | |||||
| int32 original_output_data_type = 7; | int32 original_output_data_type = 7; | ||||
| int32 original_output_format = 8; | int32 original_output_format = 8; | ||||
| uint64 size = 9; | uint64 size = 9; | ||||
| Shape origin_shape = 10; | |||||
| } | } | ||||
| message Input { | message Input { | ||||
| @@ -24,7 +23,6 @@ message Input { | |||||
| Shape shape = 3; | Shape shape = 3; | ||||
| uint64 address = 4; | uint64 address = 4; | ||||
| uint64 size = 5; | uint64 size = 5; | ||||
| Shape origin_shape = 6; | |||||
| } | } | ||||
| enum BufferType { | enum BufferType { | ||||
| @@ -191,8 +191,6 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
| graph/passes/cond_pass.cc \ | graph/passes/cond_pass.cc \ | ||||
| graph/passes/cond_remove_pass.cc \ | graph/passes/cond_remove_pass.cc \ | ||||
| graph/passes/remove_same_const_pass.cc \ | |||||
| graph/passes/useless_control_out_remove_pass.cc \ | |||||
| graph/passes/for_pass.cc \ | graph/passes/for_pass.cc \ | ||||
| graph/passes/enter_pass.cc \ | graph/passes/enter_pass.cc \ | ||||
| graph/passes/assign_pass.cc \ | graph/passes/assign_pass.cc \ | ||||
| @@ -39,7 +39,7 @@ namespace { | |||||
| } \ | } \ | ||||
| ge_tensor = MakeShared<GeTensor>(out_desc); \ | ge_tensor = MakeShared<GeTensor>(out_desc); \ | ||||
| GE_CHECK_NOTNULL(ge_tensor); \ | GE_CHECK_NOTNULL(ge_tensor); \ | ||||
| GELOGD("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE));\ | |||||
| GELOGI("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE));\ | |||||
| if (ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)) != GRAPH_SUCCESS) { \ | if (ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)) != GRAPH_SUCCESS) { \ | ||||
| GELOGE(MEMALLOC_FAILED, "Set data for output %zu of node %s failed.", i, op_desc->GetName().c_str()); \ | GELOGE(MEMALLOC_FAILED, "Set data for output %zu of node %s failed.", i, op_desc->GetName().c_str()); \ | ||||
| return MEMALLOC_FAILED; \ | return MEMALLOC_FAILED; \ | ||||
| @@ -50,7 +50,8 @@ namespace { | |||||
| } else { \ | } else { \ | ||||
| ge_tensor = outputs[i]; \ | ge_tensor = outputs[i]; \ | ||||
| GE_CHECK_NOTNULL(ge_tensor); \ | GE_CHECK_NOTNULL(ge_tensor); \ | ||||
| GELOGD("node:%s existed output %zu", op_desc->GetName().c_str(), i); \ | |||||
| GELOGI("node:%s existed output %zu, addr=%p, size=%lld", op_desc->GetName().c_str(), i, \ | |||||
| reinterpret_cast<const uint8_t *>(ge_tensor->GetData().data()), ge_tensor->GetData().size()); \ | |||||
| } \ | } \ | ||||
| auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ | auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ | ||||
| auto tensor_name = op_desc->GetOutputNameByIndex(i); \ | auto tensor_name = op_desc->GetOutputNameByIndex(i); \ | ||||
| @@ -126,8 +126,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/compile_nodes_pass.cc \ | graph/passes/compile_nodes_pass.cc \ | ||||
| graph/passes/constant_folding_pass.cc \ | graph/passes/constant_folding_pass.cc \ | ||||
| graph/passes/constant_fuse_same_pass.cc \ | graph/passes/constant_fuse_same_pass.cc \ | ||||
| graph/passes/remove_same_const_pass.cc \ | |||||
| graph/passes/useless_control_out_remove_pass.cc \ | |||||
| graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
| graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
| graph/passes/dimension_compute_pass.cc \ | graph/passes/dimension_compute_pass.cc \ | ||||
| @@ -272,7 +272,6 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> | |||||
| std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue}; | std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue}; | ||||
| GeShape dynamic_shape(dynamic_shape_dims); | GeShape dynamic_shape(dynamic_shape_dims); | ||||
| std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range; | |||||
| ge::GeTensor inputTensor; | ge::GeTensor inputTensor; | ||||
| ge::GeTensorDesc desc(input_desc); | ge::GeTensorDesc desc(input_desc); | ||||
| @@ -281,7 +280,6 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> | |||||
| (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); | (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); | ||||
| if (!is_const && shape_ori.GetDims().size() > 0) { | if (!is_const && shape_ori.GetDims().size() > 0) { | ||||
| desc.SetShape(dynamic_shape); | desc.SetShape(dynamic_shape); | ||||
| desc.SetShapeRange(dynamic_shape_range); | |||||
| } | } | ||||
| inputTensor.SetTensorDesc(desc); | inputTensor.SetTensorDesc(desc); | ||||
| @@ -530,6 +528,24 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| static Status SetModelNameForDump(GeRootModelPtr ge_root_model) { | |||||
| ModelHelper model_helper; | |||||
| string model_name = ""; | |||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
| Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), | |||||
| model_name); | |||||
| if (name_ret != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); | |||||
| GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null"); | |||||
| ge_model->SetName(model_name); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
| ModelBufferData &model, bool is_offline) { | ModelBufferData &model, bool is_offline) { | ||||
| rtContext_t ctx = nullptr; | rtContext_t ctx = nullptr; | ||||
| @@ -538,7 +554,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| GELOGD("Current ctx is null."); | GELOGD("Current ctx is null."); | ||||
| ctx = nullptr; | ctx = nullptr; | ||||
| } | } | ||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
| impl_->is_offline_ = is_offline; | impl_->is_offline_ = is_offline; | ||||
| @@ -562,22 +577,11 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| impl_->build_step_.c_str()); | impl_->build_step_.c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(ge_root_model); | GE_CHECK_NOTNULL(ge_root_model); | ||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
| ModelHelper model_helper; | |||||
| string model_name = ""; | |||||
| Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), | |||||
| model_name); | |||||
| if (name_ret != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); | |||||
| GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); | |||||
| return PARAM_INVALID; | |||||
| ret = SetModelNameForDump(ge_root_model); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | } | ||||
| map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null"); | |||||
| ge_model->SetName(model_name); | |||||
| ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); | ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Save model failed"); | GELOGE(ret, "Save model failed"); | ||||
| @@ -586,11 +590,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (ctx != nullptr) { | if (ctx != nullptr) { | ||||
| (void)rtCtxSetCurrent(ctx); | (void)rtCtxSetCurrent(ctx); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -99,7 +99,7 @@ Status GraphMemoryAssigner::AssignMemory() { | |||||
| MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset()); | MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset()); | ||||
| memory_offset_.emplace(RT_MEMORY_HBM, memory_offset); | memory_offset_.emplace(RT_MEMORY_HBM, memory_offset); | ||||
| if (mem_assigner->GetP2PMemOffset() >= 0) { | |||||
| if (mem_assigner->GetP2PMemOffset() > 0) { | |||||
| MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, mem_assigner->GetP2PMemOffset()); | MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, mem_assigner->GetP2PMemOffset()); | ||||
| memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset); | memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset); | ||||
| } | } | ||||
| @@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_ | |||||
| GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); | GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); | ||||
| size_t output_size = weight->GetData().size(); | size_t output_size = weight->GetData().size(); | ||||
| TensorUtils::SetDataOffset(tensor_desc, mem_offset); | TensorUtils::SetDataOffset(tensor_desc, mem_offset); | ||||
| GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size); | |||||
| mem_offset += output_size; | mem_offset += output_size; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -66,13 +66,13 @@ bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &com | |||||
| if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | ||||
| label_set.insert(batch_label); | label_set.insert(batch_label); | ||||
| } else { | } else { | ||||
| GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(), | |||||
| GELOGD("Node %s[%s] has no batch_label, subgraph %s, stream id: %ld ", cur_node->GetName().c_str(), | |||||
| cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); | cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(), | |||||
| comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); | |||||
| GELOGD("Node %s in subgraph %s stream id: %ld, batch_label: %s, node num: %zu", cur_node->GetName().c_str(), | |||||
| comp_graph->GetName().c_str(), stream_id, batch_label.c_str(), comp_graph->GetDirectNodesSize()); | |||||
| } | } | ||||
| if (stream_set.size() > 1 || label_set.size() > 1) { | if (stream_set.size() > 1 || label_set.size() > 1) { | ||||
| GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", | GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", | ||||
| @@ -126,12 +126,14 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com | |||||
| run_context.graphStreamList.size()); | run_context.graphStreamList.size()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| run_context.stream = run_context.graphStreamList[stream_id]; | run_context.stream = run_context.graphStreamList[stream_id]; | ||||
| std::string batch_label; | |||||
| (void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| std::string batch_label; | |||||
| (void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " | GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " | ||||
| "batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
| "batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str()); | static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str()); | ||||
| for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | ||||
| GE_CHECK_NOTNULL(*iter); | GE_CHECK_NOTNULL(*iter); | ||||
| Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); | Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); | ||||
| @@ -122,14 +122,14 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string | |||||
| ModelData &model_data) { | ModelData &model_data) { | ||||
| Status ret; | Status ret; | ||||
| if (!CheckInputPathValid(path)) { | if (!CheckInputPathValid(path)) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID; | |||||
| GELOGE(GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); | |||||
| return GE_EXEC_MODEL_PATH_INVALID; | |||||
| } | } | ||||
| GELOGI("Load model begin, model path is: %s", path.c_str()); | GELOGI("Load model begin, model path is: %s", path.c_str()); | ||||
| if (!key_path.empty() && !CheckInputPathValid(key_path)) { | if (!key_path.empty() && !CheckInputPathValid(key_path)) { | ||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| GELOGE(GE_EXEC_MODEL_KEY_PATH_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); | |||||
| return GE_EXEC_MODEL_KEY_PATH_INVALID; | |||||
| } | } | ||||
| ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | ||||
| @@ -144,6 +144,63 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphLoader::LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority, | |||||
| const std::shared_ptr<ModelListener> &listener, uint32_t &model_id) { | |||||
| Status ret; | |||||
| ModelData model_data; | |||||
| ret = LoadDataFromFile(path, key_path, priority, model_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| ret = LoadModel(model_data, listener, model_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| } | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| Status GraphLoader::LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener, | |||||
| uint32_t &model_id) { | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | |||||
| // For GeOp, Open Device 0 here. | |||||
| GE_CHK_RT_RET(rtSetDevice(0)); | |||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| Status ret = model_manager->LoadModelOffline(model_id, model_data, listener); | |||||
| if (ret != SUCCESS) { | |||||
| GE_CHK_RT(rtDeviceReset(0)); | |||||
| GELOGE(ret, "LoadModel: Load failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = model_manager->Start(model_id); | |||||
| if (ret != SUCCESS) { | |||||
| if (model_manager->Unload(model_id) != SUCCESS) { | |||||
| GELOGE(FAILED, "LoadModel: Unload failed while trying to unload after a failed start."); | |||||
| } | |||||
| GELOGE(ret, "LoadModel: Start failed."); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("LoadModel: Start model success, model_id:%u.", model_id); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphLoader::CommandHandle(const Command &command) { | Status GraphLoader::CommandHandle(const Command &command) { | ||||
| try { | try { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| @@ -168,16 +225,16 @@ Status GraphLoader::CommandHandle(const Command &command) { | |||||
| } | } | ||||
| Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, | Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, | ||||
| size_t mem_size, void *weight_ptr, size_t weight_size) { | |||||
| size_t memsize, void *weight_ptr, size_t weightsize) { | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | GELOGI("Load model begin, model_id:%u.", model_id); | ||||
| // For ACL, Open Device from App. | // For ACL, Open Device from App. | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->LoadModelOffline( | Status ret = model_manager->LoadModelOffline( | ||||
| model_id, model_data, nullptr, dev_ptr, mem_size, weight_ptr, weight_size); | |||||
| model_id, model_data, nullptr, dev_ptr, memsize, weight_ptr, weightsize); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ACL_ERROR_GE_LOAD_MODEL, "Load model failed, model_id:%u.", model_id); | |||||
| return ACL_ERROR_GE_LOAD_MODEL; | |||||
| GELOGE(ret, "Load model failed, model_id:%u.", model_id); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("Load model success, model_id:%u.", model_id); | GELOGI("Load model success, model_id:%u.", model_id); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -202,8 +259,8 @@ Status GraphLoader::LoadModelWithQ(uint32_t &model_id, const ModelData &model_da | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); | Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ACL_ERROR_GE_LOAD_MODEL, "Load model with queue failed, model_id:%u.", model_id); | |||||
| return ACL_ERROR_GE_LOAD_MODEL; | |||||
| GELOGE(ret, "Load model with queue failed, model_id:%u.", model_id); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("Load model with queue success, model_id:%u.", model_id); | GELOGI("Load model with queue success, model_id:%u.", model_id); | ||||
| @@ -44,6 +44,12 @@ class GraphLoader { | |||||
| static Status GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size); | static Status GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size); | ||||
| static Status LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener, | |||||
| uint32_t &model_id); | |||||
| static Status LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority, | |||||
| const std::shared_ptr<ModelListener> &listener, uint32_t &model_id); | |||||
| static Status CommandHandle(const Command &command); | static Status CommandHandle(const Command &command); | ||||
| static Status GetMemoryInfo(int64_t &free); | static Status GetMemoryInfo(int64_t &free); | ||||
| @@ -319,9 +319,6 @@ Status DataDumper::GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vis | |||||
| for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { | for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { | ||||
| output.mutable_shape()->add_dim(dim); | output.mutable_shape()->add_dim(dim); | ||||
| } | } | ||||
| for (auto dim : tensor_descs.at(index).GetOriginShape().GetDims()) { | |||||
| output.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t output_size = 0; | int64_t output_size = 0; | ||||
| if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { | if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "Get output size filed"); | GELOGE(PARAM_INVALID, "Get output size filed"); | ||||
| @@ -479,9 +476,6 @@ Status DataDumper::GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor | |||||
| for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { | for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { | ||||
| input.mutable_shape()->add_dim(dim); | input.mutable_shape()->add_dim(dim); | ||||
| } | } | ||||
| for (auto dim : tensor_descs.at(index).GetOriginShape().GetDims()) { | |||||
| input.mutable_origin_shape()->add_dim(dim); | |||||
| } | |||||
| int64_t input_size = 0; | int64_t input_size = 0; | ||||
| if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { | if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { | ||||
| GELOGI("Get aipp input size according to attr is %ld", input_size); | GELOGI("Get aipp input size according to attr is %ld", input_size); | ||||
| @@ -289,8 +289,8 @@ Status DavinciModel::InitWeightMem(void *dev_ptr, void *weight_ptr, size_t weigh | |||||
| if (weight_ptr == nullptr) { | if (weight_ptr == nullptr) { | ||||
| weights_mem_base_ = MallocWeightsMem(weights_size); | weights_mem_base_ = MallocWeightsMem(weights_size); | ||||
| if (weights_mem_base_ == nullptr) { | if (weights_mem_base_ == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc weight memory failed. size: %zu", weights_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, "Alloc weight memory failed. size: %zu", weights_size); | |||||
| return GE_EXEC_ALLOC_WEIGHT_MEM_FAILED; | |||||
| } | } | ||||
| is_inner_weight_base_ = true; | is_inner_weight_base_ = true; | ||||
| } | } | ||||
| @@ -307,8 +307,8 @@ Status DavinciModel::InitWeightMem(void *dev_ptr, void *weight_ptr, size_t weigh | |||||
| Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | ||||
| if (is_feature_map_mem_has_inited_) { | if (is_feature_map_mem_has_inited_) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "call InitFeatureMapMem more than once ."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(FAILED, "call InitFeatureMapMem more than once ."); | |||||
| return FAILED; | |||||
| } | } | ||||
| is_feature_map_mem_has_inited_ = true; | is_feature_map_mem_has_inited_ = true; | ||||
| @@ -316,8 +316,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| std::size_t p2p_data_size = P2PMemInfos().at(RT_MEMORY_P2P_DDR).memory_size; | std::size_t p2p_data_size = P2PMemInfos().at(RT_MEMORY_P2P_DDR).memory_size; | ||||
| if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { | if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Invalid mem param: mem_size=%zu totalsize=%zu.", mem_size, TotalMemSize()); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(FAILED, "Invalid mem param: mem_size=%zu totalsize=%zu.", mem_size, TotalMemSize()); | |||||
| return FAILED; | |||||
| } | } | ||||
| mem_base_ = static_cast<uint8_t *>(dev_ptr); | mem_base_ = static_cast<uint8_t *>(dev_ptr); | ||||
| @@ -327,8 +327,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| if (TotalMemSize() && mem_base_ == nullptr) { | if (TotalMemSize() && mem_base_ == nullptr) { | ||||
| mem_base_ = MallocFeatureMapMem(data_size); | mem_base_ = MallocFeatureMapMem(data_size); | ||||
| if (mem_base_ == nullptr) { | if (mem_base_ == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc feature map memory failed. size: %zu", data_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | |||||
| return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | |||||
| } | } | ||||
| GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", | GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", | ||||
| runtime_param_.graph_id, mem_base_, data_size); | runtime_param_.graph_id, mem_base_, data_size); | ||||
| @@ -343,8 +343,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| if (p2p_data_size != 0) { | if (p2p_data_size != 0) { | ||||
| p2p_mem_base_ = MallocP2PMem(p2p_data_size); | p2p_mem_base_ = MallocP2PMem(p2p_data_size); | ||||
| if (p2p_mem_base_ == nullptr) { | if (p2p_mem_base_ == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc p2p memory failed,size: %zu", p2p_data_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(GE_EXEC_ALLOC_P2P_MEM_FAILED, "Alloc p2p memory failed,size: %zu", p2p_data_size); | |||||
| return GE_EXEC_ALLOC_P2P_MEM_FAILED; | |||||
| } | } | ||||
| GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | ||||
| p2p_mem_base_, p2p_data_size); | p2p_mem_base_, p2p_data_size); | ||||
| @@ -710,7 +710,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
| } | } | ||||
| // collect profiling for ge | // collect profiling for ge | ||||
| GE_CHK_STATUS_RET(InitModelProfile(), "Init model profile failed"); | |||||
| auto &profiling_manager = ProfilingManager::Instance(); | auto &profiling_manager = ProfilingManager::Instance(); | ||||
| if (profiling_manager.ProfilingModelLoadOn()) { | if (profiling_manager.ProfilingModelLoadOn()) { | ||||
| Status p_ret = ReportProfilingData(); | Status p_ret = ReportProfilingData(); | ||||
| @@ -971,7 +970,7 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, ma | |||||
| uint32_t parent_index = 0; // Ignore subgraph Data Node. | uint32_t parent_index = 0; // Ignore subgraph Data Node. | ||||
| if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
| GELOGI("Init zero copy by subgraph Data node: %s.", op_desc->GetName().c_str()); | GELOGI("Init zero copy by subgraph Data node: %s.", op_desc->GetName().c_str()); | ||||
| return SUCCESS; | |||||
| return InitInputBatchLabel(node); | |||||
| } | } | ||||
| data_op_list_.push_back(op_desc); | data_op_list_.push_back(op_desc); | ||||
| @@ -1012,6 +1011,10 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, ma | |||||
| } | } | ||||
| data_op_index++; | data_op_index++; | ||||
| if (InitInputZeroCopyNodes(node) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Input zero copy nodes init failed!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1033,6 +1036,39 @@ void DavinciModel::AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_inde | |||||
| } | } | ||||
| } | } | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief input zero copy node Initialize. | |||||
| /// @param [in] NodePtr: Data Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { | |||||
| auto out_data_anchor = node->GetOutDataAnchor(kDataIndex); | |||||
| if (out_data_anchor == nullptr) { | |||||
| GELOGE(FAILED, "Out data anchor is nullptr"); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto node = peer_in_data_anchor->GetOwnerNode(); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(FAILED, "Op desc is nullptr"); | |||||
| return FAILED; | |||||
| } | |||||
| string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty()) { | |||||
| batch_label = kDefaultBatchLable; | |||||
| } | |||||
| if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) { | |||||
| zero_copy_op_id_batch_label_.emplace(pair<int64_t, string>(op_desc->GetId(), batch_label)); | |||||
| GELOGD("Init input zero copy nodes success, op name:%s, op id: %ld, batch label: %s.", op_desc->GetName().c_str(), | |||||
| op_desc->GetId(), batch_label.c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool DavinciModel::IsGetNextSinkDynamic(const OpDescPtr &op_desc) { | bool DavinciModel::IsGetNextSinkDynamic(const OpDescPtr &op_desc) { | ||||
| bool getnext_sink_dynamic = false; | bool getnext_sink_dynamic = false; | ||||
| if (ge::AttrUtils::GetBool(op_desc, ATTR_GETNEXT_SINK_DYNMAIC, getnext_sink_dynamic) && getnext_sink_dynamic) { | if (ge::AttrUtils::GetBool(op_desc, ATTR_GETNEXT_SINK_DYNMAIC, getnext_sink_dynamic) && getnext_sink_dynamic) { | ||||
| @@ -1058,7 +1094,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
| if (owner_graph->GetParentGraph() != nullptr) { | if (owner_graph->GetParentGraph() != nullptr) { | ||||
| GELOGI("Init zero copy by subgraph NetOutput node: %s.", op_desc->GetName().c_str()); | GELOGI("Init zero copy by subgraph NetOutput node: %s.", op_desc->GetName().c_str()); | ||||
| op_list_.erase(op_desc->GetId()); | op_list_.erase(op_desc->GetId()); | ||||
| return SUCCESS; | |||||
| return InitOutputBatchLabel(node); | |||||
| } | } | ||||
| output_op_list_.push_back(op_desc); | output_op_list_.push_back(op_desc); | ||||
| @@ -1110,6 +1146,8 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| GE_IF_BOOL_EXEC(InitOutputZeroCopyNodes(node) != SUCCESS, | |||||
| GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); return PARAM_INVALID;); | |||||
| GetAllGearsInfo(node); | GetAllGearsInfo(node); | ||||
| if (is_getnext_sink_dynamic_) { | if (is_getnext_sink_dynamic_) { | ||||
| GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, | GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, | ||||
| @@ -1305,6 +1343,121 @@ void DavinciModel::ParseDynamicOutShape(const std::vector<std::string> &str_info | |||||
| } | } | ||||
| } | } | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief output zero copy node Initialize. | |||||
| /// @param [in] NodePtr: netoutput Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status DavinciModel::InitOutputZeroCopyNodes(const NodePtr &node) { | |||||
| set<NodePtr> nodes_need_record; | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| nodes_need_record.emplace(peer_node); | |||||
| // Merge node output multiplexed input, upstream nodes need to be considered in multiple batch scenarios | |||||
| if (peer_node->GetType() == MERGE) { | |||||
| for (const auto &merge_peer_in_data_anchor : peer_node->GetAllInDataAnchors()) { | |||||
| auto merge_peer_out_data_anchor = merge_peer_in_data_anchor->GetPeerOutAnchor(); | |||||
| if (merge_peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto merge_peer_node = merge_peer_out_data_anchor->GetOwnerNode(); | |||||
| nodes_need_record.emplace(merge_peer_node); | |||||
| } | |||||
| } else { | |||||
| for (const auto &other_in_data_anchor : peer_out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto other_in_node = other_in_data_anchor->GetOwnerNode(); | |||||
| if (other_in_node->GetType() != NETOUTPUT) { | |||||
| nodes_need_record.emplace(other_in_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &node_need_record : nodes_need_record) { | |||||
| auto op_desc = node_need_record->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty()) { | |||||
| batch_label = kDefaultBatchLable; | |||||
| } | |||||
| if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) { | |||||
| zero_copy_op_id_batch_label_.emplace(pair<int64_t, string>(op_desc->GetId(), batch_label)); | |||||
| GELOGD("Init Output zero copy nodes success, op name:%s, op id: %ld, batch label: %s.", | |||||
| op_desc->GetName().c_str(), op_desc->GetId(), batch_label.c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief input zero copy node Initialize. | |||||
| /// @param [in] NodePtr: Data Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status DavinciModel::InitInputBatchLabel(const NodePtr &node) { | |||||
| string batch_label; | |||||
| if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| return SUCCESS; // Not Multi-batch. | |||||
| } | |||||
| const auto &out_data_anchor = node->GetOutDataAnchor(kDataIndex); | |||||
| GE_CHECK_NOTNULL(out_data_anchor); | |||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| const auto &node = peer_in_data_anchor->GetOwnerNode(); | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) { | |||||
| zero_copy_op_id_batch_label_[op_desc->GetId()] = batch_label; | |||||
| GELOGD("Init input zero copy nodes success, op name: %s, op id: %ld, batch label: %s", op_desc->GetName().c_str(), | |||||
| op_desc->GetId(), batch_label.c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief output zero copy node Initialize for Case. | |||||
| /// @param [in] NodePtr: netoutput Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status DavinciModel::InitOutputBatchLabel(const NodePtr &node) { | |||||
| string batch_label; | |||||
| if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| return SUCCESS; // Not Multi-batch. | |||||
| } | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| const auto &peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| const auto &peer_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| const auto &op_desc = peer_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) { | |||||
| zero_copy_op_id_batch_label_[op_desc->GetId()] = batch_label; | |||||
| GELOGD("Init Output zero copy nodes success, op name: %s, op id: %ld, batch label: %s", | |||||
| op_desc->GetName().c_str(), op_desc->GetId(), batch_label.c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief LabelSet Op Initialize. | /// @brief LabelSet Op Initialize. | ||||
| /// @param [in] op_desc: LabelSet Op descriptor. | /// @param [in] op_desc: LabelSet Op descriptor. | ||||
| @@ -2087,61 +2240,12 @@ Status DavinciModel::SyncVarData() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status DavinciModel::InitModelProfile() { | |||||
| for (const auto &task : task_list_) { | |||||
| GE_CHECK_NOTNULL(task); | |||||
| const FusionOpInfo *fusion_op_info = task->GetFusionOpInfo(); | |||||
| // when type is RT_MODEL_TASK_KERNEL, ctx is not null | |||||
| if ((fusion_op_info == nullptr) || fusion_op_info->original_op_names.empty()) { | |||||
| continue; | |||||
| } | |||||
| GELOGI("task.id = %u, opNum = %zu", task->GetTaskID(), fusion_op_info->original_op_names.size()); | |||||
| op_id_map_.insert(std::make_pair(fusion_op_info->op_index, task->GetTaskID())); | |||||
| } | |||||
| std::set<uint32_t> task_id_set; | |||||
| using CIT = std::multimap<uint32_t, uint32_t>::const_iterator; | |||||
| using Range = std::pair<CIT, CIT>; | |||||
| for (const auto &task : task_list_) { | |||||
| GE_CHECK_NOTNULL(task); | |||||
| const FusionOpInfo *fusion_op_info = task->GetFusionOpInfo(); | |||||
| if ((fusion_op_info == nullptr) || fusion_op_info->original_op_names.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (task_id_set.count(task->GetTaskID()) > 0) { | |||||
| continue; | |||||
| } | |||||
| const auto &op_desc = GetOpByIndex(fusion_op_info->op_index); | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return FAILED, "index: %u out of range", fusion_op_info->op_index); | |||||
| ProfileInfo profile; | |||||
| profile.fusion_info = *fusion_op_info; | |||||
| Range range = op_id_map_.equal_range(fusion_op_info->op_index); | |||||
| for (CIT range_idx = range.first; range_idx != range.second; ++range_idx) { | |||||
| profile.task_count++; | |||||
| task_id_set.insert(range_idx->second); | |||||
| } | |||||
| // memory info | |||||
| TaskMemInfo &mem_info = profile.memory_info; | |||||
| const auto input_size = ModelUtils::GetInputSize(op_desc); | |||||
| const auto output_size = ModelUtils::GetOutputSize(op_desc); | |||||
| const auto workspace_size = ModelUtils::GetWorkspaceSize(op_desc); | |||||
| const auto weight_size = ModelUtils::GetWeightSize(op_desc); | |||||
| mem_info.input_size = std::accumulate(input_size.begin(), input_size.end(), 0); | |||||
| mem_info.output_size = std::accumulate(output_size.begin(), output_size.end(), 0); | |||||
| mem_info.workspace_size = std::accumulate(workspace_size.begin(), workspace_size.end(), 0); | |||||
| mem_info.weight_size = std::accumulate(weight_size.begin(), weight_size.end(), 0); | |||||
| mem_info.total_size = mem_info.weight_size + mem_info.input_size + mem_info.output_size + mem_info.workspace_size; | |||||
| profile_list_.emplace_back(profile); | |||||
| inline int64_t SumSize(const vector<int64_t> &size_list) { | |||||
| int64_t sum_size = 0; | |||||
| for (const int64_t &size : size_list) { | |||||
| sum_size += size; | |||||
| } | } | ||||
| GELOGI("fusion task size: %zu, profile info size: %zu", op_id_map_.size(), profile_list_.size()); | |||||
| return SUCCESS; | |||||
| return sum_size; | |||||
| } | } | ||||
| Status DavinciModel::SinkModelProfile() { | Status DavinciModel::SinkModelProfile() { | ||||
| @@ -2149,12 +2253,18 @@ Status DavinciModel::SinkModelProfile() { | |||||
| auto &prof_mgr = ProfilingManager::Instance(); | auto &prof_mgr = ProfilingManager::Instance(); | ||||
| ReporterData reporter_data{}; | ReporterData reporter_data{}; | ||||
| // report model data tag name | // report model data tag name | ||||
| std::string tag_name("model_load_info_" + std::to_string(this->Id())); | |||||
| std::string tag_name; | |||||
| tag_name.append("model_load_info_").append(std::to_string(this->Id())); | |||||
| GE_CHK_BOOL_EXEC(memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN, tag_name.c_str(), tag_name.size()) == EOK, | GE_CHK_BOOL_EXEC(memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN, tag_name.c_str(), tag_name.size()) == EOK, | ||||
| return FAILED, "Sink model tag memcpy error."); | return FAILED, "Sink model tag memcpy error."); | ||||
| // Model Header | // Model Header | ||||
| std::string name = om_name_.empty() ? name_ : om_name_; | |||||
| string name; | |||||
| if (!om_name_.empty()) { | |||||
| name = om_name_; | |||||
| } else { | |||||
| name = name_; | |||||
| } | |||||
| size_t name_len = name.size(); | size_t name_len = name.size(); | ||||
| reporter_data.deviceId = device_id_; | reporter_data.deviceId = device_id_; | ||||
| reporter_data.data = (unsigned char *)&name_len; | reporter_data.data = (unsigned char *)&name_len; | ||||
| @@ -2186,71 +2296,128 @@ Status DavinciModel::SinkModelProfile() { | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | ||||
| "Reporter data fail, model id:%u.", this->Id()); | "Reporter data fail, model id:%u.", this->Id()); | ||||
| int32_t task_num = task_list_.size(); | |||||
| std::multimap<uint32_t, uint32_t> op_id_map; | |||||
| std::set<uint32_t> task_id_set; | |||||
| for (int32_t i = 0; i < task_num; i++) { | |||||
| auto task = task_list_[i]; | |||||
| GE_CHECK_NOTNULL(task); | |||||
| auto fusion_op_info = task->GetFusionOpInfo(); | |||||
| // when type is RT_MODEL_TASK_KERNEL, ctx is not null | |||||
| if (fusion_op_info != nullptr) { | |||||
| uint32_t op_num = fusion_op_info->original_op_names.size(); | |||||
| uint32_t task_id = task->GetTaskID(); | |||||
| if (op_num > 0) { | |||||
| GELOGI("task.id = %u, opNum = %u", task_id, op_num); | |||||
| op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id)); | |||||
| } | |||||
| } | |||||
| } | |||||
| struct memoryInfo { | |||||
| int64_t input_size; | |||||
| int64_t output_size; | |||||
| int64_t weight_size; | |||||
| int64_t workspace_size; | |||||
| int64_t total_size; | |||||
| memoryInfo() : input_size(0), output_size(0), weight_size(0), workspace_size(0), total_size(0) {} | |||||
| }; | |||||
| using CIT = std::multimap<uint32_t, uint32_t>::const_iterator; | using CIT = std::multimap<uint32_t, uint32_t>::const_iterator; | ||||
| using Range = std::pair<CIT, CIT>; | using Range = std::pair<CIT, CIT>; | ||||
| for (const ProfileInfo &profile : profile_list_) { | |||||
| // op name after fusion | |||||
| string fusion_op_name = profile.fusion_info.op_name; | |||||
| int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); | |||||
| reporter_data.data = (unsigned char *)&fusion_op_name_len; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)fusion_op_name.c_str(); | |||||
| reporter_data.dataLen = fusion_op_name_len; | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // original op name before fusion | |||||
| uint32_t op_num = profile.fusion_info.original_op_names.size(); | |||||
| reporter_data.data = (unsigned char *)&op_num; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| for (uint32_t k = 0; k < op_num; k++) { | |||||
| std::string op_name = profile.fusion_info.original_op_names[k]; | |||||
| int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); | |||||
| reporter_data.data = (unsigned char *)&op_name_len; | |||||
| for (int32_t i = 0; i < task_num; i++) { | |||||
| auto task = task_list_[i]; | |||||
| GE_CHECK_NOTNULL(task); | |||||
| auto fusion_op_info = task->GetFusionOpInfo(); | |||||
| if (fusion_op_info != nullptr && fusion_op_info->original_op_names.size() > 0) { | |||||
| uint32_t task_id = task->GetTaskID(); | |||||
| uint32_t op_num = fusion_op_info->original_op_names.size(); | |||||
| uint32_t task_count = 0; | |||||
| if (task_id_set.count(task_id) != 0) { | |||||
| continue; | |||||
| } | |||||
| uint32_t op_id = fusion_op_info->op_index; | |||||
| Range range = op_id_map.equal_range(op_id); | |||||
| for (CIT range_idx = range.first; range_idx != range.second; ++range_idx) { | |||||
| task_count++; | |||||
| uint32_t task_id = range_idx->second; | |||||
| task_id_set.insert(task_id); | |||||
| } | |||||
| // op name after fusion | |||||
| string fusion_op_name = fusion_op_info->op_name; | |||||
| int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); | |||||
| reporter_data.data = (unsigned char *)&fusion_op_name_len; | |||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | ||||
| "Reporter data fail, model id:%u.", this->Id()); | "Reporter data fail, model id:%u.", this->Id()); | ||||
| reporter_data.data = (unsigned char *)op_name.c_str(); | |||||
| reporter_data.dataLen = op_name_len; | |||||
| reporter_data.data = (unsigned char *)fusion_op_name.c_str(); | |||||
| reporter_data.dataLen = fusion_op_name_len; | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // original op name before fusion | |||||
| reporter_data.data = (unsigned char *)&op_num; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | ||||
| "Reporter data fail, model id:%u.", this->Id()); | "Reporter data fail, model id:%u.", this->Id()); | ||||
| } | |||||
| // stream id info | |||||
| uint32_t streamId = profile.fusion_info.stream_id; | |||||
| reporter_data.data = (unsigned char *)&streamId; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // memory info | |||||
| reporter_data.data = (unsigned char *)&profile.memory_info; | |||||
| reporter_data.dataLen = sizeof(profile.memory_info); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // task info | |||||
| reporter_data.data = (unsigned char *)&profile.task_count; | |||||
| reporter_data.dataLen = sizeof(uint32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| Range task_range = op_id_map_.equal_range(profile.fusion_info.op_index); | |||||
| for (CIT idx = task_range.first; idx != task_range.second; ++idx) { | |||||
| uint32_t task_id = idx->second; | |||||
| reporter_data.data = (unsigned char *)&task_id; | |||||
| for (uint32_t k = 0; k < op_num; k++) { | |||||
| std::string op_name = fusion_op_info->original_op_names[k]; | |||||
| int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); | |||||
| reporter_data.data = (unsigned char *)&op_name_len; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| reporter_data.data = (unsigned char *)op_name.c_str(); | |||||
| reporter_data.dataLen = op_name_len; | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| } | |||||
| // stream id info | |||||
| uint32_t streamId = task->GetStreamId(); | |||||
| reporter_data.data = (unsigned char *)&streamId; | |||||
| reporter_data.dataLen = sizeof(int32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // memory info | |||||
| struct memoryInfo memory_info; | |||||
| uint32_t op_index = fusion_op_info->op_index; | |||||
| auto iter = op_list_.find(op_index); | |||||
| GE_CHK_BOOL_EXEC(iter != op_list_.end(), return FAILED, "index is out of range, index: %u", op_index); | |||||
| auto op_desc = iter->second; | |||||
| memory_info.input_size = SumSize(ModelUtils::GetInputSize(op_desc)); | |||||
| memory_info.output_size = SumSize(ModelUtils::GetOutputSize(op_desc)); | |||||
| memory_info.workspace_size = SumSize(ModelUtils::GetWorkspaceSize(op_desc)); | |||||
| memory_info.weight_size = SumSize(ModelUtils::GetWeightSize(op_desc)); | |||||
| memory_info.total_size = | |||||
| memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size; | |||||
| reporter_data.data = (unsigned char *)&memory_info; | |||||
| reporter_data.dataLen = sizeof(struct memoryInfo); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| // task info | |||||
| reporter_data.data = (unsigned char *)&task_count; | |||||
| reporter_data.dataLen = sizeof(uint32_t); | reporter_data.dataLen = sizeof(uint32_t); | ||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | ||||
| "Reporter data fail, model id:%u.", this->Id()); | "Reporter data fail, model id:%u.", this->Id()); | ||||
| Range task_range = op_id_map.equal_range(op_id); | |||||
| for (CIT idx = task_range.first; idx != task_range.second; ++idx) { | |||||
| uint32_t task_id = idx->second; | |||||
| reporter_data.data = (unsigned char *)&task_id; | |||||
| reporter_data.dataLen = sizeof(uint32_t); | |||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | |||||
| "Reporter data fail, model id:%u.", this->Id()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -2824,19 +2991,19 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector<void *> &inputs, const | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DavinciModel::UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs) { | |||||
| for (size_t i = 0; i < total_io_addrs.size(); ++i) { | |||||
| auto it_in = knonw_input_data_info_.find(total_io_addrs[i]); | |||||
| Status DavinciModel::UpdateKnownZeroCopyAddr() { | |||||
| for (size_t i = 0; i < total_io_addrs_.size(); ++i) { | |||||
| auto it_in = knonw_input_data_info_.find(total_io_addrs_[i]); | |||||
| if (it_in != knonw_input_data_info_.end()) { | if (it_in != knonw_input_data_info_.end()) { | ||||
| GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %zu,v addr %p,p addr %p .", i, total_io_addrs[i], | |||||
| knonw_input_data_info_.at(total_io_addrs[i])); | |||||
| total_io_addrs[i] = knonw_input_data_info_.at(total_io_addrs[i]); | |||||
| GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %zu,v addr %p,p addr %p .", i, total_io_addrs_[i], | |||||
| knonw_input_data_info_.at(total_io_addrs_[i])); | |||||
| total_io_addrs_[i] = knonw_input_data_info_.at(total_io_addrs_[i]); | |||||
| } | } | ||||
| auto it_out = knonw_output_data_info_.find(total_io_addrs[i]); | |||||
| auto it_out = knonw_output_data_info_.find(total_io_addrs_[i]); | |||||
| if (it_out != knonw_output_data_info_.end()) { | if (it_out != knonw_output_data_info_.end()) { | ||||
| GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %zu,v addr %p,p addr %p .", i, total_io_addrs[i], | |||||
| knonw_output_data_info_.at(total_io_addrs[i])); | |||||
| total_io_addrs[i] = knonw_output_data_info_.at(total_io_addrs[i]); | |||||
| GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %zu,v addr %p,p addr %p .", i, total_io_addrs_[i], | |||||
| knonw_output_data_info_.at(total_io_addrs_[i])); | |||||
| total_io_addrs_[i] = knonw_output_data_info_.at(total_io_addrs_[i]); | |||||
| } | } | ||||
| } | } | ||||
| GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); | GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); | ||||
| @@ -2865,7 +3032,7 @@ Status DavinciModel::UpdateKnownNodeArgs(const vector<void *> &inputs, const vec | |||||
| } else { | } else { | ||||
| total_io_addrs_ = orig_total_io_addrs_; | total_io_addrs_ = orig_total_io_addrs_; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateKnownZeroCopyAddr(total_io_addrs_), "DavinciModel::UpdateKnownZeroCopyAddr failed."); | |||||
| GE_CHK_STATUS_RET(UpdateKnownZeroCopyAddr(), "DavinciModel::UpdateKnownZeroCopyAddr failed."); | |||||
| if (total_args_size_ == 0) { | if (total_args_size_ == 0) { | ||||
| GELOGW("DavinciModel::UpdateKnownNodeArgs device args %p, dst size %u, pass rtMemcpy.", args_, total_args_size_); | GELOGW("DavinciModel::UpdateKnownNodeArgs device args %p, dst size %u, pass rtMemcpy.", args_, total_args_size_); | ||||
| @@ -2932,14 +3099,7 @@ Status DavinciModel::MallocKnownArgs() { | |||||
| GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| // malloc dynamic and static hybrid memory | |||||
| if (total_hybrid_args_size_ != 0) { | |||||
| rt_ret = rtMalloc(&hybrid_addrs_, total_hybrid_args_size_, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| } | |||||
| // malloc fixed addr memory, eg: rts op | // malloc fixed addr memory, eg: rts op | ||||
| if (total_fixed_addr_size_ != 0) { | if (total_fixed_addr_size_ != 0) { | ||||
| GELOGI("Begin to allocate fixed addr."); | GELOGI("Begin to allocate fixed addr."); | ||||
| @@ -2993,7 +3153,9 @@ Status DavinciModel::DistributeTask() { | |||||
| } | } | ||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | ||||
| bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL) && (task_type != RT_MODEL_TASK_KERNEL_EX); | |||||
| bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL) | |||||
| && (task_type != RT_MODEL_TASK_KERNEL_EX) | |||||
| && (task_type != RT_MODEL_TASK_HCCL); | |||||
| GE_IF_BOOL_EXEC(no_need_profiling, continue); | GE_IF_BOOL_EXEC(no_need_profiling, continue); | ||||
| SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); | SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); | ||||
| @@ -3008,8 +3170,6 @@ Status DavinciModel::DistributeTask() { | |||||
| task_desc_info.block_dim = task_def.kernel().block_dim(); | task_desc_info.block_dim = task_def.kernel().block_dim(); | ||||
| task_desc_info.task_id = task->GetTaskID(); | task_desc_info.task_id = task->GetTaskID(); | ||||
| task_desc_info.stream_id = task->GetStreamId(); | task_desc_info.stream_id = task->GetStreamId(); | ||||
| task_desc_info.shape_type = "static"; | |||||
| task_desc_info.cur_iter_num = 0; | |||||
| task_desc_info_.emplace_back(task_desc_info); | task_desc_info_.emplace_back(task_desc_info); | ||||
| if (flag) { | if (flag) { | ||||
| if (task->GetSktTaskID() != 0xFFFFFFFF) { | if (task->GetSktTaskID() != 0xFFFFFFFF) { | ||||
| @@ -3097,20 +3257,27 @@ void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<v | |||||
| for (auto &input_outside_addrs : new_input_outside_addrs_) { | for (auto &input_outside_addrs : new_input_outside_addrs_) { | ||||
| ZeroCopyOffset &input_outside = input_outside_addrs.second; | ZeroCopyOffset &input_outside = input_outside_addrs.second; | ||||
| input_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); | |||||
| bool ret = input_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); | |||||
| if (ret) { | |||||
| void *args_val = static_cast<uint8_t *>(args) + offset + i * kAddrLen; | |||||
| SetBatchLabelAddr(op_desc, reinterpret_cast<uintptr_t>(args_val)); | |||||
| } | |||||
| } | } | ||||
| for (auto &output_outside_addrs : new_output_outside_addrs_) { | for (auto &output_outside_addrs : new_output_outside_addrs_) { | ||||
| ZeroCopyOffset &output_outside = output_outside_addrs.second; | ZeroCopyOffset &output_outside = output_outside_addrs.second; | ||||
| output_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); | |||||
| bool ret = output_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); | |||||
| if (ret) { | |||||
| void *args_val = static_cast<uint8_t *>(args) + offset + i * kAddrLen; | |||||
| SetBatchLabelAddr(op_desc, reinterpret_cast<uintptr_t>(args_val)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| string batch_label; | |||||
| if (!AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label) || batch_label.empty()) { | |||||
| auto it = zero_copy_op_id_batch_label_.find(op_desc->GetId()); | |||||
| if (it == zero_copy_op_id_batch_label_.end()) { | |||||
| zero_copy_task.SetBatchLabel(kDefaultBatchLable); | zero_copy_task.SetBatchLabel(kDefaultBatchLable); | ||||
| } else { | } else { | ||||
| zero_copy_task.SetBatchLabel(batch_label); | |||||
| zero_copy_task.SetBatchLabel(it->second); | |||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(outside_addrs_mutex_); | std::lock_guard<std::mutex> lock(outside_addrs_mutex_); | ||||
| @@ -3120,6 +3287,27 @@ void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<v | |||||
| } | } | ||||
| } | } | ||||
| void DavinciModel::SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr) { | |||||
| // Establish a mapping between batch label and zero copy address for multi-batch scenes | |||||
| auto it = zero_copy_op_id_batch_label_.find(op_desc->GetId()); | |||||
| if (it == zero_copy_op_id_batch_label_.end()) { | |||||
| return; | |||||
| } | |||||
| const string &batch_label = it->second; | |||||
| auto iter = zero_copy_batch_label_addrs_.find(batch_label); | |||||
| if (iter != zero_copy_batch_label_addrs_.end()) { | |||||
| iter->second.insert(addr); | |||||
| GELOGD("[ZCPY] Set zero copy batch label and addrs success, batch label: %s, op name:%s.", batch_label.c_str(), | |||||
| op_desc->GetName().c_str()); | |||||
| } else { | |||||
| set<uintptr_t> addrs = {addr}; | |||||
| zero_copy_batch_label_addrs_.emplace(pair<string, set<uintptr_t>>(batch_label, addrs)); | |||||
| GELOGD("[ZCPY] New added zero copy batch label and addrs success, batch label: %s, op name:%s.", | |||||
| batch_label.c_str(), op_desc->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Copy Check input size and model op size. | /// @brief Copy Check input size and model op size. | ||||
| @@ -3253,15 +3441,15 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| void *addr = data.second.GetDataInfo().at(count).second; | void *addr = data.second.GetDataInfo().at(count).second; | ||||
| void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) + | void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) + | ||||
| data.second.GetRelativeOffset().at(count)); | data.second.GetRelativeOffset().at(count)); | ||||
| GELOGI("[ZCPY] Copy %s blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p, batch_label: %s", | |||||
| input_or_output.c_str(), data.first, addr, size, buffer_addr, batch_label.c_str()); | |||||
| GELOGI("[ZCPY] Copy %s blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p", input_or_output.c_str(), | |||||
| data.first, addr, size, buffer_addr); | |||||
| // For input data, just copy for rts task. | // For input data, just copy for rts task. | ||||
| for (ZeroCopyTask &task : zero_copy_tasks_) { | for (ZeroCopyTask &task : zero_copy_tasks_) { | ||||
| if (task.GetBatchLabel() != kDefaultBatchLable && task.GetBatchLabel() != batch_label) { | if (task.GetBatchLabel() != kDefaultBatchLable && task.GetBatchLabel() != batch_label) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | ||||
| if (task.UpdateTaskParam(addr_val, buffer_addr) != SUCCESS) { | |||||
| if (task.UpdateTaskParam(addr_val, buffer_addr, zero_copy_batch_label_addrs_, batch_label) != SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -3623,6 +3811,9 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa | |||||
| GELOGD("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); | GELOGD("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); | ||||
| GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); | GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); | ||||
| is_dynamic_ = input_data.is_dynamic_batch; | is_dynamic_ = input_data.is_dynamic_batch; | ||||
| if (!is_dynamic_) { | |||||
| zero_copy_batch_label_addrs_.clear(); | |||||
| } | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_PRE_PROC_START)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_PRE_PROC_START)); | ||||
| Status ret = CopyModelData(input_data, output_data, is_dynamic_); | Status ret = CopyModelData(input_data, output_data, is_dynamic_); | ||||
| @@ -76,20 +76,6 @@ struct timeInfo { | |||||
| int64_t dumpEndTime; | int64_t dumpEndTime; | ||||
| }; | }; | ||||
| struct TaskMemInfo { | |||||
| int64_t input_size{0}; | |||||
| int64_t output_size{0}; | |||||
| int64_t weight_size{0}; | |||||
| int64_t workspace_size{0}; | |||||
| int64_t total_size{0}; | |||||
| }; | |||||
| struct ProfileInfo { | |||||
| FusionOpInfo fusion_info; | |||||
| TaskMemInfo memory_info; | |||||
| uint32_t task_count{0}; | |||||
| }; | |||||
| enum ExecuteMode { | enum ExecuteMode { | ||||
| INITIALIZATION, | INITIALIZATION, | ||||
| SYNCHRONIZATION, | SYNCHRONIZATION, | ||||
| @@ -240,6 +226,8 @@ class DavinciModel { | |||||
| const vector<OpDescPtr> &GetDataList() const { return data_op_list_; } | const vector<OpDescPtr> &GetDataList() const { return data_op_list_; } | ||||
| // get Op | // get Op | ||||
| const map<uint32_t, OpDescPtr> &GetOpList() const { return op_list_; } | |||||
| OpDescPtr GetOpByIndex(uint32_t index) const { | OpDescPtr GetOpByIndex(uint32_t index) const { | ||||
| if (op_list_.find(index) == op_list_.end()) { | if (op_list_.find(index) == op_list_.end()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -448,6 +436,10 @@ class DavinciModel { | |||||
| int64_t GetLoadEndTime() { return load_end_time_; } | int64_t GetLoadEndTime() { return load_end_time_; } | ||||
| Status SinkModelProfile(); | |||||
| Status SinkTimeProfile(const InputData ¤t_data); | |||||
| Status ReportProfilingData(); | Status ReportProfilingData(); | ||||
| void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | ||||
| @@ -484,14 +476,6 @@ class DavinciModel { | |||||
| void SetTotalIOAddrs(vector<void *> &io_addrs) { | void SetTotalIOAddrs(vector<void *> &io_addrs) { | ||||
| total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end()); | total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end()); | ||||
| } | } | ||||
| void SetHybridArgsSize(uint32_t args_size) { total_hybrid_args_size_ += args_size; } | |||||
| uint32_t GetHybridArgsSize() { | |||||
| return total_hybrid_args_size_; | |||||
| } | |||||
| void *GetCurrentHybridArgsAddr(uint32_t offset) { | |||||
| void *cur_args = static_cast<char *>(hybrid_addrs_) + offset; | |||||
| return cur_args; | |||||
| } | |||||
| void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); | void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); | ||||
| int64_t GetFixedAddrsSize(string tensor_name); | int64_t GetFixedAddrsSize(string tensor_name); | ||||
| void *GetCurrentFixedAddr(int64_t offset) const { | void *GetCurrentFixedAddr(int64_t offset) const { | ||||
| @@ -510,7 +494,7 @@ class DavinciModel { | |||||
| Status MallocKnownArgs(); | Status MallocKnownArgs(); | ||||
| Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs); | Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs); | ||||
| Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs); | Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs); | ||||
| Status UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs); | |||||
| Status UpdateKnownZeroCopyAddr(); | |||||
| void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; } | void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; } | ||||
| Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info); | Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info); | ||||
| @@ -545,6 +529,15 @@ class DavinciModel { | |||||
| struct timeInfo time_info_; | struct timeInfo time_info_; | ||||
| int32_t dataInputTid; | int32_t dataInputTid; | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Save Batch label Info. | |||||
| /// @param [in] const OpDescPtr &op_desc | |||||
| /// @param [in] uintptr_t addr: address value in args block. | |||||
| /// @return None. | |||||
| /// | |||||
| void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Copy Check input size and model op size. | /// @brief Copy Check input size and model op size. | ||||
| @@ -656,6 +649,14 @@ class DavinciModel { | |||||
| /// | /// | ||||
| void AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_index); | void AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_index); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief input zero copy node Initialize. | |||||
| /// @param [in] NodePtr: Data Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status InitInputZeroCopyNodes(const NodePtr &node); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief NetOutput Op Initialize. | /// @brief NetOutput Op Initialize. | ||||
| @@ -664,6 +665,30 @@ class DavinciModel { | |||||
| /// | /// | ||||
| Status InitNetOutput(const NodePtr &node); | Status InitNetOutput(const NodePtr &node); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief output zero copy node Initialize. | |||||
| /// @param [in] NodePtr: Data Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status InitOutputZeroCopyNodes(const NodePtr &node); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief input zero copy node Initialize for Case. | |||||
| /// @param [in] NodePtr: Data Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status InitInputBatchLabel(const NodePtr &node); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief output zero copy node Initialize for Case. | |||||
| /// @param [in] NodePtr: netoutput Op. | |||||
| /// @return Status | |||||
| /// | |||||
| Status InitOutputBatchLabel(const NodePtr &node); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Constant Op Init. | /// @brief Constant Op Init. | ||||
| @@ -812,11 +837,6 @@ class DavinciModel { | |||||
| void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); | void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); | ||||
| Status InitModelProfile(); | |||||
| Status SinkModelProfile(); | |||||
| Status SinkTimeProfile(const InputData ¤t_data); | |||||
| Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, | Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, | ||||
| std::vector<ge::OutputTensorInfo> &outputs); | std::vector<ge::OutputTensorInfo> &outputs); | ||||
| @@ -894,6 +914,11 @@ class DavinciModel { | |||||
| std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr. | std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr. | ||||
| std::set<const void *> copy_only_addrs_; // Address need copy to original place. | std::set<const void *> copy_only_addrs_; // Address need copy to original place. | ||||
| // {op_id, batch_label} | |||||
| std::map<int64_t, std::string> zero_copy_op_id_batch_label_; | |||||
| // {batch_label, addrs} | |||||
| std::map<std::string, std::set<uintptr_t>> zero_copy_batch_label_addrs_; | |||||
| std::vector<TaskInfoPtr> task_list_; | std::vector<TaskInfoPtr> task_list_; | ||||
| // rt_moodel_handle | // rt_moodel_handle | ||||
| rtModel_t rt_model_handle_; | rtModel_t rt_model_handle_; | ||||
| @@ -952,8 +977,6 @@ class DavinciModel { | |||||
| void *args_ = nullptr; | void *args_ = nullptr; | ||||
| void *args_host_ = nullptr; | void *args_host_ = nullptr; | ||||
| void *fixed_addrs_ = nullptr; | void *fixed_addrs_ = nullptr; | ||||
| void *hybrid_addrs_ = nullptr; | |||||
| uint32_t total_hybrid_args_size_ = 0; | |||||
| int64_t total_fixed_addr_size_ = 0; | int64_t total_fixed_addr_size_ = 0; | ||||
| std::map<const void *, void *> knonw_input_data_info_; | std::map<const void *, void *> knonw_input_data_info_; | ||||
| std::map<const void *, void *> knonw_output_data_info_; | std::map<const void *, void *> knonw_output_data_info_; | ||||
| @@ -993,9 +1016,6 @@ class DavinciModel { | |||||
| // key: input_index: input is merge node; value: each gear info and each output shape | // key: input_index: input is merge node; value: each gear info and each output shape | ||||
| std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; | std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; | ||||
| std::vector<std::vector<int64_t>> all_gears_info_; | std::vector<std::vector<int64_t>> all_gears_info_; | ||||
| std::multimap<uint32_t, uint32_t> op_id_map_; | |||||
| std::vector<ProfileInfo> profile_list_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ | ||||
| @@ -89,7 +89,6 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u | |||||
| if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) { | if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) { | ||||
| std::vector<uint64_t> v_aicpu_kernel; | std::vector<uint64_t> v_aicpu_kernel; | ||||
| std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| auto iter = model_aicpu_kernel_.find(model_key); | auto iter = model_aicpu_kernel_.find(model_key); | ||||
| if (iter != model_aicpu_kernel_.end()) { | if (iter != model_aicpu_kernel_.end()) { | ||||
| GELOGD("kernel destroy session_id %lu, model_id %u.", session_id, model_id); | GELOGD("kernel destroy session_id %lu, model_id %u.", session_id, model_id); | ||||
| @@ -177,7 +176,7 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u | |||||
| } | } | ||||
| void ModelManager::DestroyAicpuSession(uint64_t session_id) { | void ModelManager::DestroyAicpuSession(uint64_t session_id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(sess_ids_mutex_); | |||||
| auto it = sess_ids_.find(session_id); | auto it = sess_ids_.find(session_id); | ||||
| if (it == sess_ids_.end()) { | if (it == sess_ids_.end()) { | ||||
| GELOGI("The session: %lu not created.", session_id); | GELOGI("The session: %lu not created.", session_id); | ||||
| @@ -206,7 +205,7 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) { | |||||
| } | } | ||||
| ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| auto hybrid_davinci_model = hybrid_model_map_.find(model_id); | auto hybrid_davinci_model = hybrid_model_map_.find(model_id); | ||||
| if (hybrid_davinci_model != hybrid_model_map_.end()) { | if (hybrid_davinci_model != hybrid_model_map_.end()) { | ||||
| uint64_t session_id = hybrid_davinci_model->second->GetSessionId(); | uint64_t session_id = hybrid_davinci_model->second->GetSessionId(); | ||||
| @@ -216,8 +215,8 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | |||||
| auto it = model_map_.find(model_id); | auto it = model_map_.find(model_id); | ||||
| if (it == model_map_.end()) { | if (it == model_map_.end()) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID; | |||||
| GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id); | |||||
| return GE_EXEC_MODEL_ID_INVALID; | |||||
| } | } | ||||
| uint64_t session_id = it->second->GetSessionId(); | uint64_t session_id = it->second->GetSessionId(); | ||||
| DestroyAicpuSession(session_id); | DestroyAicpuSession(session_id); | ||||
| @@ -226,7 +225,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { | |||||
| ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { | ||||
| GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | ||||
| if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
| Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); | Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); | ||||
| @@ -239,7 +238,7 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_ | |||||
| } | } | ||||
| ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { | ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| std::vector<uint64_t> v_aicpu_kernel; | std::vector<uint64_t> v_aicpu_kernel; | ||||
| std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); | ||||
| if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { | ||||
| @@ -251,7 +250,7 @@ ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_i | |||||
| } | } | ||||
| ModelManager::~ModelManager() { | ModelManager::~ModelManager() { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| model_map_.clear(); | model_map_.clear(); | ||||
| model_aicpu_kernel_.clear(); | model_aicpu_kernel_.clear(); | ||||
| cust_aicpu_so_.clear(); | cust_aicpu_so_.clear(); | ||||
| @@ -359,18 +358,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { | void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { | ||||
| GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); | GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| model_map_[id] = davinci_model; | model_map_[id] = davinci_model; | ||||
| } | } | ||||
| void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | ||||
| GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| hybrid_model_map_[id] = hybrid_model; | hybrid_model_map_[id] = hybrid_model; | ||||
| } | } | ||||
| Status ModelManager::DeleteModel(uint32_t id) { | Status ModelManager::DeleteModel(uint32_t id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| auto it = model_map_.find(id); | auto it = model_map_.find(id); | ||||
| auto hybrid_model_it = hybrid_model_map_.find(id); | auto hybrid_model_it = hybrid_model_map_.find(id); | ||||
| @@ -385,22 +384,22 @@ Status ModelManager::DeleteModel(uint32_t id) { | |||||
| } else if (hybrid_model_it != hybrid_model_map_.end()) { | } else if (hybrid_model_it != hybrid_model_map_.end()) { | ||||
| (void)hybrid_model_map_.erase(hybrid_model_it); | (void)hybrid_model_map_.erase(hybrid_model_it); | ||||
| } else { | } else { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", id); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID; | |||||
| GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", id); | |||||
| return GE_EXEC_MODEL_ID_INVALID; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::shared_ptr<DavinciModel> ModelManager::GetModel(uint32_t id) { | std::shared_ptr<DavinciModel> ModelManager::GetModel(uint32_t id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| auto it = model_map_.find(id); | auto it = model_map_.find(id); | ||||
| return (it == model_map_.end()) ? nullptr : it->second; | return (it == model_map_.end()) ? nullptr : it->second; | ||||
| } | } | ||||
| std::shared_ptr<hybrid::HybridDavinciModel> ModelManager::GetHybridModel(uint32_t id) { | std::shared_ptr<hybrid::HybridDavinciModel> ModelManager::GetHybridModel(uint32_t id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| auto it = hybrid_model_map_.find(id); | auto it = hybrid_model_map_.find(id); | ||||
| return (it == hybrid_model_map_.end()) ? nullptr : it->second; | return (it == hybrid_model_map_.end()) ? nullptr : it->second; | ||||
| @@ -903,7 +902,7 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector<Inpu | |||||
| } | } | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, GE_EXEC_MODEL_ID_INVALID, | |||||
| "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); | "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); | ||||
| davinci_model->SetModelDescVersion(new_model_desc); | davinci_model->SetModelDescVersion(new_model_desc); | ||||
| @@ -971,9 +970,8 @@ Status ModelManager::GetUserDesignateShapeOrder(const uint32_t model_id, | |||||
| } | } | ||||
| Status ModelManager::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) { | Status ModelManager::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) { | ||||
| auto davinci_model = GetModel(model_id); | |||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| "GetCurShape Failed, Invalid Model ID %u!", model_id); | |||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | |||||
| GE_CHECK_NOTNULL(davinci_model); | |||||
| davinci_model->GetCurShape(batch_info, dynamic_type); | davinci_model->GetCurShape(batch_info, dynamic_type); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -986,8 +984,7 @@ Status ModelManager::GetModelAttr(uint32_t model_id, std::vector<string> &dynami | |||||
| } | } | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| "GetModelAttr Failed, Invalid Model ID %u!", model_id); | |||||
| GE_CHECK_NOTNULL(davinci_model); | |||||
| davinci_model->GetModelAttr(dynamic_output_shape_info); | davinci_model->GetModelAttr(dynamic_output_shape_info); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -997,8 +994,9 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, | |||||
| std::vector<uint32_t> &inputFormats, | std::vector<uint32_t> &inputFormats, | ||||
| std::vector<uint32_t> &outputFormats) { | std::vector<uint32_t> &outputFormats) { | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | |||||
| "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); | |||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetInputOutputDescInfo Failed, Invalid model id %u!", | |||||
| model_id); | |||||
| return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc, inputFormats, outputFormats); | return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc, inputFormats, outputFormats); | ||||
| } | } | ||||
| @@ -1013,14 +1011,18 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, | |||||
| Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) { | Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) { | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | ||||
| "GetAIPPInfo failed, invalid model_id is %u.", model_id); | |||||
| "GetAIPPInfo failed, invalid model_id is %u.", | |||||
| model_id); | |||||
| return davinci_model->GetAIPPInfo(index, aipp_info); | return davinci_model->GetAIPPInfo(index, aipp_info); | ||||
| } | } | ||||
| Status ModelManager::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) { | Status ModelManager::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) { | ||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | ||||
| "GetAIPPInfo failed, invalid model_id is %u.", model_id); | |||||
| "GetAIPPInfo failed, invalid model_id is %u.", | |||||
| model_id); | |||||
| return davinci_model->GetAippType(index, type, aipp_index); | return davinci_model->GetAippType(index, type, aipp_index); | ||||
| } | } | ||||
| @@ -1053,15 +1055,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| mmTimespec timespec = mmGetTickCount(); | mmTimespec timespec = mmGetTickCount(); | ||||
| ModelHelper model_helper; | ModelHelper model_helper; | ||||
| Status ret = model_helper.LoadRootModel(model); | |||||
| if (model_helper.GetModelType()) { | |||||
| bool is_shape_unknown = false; | |||||
| GE_CHK_STATUS_RET(model_helper.GetGeRootModel()->CheckIsUnknownShape(is_shape_unknown), | |||||
| "CheckIsUnknownShape failed, model id:%u", model_id); | |||||
| if (is_shape_unknown || GetContext().GetHostExecFlag()) { | |||||
| return DoLoadHybridModelOnline(model_id, model_helper.GetGeRootModel(), listener); | |||||
| } | |||||
| } | |||||
| Status ret = model_helper.LoadModel(model); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "load model failed."); | GELOGE(ret, "load model failed."); | ||||
| return ret; | return ret; | ||||
| @@ -1075,8 +1069,8 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed"); | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed"); | ||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } catch (...) { | } catch (...) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed since other exception raise"); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| GELOGE(INTERNAL_ERROR, "Make shared failed since other exception raise"); | |||||
| return INTERNAL_ERROR; | |||||
| } | } | ||||
| ret = davinci_model->Assign(ge_model); | ret = davinci_model->Assign(ge_model); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -1088,7 +1082,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| int32_t device_id = 0; | int32_t device_id = 0; | ||||
| rtError_t rt_ret = rtGetDevice(&device_id); | rtError_t rt_ret = rtGetDevice(&device_id); | ||||
| if (rt_ret != RT_ERROR_NONE || device_id < 0) { | if (rt_ret != RT_ERROR_NONE || device_id < 0) { | ||||
| GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
| GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| davinci_model->SetDeviceId(device_id); | davinci_model->SetDeviceId(device_id); | ||||
| @@ -1220,7 +1214,7 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
| std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
| GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, | ||||
| "Invalid model id %u, check whether model has been loaded or not.", model_id); | |||||
| "Invalid model id %u, check weather model has been loaded or not.", model_id); | |||||
| if (davinci_model->NeedDestroyAicpuKernel()) { | if (davinci_model->NeedDestroyAicpuKernel()) { | ||||
| GELOGI("Start to destroy specified aicpu kernel."); | GELOGI("Start to destroy specified aicpu kernel."); | ||||
| @@ -1243,7 +1237,7 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
| } | } | ||||
| Status ModelManager::CreateAicpuSession(uint64_t session_id) { | Status ModelManager::CreateAicpuSession(uint64_t session_id) { | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(sess_ids_mutex_); | |||||
| auto it = sess_ids_.find(session_id); | auto it = sess_ids_.find(session_id); | ||||
| // never been created by any model | // never been created by any model | ||||
| if (it == sess_ids_.end()) { | if (it == sess_ids_.end()) { | ||||
| @@ -1462,7 +1456,8 @@ void ModelManager::GenModelId(uint32_t *id) { | |||||
| if (id == nullptr) { | if (id == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| std::lock_guard<std::recursive_mutex> lock(map_mutex_); | |||||
| std::lock_guard<std::mutex> lock(map_mutex_); | |||||
| *id = ++max_model_id_; | *id = ++max_model_id_; | ||||
| } | } | ||||
| @@ -353,7 +353,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | ||||
| std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | ||||
| uint32_t max_model_id_; | uint32_t max_model_id_; | ||||
| std::recursive_mutex map_mutex_; | |||||
| std::mutex map_mutex_; | |||||
| std::mutex sess_ids_mutex_; | |||||
| std::mutex session_id_create_mutex_; | std::mutex session_id_create_mutex_; | ||||
| static::std::mutex exeception_infos_mutex_; | static::std::mutex exeception_infos_mutex_; | ||||
| uint64_t session_id_bias_; | uint64_t session_id_bias_; | ||||
| @@ -90,18 +90,20 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| fusion_op_info_.op_index = context.op_index(); fusion_op_info_.original_op_names = original_op_names; | fusion_op_info_.op_index = context.op_index(); fusion_op_info_.original_op_names = original_op_names; | ||||
| fusion_op_info_.op_name = op_desc_->GetName()); | fusion_op_info_.op_name = op_desc_->GetName()); | ||||
| string session_graph_model_id; | |||||
| davinci_model_->GetUniqueId(op_desc_, session_graph_model_id); | |||||
| // get bin_file_key | |||||
| const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); | |||||
| // new aicpu kernel(rtCpuKernelLaunch) no need to check function | // new aicpu kernel(rtCpuKernelLaunch) no need to check function | ||||
| if (kernel_type_ == ccKernelType::CCE_AI_CORE) { | if (kernel_type_ == ccKernelType::CCE_AI_CORE) { | ||||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_); | |||||
| rtError_t rt_ret; | |||||
| rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | ||||
| kernel_def.stub_func().c_str()); | kernel_def.stub_func().c_str()); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);); | return RT_ERROR_TO_GE_STATUS(rt_ret);); | ||||
| } else if (kernel_type_ == ccKernelType::TE) { | } else if (kernel_type_ == ccKernelType::TE) { | ||||
| // get bin_file_key | |||||
| string session_graph_model_id; | |||||
| davinci_model_->GetUniqueId(op_desc_, session_graph_model_id); | |||||
| const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); | |||||
| rtError_t rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | |||||
| rtError_t rt_ret; | |||||
| rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | ||||
| GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. bin_file_key: %s", bin_file_key); | GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. bin_file_key: %s", bin_file_key); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);); | return RT_ERROR_TO_GE_STATUS(rt_ret);); | ||||
| @@ -370,11 +372,7 @@ Status KernelTaskInfo::SuperKernelDistribute() { | |||||
| Status KernelTaskInfo::Distribute() { | Status KernelTaskInfo::Distribute() { | ||||
| GELOGD("KernelTaskInfo Distribute Start."); | GELOGD("KernelTaskInfo Distribute Start."); | ||||
| if (davinci_model_->IsKnownNode()) { | if (davinci_model_->IsKnownNode()) { | ||||
| if (kernel_type_ == ccKernelType::TE) { | |||||
| args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); | |||||
| } else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| args_ = davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_); | |||||
| } | |||||
| args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); | |||||
| GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); | GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); | ||||
| } | } | ||||
| rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
| @@ -430,31 +428,36 @@ Status KernelTaskInfo::UpdateArgs() { | |||||
| const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | ||||
| vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); | vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); | ||||
| vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); | vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); | ||||
| vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); | |||||
| vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
| io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | |||||
| io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | |||||
| if (kernel_type_ == ccKernelType::TE) { | |||||
| vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); | |||||
| if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { | |||||
| io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | |||||
| io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | |||||
| io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | ||||
| davinci_model_->SetTotalIOAddrs(io_addrs); | |||||
| } else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); | |||||
| uintptr_t io_addr = reinterpret_cast<uintptr_t>(args_addr.get()) + sizeof(aicpu::AicpuParamHead); | |||||
| auto addrs_size = sizeof(uint64_t) * io_addrs.size(); | |||||
| errno_t sec_ret = memcpy_s(reinterpret_cast<void *>(io_addr), addrs_size, io_addrs.data(), addrs_size); | |||||
| if (sec_ret != EOK) { | |||||
| GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); | |||||
| return FAILED; | |||||
| } | |||||
| // copy args to device | |||||
| rtError_t rt_ret = rtMemcpy(args_, args_size_, args_addr.get(), args_size_, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } else { | |||||
| string peer_input_name; | |||||
| if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { | |||||
| uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); | |||||
| if (output_index > output_data_addrs.size()) { | |||||
| GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", | |||||
| output_data_addrs.size(), output_index); | |||||
| return FAILED; | |||||
| } | |||||
| io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | |||||
| for (size_t i = 0; i < output_data_addrs.size(); ++i) { | |||||
| if (i == output_index) { | |||||
| void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); | |||||
| io_addrs.emplace_back(fixed_addr); | |||||
| continue; | |||||
| } | |||||
| io_addrs.emplace_back(output_data_addrs[i]); | |||||
| } | |||||
| io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | |||||
| } | } | ||||
| } | } | ||||
| davinci_model_->SetTotalIOAddrs(io_addrs); | |||||
| GELOGI("KernelTaskInfo::UpdateArgs success."); | GELOGI("KernelTaskInfo::UpdateArgs success."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -530,18 +533,33 @@ Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) { | |||||
| } | } | ||||
| Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | ||||
| const domi::KernelDef &kernel_def = task_def.kernel(); | |||||
| domi::KernelDef kernel_def = task_def.kernel(); | |||||
| uint32_t args_size = kernel_def.args_size(); | |||||
| args_offset_ = davinci_model->GetTotalArgsSize(); | |||||
| davinci_model->SetTotalArgsSize(args_size); | |||||
| GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); | |||||
| // get opcontext stored in model | |||||
| const domi::KernelContext &context = kernel_def.context(); | const domi::KernelContext &context = kernel_def.context(); | ||||
| kernel_type_ = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type_ == ccKernelType::TE) { | |||||
| uint32_t args_size = kernel_def.args_size(); | |||||
| args_offset_ = davinci_model->GetTotalArgsSize(); | |||||
| davinci_model->SetTotalArgsSize(args_size); | |||||
| GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); | |||||
| } else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { | |||||
| hybrid_args_offset_ = davinci_model->GetHybridArgsSize(); | |||||
| davinci_model->SetHybridArgsSize(kernel_def.args_size()); | |||||
| GELOGI("aicpu kernel task name , args_size %u, args_offset %u", kernel_def.args_size(), hybrid_args_offset_); | |||||
| // get opdesc | |||||
| op_desc_ = davinci_model->GetOpByIndex(context.op_index()); | |||||
| GE_CHECK_NOTNULL(op_desc_); | |||||
| // alloc fixed addr | |||||
| string peer_input_name; | |||||
| if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { | |||||
| uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); | |||||
| if (output_index > op_desc_->GetOutputsSize()) { | |||||
| GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", op_desc_->GetOutputsSize(), | |||||
| output_index); | |||||
| return FAILED; | |||||
| } | |||||
| fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); | |||||
| auto tensor_desc = op_desc_->GetOutputDesc(output_index); | |||||
| int64_t tensor_size = 0; | |||||
| GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); | |||||
| davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); | |||||
| GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, | |||||
| fixed_addr_offset_); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -870,7 +888,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| } | } | ||||
| // copy args to new host memory | // copy args to new host memory | ||||
| args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | |||||
| std::unique_ptr<uint8_t[]> args_addr(new (std::nothrow) uint8_t[args_size_]); | |||||
| GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | ||||
| errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | ||||
| if (sec_ret != EOK) { | if (sec_ret != EOK) { | ||||
| @@ -878,23 +896,8 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto aicpu_param_head = reinterpret_cast<aicpu::AicpuParamHead *>(args_addr.get()); | |||||
| const auto &ext_info = kernel_def.kernel_ext_info(); | |||||
| auto init_ret = InitAicpuTaskExtInfo(ext_info); | |||||
| if (init_ret != SUCCESS) { | |||||
| GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size()); | |||||
| return init_ret; | |||||
| } | |||||
| GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(), | |||||
| op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_); | |||||
| aicpu_param_head->extInfoAddr = reinterpret_cast<uintptr_t>(aicpu_ext_info_addr_); | |||||
| aicpu_param_head->extInfoLength = static_cast<uintptr_t>(ext_info.size()); | |||||
| if (davinci_model_->IsKnownNode()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | ||||
| vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | ||||
| vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | ||||
| vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
| @@ -911,6 +914,19 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| } | } | ||||
| } | } | ||||
| auto aicpu_param_head = reinterpret_cast<aicpu::AicpuParamHead *>(args_addr.get()); | |||||
| const auto &ext_info = kernel_def.kernel_ext_info(); | |||||
| auto init_ret = InitAicpuTaskExtInfo(ext_info); | |||||
| if (init_ret != SUCCESS) { | |||||
| GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size()); | |||||
| return init_ret; | |||||
| } | |||||
| GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(), | |||||
| op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_); | |||||
| aicpu_param_head->extInfoAddr = reinterpret_cast<uintptr_t>(aicpu_ext_info_addr_); | |||||
| aicpu_param_head->extInfoLength = static_cast<uintptr_t>(ext_info.size()); | |||||
| // malloc device memory for args | // malloc device memory for args | ||||
| rtError_t rt_ret = rtMalloc(static_cast<void **>(&args_), args_size_, RT_MEMORY_HBM); | rtError_t rt_ret = rtMalloc(static_cast<void **>(&args_), args_size_, RT_MEMORY_HBM); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -159,9 +159,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
| OpDescPtr op_desc_; | OpDescPtr op_desc_; | ||||
| DavinciModel *davinci_model_; | DavinciModel *davinci_model_; | ||||
| uint32_t args_offset_ = 0; | uint32_t args_offset_ = 0; | ||||
| uint32_t hybrid_args_offset_ = 0; | |||||
| int64_t fixed_addr_offset_ = 0; | int64_t fixed_addr_offset_ = 0; | ||||
| std::unique_ptr<uint8_t[]> args_addr = nullptr; | |||||
| bool call_save_dump_ = false; | bool call_save_dump_ = false; | ||||
| // aicpu ext_info device mem | // aicpu ext_info device mem | ||||
| @@ -183,18 +183,22 @@ void ZeroCopyOffset::SetOutputOutsideAddrs(const int64_t &input_offset, const bo | |||||
| addr_count_ = out_count; | addr_count_ = out_count; | ||||
| } | } | ||||
| void ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset) { | |||||
| bool ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset) { | |||||
| const auto addr_val = reinterpret_cast<uintptr_t>(outside_addr); | const auto addr_val = reinterpret_cast<uintptr_t>(outside_addr); | ||||
| bool set_batch_label_flag = false; | |||||
| for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { | for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { | ||||
| auto args_addrs = outside_addrs_[out_count].find(outside_addr); | |||||
| if (args_addrs != outside_addrs_[out_count].end()) { | |||||
| auto &addrs_mapping_list = GetOutsideAddrs(); | |||||
| auto args_addrs = addrs_mapping_list[out_count].find(outside_addr); | |||||
| if (args_addrs != addrs_mapping_list[out_count].end()) { | |||||
| GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid."); | GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid."); | ||||
| void *args_val = static_cast<uint8_t *>(args) + offset; | void *args_val = static_cast<uint8_t *>(args) + offset; | ||||
| args_addrs->second.push_back(args_val); | args_addrs->second.push_back(args_val); | ||||
| GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, | GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, | ||||
| args, offset); | args, offset); | ||||
| set_batch_label_flag = true; | |||||
| } | } | ||||
| } | } | ||||
| return set_batch_label_flag; | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -51,7 +51,7 @@ class ZeroCopyOffset { | |||||
| const OpDescPtr &op_desc, const size_t &idx, bool &fusion_flag); | const OpDescPtr &op_desc, const size_t &idx, bool &fusion_flag); | ||||
| void SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr, | void SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr, | ||||
| std::vector<void *> &tensor_addrs); | std::vector<void *> &tensor_addrs); | ||||
| void SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset); | |||||
| bool SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset); | |||||
| // basic_addr of l2-fusion | // basic_addr of l2-fusion | ||||
| void *GetBasicAddr() const { return basic_addr_; } | void *GetBasicAddr() const { return basic_addr_; } | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include "common/ge_compiler_options.h" | #include "common/ge_compiler_options.h" | ||||
| namespace ge { | namespace ge { | ||||
| const char *const kDefaultBatchLable = "Batch_default"; | |||||
| ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size) | ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size) | ||||
| : name_(name), args_addr_(args), args_size_(size), is_updated_(false) {} | : name_(name), args_addr_(args), args_size_(size), is_updated_(false) {} | ||||
| @@ -64,18 +66,59 @@ void ZeroCopyTask::SetOriginalArgs(const void *info, size_t size) { | |||||
| const uint8_t *data = static_cast<const uint8_t *>(info); | const uint8_t *data = static_cast<const uint8_t *>(info); | ||||
| args_info_.assign(data, data + size); | args_info_.assign(data, data + size); | ||||
| GELOGI("[ZCPY] %s set original args info: %p, args_addr: %p, args size: %zu, info size: %zu", name_.c_str(), info, | |||||
| GELOGI("[ZCPY] %s set info from virtual_addr: %p, args_addr: %p, args size: %zu, info size: %zu", name_.c_str(), info, | |||||
| args_addr_, args_size_, size); | args_addr_, args_size_, size); | ||||
| } | } | ||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief Check is dynamic batch node. | |||||
| * @param [in] addr: virtual address value from Op. | |||||
| * @param [in] data: data buffer from user. | |||||
| * @param [in] batch_addrs: dynamic batch addr info. | |||||
| * @param [in] batch_label: batch label. | |||||
| * @return: true / false | |||||
| */ | |||||
| bool ZeroCopyTask::CheckDynamicBatch(const map<string, set<uintptr_t>> &batch_addrs, const string &batch_label, | |||||
| uintptr_t addr) { | |||||
| // Used for dynamic batch / resolution scene | |||||
| set<uintptr_t> dynamic_input_addrs; | |||||
| auto dynamic_input_iter = batch_addrs.find(batch_label); | |||||
| if (dynamic_input_iter != batch_addrs.end()) { | |||||
| dynamic_input_addrs = dynamic_input_iter->second; | |||||
| } | |||||
| set<uintptr_t> fix_input_addrs; | |||||
| auto fix_input_iter = batch_addrs.find(kDefaultBatchLable); | |||||
| if (fix_input_iter != batch_addrs.end()) { | |||||
| fix_input_addrs = fix_input_iter->second; | |||||
| } | |||||
| if (fix_input_addrs.empty()) { | |||||
| if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end()) { | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end() && | |||||
| fix_input_addrs.find(addr) == fix_input_addrs.end()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| /** | /** | ||||
| * @ingroup ge | * @ingroup ge | ||||
| * @brief Set user data addr to Task param. | * @brief Set user data addr to Task param. | ||||
| * @param [in] addr: virtual address value from Op. | * @param [in] addr: virtual address value from Op. | ||||
| * @param [in] buffer_addr: real_data_buffer_addr from user. | * @param [in] buffer_addr: real_data_buffer_addr from user. | ||||
| * @param [in] batch_addrs: dynamic batch addr info. | |||||
| * @param [in] batch_label: batch label. | |||||
| * @return: void | * @return: void | ||||
| */ | */ | ||||
| Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr) { | |||||
| Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map<string, set<uintptr_t>> &batch_addrs, | |||||
| const string &batch_label) { | |||||
| auto iter = task_addr_offset_.find(addr); | auto iter = task_addr_offset_.find(addr); | ||||
| if (iter != task_addr_offset_.end()) { | if (iter != task_addr_offset_.end()) { | ||||
| auto &cur_pair = *iter; | auto &cur_pair = *iter; | ||||
| @@ -67,9 +67,12 @@ class ZeroCopyTask { | |||||
| * @brief Set user data addr to Task param. | * @brief Set user data addr to Task param. | ||||
| * @param [in] addr: virtual address value from Op. | * @param [in] addr: virtual address value from Op. | ||||
| * @param [in] buffer_addr: data buffer_addr from user. | * @param [in] buffer_addr: data buffer_addr from user. | ||||
| * @param [in] batch_addrs: dynamic batch addr info. | |||||
| * @param [in] batch_label: batch label. | |||||
| * @return: 0 SUCCESS / others FAILED | * @return: 0 SUCCESS / others FAILED | ||||
| */ | */ | ||||
| ge::Status UpdateTaskParam(uintptr_t addr, void *buffer_addr); | |||||
| ge::Status UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map<string, set<uintptr_t>> &batch_addrs, | |||||
| const string &batch_label); | |||||
| /** | /** | ||||
| * @ingroup ge | * @ingroup ge | ||||
| @@ -88,6 +91,9 @@ class ZeroCopyTask { | |||||
| return batch_label_; | return batch_label_; | ||||
| } | } | ||||
| protected: | |||||
| bool CheckDynamicBatch(const map<string, set<uintptr_t>> &batch_addrs, const string &batch_label, uintptr_t addr); | |||||
| private: | private: | ||||
| const string name_; | const string name_; | ||||
| @@ -23,15 +23,25 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <string> | #include <string> | ||||
| #include <thread> | #include <thread> | ||||
| #include <utility> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "common/thread_pool.h" | #include "common/thread_pool.h" | ||||
| #include "common/util.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "framework/common/ge_types.h" | |||||
| #include "analyzer/analyzer.h" | #include "analyzer/analyzer.h" | ||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| #include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/ge_global_options.h" | #include "graph/ge_global_options.h" | ||||
| #include "graph/ge_local_context.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/util/rt_context_util.h" | #include "graph/manager/util/rt_context_util.h" | ||||
| #include "graph/partition/dynamic_shape_partition.h" | #include "graph/partition/dynamic_shape_partition.h" | ||||
| #include "graph/passes/enter_pass.h" | #include "graph/passes/enter_pass.h" | ||||
| @@ -51,6 +61,8 @@ | |||||
| #include "graph/passes/dimension_adjust_pass.h" | #include "graph/passes/dimension_adjust_pass.h" | ||||
| #include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
| #include "graph/passes/flow_ctrl_pass.h" | #include "graph/passes/flow_ctrl_pass.h" | ||||
| #include "graph/passes/hccl_group_pass.h" | |||||
| #include "graph/passes/hccl_memcpy_pass.h" | |||||
| #include "graph/passes/identity_pass.h" | #include "graph/passes/identity_pass.h" | ||||
| #include "graph/passes/input_output_connection_identify_pass.h" | #include "graph/passes/input_output_connection_identify_pass.h" | ||||
| #include "graph/passes/iterator_op_pass.h" | #include "graph/passes/iterator_op_pass.h" | ||||
| @@ -65,7 +77,7 @@ | |||||
| #include "graph/passes/permute_pass.h" | #include "graph/passes/permute_pass.h" | ||||
| #include "graph/passes/prune_pass.h" | #include "graph/passes/prune_pass.h" | ||||
| #include "graph/passes/ref_identity_delete_op_pass.h" | #include "graph/passes/ref_identity_delete_op_pass.h" | ||||
| #include "graph/passes/remove_same_const_pass.h" | |||||
| #include "graph/passes/replace_with_empty_const_pass.h" | |||||
| #include "graph/passes/reshape_recovery_pass.h" | #include "graph/passes/reshape_recovery_pass.h" | ||||
| #include "graph/passes/reshape_remove_pass.h" | #include "graph/passes/reshape_remove_pass.h" | ||||
| #include "graph/passes/same_transdata_breadth_fusion_pass.h" | #include "graph/passes/same_transdata_breadth_fusion_pass.h" | ||||
| @@ -75,12 +87,13 @@ | |||||
| #include "graph/passes/switch_logic_remove_pass.h" | #include "graph/passes/switch_logic_remove_pass.h" | ||||
| #include "graph/passes/switch_to_stream_switch_pass.h" | #include "graph/passes/switch_to_stream_switch_pass.h" | ||||
| #include "graph/passes/transop_breadth_fusion_pass.h" | #include "graph/passes/transop_breadth_fusion_pass.h" | ||||
| #include "graph/passes/transop_depth_fusion_pass.h" | |||||
| #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" | #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" | ||||
| #include "graph/passes/transop_symmetry_elimination_pass.h" | #include "graph/passes/transop_symmetry_elimination_pass.h" | ||||
| #include "graph/passes/transop_without_reshape_fusion_pass.h" | #include "graph/passes/transop_without_reshape_fusion_pass.h" | ||||
| #include "graph/passes/transpose_transdata_pass.h" | #include "graph/passes/transpose_transdata_pass.h" | ||||
| #include "graph/passes/useless_control_out_remove_pass.h" | |||||
| #include "graph/passes/variable_op_pass.h" | #include "graph/passes/variable_op_pass.h" | ||||
| #include "graph/passes/variable_prepare_op_pass.h" | |||||
| #include "graph/passes/variable_ref_delete_op_pass.h" | #include "graph/passes/variable_ref_delete_op_pass.h" | ||||
| #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" | #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" | ||||
| #include "graph/passes/end_of_sequence_add_control_pass.h" | #include "graph/passes/end_of_sequence_add_control_pass.h" | ||||
| @@ -91,6 +104,9 @@ | |||||
| #include "graph/passes/memcpy_addr_async_pass.h" | #include "graph/passes/memcpy_addr_async_pass.h" | ||||
| #include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "graph/utils/type_utils.h" | |||||
| #include "graph/graph_util.h" | |||||
| #include "graph/types.h" | |||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "ir_build/atc_ir_common.h" | #include "ir_build/atc_ir_common.h" | ||||
| @@ -534,8 +550,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | ||||
| compute_graph->GetGraphID(), subgraph, | |||||
| compute_graph->GetName(), session_id, | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, | |||||
| GetThreadLocalContext()); | GetThreadLocalContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| @@ -550,8 +565,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr | |||||
| (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | ||||
| } | } | ||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, | ||||
| compute_graph->GetGraphID(), subgraph, | |||||
| compute_graph->GetName(), session_id, | |||||
| compute_graph->GetGraphID(), subgraph, compute_graph, session_id, | |||||
| GetThreadLocalContext()); | GetThreadLocalContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | GELOGE(FAILED, "Future is invalid"); | ||||
| @@ -2134,7 +2148,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| TransposeTransDataPass transpose_transdata_pass; | TransposeTransDataPass transpose_transdata_pass; | ||||
| TransOpSymmetryEliminationPass symmetry_elimination_pass; | TransOpSymmetryEliminationPass symmetry_elimination_pass; | ||||
| DimensionComputePass dimension_compute_pass; | DimensionComputePass dimension_compute_pass; | ||||
| UselessControlOutRemovePass useless_control_out_remove_pass; | |||||
| names_to_passes.emplace_back("EnterPass", &enter_pass); | names_to_passes.emplace_back("EnterPass", &enter_pass); | ||||
| names_to_passes.emplace_back("AddNPass", &addn_pass); | names_to_passes.emplace_back("AddNPass", &addn_pass); | ||||
| names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | ||||
| @@ -2148,7 +2161,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | ||||
| names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
| names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | ||||
| names_to_passes.emplace_back("UselessControlOutRemovePass", &useless_control_out_remove_pass); | |||||
| GE_TIMESTAMP_START(names_to_passes); | GE_TIMESTAMP_START(names_to_passes); | ||||
| ret = GEPass(compute_graph).Run(names_to_passes); | ret = GEPass(compute_graph).Run(names_to_passes); | ||||
| GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); | GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); | ||||
| @@ -2189,8 +2201,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", | ||||
| new (std::nothrow) VariableRefUselessControlOutDeletePass)) | new (std::nothrow) VariableRefUselessControlOutDeletePass)) | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) | ||||
| GE_CHK_STATUS_RET( | |||||
| graph_pass.AddPass("OptimizeStage1_3::RemoveSameConstPass", new (std::nothrow) RemoveSameConstPass)) | |||||
| if (options_.train_graph_flag) { | if (options_.train_graph_flag) { | ||||
| // Priority: The GlobalStepInsertPass should work before graph partitioner. | // Priority: The GlobalStepInsertPass should work before graph partitioner. | ||||
| // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory | // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory | ||||
| @@ -2461,8 +2471,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra | |||||
| Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | ||||
| const SubGraphInfoPtr &sub_graph_info_ptr, | const SubGraphInfoPtr &sub_graph_info_ptr, | ||||
| const std::string &root_graph_name, | |||||
| uint64_t session_id, | |||||
| const ComputeGraphPtr &compute_graph, uint64_t session_id, | |||||
| const GEThreadLocalContext &ge_context) { | const GEThreadLocalContext &ge_context) { | ||||
| if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | ||||
| GetContext().SetSessionId(session_id); | GetContext().SetSessionId(session_id); | ||||
| @@ -2479,13 +2488,9 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); | GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (!AttrUtils::SetStr(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_NAME, root_graph_name)) { | |||||
| GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_NAME for subgraph, \ | |||||
| root_graph_name: %s.", root_graph_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| compute_graph_tmp->SetSessionID(session_id); | compute_graph_tmp->SetSessionID(session_id); | ||||
| Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, | ||||
| compute_graph, | |||||
| engine_name); | engine_name); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str()); | GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str()); | ||||
| @@ -219,8 +219,7 @@ class GraphManager { | |||||
| static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, | ||||
| const SubGraphInfoPtr &sub_graph_info_ptr, | const SubGraphInfoPtr &sub_graph_info_ptr, | ||||
| const std::string &root_graph_name, | |||||
| uint64_t session_id, | |||||
| const ComputeGraphPtr &compute_graph, uint64_t session_id, | |||||
| const GEThreadLocalContext &ge_context); | const GEThreadLocalContext &ge_context); | ||||
| Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); | Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); | ||||
| void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); | void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); | ||||
| @@ -16,7 +16,10 @@ | |||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/manager/graph_caching_allocator.h" | #include "graph/manager/graph_caching_allocator.h" | ||||
| #include "graph/manager/rdma_pool_allocator.h" | #include "graph/manager/rdma_pool_allocator.h" | ||||
| @@ -76,7 +76,8 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) { | |||||
| } | } | ||||
| } | } | ||||
| Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name) { | |||||
| Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph, | |||||
| const std::string &engine_name) { | |||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr."); | GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr."); | ||||
| return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | ||||
| @@ -105,6 +106,10 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std | |||||
| for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | ||||
| Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph)); | Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph)); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| auto root_graph = ge::GraphUtils::FindRootGraph(parent_graph); | |||||
| if (root_graph != nullptr) { | |||||
| ErrorManager::GetInstance().SaveMstuneCompileFailedMsg(root_graph->GetName()); | |||||
| } | |||||
| GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret); | GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -42,7 +42,8 @@ class GraphOptimize { | |||||
| ~GraphOptimize() = default; | ~GraphOptimize() = default; | ||||
| // subgraph optimize | // subgraph optimize | ||||
| Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name); | |||||
| Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph, | |||||
| const std::string &engine_name); | |||||
| // original graph optimize | // original graph optimize | ||||
| Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | ||||
| @@ -18,8 +18,6 @@ | |||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| using std::string; | |||||
| namespace ge { | namespace ge { | ||||
| Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("AttachStreamLabelPass Enter."); | GELOGD("AttachStreamLabelPass Enter."); | ||||
| @@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||||
| } | } | ||||
| std::stack<NodePtr> enter_nodes; | std::stack<NodePtr> enter_nodes; | ||||
| std::string batch_label; | |||||
| for (const auto &enter_node : pair.second) { | for (const auto &enter_node : pair.second) { | ||||
| enter_nodes.emplace(enter_node); | enter_nodes.emplace(enter_node); | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| if (!tmp_label.empty()) { | |||||
| if (batch_label.empty()) { | |||||
| batch_label = tmp_label; | |||||
| } else if (batch_label != tmp_label) { | |||||
| GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { | |||||
| if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { | |||||
| GELOGE(FAILED, "Update stream_label for loop_branch failed."); | GELOGE(FAILED, "Update stream_label for loop_branch failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| } | } | ||||
| for (const auto &enter_node : enter_nodes) { | for (const auto &enter_node : enter_nodes) { | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||||
| GE_CHECK_NOTNULL(enter_node->GetOpDesc()); | |||||
| if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
| /// @param [in] batch_label | /// @param [in] batch_label | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) { | |||||
| Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||||
| const std::string &batch_label) { | |||||
| std::stack<NodePtr> nodes(enter_nodes); | std::stack<NodePtr> nodes(enter_nodes); | ||||
| NodePtr cur_node = nullptr; | NodePtr cur_node = nullptr; | ||||
| while (!nodes.empty()) { | while (!nodes.empty()) { | ||||
| @@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_ | |||||
| for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { | for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { | ||||
| OpDescPtr out_desc = out_node->GetOpDesc(); | OpDescPtr out_desc = out_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(out_desc); | GE_CHECK_NOTNULL(out_desc); | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| if (!tmp_label.empty() && (tmp_label != batch_label)) { | |||||
| continue; | |||||
| } | |||||
| std::string out_type = out_desc->GetType(); | std::string out_type = out_desc->GetType(); | ||||
| bool need_skip = | bool need_skip = | ||||
| out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || | out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || | ||||
| @@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass { | |||||
| /// @brief Update stream_label for loop_branch | /// @brief Update stream_label for loop_branch | ||||
| /// @param [in] enter_nodes | /// @param [in] enter_nodes | ||||
| /// @param [in] stream_label | /// @param [in] stream_label | ||||
| /// @param [in] batch_label | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label); | |||||
| static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||||
| const std::string &batch_label); | |||||
| /// | /// | ||||
| /// @brief Update stream_label start with enter nodes | /// @brief Update stream_label start with enter nodes | ||||
| @@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder | |||||
| node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||||
| if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||||
| GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); | GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); | ||||
| nodes_re_pass.insert(node_to_re_pass); | nodes_re_pass.insert(node_to_re_pass); | ||||
| } else { | } else { | ||||
| @@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| ret = DealWithInNodes(node); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| std::vector<int> data_relink_io_map = {kDataInputIndex}; | std::vector<int> data_relink_io_map = {kDataInputIndex}; | ||||
| return IsolateAndDeleteNode(node, data_relink_io_map); | return IsolateAndDeleteNode(node, data_relink_io_map); | ||||
| } | } | ||||
| Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto &in_data_anchor : in_data_anchors) { | |||||
| if (in_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_node_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (in_node_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_node = in_node_anchor->GetOwnerNode(); | |||||
| if (in_node->GetType() == SWITCHN) { | |||||
| auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); | |||||
| auto identity = | |||||
| AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); | |||||
| GE_CHECK_NOTNULL(identity); | |||||
| GELOGI("Create new identity node[%s] after node %s[type: %s] success.", identity->GetName().c_str(), | |||||
| in_node->GetName().c_str(), in_node->GetType().c_str()); | |||||
| GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0))) | |||||
| GE_CHECK_NOTNULL(identity->GetOutControlAnchor()); | |||||
| if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor, | |||||
| ComputeGraphPtr &graph) { | |||||
| if (graph == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node."); | |||||
| return nullptr; | |||||
| } | |||||
| OpDescPtr desc = MakeShared<OpDesc>("", ""); | |||||
| if (desc == nullptr) { | |||||
| GELOGE(MEMALLOC_FAILED, "Failed to create op desc."); | |||||
| return nullptr; | |||||
| } | |||||
| desc->SetName(name); | |||||
| desc->SetType(IDENTITY); | |||||
| auto ret = desc->AddInputDesc(tensor); | |||||
| auto ret2 = desc->AddOutputDesc(tensor); | |||||
| if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity."); | |||||
| return nullptr; | |||||
| } | |||||
| return graph->AddNodeFront(desc); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -34,10 +34,6 @@ namespace ge { | |||||
| class DimensionAdjustPass : public BaseNodePass { | class DimensionAdjustPass : public BaseNodePass { | ||||
| public: | public: | ||||
| Status Run(ge::NodePtr &node) override; | Status Run(ge::NodePtr &node) override; | ||||
| private: | |||||
| Status DealWithInNodes(ge::NodePtr &node); | |||||
| NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -23,7 +23,6 @@ | |||||
| namespace { | namespace { | ||||
| const size_t kOutNodesNum = 1; | const size_t kOutNodesNum = 1; | ||||
| const size_t kInCtrlNodesNum = 1; | |||||
| } | } | ||||
| namespace ge { | namespace ge { | ||||
| @@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| if (out_ctrl_node == nullptr) { | if (out_ctrl_node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGI("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); | |||||
| if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), | GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), | ||||
| out_ctrl_node->GetName().c_str()); | out_ctrl_node->GetName().c_str()); | ||||
| @@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (OptimizeEnterWithOnlyDataOut(node, in_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str()); | |||||
| if (OptimizeEnter(node, in_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node) { | |||||
| Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | |||||
| if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -89,61 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node) | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | ||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))) | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | |||||
| const auto &out_data_anchor = node->GetOutDataAnchor(0); | const auto &out_data_anchor = node->GetOutDataAnchor(0); | ||||
| GE_CHECK_NOTNULL(out_data_anchor); | GE_CHECK_NOTNULL(out_data_anchor); | ||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | ||||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)) | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)) | |||||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | |||||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)) | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||||
| AddNodeDeleted(node); | AddNodeDeleted(node); | ||||
| AddRePassNodesWithInOut(in_node); | AddRePassNodesWithInOut(in_node); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) { | |||||
| auto out_ctrl_nodes = node->GetOutControlNodes(); | |||||
| if (out_ctrl_nodes.empty()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | |||||
| for (auto &out_ctrl_node : out_ctrl_nodes) { | |||||
| GE_CHECK_NOTNULL(out_ctrl_node); | |||||
| if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) { | |||||
| continue; | |||||
| } | |||||
| auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes(); | |||||
| if (in_ctrl_nodes.size() != kInCtrlNodesNum) { | |||||
| continue; | |||||
| } | |||||
| // Skip when has merge out | |||||
| bool has_merge_out = false; | |||||
| auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes(); | |||||
| for (const auto &out_node_of_const : out_nodes_of_const) { | |||||
| GE_CHECK_NOTNULL(out_node_of_const); | |||||
| if (out_node_of_const->GetType() == MERGE || out_node_of_const->GetType() == REFMERGE) { | |||||
| has_merge_out = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_merge_out) { | |||||
| continue; | |||||
| } | |||||
| GELOGI("Unlink control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor())) | |||||
| for (auto &out_node_of_const : out_nodes_of_const) { | |||||
| if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) { | |||||
| GELOGI("Link control edge from %s to %s.", node->GetName().c_str(), out_node_of_const->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass { | |||||
| Status Run(NodePtr &node) override; | Status Run(NodePtr &node) override; | ||||
| private: | private: | ||||
| Status OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node); | |||||
| Status UnlinkCtrlEdgeBeforeConst(NodePtr &node); | |||||
| Status OptimizeEnter(NodePtr &node, NodePtr &in_node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ | #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ | ||||
| @@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto in_node = in_node_anchor->GetOwnerNode(); | auto in_node = in_node_anchor->GetOwnerNode(); | ||||
| if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { | |||||
| if (in_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { | |||||
| GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | ||||
| auto ret = in_node_anchor->Unlink(in_data_anchor); | auto ret = in_node_anchor->Unlink(in_data_anchor); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||||
| GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | ||||
| } | } | ||||
| if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { | |||||
| string batch_label; | |||||
| (void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (!batch_label.empty()) { | |||||
| auto stream_merge_desc = stream_merge->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(stream_merge_desc); | |||||
| (void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| } | |||||
| } | |||||
| return AddActiveNodes(graph, stream_merge); | return AddActiveNodes(graph, stream_merge); | ||||
| } | } | ||||
| @@ -19,8 +19,6 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| using std::string; | |||||
| namespace ge { | namespace ge { | ||||
| Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("NextIterationPass Enter"); | GELOGD("NextIterationPass Enter"); | ||||
| @@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (GroupWithNoBatch(graph) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (FindWhileGroups() != SUCCESS) { | if (FindWhileGroups() != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Find while groups failed."); | GELOGE(INTERNAL_ERROR, "Find while groups failed."); | ||||
| @@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| string batch_label; | |||||
| if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| frame_name += batch_label; | |||||
| std::string batch_label; | |||||
| (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||||
| if (batch_label.empty()) { | |||||
| auto frame_iter = frame_enter_map_.find(frame_name); | |||||
| if (frame_iter == frame_enter_map_.end()) { | |||||
| std::vector<NodePtr> enter_nodes; | |||||
| enter_nodes.emplace_back(enter_node); | |||||
| frame_enter_map_[frame_name] = enter_nodes; | |||||
| } else { | |||||
| frame_iter->second.emplace_back(enter_node); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| auto iter = loop_group_map_.find(frame_name); | |||||
| if (iter == loop_group_map_.end()) { | |||||
| auto group_iter = loop_group_map_.find(frame_name); | |||||
| if (group_iter == loop_group_map_.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | ||||
| if (loop_group == nullptr) { | if (loop_group == nullptr) { | ||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| loop_group->enter_nodes.emplace_back(enter_node); | loop_group->enter_nodes.emplace_back(enter_node); | ||||
| loop_group_map_[frame_name] = loop_group; | |||||
| loop_group_map_[frame_name][batch_label] = loop_group; | |||||
| } else { | } else { | ||||
| iter->second->enter_nodes.emplace_back(enter_node); | |||||
| auto batch_iter = group_iter->second.find(batch_label); | |||||
| if (batch_iter == group_iter->second.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||||
| if (loop_group == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||||
| return FAILED; | |||||
| } | |||||
| loop_group->enter_nodes.emplace_back(enter_node); | |||||
| group_iter->second[batch_label] = loop_group; | |||||
| } else { | |||||
| batch_iter->second->enter_nodes.emplace_back(enter_node); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Group Enter nodes without batch_label attr | |||||
| /// @param [in] compute_graph | |||||
| /// @return Status | |||||
| /// | |||||
| Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { | |||||
| if (frame_enter_map_.empty()) { | |||||
| GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &item : frame_enter_map_) { | |||||
| const std::string &frame_name = item.first; | |||||
| auto iter = loop_group_map_.find(frame_name); | |||||
| if (iter == loop_group_map_.end()) { | |||||
| LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||||
| if (loop_group == nullptr) { | |||||
| GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||||
| return FAILED; | |||||
| } | |||||
| loop_group->enter_nodes = item.second; | |||||
| loop_group_map_[frame_name][""] = loop_group; | |||||
| } else { | |||||
| for (auto &batch_item : iter->second) { | |||||
| for (const auto &enter_node : item.second) { | |||||
| batch_item.second->enter_nodes.emplace_back(enter_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||||
| Status NextIterationPass::FindWhileGroups() { | Status NextIterationPass::FindWhileGroups() { | ||||
| for (const auto &loop_group_iter : loop_group_map_) { | for (const auto &loop_group_iter : loop_group_map_) { | ||||
| const std::string &frame_name = loop_group_iter.first; | const std::string &frame_name = loop_group_iter.first; | ||||
| for (const auto &enter_node : loop_group_iter.second->enter_nodes) { | |||||
| for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||||
| const string &type = out_node->GetType(); | |||||
| if ((type != MERGE) && (type != REFMERGE)) { | |||||
| continue; | |||||
| } | |||||
| NodePtr next_node = nullptr; | |||||
| if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||||
| NodePtr switch_node = nullptr; | |||||
| if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (switch_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodePtr loop_cond = nullptr; | |||||
| if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (loop_group_iter.second->loop_cond == nullptr) { | |||||
| loop_group_iter.second->loop_cond = loop_cond; | |||||
| } else if (loop_group_iter.second->loop_cond != loop_cond) { | |||||
| GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); | |||||
| return FAILED; | |||||
| for (const auto &batch_iter : loop_group_iter.second) { | |||||
| const std::string &batch_label = batch_iter.first; | |||||
| for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||||
| for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||||
| GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), | |||||
| frame_name.c_str(), batch_label.c_str()); | |||||
| if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { | |||||
| continue; | |||||
| } | |||||
| std::string tmp_label; | |||||
| GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
| (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||||
| if (need_skip) { | |||||
| continue; | |||||
| } | |||||
| NodePtr next_node = nullptr; | |||||
| if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||||
| NodePtr switch_node = nullptr; | |||||
| if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (switch_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodePtr loop_cond = nullptr; | |||||
| if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", | |||||
| switch_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (batch_iter.second->loop_cond == nullptr) { | |||||
| batch_iter.second->loop_cond = loop_cond; | |||||
| } else if (batch_iter.second->loop_cond != loop_cond) { | |||||
| GELOGE(FAILED, "Multi LoopCond nodes exist."); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); | GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (loop_group_iter.second->loop_cond == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { | |||||
| if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||||
| frame_name.c_str()); | |||||
| for (const auto &batch_iter : loop_group_iter.second) { | |||||
| if (batch_iter.second->loop_cond == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { | |||||
| if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||||
| frame_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| /// | /// | ||||
| Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | ||||
| for (const auto &loop_cond_iter : loop_group_map_) { | for (const auto &loop_cond_iter : loop_group_map_) { | ||||
| const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | |||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||||
| NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||||
| NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||||
| if ((enter_active == nullptr) || (next_active == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { | |||||
| // Enter --> Active | |||||
| if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(), | |||||
| enter_active->GetName().c_str()); | |||||
| for (const auto &batch_iter : loop_cond_iter.second) { | |||||
| const std::string &cond_name = batch_iter.second->loop_cond->GetName(); | |||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||||
| NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||||
| NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||||
| if ((enter_active == nullptr) || (next_active == nullptr)) { | |||||
| GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | |||||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||||
| NodePtr merge_node = pair.first; | |||||
| NodePtr next_node = pair.second; | |||||
| // Active --> Merge | |||||
| if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||||
| // Enter --> Active | |||||
| if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != | |||||
| GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| // NextIteration --> Active | |||||
| if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| for (const auto &pair : batch_iter.second->merge_next_pairs) { | |||||
| NodePtr merge_node = pair.first; | |||||
| NodePtr next_node = pair.second; | |||||
| // Active --> Merge | |||||
| if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != | |||||
| GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // NextIteration --> Active | |||||
| if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // break link between NextIteration and Merge | |||||
| if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| // break link between NextIteration and Merge | |||||
| if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||||
| (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||||
| (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] target_type | /// @param [in] target_type | ||||
| /// @param [in] is_input | /// @param [in] is_input | ||||
| /// @param [in] batch_label | |||||
| /// @param [out] target_node | /// @param [out] target_node | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | ||||
| NodePtr &target_node) { | |||||
| const std::string &batch_label, NodePtr &target_node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "node is null."); | GELOGE(PARAM_INVALID, "node is null."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||||
| } | } | ||||
| for (const auto &tmp_node : nodes) { | for (const auto &tmp_node : nodes) { | ||||
| std::string tmp_label; | |||||
| (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||||
| bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||||
| if (need_skip) { | |||||
| continue; | |||||
| } | |||||
| const std::string type = tmp_node->GetType(); | const std::string type = tmp_node->GetType(); | ||||
| if ((target_type == LOOPCOND) && (type == target_type)) { | if ((target_type == LOOPCOND) && (type == target_type)) { | ||||
| target_node = tmp_node; | target_node = tmp_node; | ||||
| @@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||||
| /// @return SUCCESS | /// @return SUCCESS | ||||
| /// | /// | ||||
| Status NextIterationPass::ClearStatus() { | Status NextIterationPass::ClearStatus() { | ||||
| frame_enter_map_.clear(); | |||||
| loop_group_map_.clear(); | loop_group_map_.clear(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass { | |||||
| /// | /// | ||||
| Status GroupEnterNode(const NodePtr &enter_node); | Status GroupEnterNode(const NodePtr &enter_node); | ||||
| /// | |||||
| /// @brief Group Enter nodes without batch_label attr | |||||
| /// @param [in] compute_graph | |||||
| /// @return Status | |||||
| /// | |||||
| Status GroupWithNoBatch(const ComputeGraphPtr &graph); | |||||
| /// | /// | ||||
| /// @brief Find while groups | /// @brief Find while groups | ||||
| /// @return Status | /// @return Status | ||||
| @@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass { | |||||
| /// @param [out] target_node | /// @param [out] target_node | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | |||||
| Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | |||||
| const std::string &batch_label, NodePtr &target_node); | |||||
| // map<frame_name, LoopCondGroup> | |||||
| std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | |||||
| // map<frame_name, vector<enter_node>> | |||||
| std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_; | |||||
| // map<frame_name, map<batch_label, LoopCondGroup>> | |||||
| std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ | #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ | ||||
| @@ -1,106 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "remove_same_const_pass.h" | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <set> | |||||
| #include "common/base64.h" | |||||
| #include "ge_local_engine/engine/host_cpu_engine.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| std::string GetCseKey(const NodePtr &node) { | |||||
| std::stringstream ss; | |||||
| ss << node->GetType() << "control-inputs-"; | |||||
| std::set<std::string> control_in_node_names; | |||||
| for (auto &src_node : node->GetInControlNodes()) { | |||||
| control_in_node_names.insert(src_node->GetName()); | |||||
| } | |||||
| for (auto &name : control_in_node_names) { | |||||
| ss << name << "-"; | |||||
| } | |||||
| ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc()); | |||||
| return ss.str(); | |||||
| } | |||||
| bool IsConstType(const NodePtr &node) { return (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP); } | |||||
| } // namespace | |||||
| Status RemoveSameConstPass::Run(ComputeGraphPtr graph) { | |||||
| GELOGD("Begin to run RemoveSameConstPass on the graph"); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| std::map<std::string, NodePtr> keys_to_node; | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (!IsConstType(node)) { | |||||
| continue; | |||||
| } | |||||
| bool is_unknown = false; | |||||
| auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGW("Get node unknown status failed, node name:%s, type:%s.", | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (is_unknown) { | |||||
| GELOGI("Current node %s, type %s is unknown shape which should be skip.", | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto key = GetCseKey(node); | |||||
| GELOGD("The const node %s cse key %s", node->GetName().c_str(), ge::base64::EncodeToBase64(key).c_str()); | |||||
| auto iter = keys_to_node.find(key); | |||||
| if (iter == keys_to_node.end()) { | |||||
| keys_to_node[key] = node; | |||||
| continue; | |||||
| } | |||||
| if (node->GetAllOutDataAnchorsSize() != iter->second->GetAllOutDataAnchorsSize()) { | |||||
| GELOGW("The const node %s and %s have the same CSE key, but different output anchor count, skip to fusion them", | |||||
| iter->second->GetName().c_str(), node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| std::vector<int> output_map(node->GetAllOutDataAnchorsSize()); | |||||
| for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { | |||||
| output_map[i] = i; | |||||
| } | |||||
| ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to replace node %s by node %s", node->GetName().c_str(), | |||||
| iter->second->GetName().c_str(), ret); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| NodeUtils::UnlinkAll(*node); | |||||
| ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGI("Remove const node %s by RemoveSameConstPass, replace it with node %s", node->GetName().c_str(), | |||||
| iter->second->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,28 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ | |||||
| #include "graph/types.h" | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class RemoveSameConstPass : public GraphPass { | |||||
| public: | |||||
| Status Run(ge::ComputeGraphPtr graph) override ; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ | |||||
| @@ -17,8 +17,13 @@ | |||||
| #include "graph/passes/switch_to_stream_switch_pass.h" | #include "graph/passes/switch_to_stream_switch_pass.h" | ||||
| #include <stack> | #include <stack> | ||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -120,13 +125,12 @@ void SwitchToStreamSwitchPass::MarkCycleDependence( | |||||
| if (visited.count(tmp_node) > 0) { | if (visited.count(tmp_node) > 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGD("MarkCycleDependence: tmp_node=%s.", tmp_node->GetName().c_str()); | |||||
| for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { | for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { | ||||
| if (switch_nodes.find(out_node) == switch_nodes.end()) { | if (switch_nodes.find(out_node) == switch_nodes.end()) { | ||||
| out_nodes.push(out_node); | out_nodes.push(out_node); | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGD("MarkCycleDependence: tmp_node=%s, switch_node=%s.", | |||||
| tmp_node->GetName().c_str(), out_node->GetName().c_str()); | |||||
| GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, | GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, | ||||
| GELOGW("set cyclic dependence attr failed."); return ); | GELOGW("set cyclic dependence attr failed."); return ); | ||||
| auto map_iter = switch_cyclic_map_.find(out_node); | auto map_iter = switch_cyclic_map_.find(out_node); | ||||
| @@ -598,7 +602,7 @@ Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, cons | |||||
| /// | /// | ||||
| Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, | Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, | ||||
| const std::set<NodePtr> &same_cond_switch) { | const std::set<NodePtr> &same_cond_switch) { | ||||
| GELOGD("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), | |||||
| GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), | |||||
| cast_node->GetName().c_str()); | cast_node->GetName().c_str()); | ||||
| std::string orig_switch_name = switch_node->GetName(); | std::string orig_switch_name = switch_node->GetName(); | ||||
| OpDescPtr switch_desc = switch_node->GetOpDesc(); | OpDescPtr switch_desc = switch_node->GetOpDesc(); | ||||
| @@ -649,7 +653,7 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no | |||||
| /// | /// | ||||
| Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, | Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, | ||||
| const NodePtr &active_node) { | const NodePtr &active_node) { | ||||
| GELOGD("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), | |||||
| GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), | |||||
| stream_switch->GetName().c_str(), active_node->GetName().c_str()); | stream_switch->GetName().c_str(), active_node->GetName().c_str()); | ||||
| auto find_res = switch_node_map_.find(switch_node); | auto find_res = switch_node_map_.find(switch_node); | ||||
| GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { | GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { | ||||
| @@ -1,51 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/useless_control_out_remove_pass.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace ge { | |||||
| Status UselessControlOutRemovePass::Run(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("UselessControlOutRemovePass running, node: %s.", node->GetName().c_str()); | |||||
| // const has no control input | |||||
| if (node->GetInControlNodes().empty()) { | |||||
| if (node->GetOutDataNodes().empty()) { | |||||
| // It is an isolated const, just remove it. | |||||
| GELOGI("Delete isolated const: %s.", node->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(IsolateAndDeleteNode(node, {})) | |||||
| AddNodeDeleted(node); | |||||
| } else { | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| if (out_ctrl_anchor != nullptr && !out_ctrl_anchor->GetPeerAnchors().empty()) { | |||||
| GELOGI("Node: %s unlink all out control edge.", node->GetName().c_str()); | |||||
| out_ctrl_anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,29 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ | |||||
| #include "graph/passes/base_pass.h" | |||||
| namespace ge { | |||||
| class UselessControlOutRemovePass : public BaseNodePass { | |||||
| public: | |||||
| Status Run(NodePtr &node) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ | |||||
| @@ -44,8 +44,6 @@ | |||||
| using std::set; | using std::set; | ||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| using std::map; | |||||
| using std::queue; | |||||
| namespace ge { | namespace ge { | ||||
| namespace multibatch { | namespace multibatch { | ||||
| @@ -59,15 +57,10 @@ const int kDataInIndex = 0; | |||||
| const int kMergeDataOutIndex = 0; | const int kMergeDataOutIndex = 0; | ||||
| const int kStaticOutput = -1; | const int kStaticOutput = -1; | ||||
| const int kDivisionConst = 2; | const int kDivisionConst = 2; | ||||
| const int32_t kOneInDataNode = 1; | |||||
| const int32_t kFindNoMatch = 0; | |||||
| inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | ||||
| inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); } | |||||
| const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER}); | |||||
| inline bool IsGetNextType(const NodePtr &node) { | inline bool IsGetNextType(const NodePtr &node) { | ||||
| std::string original_type; | std::string original_type; | ||||
| GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | ||||
| @@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = InsertIdentityAfterSwitchN(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGI("Begin to remove useless nodes by prune pass after copy process"); | GELOGI("Begin to remove useless nodes by prune pass after copy process"); | ||||
| PrunePass prune_pass; | PrunePass prune_pass; | ||||
| ret = prune_pass.Run(graph_); | ret = prune_pass.Run(graph_); | ||||
| @@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = RelinkConstCtrlEdge(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Relink const's control edge failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ret = ExtractUnchangedStructureOutofCycle(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Extract unchanged structure out of cycle failed."); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &node : graph_->GetAllNodes()) { | for (auto &node : graph_->GetAllNodes()) { | ||||
| origin_all_nodes_.emplace_back(node); | origin_all_nodes_.emplace_back(node); | ||||
| if (IsDataLikeType(node->GetType())) { | if (IsDataLikeType(node->GetType())) { | ||||
| @@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { | |||||
| if (node->GetOutDataNodes().empty()) { | |||||
| continue; | |||||
| } | |||||
| if (!node->GetInControlNodes().empty()) { | |||||
| auto in_ctrl_nodes = node->GetInControlNodes(); | |||||
| auto out_nodes = node->GetOutAllNodes(); | |||||
| bool has_merge_out = false; | |||||
| for (const auto &out_node : out_nodes) { | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) { | |||||
| has_merge_out = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_merge_out) { | |||||
| continue; | |||||
| } | |||||
| auto in_ctrl_anchor = node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||||
| in_ctrl_anchor->UnlinkAll(); | |||||
| for (auto &in_ctrl_node : in_ctrl_nodes) { | |||||
| auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node); | |||||
| for (auto &out_node : out_nodes) { | |||||
| if (IsEnterType(out_node->GetType())) { | |||||
| continue; | |||||
| } | |||||
| if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) { | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor())) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| if (out_ctrl_anchor != nullptr) { | |||||
| out_ctrl_anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() { | |||||
| map<string, vector<NodePtr>> frame_enter; | |||||
| if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get enter nodes grouped by frame_name failed."); | |||||
| return FAILED; | |||||
| } | |||||
| queue<NodePtr> nodes_to_extract; | |||||
| if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get nodes needed to extract failed."); | |||||
| return FAILED; | |||||
| } | |||||
| while (!nodes_to_extract.empty()) { | |||||
| auto node = nodes_to_extract.front(); | |||||
| nodes_to_extract.pop(); | |||||
| OpDescPtr enter_desc = nullptr; | |||||
| if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| set<NodePtr> out_nodes; | |||||
| if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) { | |||||
| GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) { | |||||
| GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &out_node : out_nodes) { | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) { | |||||
| nodes_to_extract.push(out_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (DeleteEnterWithoutDataOut() != SUCCESS) { | |||||
| GELOGE(FAILED, "Delete enter node without out data nodes failed."); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (IsEnterType(node->GetType())) { | |||||
| if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| string frame_name; | |||||
| if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||||
| GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| frame_enter[frame_name].emplace_back(node); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter, | |||||
| queue<NodePtr> &nodes_to_extract) { | |||||
| for (const auto &one_group : frame_enter) { | |||||
| auto enters = one_group.second; | |||||
| for (const auto &enter : enters) { | |||||
| auto out_data_nodes = enter->GetOutDataNodes(); | |||||
| for (const auto &out_data_node : out_data_nodes) { | |||||
| GE_CHECK_NOTNULL(out_data_node); | |||||
| if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) { | |||||
| nodes_to_extract.push(out_data_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) { | |||||
| auto out_data_nodes = node->GetOutDataNodes(); | |||||
| for (const auto &out_data_node : out_data_nodes) { | |||||
| if (out_data_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto in_data_nodes = node->GetInDataNodes(); | |||||
| if (in_data_nodes.size() == kOneInDataNode) { | |||||
| return true; | |||||
| } | |||||
| for (const auto &in_data_node : in_data_nodes) { | |||||
| if (in_data_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) { | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto &in_data_anchor : in_data_anchors) { | |||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_anchor); | |||||
| auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (IsEnterType(peer_in_data_node->GetType())) { | |||||
| GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor)) | |||||
| GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str()); | |||||
| auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors(); | |||||
| for (auto &enter_in_data_anchor : enter_in_data_anchors) { | |||||
| auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter); | |||||
| if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) { | |||||
| continue; | |||||
| } | |||||
| GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor)) | |||||
| GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(), | |||||
| node->GetName().c_str()); | |||||
| } | |||||
| enter_desc = peer_in_data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(enter_desc); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr ©_desc, set<NodePtr> &out_nodes) { | |||||
| if (copy_desc == nullptr) { | |||||
| return SUCCESS; | |||||
| } | |||||
| map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes; | |||||
| auto out_data_anchors = node->GetAllOutDataAnchors(); | |||||
| for (auto &out_data_anchor : out_data_anchors) { | |||||
| auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors(); | |||||
| for (auto peer_in_data_anchor : peer_in_data_anchors) { | |||||
| GE_CHECK_NOTNULL(peer_in_data_anchor); | |||||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||||
| out_nodes.emplace(peer_in_data_node); | |||||
| outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node)); | |||||
| } | |||||
| } | |||||
| int32_t i = 0; | |||||
| auto node_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(node_desc); | |||||
| // Insert one enter node after node's per out data anchor | |||||
| for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) { | |||||
| string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++); | |||||
| GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str()); | |||||
| auto enter_desc = AttrUtils::CopyOpDesc(copy_desc); | |||||
| enter_desc->SetName(name); | |||||
| GE_CHK_STATUS_RET( | |||||
| enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||||
| GE_CHK_STATUS_RET( | |||||
| enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||||
| auto enter_node = graph_->AddNode(enter_desc); | |||||
| GE_CHECK_NOTNULL(enter_node); | |||||
| GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex))) | |||||
| GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex)); | |||||
| for (auto &inanchor_node : outanchor_inanchors_nodes.second) { | |||||
| GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first)) | |||||
| GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first)) | |||||
| GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(), | |||||
| inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(), | |||||
| inanchor_node.second->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Move node's in control edges to out data nodes | |||||
| Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) { | |||||
| auto in_ctrl_anchor = node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||||
| auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors(); | |||||
| for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) { | |||||
| GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor)) | |||||
| GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| node->GetName().c_str()); | |||||
| for (auto &out_node : out_nodes) { | |||||
| auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node); | |||||
| if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) { | |||||
| GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node)) | |||||
| GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| out_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (IsEnterType(node->GetType())) { | |||||
| auto out_nodes = node->GetOutAllNodes(); | |||||
| if (out_nodes.empty()) { | |||||
| GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {})) | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node)) | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | ||||
| auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | ||||
| GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | ||||
| @@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| LabelStatusForGetNextSink(data); | LabelStatusForGetNextSink(data); | ||||
| } | } | ||||
| } | } | ||||
| map<string, vector<NodePtr>> frame_enters; | |||||
| InitStatus(frame_enters); | |||||
| bool changed = true; | bool changed = true; | ||||
| // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch | // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch | ||||
| while (changed) { | while (changed) { | ||||
| @@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| if (iter != origin_nodes_status_.end()) { | if (iter != origin_nodes_status_.end()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| for (auto &in_node : node->GetInDataNodes()) { | |||||
| if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { | |||||
| if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | |||||
| origin_nodes_status_[node.get()] == kNodeInBatchBranch; | |||||
| ResetEnterStatus(frame_enters, node); | |||||
| changed = true; | |||||
| } | |||||
| for (auto &in_node : node->GetInAllNodes()) { | |||||
| bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && | |||||
| origin_nodes_status_[in_node.get()] == kNodeInBatchBranch; | |||||
| if (is_in_batch) { | |||||
| origin_nodes_status_[node.get()] = kNodeInBatchBranch; | |||||
| changed = true; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) { | |||||
| for (const auto &node : origin_all_nodes_) { | |||||
| if (!IsEnterType(node->GetType())) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| string frame_name; | |||||
| if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||||
| frame_enters[frame_name].emplace_back(node); | |||||
| } | |||||
| } | |||||
| for (const auto &data : origin_data_nodes_) { | |||||
| auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||||
| if (!IsAllDimsPositive(data_shape.GetDims())) { | |||||
| origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) { | |||||
| if (!IsEnterType(node->GetType())) { | |||||
| return; | |||||
| } | |||||
| for (const auto &frame_enter : frame_enters) { | |||||
| auto &enters = frame_enter.second; | |||||
| if (std::find(enters.begin(), enters.end(), node) != enters.end()) { | |||||
| for (const auto &enter : enters) { | |||||
| origin_nodes_status_[enter.get()] = kNodeInBatchBranch; | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| Status MultiBatchGraphCopyer::LabelStatus() { | Status MultiBatchGraphCopyer::LabelStatus() { | ||||
| if (LabelInBatchBranchStatus() != SUCCESS) { | if (LabelInBatchBranchStatus() != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); | GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); | ||||
| @@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { | |||||
| for (auto &node : graph_->GetAllNodes()) { | |||||
| if (node->GetType() != SWITCHN) { | |||||
| continue; | |||||
| } | |||||
| auto switchn_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(switchn_desc); | |||||
| size_t i = 0; | |||||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto out_node = in_data_anchor->GetOwnerNode(); | |||||
| auto op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||||
| GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY); | |||||
| GE_CHECK_NOTNULL(identity_desc); | |||||
| string batch_label; | |||||
| if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto data_desc = switchn_desc->GetOutputDesc(i); | |||||
| i++; | |||||
| GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc)); | |||||
| GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc)); | |||||
| auto identity_node = graph_->AddNode(identity_desc); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0))); | |||||
| GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor()); | |||||
| GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor())); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProcessMultiBatch(ComputeGraphPtr &graph) { | Status ProcessMultiBatch(ComputeGraphPtr &graph) { | ||||
| if (GetLocalOmgContext().dynamic_node_type.empty()) { | if (GetLocalOmgContext().dynamic_node_type.empty()) { | ||||
| const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); | const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); | ||||
| @@ -1700,6 +1415,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { | |||||
| return pass_manager.Run(graph); | return pass_manager.Run(graph); | ||||
| } | } | ||||
| } | } | ||||
| if (!GetLocalOmgContext().need_multi_batch) { | if (!GetLocalOmgContext().need_multi_batch) { | ||||
| GELOGI("No need to process_multi for no_train graph."); | GELOGI("No need to process_multi for no_train graph."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <queue> | #include <queue> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
| @@ -65,26 +64,12 @@ class MultiBatchGraphCopyer { | |||||
| private: | private: | ||||
| Status Init(); | Status Init(); | ||||
| Status CheckArguments(); | Status CheckArguments(); | ||||
| Status RelinkConstCtrlEdge(); | |||||
| Status ExtractUnchangedStructureOutofCycle(); | |||||
| Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter); | |||||
| Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter, | |||||
| std::queue<NodePtr> &nodes_to_extract); | |||||
| bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node); | |||||
| Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc); | |||||
| Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes); | |||||
| Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes); | |||||
| Status DeleteEnterWithoutDataOut(); | |||||
| // label status for origin_all_nodes_ | // label status for origin_all_nodes_ | ||||
| Status LabelStatus(); | Status LabelStatus(); | ||||
| Status LabelInBatchBranchStatus(); | Status LabelInBatchBranchStatus(); | ||||
| void LabelStatusForData(const NodePtr &data); | void LabelStatusForData(const NodePtr &data); | ||||
| void LabelStatusForGetNextSink(const NodePtr &data); | void LabelStatusForGetNextSink(const NodePtr &data); | ||||
| void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters); | |||||
| void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node); | |||||
| // add nodes functions | // add nodes functions | ||||
| Status CreateNewNodes(); | Status CreateNewNodes(); | ||||
| @@ -96,6 +81,7 @@ class MultiBatchGraphCopyer { | |||||
| Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | ||||
| std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | ||||
| Status InsertIdentityAfterSwitchN(); | |||||
| Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | ||||
| Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "common/blocking_queue.h" | #include "common/blocking_queue.h" | ||||
| #include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/ge_local_context.h" | |||||
| #include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
| #include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
| #include "hybrid/executor/hybrid_profiler.h" | #include "hybrid/executor/hybrid_profiler.h" | ||||
| @@ -39,7 +38,6 @@ struct GraphExecutionContext { | |||||
| uint64_t session_id = 0; | uint64_t session_id = 0; | ||||
| const HybridModel *model = nullptr; | const HybridModel *model = nullptr; | ||||
| const GEThreadLocalContext *ge_context = nullptr; | |||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| rtContext_t rt_context = nullptr; | rtContext_t rt_context = nullptr; | ||||
| rtContext_t rt_gen_context = nullptr; | rtContext_t rt_gen_context = nullptr; | ||||
| @@ -95,7 +95,6 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| context_.stream = stream_; | context_.stream = stream_; | ||||
| context_.model = model_; | context_.model = model_; | ||||
| context_.session_id = ::ge::GetContext().SessionId(); | context_.session_id = ::ge::GetContext().SessionId(); | ||||
| context_.ge_context = &GetThreadLocalContext(); | |||||
| GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | ||||
| context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | ||||
| GE_CHECK_NOTNULL(context_.allocator); | GE_CHECK_NOTNULL(context_.allocator); | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <chrono> | #include <chrono> | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "hybrid_execution_context.h" | #include "hybrid_execution_context.h" | ||||
| #include "subgraph_context.h" | #include "subgraph_context.h" | ||||
| @@ -36,31 +35,29 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
| this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
| } | } | ||||
| Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | |||||
| Status ShapeInferenceState::UpdateInputShape(int idx, | |||||
| const GeShape &ori_shape, | |||||
| const GeShape &shape) { | |||||
| if (node_item.IsInputShapeStatic(idx)) { | if (node_item.IsInputShapeStatic(idx)) { | ||||
| GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(), | node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(), | ||||
| target.GetShape().ToString().c_str()); | |||||
| shape.ToString().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(target, tensor_size); | |||||
| GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", | |||||
| GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s]", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| target.GetShape().ToString().c_str(), | |||||
| target.GetOriginShape().ToString().c_str(), | |||||
| tensor_size); | |||||
| shape.ToString().c_str(), | |||||
| ori_shape.ToString().c_str()); | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| auto tensor_desc = node_item.MutableInputDesc(idx); | auto tensor_desc = node_item.MutableInputDesc(idx); | ||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| tensor_desc->SetShape(target.GetShape()); | |||||
| tensor_desc->SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| tensor_desc->SetShape(shape); | |||||
| tensor_desc->SetOriginShape(ori_shape); | |||||
| if (--num_pending_shapes_ == 0) { | if (--num_pending_shapes_ == 0) { | ||||
| ready_cv_.notify_all(); | ready_cv_.notify_all(); | ||||
| } | } | ||||
| @@ -113,24 +110,24 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| for (auto &p : shape_futures) { | for (auto &p : shape_futures) { | ||||
| auto idx = p.first; | auto idx = p.first; | ||||
| auto &future = p.second; | auto &future = p.second; | ||||
| GeShape shape; | |||||
| GeShape ori_shape; | |||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | ||||
| auto src_tensor_desc = future.GetTensorDesc(); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| GE_CHK_STATUS_RET(future.Get(ori_shape, shape), | |||||
| "[%s] Get shape failed. index = %u", | |||||
| node_item.NodeName().c_str(), | |||||
| idx); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | ||||
| auto input_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| src_tensor_desc->GetShape().ToString().c_str(), | |||||
| src_tensor_desc->GetOriginShape().ToString().c_str(), | |||||
| tensor_size); | |||||
| input_desc->SetShape(src_tensor_desc->GetShape()); | |||||
| input_desc->SetOriginShape(src_tensor_desc->GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*input_desc, tensor_size); | |||||
| shape.ToString().c_str(), | |||||
| ori_shape.ToString().c_str()); | |||||
| auto input_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| input_desc->SetShape(std::move(shape)); | |||||
| input_desc->SetOriginShape(ori_shape); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -193,14 +190,5 @@ Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | |||||
| GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GeTensorDescPtr ShapeFuture::GetTensorDesc() { | |||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | |||||
| if (!subgraph_context_->Await(src_node_)) { | |||||
| GELOGE(INTERNAL_ERROR, "cancelled"); | |||||
| return nullptr; | |||||
| } | |||||
| return src_node_->GetOpDesc()->MutableOutputDesc(src_index_); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -35,7 +35,6 @@ class ShapeFuture { | |||||
| ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ||||
| ~ShapeFuture() = default; | ~ShapeFuture() = default; | ||||
| Status Get(GeShape &ori_shape, GeShape &shape); | Status Get(GeShape &ori_shape, GeShape &shape); | ||||
| GeTensorDescPtr GetTensorDesc(); | |||||
| private: | private: | ||||
| NodePtr src_node_; | NodePtr src_node_; | ||||
| @@ -46,7 +45,7 @@ class ShapeFuture { | |||||
| struct ShapeInferenceState { | struct ShapeInferenceState { | ||||
| explicit ShapeInferenceState(const NodeItem &node_item); | explicit ShapeInferenceState(const NodeItem &node_item); | ||||
| Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc); | |||||
| Status UpdateInputShape(int idx, const GeShape &ori_shape, const GeShape &shape); | |||||
| void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | ||||
| @@ -96,7 +96,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | ||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | |||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -268,6 +268,13 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); | |||||
| GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), | |||||
| "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); | |||||
| GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -20,9 +20,12 @@ | |||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "hybrid/executor//worker//shape_inference_engine.h" | |||||
| #include "common/dump/dump_manager.h" | |||||
| #include "common/dump/dump_op.h" | #include "common/dump/dump_op.h" | ||||
| #include "common/types.h" | |||||
| #include "common/ge_types.h" | |||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| #include "runtime/base.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -151,19 +154,18 @@ Status NodeDoneCallback::GetTaskDescInfo(const NodePtr node, const HybridModel * | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(model); | GE_CHECK_NOTNULL(model); | ||||
| // only report aicpu and aicore node | |||||
| bool is_profiling_report = context_->GetNodeItem().is_profiling_report; | |||||
| if (!is_profiling_report) { | |||||
| GELOGD("Node[%s] is not aicore or aicpu, and no need to report data.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("GetTaskDescInfo of node [%s] start.", node->GetName().c_str()); | GELOGD("GetTaskDescInfo of node [%s] start.", node->GetName().c_str()); | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| std::string op_name = op_desc->GetName(); | std::string op_name = op_desc->GetName(); | ||||
| std::string dynamic_model_name = model->GetModelName(); | std::string dynamic_model_name = model->GetModelName(); | ||||
| uint32_t task_id = context_->GetTaskId(); | |||||
| uint32_t stream_id = context_->GetStreamId(); | |||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| if (rtGetTaskIdAndStreamID(&task_id, &stream_id) != RT_ERROR_NONE) { | |||||
| GELOGE(PARAM_INVALID, "Get task_id and stream_id failed."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| TaskDescInfo tmp_task_desc_info; | TaskDescInfo tmp_task_desc_info; | ||||
| tmp_task_desc_info.model_name = dynamic_model_name; | tmp_task_desc_info.model_name = dynamic_model_name; | ||||
| tmp_task_desc_info.op_name = op_name; | tmp_task_desc_info.op_name = op_name; | ||||
| @@ -175,8 +177,6 @@ Status NodeDoneCallback::GetTaskDescInfo(const NodePtr node, const HybridModel * | |||||
| } | } | ||||
| tmp_task_desc_info.task_id = task_id; | tmp_task_desc_info.task_id = task_id; | ||||
| tmp_task_desc_info.stream_id = stream_id; | tmp_task_desc_info.stream_id = stream_id; | ||||
| tmp_task_desc_info.shape_type = "dynamic"; | |||||
| tmp_task_desc_info.cur_iter_num = graph_context_->iteration; | |||||
| GELOGD("GetTaskDescInfo of node [%s] end, task_id[%u], stream_id[%u]", | GELOGD("GetTaskDescInfo of node [%s] end, task_id[%u], stream_id[%u]", | ||||
| node->GetName().c_str(), task_id, stream_id); | node->GetName().c_str(), task_id, stream_id); | ||||
| task_desc_info.emplace_back(tmp_task_desc_info); | task_desc_info.emplace_back(tmp_task_desc_info); | ||||
| @@ -348,10 +348,6 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | ||||
| if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | |||||
| // update output tensor sizes | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | |||||
| } | |||||
| // PropagateOutputs for type == DEPEND_COMPUTE | // PropagateOutputs for type == DEPEND_COMPUTE | ||||
| if (node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| if (graph_context_->trace_enabled) { | if (graph_context_->trace_enabled) { | ||||
| @@ -17,15 +17,9 @@ | |||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const int kAlignment = 32; | |||||
| } | |||||
| namespace hybrid { | namespace hybrid { | ||||
| ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) | ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) | ||||
| : execution_context_(execution_context), | : execution_context_(execution_context), | ||||
| @@ -46,9 +40,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| } | } | ||||
| if (node_item.fused_subgraph != nullptr) { | if (node_item.fused_subgraph != nullptr) { | ||||
| GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph)); | |||||
| GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item)); | |||||
| return SUCCESS; | |||||
| return InferShapeForSubgraph(node_item, *node_item.fused_subgraph); | |||||
| } | } | ||||
| // Skip shape inference for node of type DEPEND_COMPUTE | // Skip shape inference for node of type DEPEND_COMPUTE | ||||
| @@ -71,15 +63,21 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
| "Invoke InferShapeAndType failed."); | |||||
| "Invoke InferShapeAndType failed."); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | ||||
| } | } | ||||
| // Check again to make sure shape is valid after shape inference | |||||
| if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) { | |||||
| bool is_unknown_shape = false; | |||||
| GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape), | |||||
| "Failed to get shape status. node = %s", | |||||
| node_item.NodeName().c_str()); | |||||
| // update output tensor sizes after shape inference | |||||
| // error if shape is still unknown and not of type DEPEND_SHAPE_RANGE | |||||
| RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); | |||||
| GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item, node_item.shape_inference_type == DEPEND_SHAPE_RANGE)); | |||||
| RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); | |||||
| GE_CHK_BOOL_RET_STATUS(!is_unknown_shape, | |||||
| INTERNAL_ERROR, | |||||
| "[%s] Shape is still unknown after shape inference.", | |||||
| node_item.NodeName().c_str()); | |||||
| } | |||||
| GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", | GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -129,6 +127,8 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| // propagate each output | // propagate each output | ||||
| for (int i = 0; i < node_item.num_outputs; ++i) { | for (int i = 0; i < node_item.num_outputs; ++i) { | ||||
| auto output_desc = node_item.op_desc->MutableOutputDesc(i); | auto output_desc = node_item.op_desc->MutableOutputDesc(i); | ||||
| const auto &shape = output_desc->MutableShape(); | |||||
| const auto &ori_shape = output_desc->GetOriginShape(); | |||||
| auto &output_nodes = node_item.outputs[i]; | auto &output_nodes = node_item.outputs[i]; | ||||
| // propagate output to all sub-inputs | // propagate output to all sub-inputs | ||||
| @@ -149,7 +149,9 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | ||||
| std::move(future)); | std::move(future)); | ||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | |||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, | |||||
| ori_shape, | |||||
| shape)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -228,92 +230,5 @@ Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeInferenceEngine::CanonicalizeShape(GeTensorDesc &tensor_desc, | |||||
| std::vector<int64_t> &shape, | |||||
| bool fallback_with_range) { | |||||
| const auto &tensor_shape = tensor_desc.MutableShape(); | |||||
| if (tensor_shape.IsUnknownShape()) { | |||||
| if (!fallback_with_range) { | |||||
| GELOGE(INTERNAL_ERROR, "Output shape is still unknown after shape inference. shape = [%s]", | |||||
| tensor_shape.ToString().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Calc output size by range"); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| GE_CHK_GRAPH_STATUS_RET(tensor_desc.GetShapeRange(shape_range), "Failed to get shape range"); | |||||
| if (shape_range.size() != shape.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "Number of shape ranges (%zu) mismatches that of dims (%zu)", | |||||
| shape_range.size(), | |||||
| shape.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (size_t dim_index = 0; dim_index < shape.size(); ++dim_index) { | |||||
| if (shape[dim_index] == ge::UNKNOWN_DIM) { | |||||
| shape[dim_index] = shape_range[dim_index].second; | |||||
| } | |||||
| } | |||||
| GELOGD("After canonicalization, shape = [%s], before = [%s]", | |||||
| GeShape(shape).ToString().c_str(), | |||||
| tensor_shape.ToString().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcTensorSize(DataType data_type, | |||||
| const std::vector<int64_t> &shape, | |||||
| int64_t &tensor_size) { | |||||
| GELOGD("To calc tensor size by shape = [%s]", GeShape(shape).ToString().c_str()); | |||||
| uint32_t type_size; | |||||
| if (!TypeUtils::GetDataTypeLength(data_type, type_size)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get data type size"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| tensor_size = type_size; | |||||
| for (const auto &dim : shape) { | |||||
| GE_CHECK_GE(dim, 0); | |||||
| GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim), | |||||
| "Shape size overflow, shape = [%s]", | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size *= dim; | |||||
| } | |||||
| GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1), | |||||
| "Tensor size is too large: %ld, shape = [%s]", | |||||
| tensor_size, | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) { | |||||
| auto op_desc = node_item.GetOpDesc(); | |||||
| for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) { | |||||
| auto tensor_desc = op_desc->MutableOutputDesc(output_index); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| const auto &shape = tensor_desc->MutableShape(); | |||||
| // modify on copy | |||||
| auto dims = shape.GetDims(); | |||||
| GE_CHK_STATUS_RET(CanonicalizeShape(*tensor_desc, dims, fallback_with_range), | |||||
| "[%s] Failed to canonicalize shape for output %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| int64_t tensor_size; | |||||
| GE_CHK_STATUS_RET(CalcTensorSize(tensor_desc->GetDataType(), dims, tensor_size), | |||||
| "[%s] Failed to calc tensor size for output %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| GELOGD("[%s] Tensor size of output %zu = %ld", node_item.NodeName().c_str(), output_index, tensor_size); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -34,11 +34,7 @@ class ShapeInferenceEngine { | |||||
| Status PropagateOutputShapes(const NodeItem &node_item); | Status PropagateOutputShapes(const NodeItem &node_item); | ||||
| static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | |||||
| private: | private: | ||||
| static Status CanonicalizeShape(GeTensorDesc &tensor_desc, std::vector<int64_t> &shape, bool fallback_with_range); | |||||
| static Status CalcTensorSize(DataType data_type, const std::vector<int64_t> &shape, int64_t &tensor_size); | |||||
| static Status UpdatePeerNodeShape(const Node &node); | static Status UpdatePeerNodeShape(const Node &node); | ||||
| Status AwaitDependentNodes(NodeState &node_state); | Status AwaitDependentNodes(NodeState &node_state); | ||||
| @@ -26,9 +26,6 @@ Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext * | |||||
| RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | ||||
| if (context->ge_context != nullptr) { | |||||
| GetThreadLocalContext() = *context->ge_context; | |||||
| } | |||||
| shared_ptr<NodeTask> kernel_task; | shared_ptr<NodeTask> kernel_task; | ||||
| auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); | auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); | ||||
| RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); | RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); | ||||
| @@ -226,10 +226,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
| new_node->node_id = node_index; | new_node->node_id = node_index; | ||||
| new_node->op_desc->SetId(node_index); | new_node->op_desc->SetId(node_index); | ||||
| node_index += 1; | node_index += 1; | ||||
| NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | |||||
| new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || | |||||
| (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || | |||||
| (executor_type == NodeExecutorManager::ExecutorType::AICPU_CUSTOM); | |||||
| *node_item = new_node.get(); | *node_item = new_node.get(); | ||||
| node_items[node] = std::move(new_node); | node_items[node] = std::move(new_node); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -48,7 +47,7 @@ Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgr | |||||
| GE_CHECK_NOTNULL(dst_op_desc); | GE_CHECK_NOTNULL(dst_op_desc); | ||||
| auto in_idx = node_and_anchor.second->GetIdx(); | auto in_idx = node_and_anchor.second->GetIdx(); | ||||
| auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); | auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); | ||||
| fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc); | |||||
| fused_subgraph.input_mapping[parent_index].emplace_back(tensor_desc); | |||||
| GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); | GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); | ||||
| } | } | ||||
| @@ -65,7 +64,7 @@ Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgrap | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc); | |||||
| fused_subgraph.output_mapping.emplace(parent_index, op_desc); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -127,7 +126,12 @@ Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_ite | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void NodeItem::ResolveOptionalInputs() { | |||||
| Status NodeItem::Init() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | ||||
| has_optional_inputs = true; | has_optional_inputs = true; | ||||
| for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | ||||
| @@ -139,18 +143,7 @@ void NodeItem::ResolveOptionalInputs() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| Status NodeItem::InitInputsAndOutputs() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| ResolveOptionalInputs(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::ResolveDynamicState() { | |||||
| (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | ||||
| GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | ||||
| if (!is_dynamic) { | if (!is_dynamic) { | ||||
| @@ -158,54 +151,38 @@ Status NodeItem::ResolveDynamicState() { | |||||
| "[%s] Failed to get shape status.", | "[%s] Failed to get shape status.", | ||||
| node->GetName().c_str()); | node->GetName().c_str()); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::ResolveStaticInputsAndOutputs() { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| if (is_dynamic) { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| if (is_output_shape_static) { | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void NodeItem::ResolveUnknownShapeType() { | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| } | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| Status NodeItem::Init() { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | |||||
| if (is_dynamic) { | |||||
| ResolveUnknownShapeType(); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | ||||
| } | } | ||||
| @@ -99,16 +99,10 @@ struct NodeItem { | |||||
| std::map<int, int> reuse_inputs; | std::map<int, int> reuse_inputs; | ||||
| std::map<int, int> reuse_outputs; | std::map<int, int> reuse_outputs; | ||||
| int num_static_input_shapes = 0; | int num_static_input_shapes = 0; | ||||
| bool is_profiling_report = false; | |||||
| private: | private: | ||||
| explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
| Status Init(); | Status Init(); | ||||
| Status InitInputsAndOutputs(); | |||||
| void ResolveOptionalInputs(); | |||||
| Status ResolveDynamicState(); | |||||
| Status ResolveStaticInputsAndOutputs(); | |||||
| void ResolveUnknownShapeType(); | |||||
| std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
| std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||
| @@ -165,16 +165,6 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> | |||||
| } | } | ||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); | ||||
| GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); | GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); | ||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| rtError_t rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "Get task_id and stream_id failed."); | |||||
| return rt_ret; | |||||
| } | |||||
| context.SetTaskId(task_id); | |||||
| context.SetStreamId(stream_id); | |||||
| GELOGD("AiCore node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | |||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | ||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | ||||
| } | } | ||||
| @@ -189,17 +189,6 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function<void( | |||||
| GE_CHK_STATUS_RET(LaunchTask(context)); | GE_CHK_STATUS_RET(LaunchTask(context)); | ||||
| uint32_t task_id = 0; | |||||
| uint32_t stream_id = 0; | |||||
| rtError_t rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "Get task_id and stream_id failed."); | |||||
| return rt_ret; | |||||
| } | |||||
| context.SetTaskId(task_id); | |||||
| context.SetStreamId(stream_id); | |||||
| GELOGD("AiCpu node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | |||||
| auto callback = [=, &context]() { | auto callback = [=, &context]() { | ||||
| GELOGD("Node[%s] callback start.", node_name_.c_str()); | GELOGD("Node[%s] callback start.", node_name_.c_str()); | ||||
| RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start"); | RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start"); | ||||
| @@ -148,10 +148,6 @@ Status TaskContext::AllocateWorkspaces() { | |||||
| } | } | ||||
| Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const { | Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const { | ||||
| if (callback_fun == nullptr) { | |||||
| GELOGW("[%s] Callback is NULL", GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | ||||
| @@ -319,22 +315,6 @@ void TaskContext::SetStatus(Status status) { | |||||
| } | } | ||||
| } | } | ||||
| uint32_t TaskContext::GetTaskId() const { | |||||
| return task_id_; | |||||
| } | |||||
| void TaskContext::SetTaskId(uint32_t task_id) { | |||||
| task_id_ = task_id; | |||||
| } | |||||
| uint32_t TaskContext::GetStreamId() const { | |||||
| return stream_id_; | |||||
| } | |||||
| void TaskContext::SetStreamId(uint32_t stream_id) { | |||||
| stream_id_ = stream_id; | |||||
| } | |||||
| Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { | Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { | ||||
| GE_CHECK_NOTNULL(buffer); | GE_CHECK_NOTNULL(buffer); | ||||
| if (ori_addr == nullptr) { | if (ori_addr == nullptr) { | ||||
| @@ -404,20 +384,6 @@ const char *TaskContext::GetNodeName() const { | |||||
| return node_item_->NodeName().c_str(); | return node_item_->NodeName().c_str(); | ||||
| } | } | ||||
| void TaskContext::ReleaseInputsAndOutputs() { | |||||
| for (int i = 0; i < node_item_->num_inputs; ++i) { | |||||
| auto tensor = inputs_start_ + i; | |||||
| tensor->Destroy(); | |||||
| GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), i); | |||||
| } | |||||
| for (int i = 0; i < node_item_->num_outputs; ++i) { | |||||
| auto tensor = outputs_start_ + i; | |||||
| tensor->Destroy(); | |||||
| GELOGD("[%s] Tensor of output[%d] released", GetNodeName(), i); | |||||
| } | |||||
| } | |||||
| void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
| auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
| if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
| @@ -490,9 +456,5 @@ Status TaskContext::TryExecuteCallback(const function<void()> &callback_fun) con | |||||
| const DumpProperties &TaskContext::GetDumpProperties() const { | const DumpProperties &TaskContext::GetDumpProperties() const { | ||||
| return execution_context_->dump_properties; | return execution_context_->dump_properties; | ||||
| } | } | ||||
| bool TaskContext::NeedCallback() { | |||||
| return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -50,8 +50,6 @@ class TaskContext { | |||||
| ConstGeTensorDescPtr GetOutputDesc(int index) const; | ConstGeTensorDescPtr GetOutputDesc(int index) const; | ||||
| GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
| GeTensorDescPtr MutableOutputDesc(int index) const; | GeTensorDescPtr MutableOutputDesc(int index) const; | ||||
| void ReleaseInputsAndOutputs(); | |||||
| bool NeedCallback(); | |||||
| void ReleaseInput(int index); | void ReleaseInput(int index); | ||||
| const TensorValue *GetInput(int index) const; | const TensorValue *GetInput(int index) const; | ||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
| @@ -96,12 +94,6 @@ class TaskContext { | |||||
| void SetStatus(Status status); | void SetStatus(Status status); | ||||
| uint32_t GetTaskId() const; | |||||
| void SetTaskId(uint32_t task_id); | |||||
| uint32_t GetStreamId() const; | |||||
| void SetStreamId(uint32_t stream_id); | |||||
| bool IsForceInferShape() const; | bool IsForceInferShape() const; | ||||
| void SetForceInferShape(bool force_infer_shape); | void SetForceInferShape(bool force_infer_shape); | ||||
| void *handle_ = nullptr; | void *handle_ = nullptr; | ||||
| @@ -123,8 +115,6 @@ class TaskContext { | |||||
| Status status_ = SUCCESS; | Status status_ = SUCCESS; | ||||
| std::vector<void *> workspaces_; | std::vector<void *> workspaces_; | ||||
| uint64_t iteration_ = 0; | uint64_t iteration_ = 0; | ||||
| uint32_t task_id_= 0; | |||||
| uint32_t stream_id_ = 0; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -63,19 +63,18 @@ vector<string> SplitInputShape(const std::string &input_shape) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| Status CheckInputFormat(const string &input_format) { | |||||
| Status CheckInputFormat(const std::string &input_format) { | |||||
| if (input_format.empty()) { | if (input_format.empty()) { | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) { | if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| "E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format is invalid!"}); | |||||
| GELOGE(ge::PARAM_INVALID, "input format [%s] is invalid!", input_format.c_str()); | |||||
| "E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format not found"}); | |||||
| GELOGE(ge::PARAM_INVALID, "user input format [%s] is not found!", input_format.c_str()); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| bool CheckDynamicBatchSizeInputShapeValid(unordered_map<string, vector<int64_t>> shape_map, | bool CheckDynamicBatchSizeInputShapeValid(unordered_map<string, vector<int64_t>> shape_map, | ||||
| std::string &dynamic_batch_size) { | std::string &dynamic_batch_size) { | ||||
| int32_t size = 0; | int32_t size = 0; | ||||
| @@ -75,7 +75,7 @@ Status CheckInsertOpConfParamValid(const std::string insert_op_conf); | |||||
| Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory); | Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory); | ||||
| Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); | Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); | ||||
| Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); | Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); | ||||
| Status CheckInputFormat(const string &input_format); | |||||
| Status CheckInputFormat(const std::string &input_format); | |||||
| void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | ||||
| void EraseEndSemicolon(std::string ¶m); | void EraseEndSemicolon(std::string ¶m); | ||||
| } | } | ||||
| @@ -305,7 +305,7 @@ class GFlagUtils { | |||||
| " --debug_dir Set the save path of operator compilation intermediate files.\n" | " --debug_dir Set the save path of operator compilation intermediate files.\n" | ||||
| "Default value: ./kernel_meta\n" | "Default value: ./kernel_meta\n" | ||||
| " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" | " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" | ||||
| "Default value: $HOME/atc_data\n" | |||||
| "Default value: $HOME/atc_data/kernel_cache\n" | |||||
| " --op_compiler_cache_mode Set the operator compilation cache mode." | " --op_compiler_cache_mode Set the operator compilation cache mode." | ||||
| "Options are disable(default), enable and force(force to refresh the cache)"); | "Options are disable(default), enable and force(force to refresh the cache)"); | ||||
| @@ -15,7 +15,6 @@ message Output { | |||||
| int32 original_output_data_type = 7; | int32 original_output_data_type = 7; | ||||
| int32 original_output_format = 8; | int32 original_output_format = 8; | ||||
| uint64 size = 9; | uint64 size = 9; | ||||
| Shape origin_shape = 10; | |||||
| } | } | ||||
| message Input { | message Input { | ||||
| @@ -24,7 +23,6 @@ message Input { | |||||
| Shape shape = 3; | Shape shape = 3; | ||||
| uint64 address = 4; | uint64 address = 4; | ||||
| uint64 size = 5; | uint64 size = 5; | ||||
| Shape origin_shape = 6; | |||||
| } | } | ||||
| enum BufferType { | enum BufferType { | ||||
| @@ -32,16 +32,14 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kDataMemAlignSize = 32; | const size_t kDataMemAlignSize = 32; | ||||
| const size_t kDataMemAlignUnit = 2; | const size_t kDataMemAlignUnit = 2; | ||||
| const string kShapeTypeDynamic = "dynamic"; | |||||
| const string kShapeTypeStatic = "static"; | |||||
| size_t GetAlignedSize(size_t size) { | size_t GetAlignedSize(size_t size) { | ||||
| size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize; | size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize; | ||||
| return aligned_size; | return aligned_size; | ||||
| } | } | ||||
| Status ProfilingTaskInfo(OpTask *op_task, const string &shape_type) { | |||||
| if (!ProfilingManager::Instance().ProfilingModelLoadOn()) { | |||||
| Status ProfilingTaskInfo(OpTask *op_task) { | |||||
| if (!ProfilingManager::Instance().ProfilingModelExecuteOn()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -68,8 +66,6 @@ Status ProfilingTaskInfo(OpTask *op_task, const string &shape_type) { | |||||
| tmp_task_desc_info.block_dim = block_dim; | tmp_task_desc_info.block_dim = block_dim; | ||||
| tmp_task_desc_info.task_id = task_id; | tmp_task_desc_info.task_id = task_id; | ||||
| tmp_task_desc_info.stream_id = stream_id; | tmp_task_desc_info.stream_id = stream_id; | ||||
| tmp_task_desc_info.shape_type = shape_type; | |||||
| tmp_task_desc_info.cur_iter_num = 0; | |||||
| GELOGD("GetTaskDescInfo of op [%s] end, task_id[%u], stream_id[%u]", op_name.c_str(), task_id, stream_id); | GELOGD("GetTaskDescInfo of op [%s] end, task_id[%u], stream_id[%u]", op_name.c_str(), task_id, stream_id); | ||||
| task_desc_info.emplace_back(tmp_task_desc_info); | task_desc_info.emplace_back(tmp_task_desc_info); | ||||
| @@ -197,7 +193,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(task, kShapeTypeStatic)); | |||||
| GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(task)); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -259,7 +255,7 @@ Status DynamicSingleOp::ExecuteAsync(const vector<GeTensorDesc> &input_desc, | |||||
| std::lock_guard<std::mutex> lk(*stream_mutex_); | std::lock_guard<std::mutex> lk(*stream_mutex_); | ||||
| GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); | GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); | ||||
| GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get(), kShapeTypeDynamic)); | |||||
| GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get())); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -119,11 +119,11 @@ Status OpTask::DoUpdateArgTable(const SingleOpModelParam ¶m, bool keep_works | |||||
| uintptr_t *arg_base = nullptr; | uintptr_t *arg_base = nullptr; | ||||
| size_t arg_num = 0; | size_t arg_num = 0; | ||||
| GetIoAddr(arg_base, arg_num); | GetIoAddr(arg_base, arg_num); | ||||
| if (arg_num < all_addresses.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] arg number mismatches, expect at least = %zu, but got = %zu", | |||||
| if (arg_num != all_addresses.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] arg number mismatches, expect = %zu, but got = %zu", | |||||
| op_desc_->GetName().c_str(), | op_desc_->GetName().c_str(), | ||||
| all_addresses.size(), | |||||
| arg_num); | |||||
| arg_num, | |||||
| all_addresses.size()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -293,7 +293,6 @@ const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; | |||||
| // Configure op bank path | // Configure op bank path | ||||
| const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | ||||
| const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; | |||||
| // Graph run mode | // Graph run mode | ||||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
| @@ -367,7 +366,6 @@ static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | |||||
| static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | ||||
| static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); | static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); | ||||
| static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); | static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); | ||||
| static const char *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str(); | |||||
| static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); | ||||
| // for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
| @@ -391,13 +389,22 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||||
| OP_COMPILER_CACHE_DIR, | OP_COMPILER_CACHE_DIR, | ||||
| OP_COMPILER_CACHE_MODE, | OP_COMPILER_CACHE_MODE, | ||||
| MDL_BANK_PATH, | MDL_BANK_PATH, | ||||
| OP_BANK_PATH, | |||||
| OP_BANK_UPDATE}; | |||||
| OP_BANK_PATH}; | |||||
| // for interface: aclgrphParse | // for interface: aclgrphParse | ||||
| const std::set<std::string> ir_parser_suppported_options = { | |||||
| INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT, | |||||
| OUT_NODES, COMPRESS_WEIGHT_CONF, ENABLE_SCOPE_FUSION_PASSES}; | |||||
| const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | |||||
| INPUT_SHAPE, | |||||
| OP_NAME_MAP, | |||||
| IS_DYNAMIC_INPUT, | |||||
| INPUT_FP16_NODES, | |||||
| IS_INPUT_ADJUST_HW_LAYOUT, | |||||
| IS_OUTPUT_ADJUST_HW_LAYOUT, | |||||
| OUTPUT, | |||||
| OUTPUT_TYPE, | |||||
| OUT_NODES, | |||||
| COMPRESS_WEIGHT_CONF, | |||||
| ENABLE_SCOPE_FUSION_PASSES, | |||||
| LOG_LEVEL}; | |||||
| // for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
| const std::set<std::string> global_options = {CORE_TYPE, | const std::set<std::string> global_options = {CORE_TYPE, | ||||
| @@ -37,9 +37,7 @@ enum FrameworkType { | |||||
| MINDSPORE = 1, | MINDSPORE = 1, | ||||
| TENSORFLOW = 3, | TENSORFLOW = 3, | ||||
| ANDROID_NN, | ANDROID_NN, | ||||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||||
| ONNX, | ONNX, | ||||
| #endif | |||||
| FRAMEWORK_RESERVED, | FRAMEWORK_RESERVED, | ||||
| }; | }; | ||||
| @@ -248,8 +246,6 @@ struct TaskDescInfo { | |||||
| uint32_t block_dim; | uint32_t block_dim; | ||||
| uint32_t task_id; | uint32_t task_id; | ||||
| uint32_t stream_id; | uint32_t stream_id; | ||||
| std::string shape_type; | |||||
| int64_t cur_iter_num; | |||||
| }; | }; | ||||
| // Profiling info of graph | // Profiling info of graph | ||||
| @@ -30,6 +30,8 @@ | |||||
| #include "runtime/base.h" | #include "runtime/base.h" | ||||
| namespace ge { | namespace ge { | ||||
| class ModelListenerAdapter; | |||||
| class SingleOp; | class SingleOp; | ||||
| class DynamicSingleOp; | class DynamicSingleOp; | ||||
| @@ -53,8 +55,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| ge::Status Initialize(); | ge::Status Initialize(); | ||||
| ge::Status Finalize(); | ge::Status Finalize(); | ||||
| // Load model | |||||
| ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | |||||
| std::shared_ptr<ge::ModelListener> listener); | |||||
| ge::Status UnloadModel(uint32_t modelId); | ge::Status UnloadModel(uint32_t modelId); | ||||
| ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | |||||
| // Get input and output descriptor | // Get input and output descriptor | ||||
| ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
| std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false); | std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false); | ||||
| @@ -160,6 +168,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
| std::vector<ge::TensorDesc> &output_desc); | std::vector<ge::TensorDesc> &output_desc); | ||||
| ge::Status LoadModel(uint32_t &model_id, const ge::ModelData &model_data, | |||||
| std::shared_ptr<ge::ModelListener> listener); | |||||
| ge::Status CommandHandle(const ge::Command &command); | ge::Status CommandHandle(const ge::Command &command); | ||||
| ge::Status SetDump(const DumpConfig &dump_config); | ge::Status SetDump(const DumpConfig &dump_config); | ||||
| @@ -286,6 +297,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| private: | private: | ||||
| static bool isInit_; | static bool isInit_; | ||||
| }; | }; | ||||
| ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ | #endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ | ||||
| @@ -36,7 +36,7 @@ using Status = domi::Status; | |||||
| namespace domi { | namespace domi { | ||||
| using GetGraphCallback = std::function<std::unique_ptr<google::protobuf::Message>( | using GetGraphCallback = std::function<std::unique_ptr<google::protobuf::Message>( | ||||
| const google::protobuf::Message *root_proto, const std::string &graph)>; | |||||
| const google::protobuf::Message *root_proto, const std::string &graph)>; | |||||
| class ModelParser { | class ModelParser { | ||||
| public: | public: | ||||
| ModelParser() {} | ModelParser() {} | ||||
| @@ -44,20 +44,19 @@ class ModelParser { | |||||
| virtual ~ModelParser() {} | virtual ~ModelParser() {} | ||||
| /** | /** | ||||
| * @ingroup domi_omg | |||||
| * @brief Analyze network model data | |||||
| * @param [in] file Network model file path | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| * @ingroup domi_omg | |||||
| * @brief Analyze network model data | |||||
| * @param [in] file Network model file path | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| virtual Status Parse(const char *file, ge::Graph &graph) = 0; | virtual Status Parse(const char *file, ge::Graph &graph) = 0; | ||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Parse relevant data from memory and save it to graph | * @brief Parse relevant data from memory and save it to graph | ||||
| * @param [in] input Model file memory data | * @param [in] input Model file memory data | ||||
| * @param [in] input Model file memory size | |||||
| * @param [in|out] graph A graph for saving the model information after analysis | * @param [in|out] graph A graph for saving the model information after analysis | ||||
| * @return SUCCESS | * @return SUCCESS | ||||
| * @return FAILED | * @return FAILED | ||||
| @@ -65,7 +64,6 @@ class ModelParser { | |||||
| */ | */ | ||||
| virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; | virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; | ||||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Parse relevant data from memory and save it to graph | * @brief Parse relevant data from memory and save it to graph | ||||
| @@ -77,37 +75,37 @@ class ModelParser { | |||||
| * @author | * @author | ||||
| */ | */ | ||||
| virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) = 0; | virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) = 0; | ||||
| #endif | |||||
| /** | /** | ||||
| * @ingroup domi_omg | |||||
| * @brief Analyze network model data | |||||
| * @param [in] proto network model | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| * @ingroup domi_omg | |||||
| * @brief Analyze network model data | |||||
| * @param [in] proto network model | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; | virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; | ||||
| /** | /** | ||||
| * @ingroup domi_omg | |||||
| * @brief Analyze callback model data in subgraph | |||||
| * @param [in] proto network model | |||||
| * @param [in] callback callback of subgraph | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, GetGraphCallback callback, | |||||
| * @ingroup domi_omg | |||||
| * @brief Analyze callback model data in subgraph | |||||
| * @param [in] proto network model | |||||
| * @param [in] callback callback of subgraph | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, | |||||
| GetGraphCallback callback, | |||||
| ge::ComputeGraphPtr &graph) = 0; | ge::ComputeGraphPtr &graph) = 0; | ||||
| /** | /** | ||||
| * @ingroup domi_omg | |||||
| * @brief Convert model files to JSON format | |||||
| * @param [in] model_file Model file path to be converted | |||||
| * @param [out] json_file Converted JSON file path | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert model files to JSON format | |||||
| * @param [in] model_file Model file path to be converted | |||||
| * @param [out] json_file Converted JSON file path | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } | virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } | ||||
| /* | /* | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 8c89c521f5d682327b2f975cf06f7093960eb2f0 | |||||
| Subproject commit 5a1b0ab95e2d205ee9ee578ac4bcde4f4fbed6d8 | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit 54ec7731e3a2951191693e02ff3165220975ed0c | |||||
| Subproject commit 77dc42c383e416ed4a0f606ddc3c02cdaa082ac3 | |||||
| @@ -384,8 +384,3 @@ rtError_t rtModelExit(rtModel_t model, rtStream_t stream) | |||||
| { | { | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId) | |||||
| { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| @@ -61,67 +61,58 @@ set(UT_FILES | |||||
| ) | ) | ||||
| set(SRC_FILES | set(SRC_FILES | ||||
| "${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/attr_value.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/graph.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/gnode.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ascend_string.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/model.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/node.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/operator.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/inference_context.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/attr_value.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/graph.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/gnode.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/ascend_string.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/model.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/node.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/operator.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/operator_reg.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/range_vistor.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/inference_context.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||||
| #"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | ||||
| "${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/runtime_inference_context.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ref_relation.cc" | |||||
| "${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cpp" | |||||
| "${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cpp" | |||||
| #"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||||
| ) | ) | ||||
| #add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | #add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
| add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
| target_compile_options(ut_libgraph PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| ) | |||||
| target_compile_definitions(ut_libgraph PRIVATE | target_compile_definitions(ut_libgraph PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_link_libraries(ut_libgraph | target_link_libraries(ut_libgraph | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| graph | |||||
| gtest | gtest | ||||
| gtest_main | gtest_main | ||||
| slog_stub | slog_stub | ||||
| ascend_protobuf | ascend_protobuf | ||||
| c_sec | c_sec | ||||
| error_manager_stub | |||||
| mmpa_stub | |||||
| -lrt | -lrt | ||||
| -ldl | -ldl | ||||
| -lgcov | |||||
| ) | ) | ||||
| @@ -245,8 +245,6 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/model/ge_model.cc" | "${GE_CODE_DIR}/ge/model/ge_model.cc" | ||||
| "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc" | "${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc" | ||||
| @@ -477,8 +475,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | ||||
| @@ -487,7 +483,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | ||||
| #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | ||||
| @@ -675,13 +671,13 @@ set(MULTI_PARTS_TEST_FILES | |||||
| ) | ) | ||||
| set(SINGLE_OP_TEST_FILES | set(SINGLE_OP_TEST_FILES | ||||
| #"single_op/single_op_model_unittest.cc" | |||||
| "single_op/single_op_model_unittest.cc" | |||||
| "single_op/single_op_manager_unittest.cc" | "single_op/single_op_manager_unittest.cc" | ||||
| "single_op/stream_resource_unittest.cc" | "single_op/stream_resource_unittest.cc" | ||||
| ) | ) | ||||
| set(PROFILING_MNG_TEST_FILES | set(PROFILING_MNG_TEST_FILES | ||||
| #"profiling/ge_profiling_manager_unittest.cc" | |||||
| "profiling/ge_profiling_manager_unittest.cc" | |||||
| ) | ) | ||||
| set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
| @@ -848,17 +844,13 @@ add_executable(ut_libge_multiparts_utest | |||||
| ${MULTI_PARTS_TEST_FILES} | ${MULTI_PARTS_TEST_FILES} | ||||
| ) | ) | ||||
| target_compile_options(ut_libge_multiparts_utest PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| ) | |||||
| target_compile_definitions(ut_libge_multiparts_utest PRIVATE | target_compile_definitions(ut_libge_multiparts_utest PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_link_libraries(ut_libge_multiparts_utest | target_link_libraries(ut_libge_multiparts_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl | |||||
| ) | ) | ||||
| # libge_others_utest | # libge_others_utest | ||||
| @@ -869,14 +861,9 @@ add_executable(ut_libge_others_utest | |||||
| ${EXECUTE_TEST_FILES} | ${EXECUTE_TEST_FILES} | ||||
| ${OTHERS_TEST_FILES} | ${OTHERS_TEST_FILES} | ||||
| ) | ) | ||||
| target_compile_options(ut_libge_others_utest PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| ) | |||||
| target_link_libraries(ut_libge_others_utest | target_link_libraries(ut_libge_others_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_load_common ge_execute_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_load_common ge_execute_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl | |||||
| ) | ) | ||||
| # libge_kernel_utest | # libge_kernel_utest | ||||
| @@ -886,14 +873,9 @@ add_executable(ut_libge_kernel_utest | |||||
| ${KERNEL_TEST_FILES} | ${KERNEL_TEST_FILES} | ||||
| ${KERNEL_SRC_FILES} | ${KERNEL_SRC_FILES} | ||||
| ) | ) | ||||
| target_compile_options(ut_libge_kernel_utest PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| ) | |||||
| target_link_libraries(ut_libge_kernel_utest | target_link_libraries(ut_libge_kernel_utest | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ge_load_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov | |||||
| ge_load_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl | |||||
| ) | ) | ||||
| # libge_distinct_load_utest | # libge_distinct_load_utest | ||||
| @@ -905,10 +887,6 @@ add_executable(ut_libge_distinct_load_utest | |||||
| ${PROFILING_MNG_TEST_FILES} | ${PROFILING_MNG_TEST_FILES} | ||||
| ) | ) | ||||
| target_compile_options(ut_libge_distinct_load_utest PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| ) | |||||
| target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| @@ -919,5 +897,5 @@ target_link_libraries(ut_libge_distinct_load_utest | |||||
| ge_execute_common ge_ut_common_format ge_load_common | ge_execute_common ge_ut_common_format ge_load_common | ||||
| ge_single_op ge_prepare_common | ge_single_op ge_prepare_common | ||||
| ge_optimize_common ge_build_common ge_partition_common ge_ut_common | ge_optimize_common ge_build_common ge_partition_common ge_ut_common | ||||
| gtest gtest_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov | |||||
| gtest gtest_main ascend_protobuf json c_sec -lrt -ldl -lpthread | |||||
| ) | ) | ||||
| @@ -147,7 +147,6 @@ class UtestMemoryAssignerTest : public testing::Test { | |||||
| void TearDown() { GetContext().out_nodes_map.clear(); } | void TearDown() { GetContext().out_nodes_map.clear(); } | ||||
| }; | }; | ||||
| /* | |||||
| TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { | TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { | ||||
| ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | ||||
| ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); | ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); | ||||
| @@ -161,7 +160,6 @@ TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { | |||||
| delete memory_block; | delete memory_block; | ||||
| } | } | ||||
| */ | |||||
| namespace ge { | namespace ge { | ||||
| @@ -52,6 +52,7 @@ | |||||
| using namespace testing; | using namespace testing; | ||||
| using namespace ge; | using namespace ge; | ||||
| using namespace cce; | |||||
| using namespace ge::test; | using namespace ge::test; | ||||
| #define TEST_OPERATOR(op_, input_shapes, output_shapes) \ | #define TEST_OPERATOR(op_, input_shapes, output_shapes) \ | ||||
| @@ -52,6 +52,7 @@ | |||||
| using namespace testing; | using namespace testing; | ||||
| using namespace ge; | using namespace ge; | ||||
| using namespace cce; | |||||
| class UtestBroadcastGradientArgsKernel : public testing::Test { | class UtestBroadcastGradientArgsKernel : public testing::Test { | ||||
| protected: | protected: | ||||
| @@ -53,6 +53,7 @@ | |||||
| using namespace testing; | using namespace testing; | ||||
| using namespace ge; | using namespace ge; | ||||
| using namespace cce; | |||||
| using namespace ge::test; | using namespace ge::test; | ||||
| class UtestEmptyKernel : public testing::Test { | class UtestEmptyKernel : public testing::Test { | ||||
| @@ -38,7 +38,6 @@ | |||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
| #include "cce/dnn.h" | |||||
| #include "cce/dnn_struct_base.hpp" | #include "cce/dnn_struct_base.hpp" | ||||
| #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | ||||
| #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | ||||
| @@ -84,7 +84,7 @@ TEST(UtestGeOperatorFactory, register_func) { | |||||
| status = OperatorFactoryImpl::RegisterVerifyFunc("ABC", nullptr); | status = OperatorFactoryImpl::RegisterVerifyFunc("ABC", nullptr); | ||||
| EXPECT_EQ(GRAPH_SUCCESS, status); | EXPECT_EQ(GRAPH_SUCCESS, status); | ||||
| } | } | ||||
| /* | |||||
| TEST(UtestGeOperatorFactory, get_ops_type_list_fail) { | TEST(UtestGeOperatorFactory, get_ops_type_list_fail) { | ||||
| auto operator_creators_temp = OperatorFactoryImpl::operator_creators_; | auto operator_creators_temp = OperatorFactoryImpl::operator_creators_; | ||||
| OperatorFactoryImpl::operator_creators_ = nullptr; | OperatorFactoryImpl::operator_creators_ = nullptr; | ||||
| @@ -92,5 +92,4 @@ TEST(UtestGeOperatorFactory, get_ops_type_list_fail) { | |||||
| graphStatus status = OperatorFactoryImpl::GetOpsTypeList(all_ops); | graphStatus status = OperatorFactoryImpl::GetOpsTypeList(all_ops); | ||||
| EXPECT_EQ(GRAPH_FAILED, status); | EXPECT_EQ(GRAPH_FAILED, status); | ||||
| OperatorFactoryImpl::operator_creators_ = operator_creators_temp; | OperatorFactoryImpl::operator_creators_ = operator_creators_temp; | ||||
| } | |||||
| */ | |||||
| } | |||||