Browse Source

Pre Merge pull request !645 from yanghaoran/sync_dev_1217

pull/645/MERGE
yanghaoran Gitee 4 years ago
parent
commit
bc7dfe38de
100 changed files with 2571 additions and 2053 deletions
  1. +14
    -15
      CMakeLists.txt
  2. +1
    -0
      cmake/external_libs/gflags.cmake
  3. +6
    -2
      cmake/external_libs/gtest.cmake
  4. +9
    -4
      cmake/external_libs/json.cmake
  5. +5
    -1
      cmake/external_libs/onnx.cmake
  6. +1
    -0
      cmake/external_libs/protobuf_shared.cmake
  7. +1
    -0
      cmake/external_libs/protobuf_static.cmake
  8. +116
    -115
      cmake/external_libs/protoc.cmake
  9. +11
    -2
      cmake/external_libs/securec.cmake
  10. +5
    -9
      ge/CMakeLists.txt
  11. +1
    -1
      ge/client/ge_api.cc
  12. +0
    -369
      ge/client/ge_prof.cc
  13. +6
    -5
      ge/client/module.mk
  14. +1
    -0
      ge/common/CMakeLists.txt
  15. +61
    -2
      ge/common/auth/file_saver.cc
  16. +7
    -0
      ge/common/auth/file_saver.h
  17. +30
    -21
      ge/common/base64.h
  18. +2
    -1
      ge/common/debug/memory_dumper.cc
  19. +43
    -25
      ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  20. +52
    -35
      ge/common/formats/format_transfers/format_transfer_fractal_zz.cc
  21. +0
    -1
      ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc
  22. +13
    -12
      ge/common/formats/format_transfers/format_transfer_transpose.cc
  23. +9
    -0
      ge/common/formats/utils/formats_definitions.h
  24. +7
    -1
      ge/common/ge/plugin_manager.cc
  25. +3
    -1
      ge/common/ge/tbe_plugin_manager.cc
  26. +1
    -0
      ge/common/ge_common.mk
  27. +396
    -31
      ge/common/helper/model_helper.cc
  28. +195
    -1
      ge/common/helper/om_file_helper.cc
  29. +1
    -1
      ge/common/op/ge_op_utils.cc
  30. +199
    -0
      ge/common/profiling/ge_profiling.cc
  31. +26
    -0
      ge/common/profiling/ge_runner_profiling.cc
  32. +232
    -406
      ge/common/profiling/profiling_manager.cc
  33. +51
    -69
      ge/common/profiling/profiling_manager.h
  34. +1
    -1
      ge/common/types.cc
  35. +43
    -40
      ge/common/util.cc
  36. +1
    -1
      ge/executor/CMakeLists.txt
  37. +5
    -3
      ge/executor/ge_executor.cc
  38. +1
    -4
      ge/executor/module.mk
  39. +1
    -0
      ge/ge_inference.mk
  40. +5
    -6
      ge/ge_local_engine/engine/host_cpu_engine.cc
  41. +4
    -6
      ge/ge_runner.mk
  42. +2
    -1
      ge/ge_runtime/runtime_model.cc
  43. +106
    -47
      ge/generator/ge_generator.cc
  44. +4
    -4
      ge/graph/build/graph_builder.cc
  45. +7
    -5
      ge/graph/build/memory/binary_block_mem_assigner.cc
  46. +328
    -194
      ge/graph/build/memory/block_mem_assigner.cc
  47. +30
    -6
      ge/graph/build/memory/block_mem_assigner.h
  48. +6
    -4
      ge/graph/build/memory/graph_mem_assigner.cc
  49. +4
    -4
      ge/graph/build/model_builder.cc
  50. +2
    -1
      ge/graph/build/stream_allocator.cc
  51. +1
    -1
      ge/graph/build/stream_graph_optimizer.cc
  52. +3
    -17
      ge/graph/build/task_generator.cc
  53. +0
    -1
      ge/graph/label/case_label_maker.h
  54. +0
    -1
      ge/graph/label/if_label_maker.h
  55. +0
    -1
      ge/graph/label/partitioned_call_label_maker.h
  56. +0
    -1
      ge/graph/label/while_label_maker.h
  57. +2
    -1
      ge/graph/load/graph_loader.cc
  58. +2
    -2
      ge/graph/load/new_model_manager/data_dumper.cc
  59. +108
    -133
      ge/graph/load/new_model_manager/davinci_model.cc
  60. +1
    -1
      ge/graph/load/new_model_manager/davinci_model.h
  61. +24
    -56
      ge/graph/load/new_model_manager/model_manager.cc
  62. +1
    -3
      ge/graph/load/new_model_manager/model_manager.h
  63. +2
    -2
      ge/graph/load/new_model_manager/model_utils.cc
  64. +3
    -2
      ge/graph/load/new_model_manager/task_info/hccl_task_info.cc
  65. +30
    -21
      ge/graph/load/new_model_manager/task_info/kernel_task_info.cc
  66. +4
    -4
      ge/graph/load/new_model_manager/task_info/kernel_task_info.h
  67. +2
    -2
      ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h
  68. +5
    -4
      ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc
  69. +14
    -12
      ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc
  70. +4
    -4
      ge/graph/load/new_model_manager/task_info/task_info.h
  71. +1
    -1
      ge/graph/load/new_model_manager/ts_mem_mall.h
  72. +2
    -0
      ge/graph/load/new_model_manager/zero_copy_offset.cc
  73. +4
    -1
      ge/graph/load/new_model_manager/zero_copy_offset.h
  74. +1
    -1
      ge/graph/load/new_model_manager/zero_copy_task.cc
  75. +6
    -6
      ge/graph/manager/graph_caching_allocator.cc
  76. +9
    -2
      ge/graph/manager/graph_caching_allocator.h
  77. +15
    -64
      ge/graph/manager/graph_manager.cc
  78. +2
    -2
      ge/graph/manager/graph_var_manager.cc
  79. +1
    -0
      ge/graph/manager/graph_var_manager.h
  80. +1
    -1
      ge/graph/manager/host_mem_manager.cc
  81. +2
    -1
      ge/graph/manager/util/debug.cc
  82. +2
    -1
      ge/graph/manager/util/hcom_util.cc
  83. +14
    -7
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  84. +26
    -26
      ge/graph/partition/graph_partition.cc
  85. +80
    -24
      ge/graph/passes/atomic_addr_clean_pass.cc
  86. +5
    -0
      ge/graph/passes/atomic_addr_clean_pass.h
  87. +23
    -31
      ge/graph/passes/attach_stream_label_pass.cc
  88. +8
    -2
      ge/graph/passes/cond_remove_pass.cc
  89. +0
    -1
      ge/graph/passes/ctrl_edge_transfer_pass.cc
  90. +2
    -1
      ge/graph/passes/data_pass.cc
  91. +9
    -16
      ge/graph/passes/enter_pass.cc
  92. +9
    -21
      ge/graph/passes/for_pass.cc
  93. +3
    -1
      ge/graph/passes/mark_agnostic_pass.cc
  94. +6
    -9
      ge/graph/passes/merge_pass.cc
  95. +32
    -56
      ge/graph/passes/multi_batch_pass.cc
  96. +0
    -4
      ge/graph/passes/pass_utils.cc
  97. +4
    -4
      ge/graph/passes/subgraph_pass.cc
  98. +37
    -39
      ge/graph/passes/switch_to_stream_switch_pass.cc
  99. +3
    -4
      ge/graph/passes/switch_to_stream_switch_pass.h
  100. +3
    -1
      ge/graph/passes/transop_breadth_fusion_pass.cc

+ 14
- 15
CMakeLists.txt View File

@@ -16,8 +16,11 @@ endif()


if(DEFINED ENV{D_PKG_SERVER}) if(DEFINED ENV{D_PKG_SERVER})
set(GE_PB_PKG $ENV{D_PKG_SERVER}) set(GE_PB_PKG $ENV{D_PKG_SERVER})
message("Download packages from PKG server")
endif()
message("Download packages from DPKG server")
elseif(DEFINED ENV{MSLIBS_SERVER})
set(GE_PB_PKG "http://$ENV{MSLIBS_SERVER}:8081")
message("Download packages from MSPKG server")
endif ()


set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
@@ -37,7 +40,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}
option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE)


if (ENABLE_OPEN_SRC) if (ENABLE_OPEN_SRC)
set(HI_PYTHON python3.7)
set(HI_PYTHON python3)


include(cmake/external_libs/protobuf_shared.cmake) include(cmake/external_libs/protobuf_shared.cmake)
include(cmake/external_libs/protobuf_static.cmake) include(cmake/external_libs/protobuf_static.cmake)
@@ -71,7 +74,7 @@ if (ENABLE_OPEN_SRC)
set(STATIC_ACL_LIB ${GE_LIB_PATH}) set(STATIC_ACL_LIB ${GE_LIB_PATH})
find_module(slog libslog.so ${GE_LIB_PATH}) find_module(slog libslog.so ${GE_LIB_PATH})
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) find_module(static_mmpa libmmpa.a ${GE_LIB_PATH})
find_module(msprof libmsprof.so ${GE_LIB_PATH})
find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH})
find_module(hccl libhccl.so ${GE_LIB_PATH}) find_module(hccl libhccl.so ${GE_LIB_PATH})
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) find_module(adump_server libadump_server.a ${GE_LIB_PATH})
find_module(runtime libruntime.so ${GE_LIB_PATH}) find_module(runtime libruntime.so ${GE_LIB_PATH})
@@ -80,20 +83,19 @@ if (ENABLE_OPEN_SRC)
find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) find_module(error_manager liberror_manager.so ${GE_LIB_PATH})
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH})
find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH})
find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH})
find_module(msprofiler_fwk libmsprofiler_fwk.a ${GE_LIB_PATH})
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else() else()
find_module(slog libslog.so ${ASCEND_ATC_DIR} ${ASCEND_DRIVER_COMMON_DIR}) find_module(slog libslog.so ${ASCEND_ATC_DIR} ${ASCEND_DRIVER_COMMON_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR})
if(PLATFORM STREQUAL "train") if(PLATFORM STREQUAL "train")
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR})
find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR})
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) find_module(resource libresource.so ${ASCEND_RUNTIME_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR})
find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR})
find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
if(PRODUCT STREQUAL "flr3") if(PRODUCT STREQUAL "flr3")
message(FATAL_ERROR "This platform is not supported in train mode, build terminated") message(FATAL_ERROR "This platform is not supported in train mode, build terminated")
@@ -106,20 +108,17 @@ if (ENABLE_OPEN_SRC)
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR})
find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR})
find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR})
#find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR})
#find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR})
if(PRODUCT STREQUAL "flr3") if(PRODUCT STREQUAL "flr3")
find_module(msprof libmsprof.so ${ASCEND_DRIVER_SHARE_DIR})
elseif(PRODUCT STREQUAL "flr1") elseif(PRODUCT STREQUAL "flr1")
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR})
elseif(PRODUCT STREQUAL "flr2") elseif(PRODUCT STREQUAL "flr2")
# flr2 ascend_hal_stub limsprof ? # flr2 ascend_hal_stub limsprof ?
else() else()
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
find_module(msprof libmsprof.so ${ASCEND_DRIVER_DIR})
endif() endif()
elseif(PLATFORM STREQUAL "all") elseif(PLATFORM STREQUAL "all")
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR})
find_module(msprofiler libmsprofiler.a ${ASCEND_DRIVER_COMMON_DIR})
find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR})
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR})
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) find_module(runtime libruntime.so ${ASCEND_ACL_DIR})
@@ -127,14 +126,14 @@ if (ENABLE_OPEN_SRC)
find_module(resource libresource.so ${ASCEND_ATC_DIR}) find_module(resource libresource.so ${ASCEND_ATC_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR})
find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR})
find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR})
find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_ACL_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
#find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR})
else() else()
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")
endif() endif()


if (ENABLE_GE_COV OR ENABLE_GE_UT)
if (ENABLE_GE_COV OR ENABLE_GE_UT)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()




+ 1
- 0
cmake/external_libs/gflags.cmake View File

@@ -23,6 +23,7 @@ ExternalProject_Add(gflags_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR>
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install INSTALL_COMMAND $(MAKE) install


+ 6
- 2
cmake/external_libs/gtest.cmake View File

@@ -10,7 +10,10 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()


if (ENABLE_GITEE)
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/gtest/release-1.8.0.tar.gz")
set(MD5 "")
elseif (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz")
set(MD5 "") set(MD5 "")
else() else()
@@ -22,8 +25,9 @@ set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-
set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(gtest_build ExternalProject_Add(gtest_build
URL ${REQ_URL} URL ${REQ_URL}
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR>
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install INSTALL_COMMAND $(MAKE) install
EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_ALL TRUE


+ 9
- 4
cmake/external_libs/json.cmake View File

@@ -5,10 +5,14 @@ endif()
include(ExternalProject) include(ExternalProject)


set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include)
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
#elseif (ENABLE_GITEE)
# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
#set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
else() else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89") set(MD5 "0dc903888211db3a0f170304cd9f3a89")
@@ -18,6 +22,7 @@ ExternalProject_Add(json_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/cloud_code/pkg/include.zip #URL /home/txd/workspace/cloud_code/pkg/include.zip
SOURCE_DIR ${JSON_SRC_DIR} SOURCE_DIR ${JSON_SRC_DIR}
TLS_VERIFY OFF
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
INSTALL_COMMAND "" INSTALL_COMMAND ""


+ 5
- 1
cmake/external_libs/onnx.cmake View File

@@ -6,7 +6,10 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx)
set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto)
file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) file(MAKE_DIRECTORY ${ONNX_PROTO_DIR})


if (ENABLE_GITEE)
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz")
set(MD5 "512f2779d6215d4a36f366b6b9acdf1e")
elseif (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz")
set(MD5 "1bdbcecdd68ea8392630467646776e02") set(MD5 "1bdbcecdd68ea8392630467646776e02")
else() else()
@@ -19,6 +22,7 @@ ExternalProject_Add(onnx
#URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz
#URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345
#SOURCE_DIR ${ONNX_SRC_DIR} #SOURCE_DIR ${ONNX_SRC_DIR}
TLS_VERIFY OFF
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
#INSTALL_COMMAND "" #INSTALL_COMMAND ""


+ 1
- 0
cmake/external_libs/protobuf_shared.cmake View File

@@ -26,6 +26,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protobuf_build ExternalProject_Add(protobuf_build
URL ${REQ_URL} URL ${REQ_URL}
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_WITH_ZLIB=OFF
-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR}


+ 1
- 0
cmake/external_libs/protobuf_static.cmake View File

@@ -27,6 +27,7 @@ ExternalProject_Add(protobuf_static_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}


+ 116
- 115
cmake/external_libs/protoc.cmake View File

@@ -1,115 +1,116 @@
if (HAVE_PROTOC)
return()
endif()
include(ExternalProject)
include(GNUInstallDirs)
#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output)
if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
(${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend"))
set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE)
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()
if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
endif ()
endif()
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install
EXCLUDE_FROM_ALL TRUE
)
set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc)
set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc)
function(protobuf_generate comp c_var h_var)
if(NOT ARGN)
message(SEND_ERROR "Error: protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
get_filename_component(parent_subdir ${file_dir} NAME)
if("${parent_subdir}" STREQUAL "proto")
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto)
else()
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir})
endif()
list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h")
add_custom_command(
OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}"
COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file}
DEPENDS protoc_build ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction()
function(protobuf_generate_py comp py_var)
if(NOT ARGN)
message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files")
return()
endif()
set(${py_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
get_filename_component(parent_subdir ${file_dir} NAME)
if("${parent_subdir}" STREQUAL "proto")
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto)
else()
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir})
endif()
list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py")
add_custom_command(
OUTPUT "${proto_output_path}/${file_name}_pb2.py"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}"
COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file}
DEPENDS protoc_build ${abs_file}
COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE)
set(${py_var} ${${py_var}} PARENT_SCOPE)
endfunction()
#set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add")
set(HAVE_PROTOC TRUE)
if (HAVE_PROTOC)
return()
endif()

include(ExternalProject)
include(GNUInstallDirs)
#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output)

if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
(${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend"))
set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE)
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()

if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
endif ()
endif()

set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install
EXCLUDE_FROM_ALL TRUE
)

set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc)

set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc)

function(protobuf_generate comp c_var h_var)
if(NOT ARGN)
message(SEND_ERROR "Error: protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})

foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
get_filename_component(parent_subdir ${file_dir} NAME)

if("${parent_subdir}" STREQUAL "proto")
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto)
else()
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir})
endif()
list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h")

add_custom_command(
OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}"
COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file}
DEPENDS protoc_build ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()

set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)

endfunction()

function(protobuf_generate_py comp py_var)
if(NOT ARGN)
message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files")
return()
endif()
set(${py_var})

foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
get_filename_component(parent_subdir ${file_dir} NAME)

if("${parent_subdir}" STREQUAL "proto")
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto)
else()
set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir})
endif()
list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py")

add_custom_command(
OUTPUT "${proto_output_path}/${file_name}_pb2.py"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}"
COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file}
DEPENDS protoc_build ${abs_file}
COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM )
endforeach()

set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE)
set(${py_var} ${${py_var}} PARENT_SCOPE)

endfunction()

#set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add")
set(HAVE_PROTOC TRUE)

+ 11
- 2
cmake/external_libs/securec.cmake View File

@@ -10,11 +10,20 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()


if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz")
set(MD5 "")
else()
set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz")
set(MD5 "")
endif ()

ExternalProject_Add(c_sec_build ExternalProject_Add(c_sec_build
URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
URL ${REQ_URL}
#URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../libc_sec #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec
PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}


+ 5
- 9
ge/CMakeLists.txt View File

@@ -60,6 +60,8 @@ set(TRAIN_SRC_LIST
"common/dump/dump_manager.cc" "common/dump/dump_manager.cc"
"common/dump/dump_properties.cc" "common/dump/dump_properties.cc"
"common/dump/dump_op.cc" "common/dump/dump_op.cc"
"common/profiling/ge_profiling.cc"
"common/profiling/ge_runner_profiling.cc"
"engine_manager/dnnengine_manager.cc" "engine_manager/dnnengine_manager.cc"
"ge_local_engine/engine/host_cpu_engine.cc" "ge_local_engine/engine/host_cpu_engine.cc"
"generator/ge_generator.cc" "generator/ge_generator.cc"
@@ -201,6 +203,7 @@ set(TRAIN_SRC_LIST
"host_kernels/sub_kernel.cc" "host_kernels/sub_kernel.cc"
"host_kernels/transdata_kernel.cc" "host_kernels/transdata_kernel.cc"
"host_kernels/unpack_kernel.cc" "host_kernels/unpack_kernel.cc"
"host_kernels/reformat_kernel.cc"
"graph/passes/folding_pass.cc" "graph/passes/folding_pass.cc"
"graph/passes/get_original_format_pass.cc" "graph/passes/get_original_format_pass.cc"
"graph/passes/guarantee_const_pass.cc" "graph/passes/guarantee_const_pass.cc"
@@ -331,7 +334,6 @@ set(TRAIN_SRC_LIST
"hybrid/hybrid_davinci_model.cc" "hybrid/hybrid_davinci_model.cc"
"executor/ge_executor.cc" "executor/ge_executor.cc"
"client/ge_api.cc" "client/ge_api.cc"
"client/ge_prof.cc"
"analyzer/analyzer.cc" "analyzer/analyzer.cc"
"ir_build/ge_ir_build.cc" "ir_build/ge_ir_build.cc"
"ir_build/atc_ir_common.cc" "ir_build/atc_ir_common.cc"
@@ -487,6 +489,7 @@ set(INFER_SRC_LIST
"host_kernels/slice_d_kernel.cc" "host_kernels/slice_d_kernel.cc"
"host_kernels/dynamic_stitch_kernel.cc" "host_kernels/dynamic_stitch_kernel.cc"
"host_kernels/identity_kernel.cc" "host_kernels/identity_kernel.cc"
"host_kernels/reformat_kernel.cc"
"graph/passes/stop_gradient_pass.cc" "graph/passes/stop_gradient_pass.cc"
"graph/passes/prevent_gradient_pass.cc" "graph/passes/prevent_gradient_pass.cc"
"graph/passes/identity_pass.cc" "graph/passes/identity_pass.cc"
@@ -602,7 +605,7 @@ set(INFER_SRC_LIST


if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
############ libge_runner.so ############ ############ libge_runner.so ############
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS})
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} $<TARGET_OBJECTS:msprofiler_fwk>)


target_compile_definitions(ge_runner PRIVATE target_compile_definitions(ge_runner PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0 PROTOBUF_INLINE_NOT_IN_HEADERS=0
@@ -647,7 +650,6 @@ target_link_libraries(ge_runner
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
ge_memory ge_memory
adump_server adump_server
msprofiler
static_mmpa static_mmpa
-Wl,--no-as-needed -Wl,--no-as-needed
graph graph
@@ -656,7 +658,6 @@ target_link_libraries(ge_runner
register register
c_sec c_sec
slog slog
msprof
runtime runtime
resource resource
error_manager error_manager
@@ -781,7 +782,6 @@ target_link_libraries(opensrc_ascendcl PRIVATE
c_sec c_sec
runtime runtime
slog slog
msprof
ascend_hal_stub ascend_hal_stub
-Wl,--as-needed -Wl,--as-needed
-lrt -lrt
@@ -797,12 +797,10 @@ set_target_properties(opensrc_ascendcl PROPERTIES
add_custom_command( add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc
COMMAND echo "Generating stub files." COMMAND echo "Generating stub files."
&& ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
&& mv ge_ir_build.cc stub_ge_ir_build.cc && mv ge_ir_build.cc stub_ge_ir_build.cc
&& mv ge_api.cc stub_ge_api.cc && mv ge_api.cc stub_ge_api.cc
&& mv ge_prof.cc stub_ge_prof.cc
&& echo "Generating stub files end." && echo "Generating stub files end."
#WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
#DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
@@ -811,7 +809,6 @@ add_custom_command(
add_custom_target(ge_stub add_custom_target(ge_stub
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc
) )


################################################################## ##################################################################
@@ -853,7 +850,6 @@ target_include_directories(atc_stub_ge_compiler PRIVATE
############ stub/libge_runner.so ############ ############ stub/libge_runner.so ############
add_library(fwk_stub_ge_runner SHARED add_library(fwk_stub_ge_runner SHARED
stub_ge_api.cc stub_ge_api.cc
stub_ge_prof.cc
stub_ge_ir_build.cc stub_ge_ir_build.cc
) )




+ 1
- 1
ge/client/ge_api.cc View File

@@ -134,7 +134,7 @@ Status GEInitialize(const std::map<string, string> &options) {


Status GEInitialize(const std::map<AscendString, AscendString> &options) { Status GEInitialize(const std::map<AscendString, AscendString> &options) {
std::map<std::string, std::string> str_options; std::map<std::string, std::string> str_options;
for (auto & option : options) {
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "GEInitialize options is nullptr."); GELOGE(FAILED, "GEInitialize options is nullptr.");
return FAILED; return FAILED;


+ 0
- 369
ge/client/ge_prof.cc View File

@@ -1,369 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge/ge_prof.h"
#include "ge/ge_api.h"
#include "init/gelib.h"
#include "common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "common/profiling/profiling_manager.h"
#include "graph/load/graph_loader.h"
#include "toolchain/prof_acl_api.h"

using std::map;
using std::string;
using std::vector;

namespace {
const uint32_t kMaxDeviceNum = 64;
const uint32_t kDeviceListIndex = 3;
const std::string kProfilingInit = "prof_init";
const std::string kProfilingFinalize = "prof_finalize";
const std::string kProfilingStart = "prof_start";
const std::string kProfilingStop = "prof_stop";
const std::string kDeviceNums = "devNums";
const std::string kDeviceIdList = "devIdList";
const std::string kAicoreMetrics = "aicoreMetrics";

const std::map<ge::ProfilingAicoreMetrics, std::string> kProfAicoreMetricsToString = {
{ge::kAicoreArithmaticThroughput, "AICORE_ARITHMATIC_THROUGHPUT"},
{ge::kAicorePipeline, "AICORE_PIPELINE"},
{ge::kAicoreSynchronization, "AICORE_SYNCHRONIZATION"},
{ge::kAicoreMemory, "AICORE_MEMORY"},
{ge::kAicoreInternalMemory, "AICORE_INTERNAL_MEMORY"},
{ge::kAicoreStall, "AICORE_STALL"}};
} // namespace

static bool g_graph_prof_init_ = false;
static std::mutex g_prof_mutex_;

namespace ge {
struct aclgrphProfConfig {
ProfConfig config;
};

Status aclgrphProfInit(const char *profiler_path, uint32_t length) {
GELOGT(TRACE_INIT, "Graph prof init start");

std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized.");
return FAILED;
}

std::lock_guard<std::mutex> lock(g_prof_mutex_);
if (g_graph_prof_init_) {
GELOGW("Multi graph profiling initializations.");
return GE_PROF_MULTI_INIT;
}

Status ret = CheckPath(profiler_path, length);
if (ret != SUCCESS) {
GELOGE(ret, "Profiling config path is invalid.");
return ret;
}
// if command mode is set, just return
if (ProfilingManager::Instance().ProfilingOn()) {
GELOGW("Graph prof init failed, cause profiling command pattern is running.");
return GE_PROF_MODE_CONFLICT;
}

ret = ProfInit(profiler_path);
if (ret != SUCCESS) {
GELOGE(ret, "ProfInit init fail");
return ret;
}

GraphLoader graph_loader;
Command command;
command.cmd_params.clear();
command.cmd_type = kProfilingInit;
command.module_index = PROF_MODEL_LOAD;
ret = graph_loader.CommandHandle(command);
if (ret != SUCCESS) {
GELOGE(ret, "Handle profiling command %s failed, config = %s", kProfilingInit.c_str(), profiler_path);
return ret;
}
if (!g_graph_prof_init_) {
g_graph_prof_init_ = true;
GELOGI("Profiling init successfully.");
}

GELOGI("Successfully execute GraphProfInit.");
return SUCCESS;
}

Status aclgrphProfFinalize() {
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized.");
return FAILED;
}
std::lock_guard<std::mutex> lock(g_prof_mutex_);
// if command mode is set, just return
if (ProfilingManager::Instance().ProfilingOn()) {
GELOGW("Graph prof finalize failed, cause profiling command pattern is running.");
return GE_PROF_MODE_CONFLICT;
}

if (!g_graph_prof_init_) {
GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize.");
return GE_PROF_NOT_INIT;
}
GraphLoader graph_loader;
Command command;
command.cmd_params.clear();
command.cmd_type = kProfilingFinalize;
Status ret = graph_loader.CommandHandle(command);
if (ret != SUCCESS) {
GELOGE(ret, "Handle profiling command %s failed.", kProfilingFinalize.c_str());
return ret;
}

ret = ProfFinalize();
if (ret != SUCCESS) {
GELOGE(ret, "Finalize profiling failed, result = %d", ret);
}

if (ret == SUCCESS) {
g_graph_prof_init_ = false;
GELOGI("Successfully execute GraphProfFinalize.");
}
return ret;
}

bool TransProfConfigToParam(const aclgrphProfConfig *profiler_config, vector<string> &prof_config_params) {
prof_config_params.clear();
prof_config_params.emplace_back(kDeviceNums);
prof_config_params.emplace_back(std::to_string(profiler_config->config.devNums));
prof_config_params.emplace_back(kDeviceIdList);
std::string devID = "";
if (profiler_config->config.devNums == 0) {
GELOGW("The device num is invalid.");
return false;
}
for (uint32_t i = 0; i < profiler_config->config.devNums; i++) {
devID.append(std::to_string(profiler_config->config.devIdList[i]));
if (i != profiler_config->config.devNums - 1) {
devID.append(",");
}
}

prof_config_params.push_back(devID);
prof_config_params.push_back(kAicoreMetrics);
auto iter =
kProfAicoreMetricsToString.find(static_cast<ProfilingAicoreMetrics>(profiler_config->config.aicoreMetrics));
if (iter == kProfAicoreMetricsToString.end()) {
GELOGW("The prof aicore metrics is invalid.");
return false;
}
prof_config_params.push_back(iter->second);
return true;
}

bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) {
if (deviceid_list == nullptr) {
GELOGE(PARAM_INVALID, "deviceIdList is nullptr");
return false;
}
if (device_nums == 0 || device_nums > kMaxDeviceNum) {
GELOGE(PARAM_INVALID, "The device nums is invalid.");
return false;
}

// real device num
int32_t dev_count = 0;
rtError_t rt_err = rtGetDeviceCount(&dev_count);
if (rt_err != RT_ERROR_NONE) {
GELOGE(INTERNAL_ERROR, "Get the Device count fail.");
return false;
}

if (device_nums > static_cast<uint32_t>(dev_count)) {
GELOGE(PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count);
return false;
}

std::unordered_set<uint32_t> record;
for (size_t i = 0; i < device_nums; ++i) {
uint32_t dev_id = deviceid_list[i];
if (dev_id >= static_cast<uint32_t>(dev_count)) {
GELOGE(PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count);
return false;
}
if (record.count(dev_id) > 0) {
GELOGE(PARAM_INVALID, "Device id %u is duplicatedly set", dev_id);
return false;
}
record.insert(dev_id);
}
return true;
}

aclgrphProfConfig *aclgrphProfCreateConfig(uint32_t *deviceid_list, uint32_t device_nums,
ProfilingAicoreMetrics aicore_metrics, ProfAicoreEvents *aicore_events,
uint64_t data_type_config) {
if (!isProfConfigValid(deviceid_list, device_nums)) {
return nullptr;
}
aclgrphProfConfig *config = new (std::nothrow) aclgrphProfConfig();
if (config == nullptr) {
GELOGE(INTERNAL_ERROR, "new aclgrphProfConfig fail");
return nullptr;
}
config->config.devNums = device_nums;
if (memcpy_s(config->config.devIdList, sizeof(config->config.devIdList), deviceid_list,
device_nums * sizeof(uint32_t)) != EOK) {
GELOGE(INTERNAL_ERROR, "copy devID failed. size = %u", device_nums);
delete config;
return nullptr;
}

config->config.aicoreMetrics = static_cast<ProfAicoreMetrics>(aicore_metrics);
config->config.dataTypeConfig = data_type_config;
GELOGI("Successfully create prof config.");
return config;
}

Status aclgrphProfDestroyConfig(aclgrphProfConfig *profiler_config) {
if (profiler_config == nullptr) {
GELOGE(PARAM_INVALID, "destroy profilerConfig failed, profilerConfig must not be nullptr");
return PARAM_INVALID;
}

delete profiler_config;
GELOGI("Successfully destroy prof config.");
return SUCCESS;
}

Status aclgrphProfStart(aclgrphProfConfig *profiler_config) {
if (profiler_config == nullptr) {
GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid.");
return FAILED;
}
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized.");
return FAILED;
}

std::lock_guard<std::mutex> lock(g_prof_mutex_);
// if command mode is set, just return
if (ProfilingManager::Instance().ProfilingOn()) {
GELOGW("Graph prof finalize failed, cause profiling command pattern is running.");
return GE_PROF_MODE_CONFLICT;
}
if (!g_graph_prof_init_) {
GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize.");
return GE_PROF_NOT_INIT;
}

Status ret = ProfStartProfiling(&profiler_config->config);
if (ret != SUCCESS) {
GELOGE(ret, "Start profiling failed, prof result = %d", ret);
return FAILED;
}

std::vector<string> prof_params;
if (!TransProfConfigToParam(profiler_config, prof_params)) {
GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed");
return PARAM_INVALID;
}

GraphLoader graph_loader;
Command command;
command.cmd_params.clear();
command.cmd_type = kProfilingStart;
command.cmd_params = prof_params;
command.module_index = profiler_config->config.dataTypeConfig;
GELOGI("Profiling will start, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(),
prof_params[kDeviceListIndex].c_str(), command.module_index);
ret = graph_loader.CommandHandle(command);
if (ret != SUCCESS) {
GELOGE(ret, "Handle profiling command failed");
return FAILED;
}

GELOGI("Successfully execute GraphProfStartProfiling.");

return SUCCESS;
}

Status aclgrphProfStop(aclgrphProfConfig *profiler_config) {
if (profiler_config == nullptr) {
GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid.");
return FAILED;
}
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized.");
return FAILED;
}

std::lock_guard<std::mutex> lock(g_prof_mutex_);
// if command mode is set, just return
if (ProfilingManager::Instance().ProfilingOn()) {
GELOGW("Graph prof finalize failed, cause profiling command pattern is running.");
return GE_PROF_MODE_CONFLICT;
}
if (!g_graph_prof_init_) {
GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize.");
return GE_PROF_NOT_INIT;
}

for (uint32_t i = 0; i < profiler_config->config.devNums; i++) {
uint64_t data_type_config;
Status status = ProfGetDataTypeConfig(profiler_config->config.devIdList[i], data_type_config);
if (status != SUCCESS) {
GELOGE(status, "Prof get data type config failed, prof result = %d", status);
return status;
}
if (data_type_config != profiler_config->config.dataTypeConfig) {
GELOGE(FAILED, "data type config verify failed");
return FAILED;
}
}

std::vector<string> prof_params;
if (!TransProfConfigToParam(profiler_config, prof_params)) {
GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed");
return PARAM_INVALID;
}

GraphLoader graph_loader;
Command command;
command.cmd_params.clear();
command.cmd_type = kProfilingStop;
command.cmd_params = prof_params;
command.module_index = profiler_config->config.dataTypeConfig;
GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(),
prof_params[kDeviceListIndex].c_str(), command.module_index);
Status ret = graph_loader.CommandHandle(command);
if (ret != SUCCESS) {
GELOGE(ret, "Handle profiling command failed");
return FAILED;
}

ret = ProfStopProfiling(&profiler_config->config);
if (ret != SUCCESS) {
GELOGE(ret, "Stop profiling failed, prof result = %d", ret);
return ret;
}

GELOGI("Successfully execute GraphProfStopProfiling.");
return SUCCESS;
}
} // namespace ge

+ 6
- 5
ge/client/module.mk View File

@@ -4,7 +4,6 @@ LOCAL_PATH := $(call my-dir)
COMMON_LOCAL_SRC_FILES := \ COMMON_LOCAL_SRC_FILES := \
proto/ge_api.proto \ proto/ge_api.proto \
ge_api.cc \ ge_api.cc \
ge_prof.cc \




COMMON_LOCAL_C_INCLUDES := \ COMMON_LOCAL_C_INCLUDES := \
@@ -69,9 +68,9 @@ LOCAL_SHARED_LIBRARIES := \
libgraph \ libgraph \
libregister \ libregister \
libge_compiler \ libge_compiler \
libge_common \
libmsprof
libge_common


LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \




LOCAL_LDFLAGS := -lrt -ldl LOCAL_LDFLAGS := -lrt -ldl
@@ -104,8 +103,10 @@ LOCAL_SHARED_LIBRARIES := \
libregister \ libregister \
libruntime \ libruntime \
libge_compiler \ libge_compiler \
libge_common \
libmsprof
libge_common


LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \




LOCAL_LDFLAGS := -lrt -ldl LOCAL_LDFLAGS := -lrt -ldl


+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -24,6 +24,7 @@ set(SRC_LIST
"helper/om_file_helper.cc" "helper/om_file_helper.cc"
"helper/model_helper.cc" "helper/model_helper.cc"
"../model/ge_model.cc" "../model/ge_model.cc"
"../model/ge_root_model.cc"
"auth/file_saver.cc" "auth/file_saver.cc"
"fp16_t.cc" "fp16_t.cc"
"math/fp16_math.cc" "math/fp16_math.cc"


+ 61
- 2
ge/common/auth/file_saver.cc View File

@@ -54,8 +54,8 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) {
Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID);
mmSsize_t write_count; mmSsize_t write_count;
uint32_t size_2g = ((uint32_t) 0x1 << 31);
uint32_t size_1g = ((uint32_t) 0x1 << 30);
uint32_t size_2g = 2147483648; // 0x1 << 31
uint32_t size_1g = 1073741824; // 0x1 << 30
// Write data // Write data
if (size > size_2g) { if (size > size_2g) {
auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data));
@@ -258,6 +258,65 @@ FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, Mod
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status
FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas) {
file_header.is_encrypt = ModelEncryptType::UNENCRYPTED;

const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.",
file_path.c_str(), file_header.length);
return SUCCESS;
}

Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas) {

GE_CHK_BOOL_EXEC(model_partition_tables.size() == all_partition_datas.size(),
return PARAM_INVALID,
"model table size %zu does not match partition size %zu",
model_partition_tables.size(), all_partition_datas.size())
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
auto &cur_partiton_data = all_partition_datas[index];
auto &cur_model_partition_table = *model_partition_tables[index];
GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0
&& cur_model_partition_table.num == cur_partiton_data.size(), FAILED,
"Invalid param:partition data size is (%u), model_partition_table.num is (%zu).",
cur_model_partition_table.num, cur_partiton_data.size());
}

// Open file
int32_t fd = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED);
Status ret = SUCCESS;
do {
// Write file header
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED;
break);
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
// Write model partition table
auto &cur_tabel = *model_partition_tables[index];
uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel));
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(&cur_tabel), table_size, fd) != SUCCESS, ret = FAILED; break);
// Write partition data
auto &cur_partition_datas = all_partition_datas[index];
for (const auto &partition_data : cur_partition_datas) {
GELOGI("GC:size[%zu]", partition_data.size);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED;
break);
}
}
} while (0);
// Close file
GE_CHK_BOOL_RET_STATUS(mmClose(fd) == EN_OK, FAILED, "Close file failed.");
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data,
int len) { int len) {
if (data == nullptr || len <= 0) { if (data == nullptr || len <= 0) {


+ 7
- 0
ge/common/auth/file_saver.h View File

@@ -74,6 +74,10 @@ class FileSaver {
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partition_datas); const std::vector<ModelPartition> &partition_datas);


static Status SaveToFile(const string &file_path, ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas);

static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header,
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partitionDatas, const std::vector<ModelPartition> &partitionDatas,
@@ -108,6 +112,9 @@ class FileSaver {
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partition_datas); const std::vector<ModelPartition> &partition_datas);
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas);
}; };
} // namespace ge } // namespace ge
#endif // GE_COMMON_AUTH_FILE_SAVER_H_ #endif // GE_COMMON_AUTH_FILE_SAVER_H_

+ 30
- 21
ge/common/base64.h View File

@@ -25,32 +25,38 @@


namespace ge { namespace ge {
namespace { namespace {
const char* kBase64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
const char *kBase64Chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
const char kEqualSymbol = '='; const char kEqualSymbol = '=';
const size_t kBase64CharsNum = 64; const size_t kBase64CharsNum = 64;
const size_t kThreeByteOneGroup = 3; const size_t kThreeByteOneGroup = 3;
const size_t kFourByteOneGroup = 4; const size_t kFourByteOneGroup = 4;
}
const size_t kThreeByteOneGroupIndex0 = 0;
const size_t kThreeByteOneGroupIndex1 = 1;
const size_t kThreeByteOneGroupIndex2 = 2;
const size_t kFourByteOneGroupIndex0 = 0;
const size_t kFourByteOneGroupIndex1 = 1;
const size_t kFourByteOneGroupIndex2 = 2;
const size_t kFourByteOneGroupIndex3 = 3;
} // namespace


namespace base64 { namespace base64 {
static inline bool IsBase64Char(const char &c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); }


static std::string EncodeToBase64(const std::string &raw_data) { static std::string EncodeToBase64(const std::string &raw_data) {
size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup;
encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup;
size_t raw_data_index = 0 ;
size_t raw_data_index = 0;
size_t encode_data_index = 0; size_t encode_data_index = 0;
std::string encode_data; std::string encode_data;
encode_data.resize(encode_length); encode_data.resize(encode_length);


for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) {
auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]);
auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]);
auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + 2]);
auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex1]);
auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + kThreeByteOneGroupIndex2]);
encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u];
encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)];
encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)];
@@ -80,8 +86,7 @@ static std::string EncodeToBase64(const std::string &raw_data) {
#pragma GCC diagnostic ignored "-Wunused-function" #pragma GCC diagnostic ignored "-Wunused-function"
static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) {
if (base64_data.size() % kFourByteOneGroup != 0) { if (base64_data.size() % kFourByteOneGroup != 0) {
GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu",
base64_data.size());
GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size());
return PARAM_INVALID; return PARAM_INVALID;
} }
decode_data.clear(); decode_data.clear();
@@ -92,10 +97,10 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco
return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff;
}; };


for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += 4) {
for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += kFourByteOneGroup) {
for (size_t i = 0; i < kFourByteOneGroup; ++i) { for (size_t i = 0; i < kFourByteOneGroup; ++i) {
if (base64_data[input_data_index + i] == kEqualSymbol && if (base64_data[input_data_index + i] == kEqualSymbol &&
input_data_index >= base64_data_len - 4 && i > 1) {
input_data_index >= base64_data_len - kFourByteOneGroup && i > 1) {
byte_4[i] = kBase64CharsNum; byte_4[i] = kBase64CharsNum;
} else if (IsBase64Char(base64_data[input_data_index + i])) { } else if (IsBase64Char(base64_data[input_data_index + i])) {
byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]);
@@ -104,19 +109,23 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco
return PARAM_INVALID; return PARAM_INVALID;
} }
} }
decode_data += static_cast<char>((byte_4[0] << 2u) + ((byte_4[1] & 0x30) >> 4u));
if (byte_4[2] >= kBase64CharsNum){
decode_data +=
static_cast<char>((byte_4[kFourByteOneGroupIndex0] << 2u) + ((byte_4[kFourByteOneGroupIndex1] & 0x30) >> 4u));
if (byte_4[kFourByteOneGroupIndex2] >= kBase64CharsNum) {
break; break;
} else if (byte_4[3] >= kBase64CharsNum) {
decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u));
} else if (byte_4[kFourByteOneGroupIndex3] >= kBase64CharsNum) {
decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) +
((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u));
break; break;
} }
decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u));
decode_data += static_cast<char>(((byte_4[2] & 0x03) << 6u) + byte_4[3]);
decode_data += static_cast<char>(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) +
((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u));
decode_data +=
static_cast<char>(((byte_4[kFourByteOneGroupIndex2] & 0x03) << 6u) + byte_4[kFourByteOneGroupIndex3]);
} }
return SUCCESS; return SUCCESS;
} }
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
}
} // namespace base64
} // namespace ge } // namespace ge
#endif // GE_COMMON_BASE64_H_ #endif // GE_COMMON_BASE64_H_

+ 2
- 1
ge/common/debug/memory_dumper.cc View File

@@ -139,7 +139,8 @@ int MemoryDumper::OpenFile(const char *filename) {
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
-1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos);
string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH,
return kInvalidFd, "Prefix path is too long!");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd,
"Dir %s does not exit.", prefix_path.c_str()); "Dir %s does not exit.", prefix_path.c_str());
real_path = std::string(tmp_path) + last_path;) real_path = std::string(tmp_path) + last_path;)


+ 43
- 25
ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -23,12 +23,30 @@
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "framework/common/types.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {
const int kDimSize4D = 4; const int kDimSize4D = 4;

const size_t kSingleDim = 1;

const size_t kNdDimIndexN = 0;
const size_t kNdDimIndexH = 1;
const size_t kNdDimIndexW = 2;

const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz

const size_t kNdDimCountBackwardsW = 1;
const size_t kNdDimCountBackwardsWH = 2;

const size_t kFNzDimCountBackwardsW0 = 1;
const size_t kFNzDimCountBackwardsW0H0 = 2;
const size_t kFNzDimCountBackwardsW0H0H1 = 3;
const size_t kFNzDimCountBackwardsW0H0H1W1 = 4;

bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; }


using ShapeVector = std::vector<int64_t>; using ShapeVector = std::vector<int64_t>;
@@ -60,14 +78,14 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
auto w0 = GetCubeSizeByDataType(data_type); auto w0 = GetCubeSizeByDataType(data_type);
int64_t h0 = kCubeSize; int64_t h0 = kCubeSize;
switch (src_shape.size()) { switch (src_shape.size()) {
case 1:
dst_shape.push_back(Ceil(src_shape[0], w0));
dst_shape.push_back(1);
case kSingleDim:
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
dst_shape.push_back(DIM_DEFAULT_VALUE);
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(1);
hw_shape.push_back(1);
hw_shape.push_back(src_shape[0]);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(src_shape[kNdDimIndexN]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -76,17 +94,17 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
default: default:
auto size = src_shape.size(); auto size = src_shape.size();
int64_t times = 1; int64_t times = 1;
for (size_t i = 0; i != size - 2; i++) {
for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) {
dst_shape.push_back(src_shape[i]); dst_shape.push_back(src_shape[i]);
times *= src_shape[i]; times *= src_shape[i];
} }
dst_shape.push_back(Ceil(src_shape[size - 1], w0));
dst_shape.push_back(Ceil(src_shape[size - 2], h0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(times); hw_shape.push_back(times);
hw_shape.push_back(src_shape[size - 2]);
hw_shape.push_back(src_shape[size - 1]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -128,16 +146,16 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con
} }


// src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto times = hw_shape.at(kNdDimIndexN);
auto h = hw_shape.at(kNdDimIndexH);
auto w = hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.dst_shape.size(); auto shape_size = args.dst_shape.size();
auto w1 = args.dst_shape[shape_size - 4];
auto h1 = args.dst_shape[shape_size - 3];
auto h0 = args.dst_shape[shape_size - 2];
auto w0 = args.dst_shape[shape_size - 1];
auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0];
auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0];
auto h1h0 = h1 * h0; auto h1h0 = h1 * h0;
auto h1h0w0 = h1h0 * w0; auto h1h0w0 = h1h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0; auto w1h1h0w0 = w1 * h1h0w0;
@@ -198,16 +216,16 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con
return OUT_OF_MEMORY; return OUT_OF_MEMORY;
} }


auto times = dst_hw_shape.at(0);
auto h = dst_hw_shape.at(1);
auto w = dst_hw_shape.at(2);
auto times = dst_hw_shape.at(kNdDimIndexN);
auto h = dst_hw_shape.at(kNdDimIndexH);
auto w = dst_hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.src_shape.size(); auto shape_size = args.src_shape.size();
auto w1 = args.src_shape[shape_size - 4];
auto h1 = args.src_shape[shape_size - 3];
auto h0 = args.src_shape[shape_size - 2];
auto w0 = args.src_shape[shape_size - 1];
auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0];
auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0];
auto h1h0 = h1 * h0; auto h1h0 = h1 * h0;
auto h1h0w0 = h1h0 * w0; auto h1h0w0 = h1h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0; auto w1h1h0w0 = w1 * h1h0w0;


+ 52
- 35
ge/common/formats/format_transfers/format_transfer_fractal_zz.cc View File

@@ -23,12 +23,29 @@
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "framework/common/types.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {
const int kDimSize4D = 4; const int kDimSize4D = 4;

const size_t kSingleDim = 1;

const size_t kNdDimIndexN = 0;
const size_t kNdDimIndexH = 1;
const size_t kNdDimIndexW = 2;

const size_t kDimDValueBNdFZz = 2; // dim d-value between Nd and FractalZz

const size_t kNdDimCountBackwardsW = 1;
const size_t kNdDimCountBackwardsWH = 2;

const size_t kFZzDimCountBackwardsW0 = 1;
const size_t kFZzDimCountBackwardsW0H0 = 2;
const size_t kFZzDimCountBackwardsW0H0W1 = 3;
const size_t kFZzDimCountBackwardsW0H0W1H1 = 4;
bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; }


using ShapeVector = std::vector<int64_t>; using ShapeVector = std::vector<int64_t>;
@@ -40,8 +57,8 @@ bool CheckShape(Format format, const ShapeVector &shape) {
case FORMAT_NHWC: case FORMAT_NHWC:
return CheckShapeValid(shape, kDimSize4D); return CheckShapeValid(shape, kDimSize4D);
default: default:
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
return false; return false;
} }
@@ -60,14 +77,14 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
auto w0 = GetCubeSizeByDataType(data_type); auto w0 = GetCubeSizeByDataType(data_type);
auto h0 = GetCubeSizeByDataType(data_type); auto h0 = GetCubeSizeByDataType(data_type);
switch (src_shape.size()) { switch (src_shape.size()) {
case 1:
dst_shape.push_back(1);
dst_shape.push_back(Ceil(src_shape[0], w0));
case kSingleDim:
dst_shape.push_back(DIM_DEFAULT_VALUE);
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(1);
hw_shape.push_back(1);
hw_shape.push_back(src_shape[0]);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(DIM_DEFAULT_VALUE);
hw_shape.push_back(src_shape[kNdDimIndexN]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -76,17 +93,17 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
default: default:
auto size = src_shape.size(); auto size = src_shape.size();
int64_t times = 1; int64_t times = 1;
for (size_t i = 0; i != size - 2; i++) {
for (size_t i = 0; i != size - kDimDValueBNdFZz; i++) {
dst_shape.push_back(src_shape[i]); dst_shape.push_back(src_shape[i]);
times *= src_shape[i]; times *= src_shape[i];
} }
dst_shape.push_back(Ceil(src_shape[size - 2], h0));
dst_shape.push_back(Ceil(src_shape[size - 1], w0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
dst_shape.push_back(h0); dst_shape.push_back(h0);
dst_shape.push_back(w0); dst_shape.push_back(w0);
hw_shape.push_back(times); hw_shape.push_back(times);
hw_shape.push_back(src_shape[size - 2]);
hw_shape.push_back(src_shape[size - 1]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
if (!IsShapeValid(dst_shape)) { if (!IsShapeValid(dst_shape)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
@@ -127,16 +144,16 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
return OUT_OF_MEMORY; return OUT_OF_MEMORY;
} }
// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto times = hw_shape.at(kNdDimIndexN);
auto h = hw_shape.at(kNdDimIndexH);
auto w = hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.dst_shape.size(); auto shape_size = args.dst_shape.size();
auto h1 = args.dst_shape[shape_size - 4];
auto w1 = args.dst_shape[shape_size - 3];
auto h0 = args.dst_shape[shape_size - 2];
auto w0 = args.dst_shape[shape_size - 1];
auto h1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1];
auto w1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1];
auto h0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0];
auto w0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0];
auto h0w0 = h0 * w0; auto h0w0 = h0 * w0;
auto w1h0w0 = w1 * h0w0; auto w1h0w0 = w1 * h0w0;
auto h1w1h0w0 = h1 * w1h0w0; auto h1w1h0w0 = h1 * w1h0w0;
@@ -155,8 +172,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + w1_idx * w0) * size; auto src_offset = (src_h_head + w1_idx * w0) * size;
auto dst_offset = (h0_head + w1_idx * h0w0) * size; auto dst_offset = (h0_head + w1_idx * h0w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0)); static_cast<size_t>(size * w0));
if (ret != EOK) { if (ret != EOK) {
@@ -171,8 +188,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + src_w_idx) * size; auto src_offset = (src_h_head + src_w_idx) * size;
auto dst_offset = (w0_head + w0_idx) * size; auto dst_offset = (w0_head + w0_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size)); static_cast<size_t>(size));
if (ret != EOK) { if (ret != EOK) {
@@ -205,16 +222,16 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
} }


// The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D
auto times = dst_hw_shape.at(0);
auto h = dst_hw_shape.at(1);
auto w = dst_hw_shape.at(2);
auto times = dst_hw_shape.at(kNdDimIndexN);
auto h = dst_hw_shape.at(kNdDimIndexH);
auto w = dst_hw_shape.at(kNdDimIndexW);
auto hw = h * w; auto hw = h * w;


auto shape_size = args.src_shape.size(); auto shape_size = args.src_shape.size();
auto h1 = args.src_shape[shape_size - 4];
auto w1 = args.src_shape[shape_size - 3];
auto h0 = args.src_shape[shape_size - 2];
auto w0 = args.src_shape[shape_size - 1];
auto h1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1];
auto w1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1];
auto h0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0];
auto w0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0];
auto h0w0 = h0 * w0; auto h0w0 = h0 * w0;
auto w1h0w0 = w1 * h0w0; auto w1h0w0 = w1 * h0w0;
auto h1w1h0w0 = h1 * w1h0w0; auto h1w1h0w0 = h1 * w1h0w0;
@@ -233,8 +250,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto src_offset = (h0_head + w1_idx * h0w0) * size; auto src_offset = (h0_head + w1_idx * h0w0) * size;
auto dst_offset = (dst_h_head + w1_idx * w0) * size; auto dst_offset = (dst_h_head + w1_idx * w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0)); static_cast<size_t>(size * w0));
if (ret != EOK) { if (ret != EOK) {
@@ -249,8 +266,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto dst_w_idx = w1_head + w0_idx; auto dst_w_idx = w1_head + w0_idx;
auto dst_offset = (dst_h_head + dst_w_idx) * size; auto dst_offset = (dst_h_head + dst_w_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size)); static_cast<size_t>(size));
if (ret != EOK) { if (ret != EOK) {


+ 0
- 1
ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc View File

@@ -35,7 +35,6 @@
* Padding to (N, ceil(Z/16)*16) * Padding to (N, ceil(Z/16)*16)
* Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16)
*/ */

namespace ge { namespace ge {
namespace formats { namespace formats {
namespace { namespace {


+ 13
- 12
ge/common/formats/format_transfers/format_transfer_transpose.cc View File

@@ -19,6 +19,7 @@
#include <securec.h> #include <securec.h>
#include <memory> #include <memory>


#include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
@@ -29,21 +30,21 @@ namespace formats {
namespace { namespace {
std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
{FORMAT_NCHW, {FORMAT_NCHW,
{{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})},
{FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})},
{FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}},
{{FORMAT_NHWC, std::vector<int64_t>({kNchwN, kNchwH, kNchwW, kNchwC})},
{FORMAT_HWCN, std::vector<int64_t>({kNchwH, kNchwW, kNchwC, kNchwN})},
{FORMAT_CHWN, std::vector<int64_t>({kNchwC, kNchwH, kNchwW, kNchwN})}}},
{FORMAT_NHWC, {FORMAT_NHWC,
{{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})},
{FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})},
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kNhwcN, kNhwcC, kNhwcH, kNhwcW})},
{FORMAT_CHWN, std::vector<int64_t>({kNhwcC, kNhwcH, kNhwcW, kNhwcN})},
{FORMAT_HWCN, std::vector<int64_t>({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}},
{FORMAT_HWCN, {FORMAT_HWCN,
{{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})},
{FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})},
{FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kHwcnN, kHwcnC, kHwcnH, kHwcnW})},
{FORMAT_NHWC, std::vector<int64_t>({kHwcnN, kHwcnH, kHwcnW, kHwcnC})},
{FORMAT_CHWN, std::vector<int64_t>({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}},
{FORMAT_CHWN, {FORMAT_CHWN,
{{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})},
{FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})},
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}},
{{FORMAT_NCHW, std::vector<int64_t>({kChwnN, kChwnC, kChwnH, kChwnW})},
{FORMAT_NHWC, std::vector<int64_t>({kChwnN, kChwnH, kChwnW, kChwnC})},
{FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}},
}; };


bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {


+ 9
- 0
ge/common/formats/utils/formats_definitions.h View File

@@ -23,6 +23,7 @@ static const int kCubeSize = 16;
static const int kNiSize = 16; static const int kNiSize = 16;
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;



enum NchwDimIndex { enum NchwDimIndex {
kNchwN, kNchwN,
kNchwC, kNchwC,
@@ -47,6 +48,14 @@ enum HwcnDimIndex {
kHwcnDimsNum kHwcnDimsNum
}; };


enum ChwnDimIndex {
kChwnC,
kChwnH,
kChwnW,
kChwnN,
kChwnDimsNum
};

enum Nc1hwc0DimIndex { enum Nc1hwc0DimIndex {
kNc1hwc0N, kNc1hwc0N,
kNc1hwc0C1, kNc1hwc0C1,


+ 7
- 1
ge/common/ge/plugin_manager.cc View File

@@ -123,7 +123,10 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec
if (handle == nullptr) { if (handle == nullptr) {
const char *error = mmDlerror(); const char *error = mmDlerror();
GE_IF_BOOL_EXEC(error == nullptr, error = ""); GE_IF_BOOL_EXEC(error == nullptr, error = "");
GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen %s!", error);
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
{"mmDlopen", "shared library path is " + FmtToStr(file_path_dlopen) + ". Errormessage" + FmtToStr(error)});
GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen the shared library path[%s]. Errormessage[%s]!",
file_path_dlopen.c_str(), error);
continue; continue;
} }


@@ -132,6 +135,9 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec
for (const auto &func_name : func_check_list) { for (const auto &func_name : func_check_list) {
auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str())); auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str()));
if (real_fn == nullptr) { if (real_fn == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
{"mmDlsym", FmtToStr(func_name) + " is skipped since function" +
FmtToStr(func_name) + " is not existed!"});
GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(),
func_name.c_str()); func_name.c_str());
is_valid = false; is_valid = false;


+ 3
- 1
ge/common/ge/tbe_plugin_manager.cc View File

@@ -37,6 +37,8 @@
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
const int kBaseInt = 10;

std::map<string, string> TBEPluginManager::options_ = {}; std::map<string, string> TBEPluginManager::options_ = {};


// Get Singleton Instance // Get Singleton Instance
@@ -155,7 +157,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) {
domi::FrameworkType type = domi::TENSORFLOW; domi::FrameworkType type = domi::TENSORFLOW;
auto it = options_.find(FRAMEWORK_TYPE); auto it = options_.find(FRAMEWORK_TYPE);
if (it != options_.end()) { if (it != options_.end()) {
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10));
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, kBaseInt));
} }
fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); fmk_type = ge::TypeUtils::FmkTypeToSerialString(type);
GELOGI("Framework type is %s.", fmk_type.c_str()); GELOGI("Framework type is %s.", fmk_type.c_str());


+ 1
- 0
ge/common/ge_common.mk View File

@@ -7,6 +7,7 @@ GE_COMMON_LOCAL_SRC_FILES := \
helper/om_file_helper.cc \ helper/om_file_helper.cc \
helper/model_helper.cc \ helper/model_helper.cc \
../model/ge_model.cc \ ../model/ge_model.cc \
../model/ge_root_model.cc \
auth/file_saver.cc \ auth/file_saver.cc \
fp16_t.cc \ fp16_t.cc \
math/fp16_math.cc \ math/fp16_math.cc \


+ 396
- 31
ge/common/helper/model_helper.cc View File

@@ -32,6 +32,7 @@ using domi::ModelTaskDef;


namespace { namespace {
const int64_t kOriginalOmPartitionNum = 1; const int64_t kOriginalOmPartitionNum = 1;
const uint32_t kStatiOmFileModelNum = 1;
} }




@@ -39,7 +40,7 @@ namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); }


Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type,
const uint8_t *data, size_t size) {
const uint8_t *data, size_t size, size_t model_index) {
if (size < 1 || size > UINT32_MAX) { if (size < 1 || size > UINT32_MAX) {
GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size);
if (size > UINT32_MAX) { if (size > UINT32_MAX) {
@@ -68,25 +69,16 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil
partition_model.data = const_cast<uint8_t *>(data); partition_model.data = const_cast<uint8_t *>(data);
partition_model.size = static_cast<uint32_t>(size); partition_model.size = static_cast<uint32_t>(size);
partition_model.type = type; partition_model.type = type;
if (om_file_save_helper->AddPartition(partition_model) != SUCCESS) {
if (om_file_save_helper->AddPartition(partition_model, model_index) != SUCCESS) {
GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu", size); GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu", size);
return PARAM_INVALID; return PARAM_INVALID;
} }
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model,
const SaveParam &save_param,
const std::string &output_file,
ModelBufferData& model) {
if (output_file.empty()) {
GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix");
return FAILED;
}


GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED);
std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>();
GE_CHECK_NOTNULL(om_file_save_helper);
Status ModelHelper::SaveModelDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) {
ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion());
if (model_tmp == nullptr) { if (model_tmp == nullptr) {
GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str());
@@ -96,16 +88,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
model_tmp->SetVersion(ge_model->GetVersion()); model_tmp->SetVersion(ge_model->GetVersion());
model_tmp->SetAttr(ge_model->MutableAttrMap()); model_tmp->SetAttr(ge_model->MutableAttrMap());


ge::Buffer model_buffer;
(void)model_tmp->Save(model_buffer); (void)model_tmp->Save(model_buffer);
GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize());
if (model_buffer.GetSize() > 0) { if (model_buffer.GetSize() > 0) {
if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(), if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(),
model_buffer.GetSize()) != SUCCESS) {
model_buffer.GetSize(), model_index) != SUCCESS) {
GELOGE(PARAM_INVALID, "Add model graph partition failed"); GELOGE(PARAM_INVALID, "Add model graph partition failed");
return PARAM_INVALID; return PARAM_INVALID;
} }
} }
return SUCCESS;
}

Status ModelHelper::SaveModelWeights(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, size_t model_index) {
auto ge_model_weight = ge_model->GetWeight(); auto ge_model_weight = ge_model->GetWeight();
GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData());
// weight is not necessary // weight is not necessary
@@ -113,31 +110,43 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper,
ModelPartitionType::WEIGHTS_DATA, ModelPartitionType::WEIGHTS_DATA,
ge_model_weight.GetData(), ge_model_weight.GetData(),
ge_model_weight.GetSize()), "Add weight partition failed");
ge_model_weight.GetSize(), model_index), "Add weight partition failed");
} }
return SUCCESS;
}


Status ModelHelper::SaveModelTbeKernel(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, size_t model_index) {
TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore();
GELOGD("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); GELOGD("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize());
if (tbe_kernel_store.DataSize() > 0) { if (tbe_kernel_store.DataSize() > 0) {
GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper,
ModelPartitionType::TBE_KERNELS,
tbe_kernel_store.Data(),
tbe_kernel_store.DataSize()), "Add tbe kernel partition failed");
GE_CHK_STATUS_RET(
SaveModelPartition(om_file_save_helper, ModelPartitionType::TBE_KERNELS,
ge_model->GetTBEKernelStore().Data(), ge_model->GetTBEKernelStore().DataSize(),
model_index), "Add tbe kernel partition failed");
} }

// no need to check value, DATA->NetOutput // no need to check value, DATA->NetOutput
(void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize());


return SUCCESS;
}

Status ModelHelper::SaveModelCustAICPU(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, size_t model_index) {
CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore();
GELOGD("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); GELOGD("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize());
if (cust_aicpu_kernel_store.DataSize() > 0) { if (cust_aicpu_kernel_store.DataSize() > 0) {
GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper,
ModelPartitionType::CUST_AICPU_KERNELS, ModelPartitionType::CUST_AICPU_KERNELS,
cust_aicpu_kernel_store.Data(),
cust_aicpu_kernel_store.DataSize()),
ge_model->GetCustAICPUKernelStore().Data(),
cust_aicpu_kernel_store.DataSize(), model_index),
"Add cust aicpu kernel partition failed"); "Add cust aicpu kernel partition failed");
} }
return SUCCESS;
}


Status ModelHelper::SaveModelTaskDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, ge::Buffer &task_buffer, size_t model_index) {
std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr();
if (model_task_def == nullptr) { if (model_task_def == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed");
@@ -146,9 +155,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
size_t partition_task_size = model_task_def->ByteSizeLong(); size_t partition_task_size = model_task_def->ByteSizeLong();
GE_IF_BOOL_EXEC(partition_task_size == 0 || partition_task_size > INT_MAX, GE_IF_BOOL_EXEC(partition_task_size == 0 || partition_task_size > INT_MAX,
GELOGE(FAILED, "Model_def's byte size (%zu) is invalid!", partition_task_size); GELOGE(FAILED, "Model_def's byte size (%zu) is invalid!", partition_task_size);
return FAILED);
return FAILED);


ge::Buffer task_buffer(partition_task_size);
task_buffer = ge::Buffer(partition_task_size);
if (task_buffer.GetSize() == 0) { if (task_buffer.GetSize() == 0) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc model task def buffer failed"); GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc model task def buffer failed");
return ACL_ERROR_GE_MEMORY_ALLOCATION; return ACL_ERROR_GE_MEMORY_ALLOCATION;
@@ -159,21 +168,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
GELOGD("TASK_INFO size is %zu", partition_task_size); GELOGD("TASK_INFO size is %zu", partition_task_size);


if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TASK_INFO, task_buffer.GetData(), if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TASK_INFO, task_buffer.GetData(),
partition_task_size) != SUCCESS) {
partition_task_size, model_index) != SUCCESS) {
GELOGE(PARAM_INVALID, "Add model task def partition failed"); GELOGE(PARAM_INVALID, "Add model task def partition failed");
return PARAM_INVALID; return PARAM_INVALID;
} }
return SUCCESS;
}

Status ModelHelper::SaveModelHeader(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeModelPtr &ge_model, size_t model_num) {
// Save target/version to model_header // Save target/version to model_header
ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader();
model_header.platform_type = ge_model->GetPlatformType(); model_header.platform_type = ge_model->GetPlatformType();
model_header.om_ir_version = ge_model->GetVersion(); model_header.om_ir_version = ge_model->GetVersion();
model_header.model_num = model_num;
std::string platform_version = ge_model->GetPlatformVersion(); std::string platform_version = ge_model->GetPlatformVersion();


errno_t err; errno_t err;
err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(),
platform_version.size() + 1); platform_version.size() + 1);
if (err != EOK) { if (err != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelHelper SaveModel failed while allocating memory for platform_version.");
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"ModelHelper SaveModel failed while allocating memory for platform_version.");
return ACL_ERROR_GE_MEMORY_ALLOCATION; return ACL_ERROR_GE_MEMORY_ALLOCATION;
} }
string version = reinterpret_cast<char *>(model_header.platform_version); string version = reinterpret_cast<char *>(model_header.platform_version);
@@ -188,8 +204,142 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod
} }
string model_name = reinterpret_cast<char *>(model_header.name); string model_name = reinterpret_cast<char *>(model_header.name);
GELOGD("Model name save:%s", model_name.c_str()); GELOGD("Model name save:%s", model_name.c_str());
return SUCCESS;
}

Status ModelHelper::SaveAllModelPartiton(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper,
const GeModelPtr &ge_model, ge::Buffer &model_buffer,
ge::Buffer &task_buffer, size_t model_index) {
if (SaveModelDef(om_file_save_helper, ge_model, model_buffer, model_index) != SUCCESS) {
GELOGE(FAILED, "save model def failed");
return FAILED;
}

if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) {
GELOGE(FAILED, "save model weights failed");
return FAILED;
}

if (SaveModelTbeKernel(om_file_save_helper, ge_model, model_index) != SUCCESS) {
GELOGE(FAILED, "save model tbe kernel failed");
return FAILED;
}

if (SaveModelCustAICPU(om_file_save_helper, ge_model, model_index) != SUCCESS) {
GELOGE(FAILED, "save model cust ai cpu failed");
return FAILED;
}


if (SaveModelTaskDef(om_file_save_helper, ge_model, task_buffer, model_index) != SUCCESS) {
GELOGE(FAILED, "save task def failed");
return FAILED;
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model,
const SaveParam &save_param,
const std::string &output_file,
ModelBufferData& model) {
if (output_file.empty()) {
GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix");
return FAILED;
}


Status ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_);
GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED);
std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>();
GE_CHECK_NOTNULL(om_file_save_helper);
ge::Buffer model_buffer;
ge::Buffer task_buffer;

auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffer, task_buffer);
if (ret != SUCCESS) {
GELOGE(ret, "save all model partition failed");
return ret;
}

ret = SaveModelHeader(om_file_save_helper, ge_model);
if (ret != SUCCESS) {
GELOGE(ret, "save model header failed");
return ret;
}

ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_);
if (ret != SUCCESS) {
GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail.");
return ret;
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRootModel(
const GeRootModelPtr &ge_root_model,
const SaveParam &save_param,
const std::string &output_file,
ModelBufferData& model,
bool is_unknown_shape) {

GE_CHECK_NOTNULL(ge_root_model);
GE_IF_BOOL_EXEC(ge_root_model == nullptr, GELOGE(FAILED, "Ge_root_model is nullptr"); return FAILED);

auto &name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GE_IF_BOOL_EXEC(name_to_ge_model.empty(), GELOGE(FAILED, "Ge_root_model has no sub model"); return FAILED);
GE_IF_BOOL_EXEC(output_file.empty(),
GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix");
return FAILED);

if (!is_unknown_shape) {
auto &model_root = name_to_ge_model.begin()->second;
return SaveToOmModel(model_root, save_param, output_file, model);
}

std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>();
GE_CHECK_NOTNULL(om_file_save_helper);

auto &first_ge_model = name_to_ge_model.at(ge_root_model->GetRootGraph()->GetName());

// ge root model must be the first to be loaded
vector<string> model_names{ge_root_model->GetRootGraph()->GetName()};
for (auto &item : name_to_ge_model) {
if (item.first != model_names.front()) {
model_names.emplace_back(item.first);
}
}
vector<ge::Buffer> model_buffers(model_names.size());
vector<ge::Buffer> task_buffers(model_names.size());

size_t cur_index = 0;

if (model_names.size() > 1) {
GELOGD("only save first model MODEL_DEF");
if (SaveModelDef(om_file_save_helper, first_ge_model, model_buffers[cur_index], cur_index) != SUCCESS) {
GELOGE(FAILED, "save model def failed");
return FAILED;
}
++cur_index;
}

for (; cur_index < model_names.size(); ++cur_index) {
auto model_name = model_names[cur_index];
GELOGD("cur model %s index is %zu", model_name.c_str(), cur_index);
const GeModelPtr &ge_model = name_to_ge_model.at(model_name);
auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffers[cur_index],
task_buffers[cur_index], cur_index);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Save model %s failed", model_name.c_str());
return INTERNAL_ERROR;
}
}

auto ret = SaveModelHeader(om_file_save_helper, first_ge_model, model_names.size());
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Save model %s header failed", first_ge_model->GetName().c_str());
return INTERNAL_ERROR;
}

ret = om_file_save_helper->SaveRootModel(save_param, output_file.c_str(), model, is_offline_);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail.");
return FAILED; return FAILED;
@@ -288,7 +438,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c
} }


file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data); file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data);

OmFileLoadHelper om_load_helper; OmFileLoadHelper om_load_helper;
status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_);
if (status != SUCCESS) { if (status != SUCCESS) {
@@ -310,7 +459,61 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c
GELOGE(status, "GenerateGeModel failed"); GELOGE(status, "GenerateGeModel failed");
return status; return status;
} }
GELOGD("in ModelHelper::LoadModel, is_assign_model_ is setted to true!");
is_assign_model_ = true;
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) {
if (model_data.model_data == nullptr || model_data.model_len == 0) {
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0");
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}

if (is_assign_model_) {
GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!");
return GE_EXEC_LOAD_MODEL_REPEATED;
}


if (ReleaseLocalModelData() != SUCCESS) {
GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed.");
return INTERNAL_ERROR;
}

Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_);
if (status != SUCCESS) {
GELOGE(status, "Parse model content failed!");
return status;
}

file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data);

//model verison 1.0 file header does not have model_num member
is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION &&
file_header_->model_num > kStatiOmFileModelNum;
GELOGD("cur om model is ge root model or no %d, model version %zu", is_unknown_shape_model_, file_header_->version);

OmFileLoadHelper om_load_helper;
if (is_unknown_shape_model_) {
auto model_num = file_header_->model_num;
status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_, model_num);
} else {
status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_);
}
if (status != SUCCESS) {
GELOGE(status, "Om_load_helper init failed");
model_addr_tmp_ = nullptr;
return status;
}
// Encrypt model need to del temp model/no encrypt model don't need to del model
model_addr_tmp_ = nullptr;

status = GenerateGeRootModel(om_load_helper);
if (status != SUCCESS) {
GELOGE(status, "GenerateGeRootModel failed");
return status;
}
GELOGD("in ModelHelper::LoadRootModel, is_assign_model_ is setted to true!");
is_assign_model_ = true; is_assign_model_ = true;
return SUCCESS; return SUCCESS;
} }
@@ -341,6 +544,61 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) {
return SUCCESS; return SUCCESS;
} }


Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) {
GELOGD("Begin to generate ge root model");
root_model_ = ge::MakeShared<ge::GeRootModel>();
GE_CHECK_NOTNULL(root_model_);
if (!is_unknown_shape_model_) {
if (GenerateGeModel(om_load_helper) != SUCCESS) {
GELOGE(FAILED, "GenerateGeModel failed");
return FAILED;
}
GE_CHECK_NOTNULL(model_);
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(model_->GetGraph()));
return SUCCESS;
}

bool is_first_model = true;
for (size_t mode_index = 0; mode_index < file_header_->model_num; ++mode_index) {
GeModelPtr cur_model = ge::MakeShared<ge::GeModel>();
Status ret = LoadModelData(om_load_helper, cur_model, mode_index);
if (ret != SUCCESS) {
return GE_EXEC_LOAD_MODEL_PARTITION_FAILED;
}

if (is_first_model) {
is_first_model = false;
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph()));
root_model_->SetModelId(cur_model->GetModelId());
model_ = cur_model;
continue;
}

ret = LoadWeights(om_load_helper, cur_model, mode_index);
if (ret != SUCCESS) {
return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED;
}

ret = LoadTBEKernelStore(om_load_helper, cur_model, mode_index);
if (ret != SUCCESS) {
return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED;
}

ret = LoadCustAICPUKernelStore(om_load_helper, cur_model, mode_index);
if (ret != SUCCESS) {
return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED;
}

ret = LoadTask(om_load_helper, cur_model, mode_index);
if (ret != SUCCESS) {
return GE_EXEC_LOAD_TASK_PARTITION_FAILED;
}
root_model_->SetSubgraphInstanceNameToModel(cur_model->GetName(), cur_model);
}

return SUCCESS;
}

Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper) { Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper) {
ModelPartition partition_model_def; ModelPartition partition_model_def;
// no need to check value, DATA->NetOutput // no need to check value, DATA->NetOutput
@@ -366,6 +624,28 @@ void ModelHelper::SetModelToGeModel(ge::Model &model) {
model_->SetAttr(model.MutableAttrMap()); model_->SetAttr(model.MutableAttrMap());
} }


Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) {
ModelPartition partition_model_def;
// no need to check value, DATA->NetOutput
om_load_helper.GetModelPartition(ModelPartitionType::MODEL_DEF, partition_model_def, mode_index);
GELOGD("Model_def partition addr:%p,size:%u", partition_model_def.data, partition_model_def.size);

ge::Model model;
if (ge::Model::Load(partition_model_def.data, partition_model_def.size, model) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Load model failed.");
return INTERNAL_ERROR;
}

cur_model->SetGraph(model.GetGraph());
cur_model->SetName(model.GetName());
cur_model->SetVersion(model.GetVersion());
cur_model->SetPlatformVersion(model.GetPlatformVersion());
cur_model->SetAttr(model.MutableAttrMap());

return SUCCESS;
}


Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) {
ModelPartition partition; ModelPartition partition;
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) {
@@ -379,6 +659,19 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) {
return SUCCESS; return SUCCESS;
} }


Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) {
ModelPartition partition;
if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) {
GELOGE(FAILED, "Get weight model partition failed.");
return FAILED;
}
ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size);
cur_model->SetWeight(weight);

GELOGD("GetWeight size:%u", partition.size);
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) {
ModelPartition task_partition; ModelPartition task_partition;
if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) {
@@ -398,6 +691,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(Om
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper,
GeModelPtr &cur_model,
size_t mode_index) {
ModelPartition task_partition;
if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition, mode_index) != SUCCESS) {
GELOGE(FAILED, "Get task model partition failed.");
return FAILED;
}
std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>();
GE_CHECK_NOTNULL(task);
if (task_partition.size != 0) {
if (!ReadProtoFromArray(task_partition.data, task_partition.size, task.get())) {
GELOGE(INTERNAL_ERROR, "ReadProtoFromArray failed.");
return INTERNAL_ERROR;
}
GELOGD("TASK_INFO op_size:%zu, stream_num:%u", task->op().size(), task->stream_num());
}
cur_model->SetModelTaskDef(task);
return SUCCESS;
}

Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) {
// Load tbe kernels // Load tbe kernels
ModelPartition partition_kernel_def; ModelPartition partition_kernel_def;
@@ -414,6 +728,23 @@ Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) {
return SUCCESS; return SUCCESS;
} }


Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) {
// Load tbe kernels
ModelPartition partition_kernel_def;
TBEKernelStore kernel_store;
if (om_load_helper.GetModelPartition(ModelPartitionType::TBE_KERNELS, partition_kernel_def, mode_index) ==
SUCCESS) {
GELOGD("Kernels partition size:%u", partition_kernel_def.size);
if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) {
GELOGD("Load tbe kernels success");
} else {
GELOGW("Load tbe kernels failed");
}
}
cur_model->SetTBEKernelStore(kernel_store);
return SUCCESS;
}

Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) {
// Load cust aicpu kernels // Load cust aicpu kernels
ModelPartition partition_kernel_def; ModelPartition partition_kernel_def;
@@ -421,19 +752,39 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) {
if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) { if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) {
GELOGD("Kernels partition size:%u", partition_kernel_def.size); GELOGD("Kernels partition size:%u", partition_kernel_def.size);
if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) {
GELOGI("Load cust aicpu kernels success");
GELOGD("Load cust aicpu kernels success");
} else {
GELOGW("Load cust aicpu kernels failed");
} }
} }
model_->SetCustAICPUKernelStore(kernel_store); model_->SetCustAICPUKernelStore(kernel_store);
return SUCCESS; return SUCCESS;
} }


Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper,
GeModelPtr &cur_model, size_t mode_index) {
// Load cust aicpu kernels
ModelPartition partition_kernel_def;
CustAICPUKernelStore kernel_store;
if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def, mode_index)
== SUCCESS) {
GELOGD("Kernels partition size:%u", partition_kernel_def.size);
if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) {
GELOGD("Load cust aicpu kernels success");
} else {
GELOGW("Load cust aicpu kernels failed");
}
}
cur_model->SetCustAICPUKernelStore(kernel_store);
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() {
if (model_ != nullptr) { if (model_ != nullptr) {
return model_; return model_;
} }


GELOGI("Model has not been loaded!");
GELOGD("Model has not been loaded!");
std::shared_ptr<ge::GeModel> out_model = ge::MakeShared<ge::GeModel>(); std::shared_ptr<ge::GeModel> out_model = ge::MakeShared<ge::GeModel>();
if (out_model == nullptr) { if (out_model == nullptr) {
return nullptr; return nullptr;
@@ -441,6 +792,20 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo
return out_model; return out_model;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::GetGeRootModel() {
if (root_model_ != nullptr) {
return root_model_;
}

GELOGD("Model has not been loaded!");
std::shared_ptr<ge::GeRootModel> out_model = ge::MakeShared<ge::GeRootModel>();
if (out_model == nullptr) {
return nullptr;
}
return out_model;
}


Status ModelHelper::ReleaseLocalModelData() noexcept { Status ModelHelper::ReleaseLocalModelData() noexcept {
Status result = SUCCESS; Status result = SUCCESS;
if (model_addr_tmp_ != nullptr) { if (model_addr_tmp_ != nullptr) {


+ 195
- 1
ge/common/helper/om_file_helper.cc View File

@@ -52,6 +52,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data,
uint32_t model_data_size,
uint32_t model_num) {
Status status = LoadModelPartitionTable(model_data, model_data_size, model_num);
if (status != SUCCESS) {
return status;
}
is_inited_ = true;
return SUCCESS;
}

// Use both // Use both
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type,
ModelPartition &partition) { ModelPartition &partition) {
@@ -79,6 +90,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type,
ModelPartition &partition,
size_t model_index) {
if (!is_inited_) {
GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!");
return PARAM_INVALID;
}
if (model_index >= model_contexts_.size()) {
GELOGE(PARAM_INVALID, "cur index : %zu, model_contexts size:%zu", model_index, model_contexts_.size());
return PARAM_INVALID;
}
auto &cur_ctx = model_contexts_[model_index];
bool found = false;
for (ModelPartition &part : cur_ctx.partition_datas_) {
if (part.type == type) {
partition = part;
found = true;
break;
}
}

if (!found) {
if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA &&
type != ModelPartitionType::CUST_AICPU_KERNELS) {
GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type));
return FAILED;
}
}
return SUCCESS;
}

Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const {
// Parameter validity check // Parameter validity check
if (model.model_data == nullptr) { if (model.model_data == nullptr) {
@@ -138,7 +180,8 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
context_.partition_datas_.push_back(partition); context_.partition_datas_.push_back(partition);


if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.",
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID,
"The partition size %zu is greater than the model data size %u.",
partition.size + mem_offset, model_data_size); partition.size + mem_offset, model_data_size);
return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID;
} }
@@ -148,6 +191,61 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
return SUCCESS; return SUCCESS;
} }


Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) {
if (model_data == nullptr) {
GELOGE(PARAM_INVALID, "Param model_data must not be null!");
return PARAM_INVALID;
}

uint32_t cur_offset = 0;
for (uint32_t index = 0; index < model_num; ++index) {
// Init partition table
auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_data + cur_offset);
size_t partition_table_size = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table);
cur_offset += partition_table_size;
GELOGD("Cur model index %zu: ModelPartitionTable num :%u, "
"ModelFileHeader length :%zu, ModelPartitionTable length :%zu",
index, partition_table->num, sizeof(ModelFileHeader), partition_table_size);
if (model_data_size <= cur_offset) {
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u",
partition_table->num, model_data_size);
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}

for (uint32_t i = 0; i < partition_table->num; i++) {
ModelPartition partition;
partition.size = partition_table->partition[i].mem_size;
partition.data = model_data + cur_offset;
partition.type = partition_table->partition[i].type;
if (index >= model_contexts_.size()) {
if (index != model_contexts_.size()) {
GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", index);
return FAILED;
}

OmFileContext tmp_ctx;
tmp_ctx.partition_datas_.push_back(partition);
model_contexts_.push_back(tmp_ctx);
} else {
model_contexts_[index].partition_datas_.push_back(partition);
}

if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) {
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.",
partition.size + cur_offset, model_data_size);
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}
cur_offset += partition.size;
GELOGD("Partition, type:%d, size:%u, model_index:%zu", static_cast<int>(partition.type), partition.size, index);
}
}
if (cur_offset != model_data_size) {
GELOGE(FAILED, "do not get the complete model, read end offset:%zu, all size:%zu", cur_offset, model_data_size);
return FAILED;
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition> FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition>
&OmFileSaveHelper::GetModelPartitions() const { &OmFileSaveHelper::GetModelPartitions() const {
return context_.partition_datas_; return context_.partition_datas_;
@@ -172,6 +270,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave
return partition_table; return partition_table;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable(
size_t cur_ctx_index) {
auto &cur_ctx = model_contexts_[cur_ctx_index];
auto partition_size = static_cast<uint32_t>(cur_ctx.partition_datas_.size());
// Build ModelPartitionTable, flex array
cur_ctx.partition_table_.clear();
cur_ctx.partition_table_.resize(sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * partition_size, 0);

auto partition_table = reinterpret_cast<ModelPartitionTable *>(cur_ctx.partition_table_.data());
partition_table->num = partition_size;

uint32_t mem_offset = 0;
for (uint32_t i = 0; i < partition_size; i++) {
ModelPartition partition = cur_ctx.partition_datas_[i];
partition_table->partition[i] = {partition.type, mem_offset, partition.size};
mem_offset += partition.size;
GELOGD("Partition, type:%d, size:%u", static_cast<int>(partition.type), partition.size);
}
return partition_table;
}


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) {
if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) {
GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size);
@@ -182,6 +302,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPar
return SUCCESS; return SUCCESS;
} }


Status OmFileSaveHelper::AddPartition(ModelPartition &partition, size_t cur_index) {
if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) {
GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size);
return FAILED;
}
if (cur_index >= model_contexts_.size()) {
if (cur_index != model_contexts_.size()) {
GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", cur_index);
return FAILED;
}
OmFileContext tmp_ctx;
tmp_ctx.model_data_len_ += partition.size;
tmp_ctx.partition_datas_.push_back(partition);
model_contexts_.push_back(tmp_ctx);
} else {
model_contexts_[cur_index].model_data_len_ += partition.size;
model_contexts_[cur_index].partition_datas_.push_back(partition);
}
return SUCCESS;
}

Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model,
bool is_offline) { bool is_offline) {
(void)save_param.cert_file; (void)save_param.cert_file;
@@ -198,6 +339,10 @@ Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *outp


Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) { Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) {
#if !defined(NONSUPPORT_SAVE_TO_FILE) #if !defined(NONSUPPORT_SAVE_TO_FILE)
if (context_.partition_datas_.empty()) {
GE_CHK_BOOL_EXEC(!model_contexts_.empty(), return FAILED, "mode contexts empty");
context_ = model_contexts_.front();
}
uint32_t model_data_len = context_.model_data_len_; uint32_t model_data_len = context_.model_data_len_;
if (model_data_len == 0) { if (model_data_len == 0) {
GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0");
@@ -231,4 +376,53 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat
return SUCCESS; return SUCCESS;
#endif #endif
} }

Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file,
ModelBufferData &model, bool is_offline) {
(void)save_param.cert_file;
(void)save_param.ek_file;
(void)save_param.encode_mode;
(void)save_param.hw_key_file;
(void)save_param.pri_key_file;

#if !defined(NONSUPPORT_SAVE_TO_FILE)
vector<ModelPartitionTable *> model_partition_tabels;
vector<vector<ModelPartition>> all_model_partitions;
for (size_t ctx_index = 0; ctx_index < model_contexts_.size(); ++ctx_index) {
auto &cur_ctx = model_contexts_[ctx_index];
uint32_t cur_model_data_len = cur_ctx.model_data_len_;
if (cur_model_data_len == 0) {
GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0");
return domi::PARAM_INVALID;
}

auto tmp_table = GetPartitionTable(ctx_index);
if (tmp_table == nullptr) {
GELOGE(ge::GE_GRAPH_SAVE_FAILED, "SaveModelToFile execute failed: partition_table is NULL.");
return ge::GE_GRAPH_SAVE_FAILED;
}
uint32_t size_of_table = SIZE_OF_MODEL_PARTITION_TABLE(*tmp_table);
FMK_UINT32_ADDCHECK(size_of_table, cur_model_data_len)
FMK_UINT32_ADDCHECK(size_of_table + cur_model_data_len, model_header_.length)
model_header_.length += size_of_table + cur_model_data_len;
model_partition_tabels.push_back(tmp_table);
all_model_partitions.push_back(cur_ctx.partition_datas_);
GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu",
size_of_table, cur_model_data_len, ctx_index);
}
Status ret;
if (is_offline) {
ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions);
} else {
GELOGW("do not support save ge root model to buff now");
return FAILED;
}
if (ret == SUCCESS) {
GELOGD("Save model success without encrypt.");
}
return ret;
#else
return SUCCESS;
#endif
}
} // namespace ge } // namespace ge

+ 1
- 1
ge/common/op/ge_op_utils.cc View File

@@ -357,7 +357,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCH
const char *w_data = (const char *)input; const char *w_data = (const char *)input;


int64_t count = h * w * c * k; int64_t count = h * w * c * k;
GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return );
GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return);
float *buf = new (std::nothrow) float[count](); float *buf = new (std::nothrow) float[count]();
GE_RT_VOID_CHECK_NOTNULL(buf); GE_RT_VOID_CHECK_NOTNULL(buf);
float *src_buff = nullptr; float *src_buff = nullptr;


+ 199
- 0
ge/common/profiling/ge_profiling.cc View File

@@ -0,0 +1,199 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/profiling/ge_profiling.h"
#include "runtime/base.h"
#include "common/profiling/profiling_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/load/graph_loader.h"
#include "init/gelib.h"
#include "framework/common/ge_inner_error_codes.h"

namespace {
const uint32_t kDeviceListIndex = 3;
const std::string kDeviceNums = "devNums";
const std::string kDeviceIdList = "devIdList";
const std::string kProfilingInit = "prof_init";
const std::string kProfilingFinalize = "prof_finalize";
const std::string kProfilingStart = "prof_start";
const std::string kProfilingStop = "prof_stop";
const std::string kProfModelSubscribe = "prof_model_subscribe";
const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe";
const std::string kRtSetDeviceRegName = "profiling";

const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = {
{kProfCommandhandleInit, kProfilingInit},
{kProfCommandhandleStart, kProfilingStart},
{kProfCommandhandleStop, kProfilingStop},
{kProfCommandhandleFinalize, kProfilingFinalize},
{kProfCommandhandleModelSubscribe, kProfModelSubscribe},
{kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}};
} // namespace

bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) {
prof_config_params.clear();
prof_config_params.emplace_back(kDeviceNums);
prof_config_params.emplace_back(std::to_string(profCommand.devNums));
prof_config_params.emplace_back(kDeviceIdList);
std::string devID = "";
if (profCommand.devNums == 0) {
GELOGW("The device num is invalid.");
return false;
}
for (uint32_t i = 0; i < profCommand.devNums; i++) {
devID.append(std::to_string(profCommand.devIdList[i]));
if (i != profCommand.devNums - 1) {
devID.append(",");
}
}

prof_config_params.push_back(devID);
return true;
}

bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) {
if (deviceid_list == nullptr) {
GELOGE(ge::PARAM_INVALID, "deviceIdList is nullptr");
return false;
}
if (device_nums == 0 || device_nums > MAX_DEV_NUM) {
GELOGE(ge::PARAM_INVALID, "The device nums: %u is invalid.", device_nums);
return false;
}

// real device num
int32_t dev_count = 0;
rtError_t rt_err = rtGetDeviceCount(&dev_count);
if (rt_err != RT_ERROR_NONE) {
GELOGE(ge::INTERNAL_ERROR, "Get the Device count fail.");
return false;
}

if (device_nums > static_cast<uint32_t>(dev_count)) {
GELOGE(ge::PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count);
return false;
}

std::unordered_set<uint32_t> record;
for (size_t i = 0; i < device_nums; ++i) {
uint32_t dev_id = deviceid_list[i];
if (dev_id >= static_cast<uint32_t>(dev_count)) {
GELOGE(ge::PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count);
return false;
}
if (record.count(dev_id) > 0) {
GELOGE(ge::PARAM_INVALID, "Device id %u is duplicatedly set", dev_id);
return false;
}
record.insert(dev_id);
}
return true;
}

ge::Status RegProfCtrlCallback(MsprofCtrlCallback func) {
if (func == nullptr) {
GELOGE(ge::PARAM_INVALID, "Msprof ctrl callback is nullptr.");
return ge::PARAM_INVALID;
}
if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofCtrlCallback != nullptr) {
GELOGW("Msprof ctrl callback is exist, just ignore it.");
} else {
GELOGI("GE register Msprof ctrl callback.");
ge::ProfilingManager::Instance().SetMsprofCtrlCallback(func);
}
return ge::SUCCESS;
}

ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) {
if (func == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofSetDeviceCallback callback is nullptr.");
return ge::PARAM_INVALID;
}
// Pass MsprofSetDeviceCallback to runtime
GELOGI("GE pass setdevice callback to runtime.");
ge::Status rt_ret = rtRegDeviceStateCallback(kRtSetDeviceRegName.c_str(), static_cast<rtDeviceStateCallback>(func));
if (rt_ret != ge::SUCCESS) {
GELOGE(rt_ret, "Pass MsprofSetDeviceCallback to runtime failed!");
return rt_ret;
}
return ge::SUCCESS;
}

ge::Status RegProfReporterCallback(MsprofReporterCallback func) {
if (func == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr.");
return ge::PARAM_INVALID;
}
if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofReporterCallback != nullptr) {
GELOGW("Msprof reporter callback is exist, just ignore it.");
} else {
GELOGI("GE register Msprof reporter callback.");
ge::ProfilingManager::Instance().SetMsprofReporterCallback(func);
// Pass MsprofReporterCallback to runtime
ge::Status rt_ret = rtSetMsprofReporterCallback(func);
if (rt_ret != ge::SUCCESS) {
GELOGE(rt_ret, "Pass MsprofReporterCallback to runtime failed!!");
return rt_ret;
}
// Pass MsprofReporterCallback to hccl
}
return ge::SUCCESS;
}

ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) {
if (type != kProfCommandhandleFinalize) {
GE_CHECK_NOTNULL(data);
}
ProfCommandHandleData *prof_config_param = (ProfCommandHandleData *)data;
auto iter = kProfCommandTypeMap.find(type);
if (iter == kProfCommandTypeMap.end()) {
GELOGW("The prof comand type is invalid.");
return ge::PARAM_INVALID;
}
std::vector<string> prof_params;
if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) {
if (!isProfConfigValid(prof_config_param->devIdList, prof_config_param->devNums)) {
return ge::FAILED;
}
if (!TransProfConfigToParam(*prof_config_param, prof_params)) {
GELOGE(ge::PARAM_INVALID, "Transfer profilerConfig to string vector failed");
return ge::PARAM_INVALID;
}
}
ge::GraphLoader graph_loader;
ge::Command command;
command.cmd_params.clear();
command.cmd_type = iter->second;
command.cmd_params = prof_params;
if (type != kProfCommandhandleFinalize) {
command.module_index = prof_config_param->profSwitch;
}
GELOGI("GE commandhandle execute, Command Type: %d, data type config: 0x%llx", type, command.module_index);
if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) {
GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str());
}
ge::Status ret = graph_loader.CommandHandle(command);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Handle profiling command failed");
return ge::FAILED;
}

GELOGI("Successfully execute profiling command type: %d, command 0x%llx.", type, command.module_index);
return ge::SUCCESS;
}


+ 26
- 0
ge/common/profiling/ge_runner_profiling.cc View File

@@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/profiling/ge_runner_profiling.h"
#include "init/gelib.h"

bool IsInitialize() {
std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || instance_ptr->InitFlag() == false) {
return false;
}
return true;
}

+ 232
- 406
ge/common/profiling/profiling_manager.cc View File

@@ -24,16 +24,9 @@
#include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model.h"


namespace { namespace {
const char *const kJobID = "jobID";
const char *const kDeviceID = "deviceID";
const char *const kStartCfg = "startCfg";
const char *const kFeatures = "features";
const char *const kConf = "conf";
const char *const kEvents = "events";
const char *const kAiCoreEvents = "ai_core_events";
const char *const kName = "name";
const char *const kTraceID = "traceId";
const char *const kProfDir = "resultPath";
const char *const kTrainingTrace = "training_trace";
const char *const kFpPoint = "fp_point";
const char *const kBpPoint = "bp_point";
const size_t kReportMaxLen = 2048; const size_t kReportMaxLen = 2048;
const int32_t kMaxDeviceNum = 256; const int32_t kMaxDeviceNum = 256;
const std::string kConfigNumsdev = "devNums"; const std::string kConfigNumsdev = "devNums";
@@ -45,7 +38,13 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe";
} // namespace } // namespace


namespace ge { namespace ge {
ProfilingManager::ProfilingManager() : subscribe_count_(0) {}
ProfilingManager::ProfilingManager() : is_load_profiling_(false),
is_execute_profiling_(false),
is_training_trace_(false),
subscribe_count_(0) {
prof_cb_.msprofCtrlCallback = nullptr;
prof_cb_.msprofReporterCallback = nullptr;
}


ProfilingManager::~ProfilingManager() {} ProfilingManager::~ProfilingManager() {}


@@ -58,44 +57,29 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
vector<int32_t>().swap(device_id_); vector<int32_t>().swap(device_id_);
subscribe_count_ = 0; subscribe_count_ = 0;
job_id_ = options.job_id;

GELOGI("ProfilingManager::Init job_id:%s", job_id_.c_str());

GELOGI("ProfilingManager::Init job_id:%s", options.job_id.c_str());



Status ret;
if (!recv_profiling_config_.empty()) {
GELOGI("Profiling json config from acl:%s", recv_profiling_config_.c_str());
ret = InitFromAclCfg(recv_profiling_config_);
} else {
ret = InitFromOptions(options);
if (ret == SUCCESS && is_load_profiling_) {
device_id_.push_back(options.device_id);
}
}
struct MsprofGeOptions prof_conf = {{ 0 }};
Status ret = InitFromOptions(options, prof_conf);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Failed to init profiling."); GELOGE(ret, "Failed to init profiling.");
return ret; return ret;
} }


if (is_load_profiling_) {
// register Framework to profiling
int result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_);
if (result != 0) {
GELOGE(FAILED, "Register profiling engine failed.");
return FAILED;
if (is_execute_profiling_) {
if (prof_cb_.msprofCtrlCallback == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr.");
return ge::PARAM_INVALID;
} }
// profiling startup first time
GELOGI("Begin to init profiling, device num %zu", device_id_.size());
for (size_t i = 0; i < device_id_.size(); ++i) {
ret = StartProfiling(0, device_id_[i]);
if (ret != SUCCESS) {
GELOGW("Profiling start failed on device %d.", device_id_[i]);
continue;
}
GELOGI("Profiling init succ on device %d.", device_id_[i]);
int32_t cb_ret = prof_cb_.msprofCtrlCallback(
static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS),
static_cast<void *>(&prof_conf), sizeof(MsprofGeOptions));
if (cb_ret != 0) {
GELOGE(FAILED, "Call msprofCtrlCallback failed, type:%u, return:%d",
static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS), cb_ret);
return FAILED;
} }
GELOGI("Profiling init success");
} else { } else {
GELOGI("The profiling is off, skip the initialization"); GELOGI("The profiling is off, skip the initialization");
} }
@@ -103,288 +87,116 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromAclCfg(
const std::string &config) {
ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOptions &prof_conf) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
try {
is_load_profiling_ = false;
is_execute_profiling_ = false;
profiling_opts_.clear();
op_trace_conf_.clear();
Json start_prof_conf = Json::parse(config);
Json &prof_conf = start_prof_conf[kStartCfg][0];
job_id_ = prof_conf[kJobID];
auto iter = prof_conf.find(kProfDir);
if (iter != prof_conf.end()) {
prof_dir_ = prof_conf[kProfDir];
}
Json &device_id = prof_conf[kDeviceID];
if (device_id.size() != 0) {
vector<int32_t>().swap(device_id_);
bool is_all = false;
for (size_t i = 0; i < device_id.size(); i++) {
std::string device_id_str = device_id[i].get<std::string>();
if (device_id_str == "all") {
is_all = true;
break;
}
device_id_.push_back(std::stoi(device_id_str));
}
if (is_all) {
int32_t count = 0;
rtError_t rt_err = rtGetDeviceCount(&count);
if (rt_err != RT_ERROR_NONE) {
GELOGE(FAILED, "Call rtGetDeviceCount to get device failed.");
}

vector<int32_t>().swap(device_id_);
for (int32_t i = 0; i < count; ++i) {
device_id_.push_back(i);
}
}
// enable profiling by env
char env_profiling_mode[MMPA_MAX_PATH] = { 0x00 };
is_load_profiling_ = false; // Change in ProfInit
is_execute_profiling_ = false;

if (options.profiling_mode == "1" && !options.profiling_options.empty()) {
// enable profiling by ge option
if (memcpy_s(prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX, options.profiling_options.c_str(),
options.profiling_options.size()) != EOK) {
GELOGE(INTERNAL_ERROR, "copy profiling_options failed.");
return INTERNAL_ERROR;
} }

Json &features = prof_conf[kFeatures];
if (ParseFeaturesFromAclCfg(features) != SUCCESS) {
GELOGE(FAILED, "Parse feature from acl cfg failed.");
return FAILED;
is_execute_profiling_ = true;
GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(),
prof_conf.options, options.profiling_options.c_str());
} else {
(void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH);
(void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX);
// The env is invalid
if ((strcmp("true", env_profiling_mode) != 0) || (strcmp(prof_conf.options, "\0") == 0)) {
return SUCCESS;
} }
is_load_profiling_ = true;
// enable profiling by env
is_execute_profiling_ = true; is_execute_profiling_ = true;
} catch (...) {
GELOGE(FAILED, "Json conf is not invalid !");
GELOGI("The profiling in env is %s, %s", env_profiling_mode, prof_conf.options);
}

if (!is_execute_profiling_) {
return SUCCESS;
}

// Parse json str for bp fp
Status ret = ParseOptions(prof_conf.options);
if (ret != ge::SUCCESS) {
GELOGE(ge::PARAM_INVALID, "Parse training trace param failed.");
return ge::PARAM_INVALID; return ge::PARAM_INVALID;
} }

if (memcpy_s(prof_conf.jobId, sizeof(prof_conf.jobId), options.job_id.c_str(),
sizeof(options.job_id.c_str())) != EOK) {
GELOGE(INTERNAL_ERROR, "copy job_id failed.");
return INTERNAL_ERROR;
}
#endif #endif
return ge::SUCCESS; return ge::SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::ParseFeaturesFromAclCfg(
const Json &features) {
#ifdef DAVINCI_SUPPORT_PROFILING
ge::Status ProfilingManager::ParseOptions(const std::string &options) {
if (options.empty()) {
GELOGE(ge::PARAM_INVALID, "Profiling options is empty.");
return ge::PARAM_INVALID;
}
try { try {
for (size_t i = 0; i < features.size(); ++i) {
const Json &feature = features[i];
if ((feature.find(kName) == feature.end()) || feature[kName].is_null()) {
continue;
}
const std::string &name = feature[kName];
if (name == "op_trace") {
const Json &conf = feature[kConf];
const Json &events = conf[0][kEvents];
const std::string &ai_core_events = events[0][kAiCoreEvents];
GELOGI("Op trace config from acl ai_core_events:%s", ai_core_events.c_str());
is_op_trace_ = true;
ProfMgrConf prof_mgr_conf;
int result = ProfMgrGetConf(ai_core_events, &prof_mgr_conf);
if (result != 0) {
GELOGE(FAILED, "ProfMgrGetConf failed.");
return FAILED;
}
op_trace_conf_ = prof_mgr_conf.conf;
op_trace_iter_num_ = static_cast<int32_t>(op_trace_conf_.size());
GELOGI("Op trace profiling iter num %d,", op_trace_iter_num_);
} else if (name == "task_trace") {
is_op_trace_ = false;
if (feature.find(kConf) != feature.end()) {
const Json &conf = feature[kConf];
std::stringstream task_trace_conf;
task_trace_conf << conf;
task_trace_conf_ = task_trace_conf.str();
}
GELOGI("Task trace config from acl");
} else if (name == "system_trace") {
is_op_trace_ = false;
const Json &conf = feature[kConf];
std::stringstream system_trace_conf;
system_trace_conf << conf;
system_trace_conf_ = system_trace_conf.str();
GELOGI("System trace config from acl");
}
profiling_opts_.push_back(name);
Json prof_options = Json::parse(options);
const std::string training_trace = prof_options[kTrainingTrace];
if (training_trace.empty()) {
GELOGI("Training trace will not take effect.");
return ge::SUCCESS;
}
GELOGI("GE profiling training trace:%s", training_trace.c_str());
if (training_trace != "on") {
GELOGE(ge::PARAM_INVALID, "Training trace param:%s is invalid.", training_trace.c_str());
return ge::PARAM_INVALID;
}
fp_point_ = prof_options[kFpPoint];
bp_point_ = prof_options[kBpPoint];
if (!fp_point_.empty() && !bp_point_.empty()) {
GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str());
} }
} catch (...) { } catch (...) {
GELOGE(ge::PARAM_INVALID, "Json conf feature is not invalid !");
GELOGE(FAILED, "Json prof_conf options is invalid.");
return ge::PARAM_INVALID; return ge::PARAM_INVALID;
} }
#endif
return ge::SUCCESS; return ge::SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromOptions(const Options &options) {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
// enable profiling support two ways: env and front end
char profiling_mode_temp[MMPA_MAX_PATH] = { 0x00 };
char prof_options_temp[MMPA_MAX_PATH] = { 0x00 };
(void)mmGetEnv("PROFILING_MODE", profiling_mode_temp, MMPA_MAX_PATH);
(void)mmGetEnv("PROFILING_OPTIONS", prof_options_temp, MMPA_MAX_PATH );
const char *profiling_mode = profiling_mode_temp;
const char *prof_options = prof_options_temp;
if ((profiling_mode == nullptr) || (strcmp("true", profiling_mode) != 0) || (prof_options == nullptr)) {
is_load_profiling_ = false;
is_execute_profiling_ = false;
} else {
std::string prof_options_str = std::string(prof_options);
profiling_opts_ = StringUtils::Split(prof_options_str, ':');
is_load_profiling_ = true;
is_execute_profiling_ = true;
GELOGI("The profiling in env is %s, %s", profiling_mode, prof_options);
}
if (!is_load_profiling_) {
const std::string enable_profiling = "1";
if (options.profiling_mode != enable_profiling || options.profiling_options.empty()) {
is_load_profiling_ = false;
is_execute_profiling_ = false;
return SUCCESS;
} else {
profiling_opts_ = StringUtils::Split(options.profiling_options, ':');
is_load_profiling_ = true;
is_execute_profiling_ = true;
GELOGI("The profiling in options is %s, %s", options.profiling_mode.c_str(), options.profiling_options.c_str());
}
}
// features:'training_trace', 'task_trace' or 'op_trace' etc
if (!profiling_opts_.empty()) {
if (profiling_opts_[0] == "op_trace") {
is_op_trace_ = true;
// op trace get conf
ProfMgrConf prof_mgr_conf;
int result = ProfMgrGetConf("", &prof_mgr_conf);
if (result != 0) {
GELOGE(FAILED, "ProfMgrGetConf failed.");
return FAILED;
}
op_trace_conf_ = prof_mgr_conf.conf;
op_trace_iter_num_ = static_cast<int32_t>(op_trace_conf_.size());
GELOGI("op trace profiling iter num %d,", op_trace_iter_num_);
} else {
is_op_trace_ = false;
op_trace_iter_num_ = 1;
uint64_t module = GetProfilingModule();
// The following if case will not be executed in normal case, inc case of ProfStopProfiling is abnormal
int32_t device_num = static_cast<int32_t>(device_id_.size());
if (device_num != 0) {
auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]);
if (device_id_ptr == nullptr) {
GELOGE(FAILED, "Stop profiling: device id ptr is null.");
return;
} }
}
#endif
return ge::SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::StartProfiling(int32_t iter_num,
int32_t device_id) {
#ifdef DAVINCI_SUPPORT_PROFILING
if (!profiling_opts_.empty()) {
GELOGI("Start profiling index is %d", iter_num);
// current one docker only use one device
Json p_device;

try {
// profiling need physical_device_id
p_device[kDeviceID] = std::to_string(device_id);
p_device[kJobID] = job_id_;
p_device[kTraceID] = std::to_string(GetContext().TraceId());
if (!prof_dir_.empty()) {
p_device[kProfDir] = prof_dir_;
GELOGI("Prof dir: %s.", prof_dir_.c_str());
}

Json features;
if (is_op_trace_) {
Json f;
f[kName] = "op_trace";
Json conf;
if (op_trace_conf_.size() <= static_cast<size_t>(iter_num)) {
GELOGE(FAILED, "Op trace iter num is invalid!");
return FAILED;
}
Json events;
events[0] = nlohmann::json::parse(op_trace_conf_[iter_num]);
conf[0][kEvents] = events;
f[kConf] = conf;
features[0] = f;
if (iter_num == 0) {
is_load_ = true;
}
} else {
for (std::vector<std::string>::size_type i = 0; i < profiling_opts_.size(); i++) {
Json f;
if (profiling_opts_[i] == "system_trace") {
f[kConf] = nlohmann::json::parse(system_trace_conf_);
} else if (profiling_opts_[i] == "task_trace") {
if (!task_trace_conf_.empty()) {
f[kConf] = nlohmann::json::parse(task_trace_conf_);
}
}
f[kName] = profiling_opts_[i];
features[i] = f;
}
is_load_ = true;
}
p_device[kFeatures] = features;
// only one device, but sProfMgrStartUp API require for device list
Json devices;
devices[0] = p_device;

Json start_cfg;
start_cfg[kStartCfg] = devices;

// convert json to string
std::stringstream ss;
ss << start_cfg;
send_profiling_config_ = ss.str();
GELOGI("Profiling config %s\n", send_profiling_config_.c_str());
} catch (...) {
GELOGE(FAILED, "Op trace json conf is not invalid !");
return FAILED;
for (int32_t i = 0; i < device_num; i++) {
device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]);
} }

// runtime startup for profiling
uint64_t module = GetProfilingModule();
int32_t device_num = 1;
uint32_t device_id_rt = static_cast<uint32_t>(device_id);
GE_CHK_RT_RET(rtProfilerStart(module, device_num, &device_id_rt));

// call profiling startup API
ProfMgrCfg prof_cfg = {send_profiling_config_};
void *prof_handle = ProfMgrStartUp(&prof_cfg);
if (prof_handle == nullptr) {
GELOGW("ProfMgrStartUp failed on device %d ", device_id);
return FAILED;
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) {
GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret);
} }
GELOGD("StartProfiling, prof_handle: %p", prof_handle);
prof_handle_vec_.push_back(prof_handle);
} }
#endif
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() {
#ifdef DAVINCI_SUPPORT_PROFILING
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter != nullptr) {
int ret = reporter->Flush();
GELOGI("Report data end, ret is %d", ret);
// stop profiling
if (prof_cb_.msprofCtrlCallback == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr.");
return;
} }
uint64_t module = GetProfilingModule();
int32_t device_num = static_cast<int32_t>(device_id_.size());
auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]);
if (device_id_ptr == nullptr) {
GELOGE(FAILED, "Stop profiling: device id ptr is null.");
int32_t cb_ret = prof_cb_.msprofCtrlCallback(static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE),
nullptr, 0);
if (cb_ret != 0) {
GELOGW("call msprofCtrlCallback failed, type:%u, return:%d",
static_cast<uint32_t>(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE), cb_ret);
return; return;
} }
for (int32_t i = 0; i < device_num; i++) {
device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]);
}
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) {
GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret);
}

for (size_t i = 0; i < prof_handle_vec_.size(); ++i) {
int result = ProfMgrStop(prof_handle_vec_[i]);
if (result != 0) {
GELOGW("ProfMgr stop return fail:%d, handle:%p", result, prof_handle_vec_[i]);
}
}
vector<void *>().swap(prof_handle_vec_);
is_load_ = false;
recv_profiling_config_ = "";
GELOGI("Stop Profiling success."); GELOGI("Stop Profiling success.");
#endif #endif
} }
@@ -392,12 +204,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo(
uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter == nullptr) {
GELOGI("Profiling report is nullptr!");
return;
}

std::string data; std::string data;
for (const auto &task : task_desc_info) { for (const auto &task : task_desc_info) {
std::string model_name = task.model_name; std::string model_name = task.model_name;
@@ -412,7 +218,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin
.append(std::to_string(stream_id)).append(" ") .append(std::to_string(stream_id)).append(" ")
.append(std::to_string(model_id)).append("\n")); .append(std::to_string(model_id)).append("\n"));


Msprof::Engine::ReporterData reporter_data{};
ReporterData reporter_data{};
reporter_data.deviceId = device_id; reporter_data.deviceId = device_id;
reporter_data.data = (unsigned char *)data.c_str(); reporter_data.data = (unsigned char *)data.c_str();
reporter_data.dataLen = data.size(); reporter_data.dataLen = data.size();
@@ -422,9 +228,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin
return; return;
} }


ret = reporter->Report(&reporter_data);
if (ret != SUCCESS) {
GELOGE(ret, "Reporter data of task_desc_info fail!");
int32_t cb_ret = CallMsprofReport(reporter_data);
if (cb_ret != 0) {
GELOGE(cb_ret, "Reporter data of task_desc_info failed, ret:%d", cb_ret);
return; return;
} }
} }
@@ -436,9 +242,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo(
uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;);

std::string data; std::string data;
for (const auto &graph : compute_graph_desc_info) { for (const auto &graph : compute_graph_desc_info) {
data.append("model_name:") data.append("model_name:")
@@ -493,64 +296,52 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin
} }


data.append(" model_id:").append(std::to_string(model_id)); data.append(" model_id:").append(std::to_string(model_id));

data.append("\n"); data.append("\n");


Msprof::Engine::ReporterData reporter_data{};
Report(device_id, data, *reporter, reporter_data);

GraphDescReport(device_id, data);
data.clear(); data.clear();
} }
#endif #endif
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report(
const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter,
Msprof::Engine::ReporterData &reporter_data) {
void ProfilingManager::GraphDescReport(const int32_t &device_id, const string &data) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
ReporterData reporter_data{};
int ret = -1;
int32_t cb_ret = -1;
size_t index = data.size() / kReportMaxLen; size_t index = data.size() / kReportMaxLen;
if (index >= 1) { if (index >= 1) {
reporter_data.deviceId = device_id; reporter_data.deviceId = device_id;
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info"));
ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info"));
GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;);
for (size_t i = 0; i < index; ++i) { for (size_t i = 0; i < index; ++i) {
reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * i; reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * i;
reporter_data.dataLen = kReportMaxLen; reporter_data.dataLen = kReportMaxLen;
ret = reporter.Report(&reporter_data);
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;);
cb_ret = CallMsprofReport(reporter_data);
GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;);
} }
reporter_data.dataLen = data.size() - kReportMaxLen * index; reporter_data.dataLen = data.size() - kReportMaxLen * index;
if (reporter_data.dataLen != 0) { if (reporter_data.dataLen != 0) {
reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * index; reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * index;
ret = reporter.Report(&reporter_data);
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;);
cb_ret = CallMsprofReport(reporter_data);
GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;);
} }
} else { } else {
reporter_data.deviceId = device_id; reporter_data.deviceId = device_id;
reporter_data.data = (unsigned char *)data.c_str(); reporter_data.data = (unsigned char *)data.c_str();
reporter_data.dataLen = data.size(); reporter_data.dataLen = data.size();
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info"));
ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info"));
GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;);


ret = reporter.Report(&reporter_data);
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;);
}
#endif
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit(const std::string &module) const {
#ifdef DAVINCI_SUPPORT_PROFILING
int ret = Msprof::Engine::UnInit(module);
if (ret != SUCCESS) {
GELOGE(ret, "profiling plugin uninit failed, ret:%d", ret);
cb_ret = CallMsprofReport(reporter_data);
GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;);
} }
#endif #endif
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData(
uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info,
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info,
bool check_device) {
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
int32_t logic_device_id = 0; int32_t logic_device_id = 0;
rtError_t rt_ret = rtGetDevice(&logic_device_id); rtError_t rt_ret = rtGetDevice(&logic_device_id);
@@ -559,13 +350,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr
return; return;
} }
GELOGD("current logic_device_id:%d", logic_device_id); GELOGD("current logic_device_id:%d", logic_device_id);
if (check_device) {
auto ret = std::find(device_id_.begin(), device_id_.end(), logic_device_id);
if (ret == device_id_.end()) {
GELOGE(FAILED, "get valid phy_device_id failed, profiling report failed.");
return;
}
}
GELOGD("start ProfilingTaskDescInfo."); GELOGD("start ProfilingTaskDescInfo.");
ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id); ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id);
GELOGD("start ProfilingGraphDescInfo."); GELOGD("start ProfilingGraphDescInfo.");
@@ -574,11 +358,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr
#endif #endif
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::SetProfilingConfig(
const std::string &profiling_cfg) {
recv_profiling_config_ = profiling_cfg;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() {
uint64_t module = PROF_MODEL_EXECUTE_MASK | uint64_t module = PROF_MODEL_EXECUTE_MASK |
PROF_RUNTIME_API_MASK | PROF_RUNTIME_API_MASK |
@@ -594,9 +373,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetP
return module; return module;
} }


void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type,
uint32_t device_id,
uint64_t module) {
void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module) {
#ifdef DAVINCI_SUPPORT_PROFILING #ifdef DAVINCI_SUPPORT_PROFILING
if (prof_type == kProfModelSubscribe) { if (prof_type == kProfModelSubscribe) {
if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) {
@@ -608,9 +385,13 @@ void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type,
subs_dev_module_[device_id] = dev_info; subs_dev_module_[device_id] = dev_info;
} }
} else if (prof_type == kProfModelUnsubscribe) { } else if (prof_type == kProfModelUnsubscribe) {
if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) {
if (subs_dev_module_[device_id].subscribe_count > 0) {
subs_dev_module_[device_id].subscribe_count--;
auto iter = subs_dev_module_.find(device_id);
if (iter != subs_dev_module_.end()) {
if (iter->second.subscribe_count > 0) {
iter->second.subscribe_count--;
}
if (iter->second.subscribe_count == 0) {
subs_dev_module_.erase(iter);
} }
} }
} else { } else {
@@ -626,10 +407,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo
uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK;
if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) { if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) {
// register framework to profiling // register framework to profiling
int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_);
if (result != SUCCESS) {
GELOGE(FAILED, "Register profiling engine failed.");
return FAILED;
// register Framework to profiling
int32_t cb_ret = PluginInit();
if (cb_ret != 0) {
GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret);
return cb_ret;
} }
GELOGI("Prof subscribe: model load profiling on."); GELOGI("Prof subscribe: model load profiling on.");
} }
@@ -647,7 +429,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo
UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module); UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module);


// Report profiling data // Report profiling data
Status p_ret = davinci_model->ReportProfilingData(false);
Status p_ret = davinci_model->ReportProfilingData();
if (p_ret != SUCCESS) { if (p_ret != SUCCESS) {
GELOGE(p_ret, "Report profiling data failed."); GELOGE(p_ret, "Report profiling data failed.");
return p_ret; return p_ret;
@@ -672,6 +454,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo
auto iter = subs_dev_module_.find(device[0]); auto iter = subs_dev_module_.find(device[0]);
if (iter != subs_dev_module_.end()) { if (iter != subs_dev_module_.end()) {
if (subs_dev_module_[device[0]].subscribe_count == 1) { if (subs_dev_module_[device[0]].subscribe_count == 1) {
// The same device_id, only stop at last time
rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device); rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(FAILED, "Runtime profiler stop failed."); GELOGE(FAILED, "Runtime profiler stop failed.");
@@ -679,15 +462,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo
} }
} }
UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module); UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module);
} else {
GELOGE(FAILED, "The device_id:%u has not been subscribed, do not need to cancel.", device[0]);
return FAILED;
} }


subscribe_count_--; subscribe_count_--;
if (subscribe_count_ == 0) { if (subscribe_count_ == 0) {
int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE);
if (ret != SUCCESS) {
GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret);
return ret;
}
// profiling plugin uninit at last subscription
PluginUnInit();
} }
#endif #endif
return SUCCESS; return SUCCESS;
@@ -700,11 +483,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn


if (model_load_mask == PROF_MODEL_LOAD_MASK) { if (model_load_mask == PROF_MODEL_LOAD_MASK) {
// register Framework to profiling // register Framework to profiling
int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_);
if (result != SUCCESS) {
GELOGE(FAILED, "Register profiling engine failed.");
return FAILED;
int32_t cb_ret = PluginInit();
if (cb_ret != 0) {
GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret);
return cb_ret;
} }

int32_t device_num = -1; int32_t device_num = -1;
rtError_t rt_ret = rtProfilerStart(model_load_mask, device_num, nullptr); rtError_t rt_ret = rtProfilerStart(model_load_mask, device_num, nullptr);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
@@ -719,7 +503,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn
if (training_trace_mask == PROF_TRAINING_TRACE_MASK) { if (training_trace_mask == PROF_TRAINING_TRACE_MASK) {
is_training_trace_ = true; is_training_trace_ = true;
} }
is_acl_api_mode_ = true;
GELOGI("Prof init success."); GELOGI("Prof init success.");
#endif #endif
return SUCCESS; return SUCCESS;
@@ -730,19 +513,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
is_load_profiling_ = false; is_load_profiling_ = false;
is_training_trace_ = false; is_training_trace_ = false;
is_acl_api_mode_ = false;
is_execute_profiling_ = false;

// profiling plugin uninit
PluginUnInit();


int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE);
if (ret != SUCCESS) {
GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret);
}
int32_t dev_num = -1; int32_t dev_num = -1;
rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr); rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(FAILED, "Runtime profiler stop failed."); GELOGE(FAILED, "Runtime profiler stop failed.");
return FAILED; return FAILED;
} }

for (auto device_id_module : device_id_module_map_) { for (auto device_id_module : device_id_module_map_) {
if (device_id_module.second != 0) { if (device_id_module.second != 0) {
uint32_t device_id = static_cast<uint32_t>(device_id_module.first); uint32_t device_id = static_cast<uint32_t>(device_id_module.first);
@@ -792,6 +573,7 @@ Status ProfilingManager::ProfParseDeviceId(const std::map<std::string, std::stri
return FAILED; return FAILED;
} catch (std::out_of_range &) { } catch (std::out_of_range &) {
GELOGE(FAILED, "Device id: %s is out of range.", decvice_id[i].c_str()); GELOGE(FAILED, "Device id: %s is out of range.", decvice_id[i].c_str());
return FAILED;
} catch (...) { } catch (...) {
GELOGE(FAILED, "Device id: %s cannot change to int.", decvice_id[i].c_str()); GELOGE(FAILED, "Device id: %s cannot change to int.", decvice_id[i].c_str());
return FAILED; return FAILED;
@@ -818,6 +600,7 @@ Status ProfilingManager::ProfParseParam(const std::map<std::string, std::string>
return FAILED; return FAILED;
} catch (std::out_of_range &) { } catch (std::out_of_range &) {
GELOGE(FAILED, "Device num: %s is out of range.", iter->second.c_str()); GELOGE(FAILED, "Device num: %s is out of range.", iter->second.c_str());
return FAILED;
} catch (...) { } catch (...) {
GELOGE(FAILED, "Device num: %s cannot change to int.", iter->second.c_str()); GELOGE(FAILED, "Device num: %s cannot change to int.", iter->second.c_str());
return FAILED; return FAILED;
@@ -859,7 +642,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt
for (int32_t i = 0; i < device_num; i++) { for (int32_t i = 0; i < device_num; i++) {
device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); device_id_ptr[i] = static_cast<uint32_t>(device_list[i]);
} }
GELOGD("Runtime config param: 0x%llx, device num: %d.", module, device_num);
GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num);


rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
@@ -878,7 +661,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt
GELOGW("Prof start: load model module is invalid."); GELOGW("Prof start: load model module is invalid.");
} }
UpdateDeviceIdModuleMap(kProfStart, module, device_list); UpdateDeviceIdModuleMap(kProfStart, module, device_list);
GELOGD("Prof start profiling success.");
GELOGI("Prof start profiling success.");
#endif #endif
return SUCCESS; return SUCCESS;
} }
@@ -901,7 +684,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt
for (int32_t i = 0; i < device_num; i++) { for (int32_t i = 0; i < device_num; i++) {
device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); device_id_ptr[i] = static_cast<uint32_t>(device_list[i]);
} }
GELOGD("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num);
GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num);
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); GELOGE(FAILED, "Prof stop: runtime profiler config proc failed.");
@@ -921,7 +704,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt
GELOGW("Prof stop: load model module is invalid."); GELOGW("Prof stop: load model module is invalid.");
} }
UpdateDeviceIdModuleMap(kProfStop, module, device_list); UpdateDeviceIdModuleMap(kProfStop, module, device_list);
GELOGD("Prof stop profiling success.");
GELOGI("Prof stop profiling success.");
#endif #endif
return SUCCESS; return SUCCESS;
} }
@@ -963,47 +746,90 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::Profilin
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(rt_ret, "Runtime get logic_device_id failed, current logic_device_id:%d", logic_device_id); GELOGE(rt_ret, "Runtime get logic_device_id failed, current logic_device_id:%d", logic_device_id);
} }
GELOGD("Current logic_device_id:%d", logic_device_id);
GELOGI("Current logic_device_id:%d", logic_device_id);


bool execute_model_prof_on = false; bool execute_model_prof_on = false;
auto iter = std::find(device_id_.begin(), device_id_.end(), logic_device_id); auto iter = std::find(device_id_.begin(), device_id_.end(), logic_device_id);
if (iter != device_id_.end()) { if (iter != device_id_.end()) {
execute_model_prof_on = true; execute_model_prof_on = true;
} }
GELOGD("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on);
return is_execute_profiling_ || execute_model_prof_on;
GELOGI("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on);
return execute_model_prof_on;
} }


/**
* @brief Profiling PluginImpl
*/
// PluginImpl static variable init
Msprof::Engine::Reporter *PluginImpl::reporter_ = nullptr;

PluginImpl::PluginImpl(const std::string &module) : module_(module) { GELOGI("Create PluginImpl\n"); }

int PluginImpl::Init(const Msprof::Engine::Reporter *reporter) {
GELOGI("PluginImpl init");
reporter_ = const_cast<Msprof::Engine::Reporter *>(reporter);
return 0;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::PluginInit() const {
if (prof_cb_.msprofReporterCallback == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr.");
return ge::PARAM_INVALID;
}
return prof_cb_.msprofReporterCallback(
static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK),
static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_INIT),
nullptr, 0);
} }


int PluginImpl::UnInit() {
GELOGI("PluginImpl Uninit");
reporter_ = nullptr;
return 0;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit() const {
#ifdef DAVINCI_SUPPORT_PROFILING
if (prof_cb_.msprofReporterCallback == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr.");
return;
}
int32_t cb_ret = prof_cb_.msprofReporterCallback(
static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK),
static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_UNINIT),
nullptr, 0);
if (cb_ret != 0) {
GELOGW("profiling plugin uninit failed, ret:%d", cb_ret);
}
#endif
} }


Msprof::Engine::PluginIntf *ProfilingEngineImpl::CreatePlugin() {
GELOGI(" Create Plugin");
return new (std::nothrow) PluginImpl(GE_PROFILING_MODULE);
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMsprofReport(
ReporterData &reporter_data) const {
if (prof_cb_.msprofReporterCallback == nullptr) {
GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr.");
return ge::PARAM_INVALID;
}
return prof_cb_.msprofReporterCallback(
static_cast<uint32_t>(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK),
static_cast<uint32_t>(MsprofReporterCallbackType::MSPROF_REPORTER_REPORT),
static_cast<void *>(&reporter_data), sizeof(ReporterData));
} }


int ProfilingEngineImpl::ReleasePlugin(Msprof::Engine::PluginIntf *plugin) {
if (plugin != nullptr) {
delete plugin;
plugin = nullptr;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint(
std::string &fp_point, std::string &bp_point) {
// Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init
if (!fp_point_.empty() && !bp_point_.empty()) {
fp_point = fp_point_;
bp_point = bp_point_;
GELOGI("Bp Fp have been initialized in env or options. bp_point: %s, fp_point: %s", bp_point.c_str(), fp_point.c_str());
return;
}
// ProfApi mode and training trace is set
try {
char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = { 0x00 };
INT32 ret = mmGetEnv("PROFILING_OPTIONS", env_profiling_options, MSPROF_OPTIONS_DEF_LEN_MAX);
if (ret != EN_OK) {
GELOGI("PROFILING_OPTIONS env is not exist.");
return;
}
GELOGI("Parse env PROFILING_OPTIONS:%s.", env_profiling_options);
Json prof_options = Json::parse(env_profiling_options);

fp_point_ = prof_options[kFpPoint];
bp_point_ = prof_options[kBpPoint];

fp_point = fp_point_;
bp_point = bp_point_;
if (!fp_point_.empty() && !bp_point_.empty()) {
GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str());
}
} catch (...) {
GELOGE(FAILED, "Json prof options is invalid.");
return;
} }
return 0;
return;
} }


} // namespace ge } // namespace ge

+ 51
- 69
ge/common/profiling/profiling_manager.h View File

@@ -26,9 +26,7 @@
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "framework/common/ge_types.h" #include "framework/common/ge_types.h"
#include "external/register/register_types.h" #include "external/register/register_types.h"
#include "toolchain/prof_engine.h"
#include "toolchain/prof_mgr_core.h"
#include "toolchain/prof_acl_api.h"
#include "toolchain/prof_callback.h"


using std::map; using std::map;
using std::string; using std::string;
@@ -37,35 +35,33 @@ using Json = nlohmann::json;


namespace { namespace {
const std::string GE_PROFILING_MODULE = "Framework"; const std::string GE_PROFILING_MODULE = "Framework";
// DataTypeConfig MASK
#define PROF_ACL_API_MASK 0x0001
#define PROF_TASK_TIME_MASK 0x0002
#define PROF_AICORE_METRICS_MASK 0x0004
#define PROF_AICPU_TRACE_MASK 0x0008
#define PROF_MODEL_EXECUTE_MASK 0x0010
#define PROF_RUNTIME_API_MASK 0x0020
#define PROF_RUNTIME_TRACE_MASK 0x0040
#define PROF_SCHEDULE_TIMELINE_MASK 0x0080
#define PROF_SCHEDULE_TRACE_MASK 0x0100
#define PROF_AIVECTORCORE_METRICS_MASK 0x0200
#define PROF_SUBTASK_TIME_MASK 0x0400
#define PROF_TRAINING_TRACE_MASK 0x0800
#define PROF_HCCL_TRACE_MASK 0x1000
#define PROF_DATA_PROCESS_MASK 0x2000
#define PROF_MODEL_LOAD_MASK 0x8000000000000000

} // namespace } // namespace
namespace ge { namespace ge {
struct DeviceSubsInfo { struct DeviceSubsInfo {
uint64_t module; uint64_t module;
uint32_t subscribe_count; uint32_t subscribe_count;
}; };
// register Plugin
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf {
public:
explicit PluginImpl(const std::string &module);
~PluginImpl() {}

int Init(const Msprof::Engine::Reporter *reporter);
int UnInit();
static Msprof::Engine::Reporter *GetPluginReporter() { return reporter_; }


private:
static Msprof::Engine::Reporter *reporter_;
std::string module_;
};

// register Engine
class ProfilingEngineImpl : public Msprof::Engine::EngineIntf {
public:
ProfilingEngineImpl() {}
~ProfilingEngineImpl() {}

Msprof::Engine::PluginIntf *CreatePlugin();
int ReleasePlugin(Msprof::Engine::PluginIntf *plugin);
struct MsprofCallback {
MsprofCtrlCallback msprofCtrlCallback;
MsprofReporterCallback msprofReporterCallback;
}; };


class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager {
@@ -73,68 +69,54 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager {
ProfilingManager(); ProfilingManager();
virtual ~ProfilingManager(); virtual ~ProfilingManager();
static ProfilingManager &Instance(); static ProfilingManager &Instance();
ge::Status Init(const Options &options);
ge::Status InitFromOptions(const Options &options);
ge::Status InitFromAclCfg(const std::string &config);
ge::Status StartProfiling(int32_t iter, int32_t device_id);
void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module);
ge::Status ProfModelSubscribe(uint64_t module, void *model);
ge::Status ProfModelUnsubscribe(void *model);
ge::Status ProfInit(uint64_t module);
ge::Status ProfFinalize();
ge::Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para);
ge::Status ProfStopProfiling(uint64_t module, const std::map<std::string, std::string> &config_para);
Status Init(const Options &options);
Status ProfInit(uint64_t module);
Status ProfFinalize();
Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para);
Status ProfStopProfiling(uint64_t module, const std::map<std::string, std::string> &config_para);
Status ProfModelSubscribe(uint64_t module, void *model);
Status ProfModelUnsubscribe(void *model);
void StopProfiling(); void StopProfiling();
bool ProfilingOpTraceOn() const { return is_op_trace_; }
bool ProfilingLoadFlag() const { return is_load_; }
bool ProfilingTrainingTraceOn() const { return is_training_trace_; } bool ProfilingTrainingTraceOn() const { return is_training_trace_; }
bool ProfilingModelLoadOn() const { return is_load_profiling_; } bool ProfilingModelLoadOn() const { return is_load_profiling_; }
bool ProfilingModelExecuteOn() const; bool ProfilingModelExecuteOn() const;
bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // only used by command pattern
bool IsAclApiMode() const { return is_acl_api_mode_; }
int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; }
bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // is_execute_profiling_ only used by ge option and env
void ReportProfilingData(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, void ReportProfilingData(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info,
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info,
bool check_device);
void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter,
Msprof::Engine::ReporterData &reporter_data);
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info);
void ProfilingTaskDescInfo(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, void ProfilingTaskDescInfo(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info,
const int32_t &device_id); const int32_t &device_id);
void ProfilingGraphDescInfo(uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, void ProfilingGraphDescInfo(uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info,
const int32_t &device_id); const int32_t &device_id);
void SetProfilingConfig(const string &profiling_cfg);
vector<int32_t> GetProfilingDeviceId() const { return device_id_; }
void PluginUnInit(const std::string &module) const;
Status PluginInit() const;
void PluginUnInit() const;
Status CallMsprofReport(ReporterData &reporter_data) const;
struct MsprofCallback &GetMsprofCallback() { return prof_cb_; }
void SetMsprofCtrlCallback(MsprofCtrlCallback func) { prof_cb_.msprofCtrlCallback = func; }
void SetMsprofReporterCallback(MsprofReporterCallback func) { prof_cb_.msprofReporterCallback = func; }
void GetFpBpPoint(std::string &fp_point, std::string &bp_point);
private: private:
ge::Status ParseFeaturesFromAclCfg(const Json &feature);
ge::Status ProfParseParam(const std::map<std::string, std::string> &config_para, int32_t &device_num,
vector<int32_t> &device_list);
ge::Status ProfParseDeviceId(const std::map<std::string, std::string> &config_para,
Status InitFromOptions(const Options &options, MsprofGeOptions &prof_conf);
Status ParseOptions(const std::string &options);
Status ProfParseParam(const std::map<std::string, std::string> &config_para, int32_t &device_num,
vector<int32_t> &device_list);
Status ProfParseDeviceId(const std::map<std::string, std::string> &config_para,
vector<int32_t> &device_list); vector<int32_t> &device_list);
uint64_t GetProfilingModule(); uint64_t GetProfilingModule();
void GraphDescReport(const int32_t &device_id, const string &data);
void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list); void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector<int32_t> &device_list);
bool is_load_profiling_ = false;
bool is_execute_profiling_ = false;
bool is_op_trace_ = false;
bool is_load_ = false;
bool is_training_trace_ = false;
bool is_acl_api_mode_ = false;
int32_t op_trace_iter_num_ = 0;
string job_id_;
string prof_dir_;
void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module);

bool is_load_profiling_;
bool is_execute_profiling_;
bool is_training_trace_;
vector<int32_t> device_id_; vector<int32_t> device_id_;
vector<string> op_trace_conf_;
vector<string> profiling_opts_;
vector<void *> prof_handle_vec_;
string recv_profiling_config_;
string send_profiling_config_;
string system_trace_conf_;
string task_trace_conf_;
const ProfilingEngineImpl engine_;
map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module
map<uint32_t, DeviceSubsInfo> subs_dev_module_; // key: device_id, value: profiling on module map<uint32_t, DeviceSubsInfo> subs_dev_module_; // key: device_id, value: profiling on module
uint32_t subscribe_count_; uint32_t subscribe_count_;
std::mutex mutex_; std::mutex mutex_;
MsprofCallback prof_cb_;
std::string fp_point_;
std::string bp_point_;
}; };
} // namespace ge } // namespace ge
#endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_

+ 1
- 1
ge/common/types.cc View File

@@ -801,7 +801,7 @@ const uint32_t XRGB_CHN_NUM = 4;
/// ///
const bool DEFAULT_GLOBAL_POOLING = false; const bool DEFAULT_GLOBAL_POOLING = false;


const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0///
const uint32_t MODEL_VERSION = 0x20000000; ///< Model version 2.0///


// Eltwise's input size // Eltwise's input size
const int ELTWISE_MIN_INPUT_SIZE = 2; const int ELTWISE_MIN_INPUT_SIZE = 2;


+ 43
- 40
ge/common/util.cc View File

@@ -51,14 +51,15 @@ namespace {
* If such an exception is encountered during operation, * If such an exception is encountered during operation,
* the proto file can be divided into several small files or the limit value can be increased. * the proto file can be divided into several small files or the limit value can be increased.
*/ */
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
const int kFileSizeOutLimitedOrOpenFailed = -1;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 1073741824; // 536870912 * 2 536870912 represent 512M


/// The maximum length of the file. /// The maximum length of the file.
const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now
const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now
const int kMaxBuffSize = 256; const int kMaxBuffSize = 256;
const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024;
constexpr uint32_t kMaxConfigFileByte = 10485760; // 10 * 1024 * 1024
} // namespace } // namespace


namespace ge { namespace ge {
@@ -76,7 +77,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co
std::string real_path = RealPath(file); std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == kFileSizeOutLimitedOrOpenFailed, return false,
"file size not valid.");


std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) { if (!fs.is_open()) {
@@ -118,20 +120,20 @@ long GetFileLength(const std::string &input_file) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
unsigned long long file_length = 0; unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)});
return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));
mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)});
return kFileSizeOutLimitedOrOpenFailed, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
return -1, "File[%s] size is 0, not valid.", input_file.c_str()); return -1, "File[%s] size is 0, not valid.", input_file.c_str());


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length,
kMaxFileSizeLimit);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return kFileSizeOutLimitedOrOpenFailed, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length,
kMaxFileSizeLimit);
return static_cast<long>(file_length); return static_cast<long>(file_length);
} }


@@ -187,7 +189,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co
std::streamsize size = file.tellg(); std::streamsize size = file.tellg();


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t >(kMaxFileSizeLimit), file.close();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t>(kMaxFileSizeLimit), file.close();
return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit);


file.seekg(0, std::ios::beg); // [no need to check value] file.seekg(0, std::ios::beg); // [no need to check value]
@@ -210,8 +212,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length(); auto dir_path_len = directory_path.length();
if (dir_path_len >= MMPA_MAX_PATH) { if (dir_path_len >= MMPA_MAX_PATH) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {directory_path, std::to_string(MMPA_MAX_PATH)});
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{directory_path, std::to_string(MMPA_MAX_PATH)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH);
return -1; return -1;
} }
@@ -224,8 +226,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::
if (ret != 0) { if (ret != 0) {
if (errno != EEXIST) { if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.",
directory_path.c_str());
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret; return ret;
} }
} }
@@ -265,7 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch


std::string real_path = RealPath(file); std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage(
"E19000", {"path", "errmsg"}, {file, strerror(errno)});
"E19000", {"path", "errmsg"}, {file, strerror(errno)});
return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno));


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
@@ -301,13 +302,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha
google::protobuf::io::IstreamInputStream input(&fs); google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message); bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));


return ret; return ret;
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
mmTimeval tv {};
mmTimeval tv{};
int ret = mmGetTimeOfDay(&tv, nullptr); int ret = mmGetTimeOfDay(&tv, nullptr);
GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
@@ -315,7 +316,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp()
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() {
mmTimeval tv {};
mmTimeval tv{};
int ret = mmGetTimeOfDay(&tv, nullptr); int ret = mmGetTimeOfDay(&tv, nullptr);
GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_sec; // seconds auto total_use_time = tv.tv_sec; // seconds
@@ -350,8 +351,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"},
{path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH);


// Nullptr is returned when the path does not exist or there is no permission // Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible // Return absolute path when path is accessible
@@ -385,16 +387,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__ #ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else #else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif #endif


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);
!ValidateStr(real_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);


// The absolute path points to a file that is not readable // The absolute path points to a file that is not readable
if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) {
@@ -416,24 +418,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
} }


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), MMPA_MAX_PATH);
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)});
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
MMPA_MAX_PATH);


// A regular matching expression to verify the validity of the input file path // A regular matching expression to verify the validity of the input file path
// Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores
// File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.)
#ifdef __GNUC__ #ifdef __GNUC__
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$";
#else #else
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$";
#endif #endif


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);
!ValidateStr(file_path, mode),
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);


std::string real_path = RealPath(file_path.c_str()); std::string real_path = RealPath(file_path.c_str());
// Can get absolute path (file exists) // Can get absolute path (file exists)


+ 1
- 1
ge/executor/CMakeLists.txt View File

@@ -17,6 +17,7 @@ set(SRC_LIST
"../common/dump/dump_properties.cc" "../common/dump/dump_properties.cc"
"../common/dump/dump_manager.cc" "../common/dump/dump_manager.cc"
"../common/dump/dump_op.cc" "../common/dump/dump_op.cc"
"../common/profiling/ge_profiling.cc"
"../graph/load/graph_loader.cc" "../graph/load/graph_loader.cc"
"../graph/execute/graph_execute.cc" "../graph/execute/graph_execute.cc"
"../omm/csa_interact.cc" "../omm/csa_interact.cc"
@@ -244,7 +245,6 @@ target_link_libraries(ge_executor_shared PRIVATE
mmpa mmpa
graph graph
register register
msprof
error_manager error_manager
ascend_hal_stub ascend_hal_stub
ascend_protobuf ascend_protobuf


+ 5
- 3
ge/executor/ge_executor.cc View File

@@ -283,7 +283,8 @@ Status GeExecutor::Initialize() {
// Start profiling // Start profiling
Options profiling_options; Options profiling_options;
profiling_options.device_id = 0; profiling_options.device_id = 0;
profiling_options.job_id = "";
// job id need to be set, the value is meaningless;
profiling_options.job_id = "1";
ProfilingManager::Instance().Init(profiling_options); ProfilingManager::Instance().Init(profiling_options);


isInit_ = true; isInit_ = true;
@@ -303,7 +304,7 @@ Status GeExecutor::Finalize() {
// Stop profiling // Stop profiling
if (ProfilingManager::Instance().ProfilingOn()) { if (ProfilingManager::Instance().ProfilingOn()) {
ProfilingManager::Instance().StopProfiling(); ProfilingManager::Instance().StopProfiling();
ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE);
ProfilingManager::Instance().PluginUnInit();
} }


GELOGI("Uninit GeExecutor over."); GELOGI("Uninit GeExecutor over.");
@@ -638,7 +639,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) {
return ACL_ERROR_GE_INTERNAL_ERROR; return ACL_ERROR_GE_INTERNAL_ERROR;
} }


std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = ModelManager::GetInstance()->GetHybridModel(model_id);
std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model =
ModelManager::GetInstance()->GetHybridModel(model_id);
if (hybrid_davinci_model != nullptr) { if (hybrid_davinci_model != nullptr) {
uint64_t session_id = hybrid_davinci_model->GetSessionId(); uint64_t session_id = hybrid_davinci_model->GetSessionId();
VarManagerPool::Instance().RemoveVarManager(session_id); VarManagerPool::Instance().RemoveVarManager(session_id);


+ 1
- 4
ge/executor/module.mk View File

@@ -8,6 +8,7 @@ local_ge_executor_src_files := \
../common/dump/dump_op.cc \ ../common/dump/dump_op.cc \
../common/ge/plugin_manager.cc \ ../common/ge/plugin_manager.cc \
../common/ge/op_tiling_manager.cc \ ../common/ge/op_tiling_manager.cc \
../common/profiling/ge_profiling.cc \
../graph/load/graph_loader.cc \ ../graph/load/graph_loader.cc \
../graph/execute/graph_execute.cc \ ../graph/execute/graph_execute.cc \
../omm/csa_interact.cc \ ../omm/csa_interact.cc \
@@ -177,7 +178,6 @@ local_ge_executor_shared_library := \
libmmpa \ libmmpa \
libgraph \ libgraph \
libregister \ libregister \
libmsprof \
liberror_manager \ liberror_manager \


local_ge_executor_ldflags := -lrt -ldl \ local_ge_executor_ldflags := -lrt -ldl \
@@ -234,7 +234,6 @@ LOCAL_SHARED_LIBRARIES := \
libmmpa \ libmmpa \
libgraph \ libgraph \
libregister \ libregister \
libmsprof \
liberror_manager \ liberror_manager \
stub/libascend_hal \ stub/libascend_hal \


@@ -272,7 +271,6 @@ LOCAL_SHARED_LIBRARIES := \
libruntime \ libruntime \
libslog \ libslog \
libmmpa \ libmmpa \
libmsprof \


LOCAL_LDFLAGS += $(local_ge_executor_ldflags) LOCAL_LDFLAGS += $(local_ge_executor_ldflags)


@@ -304,7 +302,6 @@ LOCAL_SHARED_LIBRARIES := \
libruntime \ libruntime \
libslog \ libslog \
libmmpa \ libmmpa \
libmsprof \


ifeq ($(device_os),android) ifeq ($(device_os),android)
LOCAL_LDFLAGS += -ldl LOCAL_LDFLAGS += -ldl


+ 1
- 0
ge/ge_inference.mk View File

@@ -164,6 +164,7 @@ OMG_HOST_SRC_FILES := \
host_kernels/slice_d_kernel.cc \ host_kernels/slice_d_kernel.cc \
host_kernels/dynamic_stitch_kernel.cc \ host_kernels/dynamic_stitch_kernel.cc \
host_kernels/identity_kernel.cc \ host_kernels/identity_kernel.cc \
host_kernels/reformat_kernel.cc \
graph/passes/stop_gradient_pass.cc \ graph/passes/stop_gradient_pass.cc \
graph/passes/prevent_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \
graph/passes/identity_pass.cc \ graph/passes/identity_pass.cc \


+ 5
- 6
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -14,7 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
#include "host_cpu_engine.h" #include "host_cpu_engine.h"
#include <dlfcn.h>
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
@@ -96,8 +95,8 @@ Status GetDataNumber(const GeTensorDesc &out_desc, uint64_t &data_num) {


void HostCpuEngine::CloseSo() { void HostCpuEngine::CloseSo() {
for (auto handle : lib_handles_) { for (auto handle : lib_handles_) {
if (dlclose(handle) != 0) {
GELOGW("failed to close handle, message: %s", dlerror());
if (mmDlclose(handle) != 0) {
GELOGW("failed to close handle, message: %s", mmDlerror());
} }
} }
lib_handles_.clear(); lib_handles_.clear();
@@ -323,13 +322,13 @@ Status HostCpuEngine::LoadLibs(std::vector<std::string> &lib_paths) {


Status HostCpuEngine::LoadLib(const std::string &lib_path) { Status HostCpuEngine::LoadLib(const std::string &lib_path) {
GELOGI("To invoke dlopen on lib: %s", lib_path.c_str()); GELOGI("To invoke dlopen on lib: %s", lib_path.c_str());
auto handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
auto handle = mmDlopen(lib_path.c_str(), MMPA_RTLD_NOW | MMPA_RTLD_GLOBAL);
if (handle == nullptr) { if (handle == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), dlerror());
GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), mmDlerror());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


auto initialize = (Status (*)(const HostCpuContext &))dlsym(handle, "Initialize");
auto initialize = (Status (*)(const HostCpuContext &))mmDlsym(handle, "Initialize");
if (initialize != nullptr) { if (initialize != nullptr) {
GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str()); GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str());
if (initialize(HostCpuContext()) != SUCCESS) { if (initialize(HostCpuContext()) != SUCCESS) {


+ 4
- 6
ge/ge_runner.mk View File

@@ -29,6 +29,8 @@ LIBGE_LOCAL_SRC_FILES := \
common/dump/dump_manager.cc \ common/dump/dump_manager.cc \
common/dump/dump_properties.cc \ common/dump/dump_properties.cc \
common/dump/dump_op.cc \ common/dump/dump_op.cc \
common/profiling/ge_profiling.cc \
common/profiling/ge_runner_profiling.cc \
engine_manager/dnnengine_manager.cc \ engine_manager/dnnengine_manager.cc \
ge_local_engine/engine/host_cpu_engine.cc \ ge_local_engine/engine/host_cpu_engine.cc \
generator/ge_generator.cc \ generator/ge_generator.cc \
@@ -170,6 +172,7 @@ LIBGE_LOCAL_SRC_FILES := \
host_kernels/sub_kernel.cc \ host_kernels/sub_kernel.cc \
host_kernels/transdata_kernel.cc \ host_kernels/transdata_kernel.cc \
host_kernels/unpack_kernel.cc \ host_kernels/unpack_kernel.cc \
host_kernels/reformat_kernel.cc \
graph/passes/folding_pass.cc \ graph/passes/folding_pass.cc \
graph/passes/get_original_format_pass.cc \ graph/passes/get_original_format_pass.cc \
graph/passes/guarantee_const_pass.cc \ graph/passes/guarantee_const_pass.cc \
@@ -306,7 +309,6 @@ LIBGE_LOCAL_SRC_FILES := \
LIBCLIENT_LOCAL_SRC_FILES := \ LIBCLIENT_LOCAL_SRC_FILES := \
proto/ge_api.proto \ proto/ge_api.proto \
client/ge_api.cc \ client/ge_api.cc \
client/ge_prof.cc \


RUNNER_LOCAL_C_INCLUDES := \ RUNNER_LOCAL_C_INCLUDES := \
$(LOCAL_PATH) ./ \ $(LOCAL_PATH) ./ \
@@ -371,7 +373,7 @@ LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES)


LOCAL_STATIC_LIBRARIES := libge_memory \ LOCAL_STATIC_LIBRARIES := libge_memory \
libadump_server \ libadump_server \
libmsprofiler \
libmsprofiler_fwk \
libmmpa \ libmmpa \


LOCAL_SHARED_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \
@@ -381,7 +383,6 @@ LOCAL_SHARED_LIBRARIES := \
libgraph \ libgraph \
libregister \ libregister \
libge_common \ libge_common \
libmsprof \
liberror_manager \ liberror_manager \


LOCAL_LDFLAGS := -lrt -ldl LOCAL_LDFLAGS := -lrt -ldl
@@ -408,7 +409,6 @@ endif
LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES)


LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \
../../out/ge/lib64/stub/ge_prof.cc \
../../out/ge/lib64/stub/ge_ir_build.cc \ ../../out/ge/lib64/stub/ge_ir_build.cc \


LOCAL_SHARED_LIBRARIES := LOCAL_SHARED_LIBRARIES :=
@@ -464,7 +464,6 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \ libc_sec \
libslog \ libslog \
libmmpa \ libmmpa \
libmsprof \


LOCAL_LDFLAGS := -lrt -ldl LOCAL_LDFLAGS := -lrt -ldl


@@ -497,7 +496,6 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \ libc_sec \
libslog \ libslog \
libmmpa \ libmmpa \
libmsprof \


LOCAL_LDFLAGS := -lrt -ldl LOCAL_LDFLAGS := -lrt -ldl




+ 2
- 1
ge/ge_runtime/runtime_model.cc View File

@@ -28,6 +28,7 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
const int kOffsetUnit = 8;
RuntimeModel::~RuntimeModel() { RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start"); GELOGI("RuntimeModel destructor start");


@@ -495,7 +496,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
return false; return false;
} }
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data()));
int64_t offset = elem_num * 8;
int64_t offset = elem_num * kOffsetUnit;
uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset; uintptr_t hbm_raw_data_base_addr = reinterpret_cast<uintptr_t>(constant->output_addrs[0]) + offset;
for (int64_t i = elem_num - 1; i >= 0; --i) { for (int64_t i = elem_num - 1; i >= 0; --i) {
buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]);


+ 106
- 47
ge/generator/ge_generator.cc View File

@@ -156,7 +156,12 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen
} }


string op_type; string op_type;
if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
bool is_const = false;
(void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
if (is_const) {
GELOGD("Get input[%d] is const", index);
op_type = CONSTANTOP;
} else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
op_type = DATA; op_type = DATA;
} }


@@ -165,6 +170,18 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen
if (data_op == nullptr) { if (data_op == nullptr) {
return FAILED; return FAILED;
} }
if (is_const) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
return FAILED;
}
if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
return FAILED;
}
}

(void)AttrUtils::SetBool(data_op, "_is_single_op", true); (void)AttrUtils::SetBool(data_op, "_is_single_op", true);


GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail."); GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail.");
@@ -240,6 +257,8 @@ class GeGenerator::Impl {


Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);


Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);

Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);


@@ -505,19 +524,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr


GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model);
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
ModelHelper model_helper;
string model_name = "";
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid");
return PARAM_INVALID;
}
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null");
ge_model->SetName(model_name);
ret = impl_->SaveModel(file_name_prefix, ge_model, model);
ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Save model failed"); GELOGE(ret, "Save model failed");
if (impl_->graph_manager_.Finalize() != SUCCESS) { if (impl_->graph_manager_.Finalize() != SUCCESS) {
@@ -567,6 +574,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) { bool is_offline) {
if (!is_offline) {
(void)AttrUtils::SetBool(op_desc, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true);
}


if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
@@ -594,40 +604,11 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in


// 2. Create ComputeGraph. // 2. Create ComputeGraph.
string name = ge::CurrentTimeInStr() + "_" + model_file_name; string name = ge::CurrentTimeInStr() + "_" + model_file_name;
ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name);
GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);

// 3. Add Node to ComputeGraph.
NodePtr op_node = compute_graph->AddNode(op_desc);
GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);

// 4. Create InputData node.
int32_t arg_index = 0;
if (inputs.empty()) {
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
arg_index++;
}
} else {
for (const auto &in_desc : inputs) {
GeTensorDesc input_desc = in_desc.GetTensorDesc();
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
arg_index++;
}
Graph graph;
if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) {
GELOGE(GRAPH_FAILED, "make graph fail.");
return GRAPH_FAILED;
} }

// 5. Create Output node.
if (!outputs.empty()) {
GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
}

// dump ComputeGraph.
compute_graph->Dump();
Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
GELOGI("ATC parser success in single op build."); GELOGI("ATC parser success in single op build.");


GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
@@ -683,6 +664,46 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor
return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false); return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
} }


Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);

// 1. Add Node to ComputeGraph.
NodePtr op_node = compute_graph->AddNode(op_desc);
GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);

// 2. Create InputData node.
int32_t arg_index = 0;
if (inputs.empty()) {
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
arg_index++;
}
} else {
for (const auto &in_desc : inputs) {
GeTensorDesc input_desc = in_desc.GetTensorDesc();
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
arg_index++;
}
}

// 3. Create Output node.
if (!outputs.empty()) {
GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
}

// dump ComputeGraph node.
compute_graph->Dump();
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);

return SUCCESS;
}

Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) { const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
@@ -712,6 +733,44 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &
return SUCCESS; return SUCCESS;
} }


Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
ModelBufferData &model_buff) {
bool is_unknown_shape = false;
auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
if (ret != SUCCESS) {
GELOGE(FAILED, "Check root model is unkonwn shape failed");
return FAILED;
}
GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
"ge root model has no sub model")
GeModelPtr model_root = nullptr;
if (is_unknown_shape) {
model_root = make_shared<GeModel>();
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
model_root->SetName(ge_root_model->GetRootGraph()->GetName());
} else {
model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
}
// set atc version
if (!SetAtcVersionInfo(*(model_root.get()))) {
GELOGW("SetPackageVersionInfo of atc failed!");
}
// set opp version
if (!SetOppVersionInfo(*(model_root.get()))) {
GELOGW("SetPackageVersionInfo of ops failed!");
}
ModelHelper model_helper;
model_helper.SetSaveMode(is_offline_);
ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
if (ret != SUCCESS) {
GELOGE(ret, "Save to om model failed");
return ret;
}
return SUCCESS;
}

Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
GeRootModelPtr &ge_root_model) { GeRootModelPtr &ge_root_model) {
static std::atomic<GraphId> atomic_graph_id(0); static std::atomic<GraphId> atomic_graph_id(0);


+ 4
- 4
ge/graph/build/graph_builder.cc View File

@@ -349,7 +349,8 @@ static Status GenerateTaskForConstant(const std::shared_ptr<ComputeGraph> &graph
GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str());
std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy";
if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) {
GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str());
GELOGE(FAILED, "Insert memcpy between %s and %s failed.",
in_node->GetName().c_str(), node->GetName().c_str());
return FAILED; return FAILED;
} }
} }
@@ -475,7 +476,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr
} }


Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) {
// set input_desc.size = src_node.output_desc.size
// Set the size of input_desc to 'src_node.output_desc.size'
if (node_ptr->GetType() == DATA) { if (node_ptr->GetType() == DATA) {
bool is_unknown_shape = false; bool is_unknown_shape = false;
GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape),
@@ -498,7 +499,7 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) {
GE_IF_BOOL_EXEC(src_op == nullptr, continue); GE_IF_BOOL_EXEC(src_op == nullptr, continue);
auto node_op_desc = node_ptr->GetOpDesc(); auto node_op_desc = node_ptr->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
// set dst_node.input_desc = src_node.output_desc
// Set the input_desc of dst_node to 'src_node.output_desc'
auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx());
int64_t size = 0; int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!"));
@@ -512,7 +513,6 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) {
auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx());
GE_CHECK_NOTNULL(input_desc); GE_CHECK_NOTNULL(input_desc);
(void) ge::TensorUtils::SetSize(*input_desc, size); (void) ge::TensorUtils::SetSize(*input_desc, size);
GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc));
GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(),
input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(),
TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str());


+ 7
- 5
ge/graph/build/memory/binary_block_mem_assigner.cc View File

@@ -21,8 +21,8 @@
namespace { namespace {
const uint32_t kRangeCeilInterval = 2; const uint32_t kRangeCeilInterval = 2;
const uint32_t kLogBase = 2; const uint32_t kLogBase = 2;
const int64_t kLargeBlockSize = 8 * 1024 * 1024;
const int64_t kLargeBlockRangeSize = 10;
const int64_t kLargeBlockSize = 8388608; // 8 * 1024 * 1024
const int64_t kLargeBlockRangeSize = 2;
} // namespace } // namespace


namespace ge { namespace ge {
@@ -73,15 +73,17 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) {
GELOGE(FAILED, "dividend is 0!"); GELOGE(FAILED, "dividend is 0!");
return FAILED; return FAILED;
} }
// Memory size is 512 aligned, so it is not necessary to take less than 512
int64_t min_memory_size = (all_memory_size.back() > MEM_ALIGN_SIZE) ? MEM_ALIGN_SIZE : all_memory_size.front();
auto range_number = static_cast<size_t>( auto range_number = static_cast<size_t>(
ceil(log(all_memory_size.back() / static_cast<double>(all_memory_size.front())) / log(kLogBase)));
ceil(log(all_memory_size.back() / static_cast<double>(min_memory_size)) / log(kLogBase)));
range_number = (range_number == 0) ? 1 : range_number; range_number = (range_number == 0) ? 1 : range_number;
GELOGD("Range number: %zu", range_number); GELOGD("Range number: %zu", range_number);


vector<vector<int64_t>> ranges(range_number); vector<vector<int64_t>> ranges(range_number);
GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0.");
size_t range_number_limit = all_memory_size.size() / range_number; size_t range_number_limit = all_memory_size.size() / range_number;
int64_t range_ceil = all_memory_size[0];
int64_t range_ceil = min_memory_size;
for (size_t i = 1; i <= range_number; i++) { for (size_t i = 1; i <= range_number; i++) {
GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval), GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval),
GELOGE(FAILED, "Multiply result is out of range."); GELOGE(FAILED, "Multiply result is out of range.");
@@ -114,7 +116,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) {
range_ceils.push_back(range.back()); range_ceils.push_back(range.back());
} }
} }
GELOGD("Range ceils: %s", ToString(range_ceils).c_str());
GELOGI("Range ceils: %s", ToString(range_ceils).c_str());


return SUCCESS; return SUCCESS;
} }


+ 328
- 194
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -65,6 +65,98 @@ void AlignMemOffset(size_t &mem_align_size) {
mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE;
} }


static bool CompareLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) {
auto left_node_op_desc = left.node->GetOpDesc();
auto right_node_op_desc = right.node->GetOpDesc();
if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)
&& (left_node_op_desc->GetId() < right_node_op_desc->GetId())) {
return true;
}
return false;
}

void GetLifeList(const MemoryBlock &block, std::vector<NodeTypeIndex> &life_list, bool child) {
for (auto &node : block.NodeTypeIndexList()) {
life_list.emplace_back(node);
}

if (child) {
for (auto child_block : block.ChildBlockList()) {
if (child_block == nullptr) {
continue;
}
if (block.stream_id_ != child_block->stream_id_ || !block.same_stream_ || !child_block->same_stream_) {
life_list.clear();
return;
}
GetLifeList(*child_block, life_list, child);
}
}
}

bool CrossLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) {
if ((left.node == nullptr) || (right.node == nullptr)) {
return true;
}
auto left_node_op_desc = left.node->GetOpDesc();
auto right_node_op_desc = right.node->GetOpDesc();
if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)) {
if (left_node_op_desc->GetId() < right_node_op_desc->GetId()) {
if (left.life_time_end >= static_cast<size_t>(right_node_op_desc->GetId())) {
return true;
}
} else if (left_node_op_desc->GetId() == right_node_op_desc->GetId()) {
return true;
} else {
if (right.life_time_end >= static_cast<size_t>(left_node_op_desc->GetId())) {
return true;
}
}
}
return false;
}

///
/// When child block's life time are not cross with parent block, they can be reused(only same stream).
/// |-----------------------------parent block---------------------|
/// |------child block1--------------||------child block2------|
/// |--child block1-1-|
///
bool CanIntervalLifeReuse(MemoryBlock &parent_block, MemoryBlock &child_block) {
// judge by interval life time, only same stream can be judged by interval life time
if (parent_block.stream_id_ != child_block.stream_id_ || !parent_block.same_stream_ || !child_block.same_stream_
|| parent_block.NodeTypeIndexList().empty() || child_block.NodeTypeIndexList().empty()) {
return false;
}

// quick judge by front and back node
if (CrossLifeTime(parent_block.NodeTypeIndexList().front(), child_block.NodeTypeIndexList().front())) {
return false;
}
if (CrossLifeTime(parent_block.NodeTypeIndexList().back(), child_block.NodeTypeIndexList().back())) {
return false;
}

std::vector<NodeTypeIndex> life_list;
GetLifeList(parent_block, life_list, false);
GetLifeList(child_block, life_list, true);
if (life_list.empty()) {
return false;
}
std::sort(life_list.begin(), life_list.end(), CompareLifeTime);
size_t pre_life_end = 0;
for (auto &node : life_list) {
auto node_op_desc = node.node->GetOpDesc();
if (node_op_desc != nullptr && pre_life_end >= static_cast<size_t>(node_op_desc->GetId())) {
// life time cross
return false;
}
pre_life_end = node.life_time_end;
}
GELOGI("Block size[%zu, %zu] life time are not cross.", parent_block.Size(), child_block.Size());
return true;
}

void MemoryBlock::SetHeadOffset(size_t offset) { void MemoryBlock::SetHeadOffset(size_t offset) {
head_offset_ = offset; head_offset_ = offset;
size_t child_offset = head_offset_; size_t child_offset = head_offset_;
@@ -125,20 +217,12 @@ size_t MemoryBlock::AlignSize() const {
return align_block_size; return align_block_size;
} }


bool MemoryBlock::IsSameLabel(std::string &first_batch_label) {
if (node_type_index_list_.empty()) {
bool MemoryBlock::IsSameBatchLabel() {
// only same batch label can reuse
if (batch_label_.empty() || node_type_index_list_.empty()) {
return false; return false;
} }


auto node_op_desc = node_type_index_list_[0].node->GetOpDesc();
if (node_op_desc == nullptr) {
return false;
}
// not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label);
if (first_batch_label.empty()) {
return false;
}
bool all_same_label = true; bool all_same_label = true;
for (size_t index = 1; index < node_type_index_list_.size(); ++index) { for (size_t index = 1; index < node_type_index_list_.size(); ++index) {
if (node_type_index_list_[index].node == nullptr) { if (node_type_index_list_[index].node == nullptr) {
@@ -147,8 +231,9 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) {
std::string batch_label; std::string batch_label;
auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); auto index_op_desc = node_type_index_list_[index].node->GetOpDesc();
GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue);
// not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter
(void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (first_batch_label != batch_label) {
if (batch_label_ != batch_label) {
all_same_label = false; all_same_label = false;
break; break;
} }
@@ -197,7 +282,7 @@ void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLi
} }


void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) {
if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) {
if (CanNotLifeReuse(this) || CanNotLifeReuse(block) || (batch_label_ != block->batch_label_)) {
return; return;
} }
if (block->continuous_block_) { if (block->continuous_block_) {
@@ -207,16 +292,27 @@ void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_
MemoryBlock *parent = nullptr; MemoryBlock *parent = nullptr;
MemoryBlock *child = nullptr; MemoryBlock *child = nullptr;
// merge small block to large block // merge small block to large block
if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) {
if ((child_offset_ + block->AlignSize()) <= AlignSize()) {
parent = this;
child = block;
} else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) {
parent = block;
child = this;
// noalign size 802816 + 802816 = 1605632 can reuse
// after 32 align size 802848 + 802848 > 1605664 can't reuse
// after 512 align size 803328 + 803328 > 1606144 can't reuse
// so 803328 + 803328 = 1606144 + 512 can reuse
if ((child_offset_ + block->AlignSize()) <= (AlignSize() + MEM_ALIGN_SIZE)) {
parent = this;
child = block;
} else if ((block->child_offset_ + AlignSize()) <= (block->AlignSize() + MEM_ALIGN_SIZE)) {
parent = block;
child = this;
}

if ((parent != nullptr) && (child != nullptr)) {
// Different streams must use stream dependency to judge the life cycle
// In case same stream if it has child block, can judge all the child block's life time in CanIntervalLifeReuse
bool can_block_life_reuse = (child->child_blocks_.empty()
&& (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()));
if (!can_block_life_reuse && !CanIntervalLifeReuse(*parent, *child)) {
return;
} }
}
if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) {

parent->child_blocks_.emplace_back(child); parent->child_blocks_.emplace_back(child);
parent->child_offset_ += child->AlignSize(); parent->child_offset_ += child->AlignSize();
child->deleted_block_ = true; child->deleted_block_ = true;
@@ -261,6 +357,7 @@ size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &tota
void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id,
std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) {
GE_CHECK_NOTNULL_EXEC(node, return); GE_CHECK_NOTNULL_EXEC(node, return);
GE_CHECK_NOTNULL_EXEC(org_node, return);
auto node_desc = node->GetOpDesc(); auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(node_desc, return); GE_CHECK_NOTNULL_EXEC(node_desc, return);
auto node_id = node_desc->GetId(); auto node_id = node_desc->GetId();
@@ -415,12 +512,60 @@ BlockMemAssigner::~BlockMemAssigner() {
} }
} }


void GetMaxBatchAllMemorySize(std::map<std::string, vector<int64_t>> &batch_all_memory_size,
std::map<std::string, int64_t> batch_total_size, vector<int64_t> &all_memory_size,
std::string &max_batch_label) {
// use max batch all memory size for reuse range
int64_t max_batch_size = 0;
for (const auto &it : batch_total_size) {
GELOGI("Batch[%s] total memory size[%ld]", it.first.c_str(), it.second);
// no batch label
if (it.first.empty()) {
continue;
}
if (it.second > max_batch_size) {
max_batch_size = it.second;
max_batch_label = it.first;
}
}
GELOGI("Max batch[%s] total memory size[%ld]", max_batch_label.c_str(), max_batch_size);

for (const auto &it : batch_all_memory_size) {
if (it.first.empty() || (it.first == max_batch_label)) {
all_memory_size.insert(all_memory_size.end(), it.second.begin(), it.second.end());
}
}
// all_memory_size can't be empty
if (all_memory_size.empty()) {
all_memory_size.emplace_back(MEM_ALIGN_SIZE);
}
sort(all_memory_size.begin(), all_memory_size.end());
GELOGD("All memory size: %s", ToString(all_memory_size).c_str());

for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) {
if (*iter == 0) {
iter = all_memory_size.erase(iter);
} else {
++iter;
}
}
}

void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
vector<int64_t> temp; vector<int64_t> temp;
std::map<std::string, vector<int64_t>> batch_all_memory_size;
std::map<std::string, int64_t> batch_total_size;
for (const NodePtr &n : compute_graph_->GetAllNodes()) { for (const NodePtr &n : compute_graph_->GetAllNodes()) {
auto node_op_desc = n->GetOpDesc(); auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);


if (CheckIsZeroMemNodeType(node_op_desc->GetType())) {
continue;
}

std::string batch_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);

if (node_op_desc->GetType() == ATOMICADDRCLEAN) { if (node_op_desc->GetType() == ATOMICADDRCLEAN) {
atomic_addr_clean_id_ = node_op_desc->GetId(); atomic_addr_clean_id_ = node_op_desc->GetId();
} }
@@ -434,9 +579,14 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
if (!reuse_input) { if (!reuse_input) {
int64_t size = 0; int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed"));
if (anchor_to_symbol_.empty()) {
all_memory_size.emplace_back(size);
batch_all_memory_size[batch_label].emplace_back(size);
if (batch_total_size.find(batch_label) == batch_total_size.end()) {
batch_total_size[batch_label] = size;
} else { } else {
batch_total_size[batch_label] += size;
}

if (!anchor_to_symbol_.empty()) {
auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString());
if (iter1 == anchor_to_symbol_.end()) { if (iter1 == anchor_to_symbol_.end()) {
continue; continue;
@@ -452,23 +602,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
} }
} }
temp.clear(); temp.clear();
GetNodeWorkSpaceSize(n, temp);
all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end());
}
for (const auto &pair : symbol_size_) {
all_memory_size.emplace_back(pair.second);
}
sort(all_memory_size.begin(), all_memory_size.end());
GELOGD("All memory size: %s", ToString(all_memory_size).c_str());

for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) {
if (*iter == 0) {
iter = all_memory_size.erase(iter);
} else {
++iter;
}
GetNodeWorkSpaceSize(n, temp, batch_total_size[batch_label]);
batch_all_memory_size[batch_label].insert(batch_all_memory_size[batch_label].end(), temp.begin(), temp.end());
} }

GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_);
GetMaxBatchAllMemorySize(batch_all_memory_size, batch_total_size, all_memory_size, max_batch_label_);
InitReuseFlag(); InitReuseFlag();
PrintSymbolMap(); PrintSymbolMap();
} }
@@ -529,16 +667,6 @@ bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const Me
bool can_reuse = false; bool can_reuse = false;
if (reusable_block.Size() == block_size) { if (reusable_block.Size() == block_size) {
can_reuse = true; can_reuse = true;
} else {
string key = std::to_string(reusable_block.Size());
key += "_" + std::to_string(reusable_block.stream_id_);
key += "_" + std::to_string(reusable_block.memory_type_);
auto it = reusable_block_counts.find(key);
GE_IF_BOOL_EXEC((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) &&
(reusable_block.Size() > block_size),
can_reuse = true;
GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu",
reusable_block.Size(), block_size););
} }
return can_reuse; return can_reuse;
} }
@@ -860,34 +988,35 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null.");
auto node_op_desc = n->GetOpDesc(); auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr);
std::string batch_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty() || (batch_label == max_batch_label_)) {
size_t align_size = real_size;
AlignMemOffset(align_size);
theory_memory_size_ += align_size;
if (theory_memory_size_ > theory_min_memory_size_) {
theory_min_memory_size_ = theory_memory_size_;
}
}


bool is_reuse_memory = false; bool is_reuse_memory = false;
string ge_disable_reuse_mem_env = "0";
(void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env);
if (ge_disable_reuse_mem_env != "1") {
if (ge_disable_reuse_mem_env_ != "1") {
bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) :
!((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]);
is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) &&
!node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem;
auto stream_id = node_op_desc->GetStreamId();
if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) {
for (auto it = reusable_blocks_[memory_type][stream_id].begin();
it != reusable_blocks_[memory_type][stream_id].end(); ++it) {
bool do_reuse = is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty();
if (do_reuse) {
auto stream_id = node_op_desc->GetStreamId();
for (auto it = reusable_blocks_[memory_type][stream_id].rbegin();
it != reusable_blocks_[memory_type][stream_id].rend(); ++it) {
MemoryBlock *reusable_block = *it; MemoryBlock *reusable_block = *it;
if (!IsPostReuse(reusable_block)) { if (!IsPostReuse(reusable_block)) {
reusable_block->reuse_mem_ = false; reusable_block->reuse_mem_ = false;
GELOGI("Unreusable block."); GELOGI("Unreusable block.");
continue; continue;
} }
std::string batch_label;
if (reusable_block->IsSameLabel(batch_label)) {
std::string op_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label);
if (batch_label != op_label) {
GELOGI("label diff, op name %s", node_op_desc->GetName().c_str());
continue;
}
}
GE_IF_BOOL_EXEC(reusable_block->batch_label_ != batch_label, continue);


// A node can reuse blocks of the same stream and preorder streams // A node can reuse blocks of the same stream and preorder streams
if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) {
@@ -901,7 +1030,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
reusable_block->continuous_block_ = continuous; reusable_block->continuous_block_ = continuous;
reusable_block->ref_count_++; reusable_block->ref_count_++;
ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); ReduceReusableBlockCount(*reusable_block, reusable_block_counts_);
reusable_blocks_[memory_type][stream_id].erase(it);
reusable_blocks_[memory_type][stream_id].erase((++it).base());
return reusable_block; return reusable_block;
} }
} }
@@ -914,10 +1043,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
// Data and netoutput need zero copy block // Data and netoutput need zero copy block
block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); block->is_zero_copy_ = IsZeroCopyBlock(n, continuous);


block->Init(real_size, mem_type, n, out_index, no_align_size);
block->Init(real_size, mem_type, n, out_index, no_align_size, node_op_desc->GetStreamId());
block->stream_id_ = node_op_desc->GetStreamId(); block->stream_id_ = node_op_desc->GetStreamId();
block->ref_count_++; block->ref_count_++;
block->continuous_block_ = continuous; block->continuous_block_ = continuous;
block->batch_label_ = batch_label;
if (mem_type == kOutput) { if (mem_type == kOutput) {
auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString());
if (iter != anchor_to_symbol_.end()) { if (iter != anchor_to_symbol_.end()) {
@@ -945,6 +1075,11 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
return nullptr; return nullptr;
} }


if (CheckIsZeroMemNodeType(n->GetType())) {
zero_memory_list_.emplace_back(n, kOutput, index);
continue;
}

int64_t size = 0; int64_t size = 0;
if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) {
GELOGI("Get size failed"); GELOGI("Get size failed");
@@ -957,9 +1092,7 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
// only apply total size in first block // only apply total size in first block
if (index != 0) { if (index != 0) {
zero_memory_list_.emplace_back(n, kOutput, index); zero_memory_list_.emplace_back(n, kOutput, index);
}

if (index == 0) {
} else {
NodeIndexIO node_index_io(n, index, kOut); NodeIndexIO node_index_io(n, index, kOut);
auto iter = anchor_to_symbol_.find(node_index_io.ToString()); auto iter = anchor_to_symbol_.find(node_index_io.ToString());
if (iter != anchor_to_symbol_.end()) { if (iter != anchor_to_symbol_.end()) {
@@ -972,6 +1105,10 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
} }
} }


if (total_size == 0) {
return nullptr;
}

auto block_size = GetBlockSize(total_size, ranges); auto block_size = GetBlockSize(total_size, ranges);
GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(),
total_size, block_size); total_size, block_size);
@@ -1119,15 +1256,28 @@ bool IsKnownSubgraphData(const NodePtr &node) {
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
} }


void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory) {
void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory,
bool same_stream) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null.");
GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory");
GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory");
--to_release->ref_count_; --to_release->ref_count_;
if (!same_stream) {
to_release->same_stream_ = false;
}
if (to_release->ref_count_ == 0) { if (to_release->ref_count_ == 0) {
to_release->SetLifeTimeEnd(life_time_);
reusable_memory.emplace_back(to_release);
AddReusableBlockCount(*to_release, reusable_block_counts_);
if (to_release->reuse_mem_ && !to_release->RealSizeList().empty()) {
if (to_release->batch_label_.empty() || (to_release->batch_label_ == max_batch_label_)) {
size_t align_size = to_release->RealSizeList().back();
AlignMemOffset(align_size);
theory_memory_size_ -= align_size;
}
}
if (to_release->same_stream_) {
to_release->SetLifeTimeEnd(life_time_);
reusable_memory.emplace_back(to_release);
AddReusableBlockCount(*to_release, reusable_block_counts_);
}
} }
} }


@@ -1167,10 +1317,9 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_map<string, vec
node_type_indexs.back().node->GetName().c_str()); node_type_indexs.back().node->GetName().c_str());


if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) &&
(node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx())) &&
(node->GetOpDesc()->GetStreamId() == block->stream_id_)) {
ReleaseMemory(block, reusable_memory);
if (block->ref_count_ == 0) {
(node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx()))) {
ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_));
if (block->ref_count_ == 0 && block->same_stream_) {
SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); SetLastUsedInputMemAttr(node, in_anchor->GetIdx());
} }
} }
@@ -1267,7 +1416,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector
bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType()));
if (!no_need_assign_memory) { if (!no_need_assign_memory) {
out_node_set_continuous_input = out_node_set_continuous_input =
IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, no_need_assign_memory, reset_zero_copy_flag);
IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index,
no_need_assign_memory, reset_zero_copy_flag);
GE_IF_BOOL_EXEC(!no_need_assign_memory, GE_IF_BOOL_EXEC(!no_need_assign_memory,
no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input););
} }
@@ -1328,7 +1478,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
iter->second[stream_id].clear(); iter->second[stream_id].clear();
} }
vector<int64_t> temp; vector<int64_t> temp;
GetNodeWorkSpaceSize(n, temp);
int64_t tatal_size = 0;
GetNodeWorkSpaceSize(n, temp, tatal_size);
vector<int64_t> workspace_bytes; vector<int64_t> workspace_bytes;
vector<int64_t> tvm_workspace_memory_type; vector<int64_t> tvm_workspace_memory_type;
bool has_tvm_workspace_mem_type_attr = bool has_tvm_workspace_mem_type_attr =
@@ -1349,7 +1500,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
bool workspace_skip_flag = false; bool workspace_skip_flag = false;
if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) {
GELOGI( GELOGI(
"fusion: node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]",
"fusion:node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]",
node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]);
workspace_skip_flag = true; workspace_skip_flag = true;
} }
@@ -1380,9 +1531,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
(void)mem_block; // Fix warning (void)mem_block; // Fix warning
} }


bool merge_dynamic_batch = false;
GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks());
GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size()));
GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), ReuseBlocksByLifeTime(ranges.size()));
AssignContinuousBlocks(); AssignContinuousBlocks();
ResizeMemoryBlocks(); ResizeMemoryBlocks();


@@ -1402,92 +1551,19 @@ void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_f
} }
} }


void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory) {
void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory,
int64_t &total_size) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null.");
vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes();


GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size());
for (int64_t byte_size : workspace_byte_nums) { for (int64_t byte_size : workspace_byte_nums) {
workspace_memory.emplace_back(byte_size); workspace_memory.emplace_back(byte_size);
total_size += byte_size;
GELOGD("push back size:%ld", byte_size); GELOGD("push back size:%ld", byte_size);
} }
} }


// descending order
static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) {
if (left == nullptr || right == nullptr) {
return false;
}
auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end());
if (left_max_size != left->RealSizeList().end()) {
auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end());
if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) {
return true;
}
}
return false;
}

void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &src) {
for (size_t i = 0; i < dest.size(); ++i) {
if (i >= src.size()) {
return;
}
if (dest[i] != nullptr && src[i] != nullptr) {
if (!dest[i]->reuse_mem_ || !src[i]->reuse_mem_) {
GELOGD("Diff batch's workspace can't be reused, i: %zu, dest[i]: %s, stream: %ld, src[i]: %s, stream: %ld.",
i, dest[i]->String().c_str(), dest[i]->stream_id_, src[i]->String().c_str(), src[i]->stream_id_);
continue;
}
for (auto &symbol : src[i]->SymbolList()) {
dest[i]->AddSymbol(symbol);
}
for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) {
dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j],
src[i]->RealSizeList()[j],
src[i]->NoAlignSizeList()[j]);
src[i]->deleted_block_ = true;
}
}
}
}

bool BlockMemAssigner::MergeDynamicBatchBlocks() {
bool merged = false;
std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
for (auto block : memory_blocks_) {
if (block == nullptr) {
continue;
}
std::string batch_label;
if (block->IsSameLabel(batch_label)) {
dynamic_batch_blocks[batch_label].emplace_back(block);
}
}

auto it = dynamic_batch_blocks.begin();
auto it_max = it;

// find max block counts
for (; it != dynamic_batch_blocks.end(); ++it) {
if (it->second.size() > it_max->second.size()) {
it_max = it;
}
std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize);
}
if (it_max != dynamic_batch_blocks.end()) {
GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size());
}
for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) {
if (it != it_max) {
GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str());
MergeBlocks(it_max->second, it->second);
merged = true;
}
}
return merged;
}

// asending order // asending order
static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) {
if (left == nullptr || right == nullptr) { if (left == nullptr || right == nullptr) {
@@ -1597,38 +1673,93 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) {
} }
} }


void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) {
if (block.memory_type_ == RT_MEMORY_HBM) {
if (block.first_continuous_block_) {
mem_offset += MEM_ALIGN_SIZE;
}
block.Resize();
block.SetHeadOffset(mem_offset);
mem_offset += block.Size();
block.SetTailOffset(mem_offset - 1);
} else if (block.memory_type_ == RT_MEMORY_P2P_DDR) {
if (block.first_continuous_block_) {
p2p_mem_offset += MEM_ALIGN_SIZE;
}
block.Resize();
block.SetHeadOffset(p2p_mem_offset);
p2p_mem_offset += block.Size();
block.SetTailOffset(p2p_mem_offset - 1);
}
}

bool DynamicBatchBlockReuse(MemoryBlock &block) {
return (block.IsSameBatchLabel() && block.reuse_mem_);
}

/// ///
/// @ingroup domi_omg /// @ingroup domi_omg
/// @brief traverse memory size, resize, calculate offset
/// @brief get max batch memory size, others reuse this block memory
/// @param [in&out] memory_blocks_ memory block, after calculating offset /// @param [in&out] memory_blocks_ memory block, after calculating offset
/// |-dynamic batch block batch1|
/// |-dynamic batch block batch2----|
/// |-dynamic batch block batch3--|
/// ///
void BlockMemAssigner::ResizeMemoryBlocks() {
for (auto &memory_block : memory_blocks_) {
if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) {
void BlockMemAssigner::ResizeDynamicBatchBlocks() {
std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
for (auto block : memory_blocks_) {
if (block == nullptr) {
continue; continue;
} }
if (memory_block->memory_type_ == RT_MEMORY_HBM) {
if (memory_block->first_continuous_block_) {
mem_offset_ += MEM_ALIGN_SIZE;
}
// when memory is not reuseable, it can't be reused by different branch
if (DynamicBatchBlockReuse(*block)) {
dynamic_batch_blocks[block->batch_label_].emplace_back(block);
}
}


memory_block->Resize();
memory_block->SetHeadOffset(mem_offset_);
mem_offset_ += memory_block->Size();
memory_block->SetTailOffset(mem_offset_ - 1);
} else if (memory_block->memory_type_ == RT_MEMORY_P2P_DDR) {
if (memory_block->first_continuous_block_) {
p2p_mem_offset_ += MEM_ALIGN_SIZE;
size_t max_mem_offset = mem_offset_;
size_t max_p2p_mem_offset = p2p_mem_offset_;
for (auto &batch_blocks : dynamic_batch_blocks) {
size_t mem_offset = mem_offset_;
size_t p2p_mem_offset = p2p_mem_offset_;
for (auto block : batch_blocks.second) {
if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) {
continue;
} }
AddBlockMemOffset(mem_offset, p2p_mem_offset, *block);
}
if (mem_offset > max_mem_offset) {
max_mem_offset = mem_offset;
}
if (p2p_mem_offset > max_p2p_mem_offset) {
max_p2p_mem_offset = p2p_mem_offset;
}
GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset);
}
mem_offset_ = max_mem_offset;
p2p_mem_offset_ = max_p2p_mem_offset;
}


memory_block->Resize();
memory_block->SetHeadOffset(p2p_mem_offset_);
p2p_mem_offset_ += memory_block->Size();
memory_block->SetTailOffset(p2p_mem_offset_ - 1);
///
/// @ingroup domi_omg
/// @brief traverse memory size, resize, calculate offset
/// @param [in&out] memory_blocks_ memory block, after calculating offset
/// |-not dynamic batch block-||-dynamic batch block batch1| |-zero copy block-|
/// |-not dynamic batch block-||-dynamic batch block batch2----||-zero copy block-|
/// |-not dynamic batch block-||-dynamic batch block batch3--| |-zero copy block-|
///
void BlockMemAssigner::ResizeMemoryBlocks() {
for (auto &memory_block : memory_blocks_) {
if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_
|| DynamicBatchBlockReuse(*memory_block)) {
continue;
} }

AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block);
} }
GELOGD("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu.",
mem_offset_, p2p_mem_offset_);
ResizeDynamicBatchBlocks();
GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu,"
"theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_);
} }


/// ///
@@ -1641,7 +1772,7 @@ void BlockMemAssigner::ResizeMemoryBlocks() {
/// @return Status result /// @return Status result
/// ///
void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block,
size_t real_size, size_t no_align_size, bool child_block) {
size_t real_size, size_t no_align_size, int32_t child_block_level) {
ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); ge::OpDescPtr op_desc = node_type.node->GetOpDesc();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null.");
string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); string graph_name = node_type.node->GetOwnerComputeGraph()->GetName();
@@ -1689,14 +1820,15 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block,
} }
op_desc->SetWorkspace(workspace_list); op_desc->SetWorkspace(workspace_list);
} }
GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]"
" noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(),
GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu] noalignsize[%zu] "
"life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", graph_name.c_str(),
op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(),
block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, block->reuse_mem_,
block->continuous_block_, block->deleted_block_, node_type.ref_input);
block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block_level, block->reuse_mem_,
block->continuous_block_, block->is_zero_copy_, block->same_stream_, node_type.ref_input,
block->batch_label_.c_str());
} }


void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) {
void SetBlockOpMemOffset(MemoryBlock *block, int32_t child_block_level) {
if (block == nullptr) { if (block == nullptr) {
return; return;
} }
@@ -1709,9 +1841,14 @@ void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) {
real_size = block->RealSizeList()[index]; real_size = block->RealSizeList()[index];
no_align_size = block->NoAlignSizeList()[index]; no_align_size = block->NoAlignSizeList()[index];
} }
SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block);
SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block_level);
index++; index++;
} }

child_block_level++;
for (MemoryBlock *child_block : block->ChildBlockList()) {
SetBlockOpMemOffset(child_block, child_block_level);
}
} }


void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
@@ -1724,16 +1861,13 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
continue; continue;
} }


SetBlockOpMemOffset(memory_block, false);
for (MemoryBlock *child_block : memory_block->ChildBlockList()) {
SetBlockOpMemOffset(child_block, true);
}
SetBlockOpMemOffset(memory_block, 0);
} }


if (!is_zero_copy) { if (!is_zero_copy) {
for (const NodeTypeIndex &node_type_index : zero_memory_list_) { for (const NodeTypeIndex &node_type_index : zero_memory_list_) {
MemoryBlock block(0, 0); MemoryBlock block(0, 0);
SetOffsetSize(node_type_index, &block, 0, 0, false);
SetOffsetSize(node_type_index, &block, 0, 0, 0);
} }
} }
} }


+ 30
- 6
ge/graph/build/memory/block_mem_assigner.h View File

@@ -65,6 +65,7 @@ class MemoryBlock {
stream_id_(stream_id), stream_id_(stream_id),
deleted_block_(false), deleted_block_(false),
reuse_mem_(reuse_mem), reuse_mem_(reuse_mem),
same_stream_(true),
input_index_(0), input_index_(0),
continuous_block_(false), continuous_block_(false),
first_continuous_block_(false), first_continuous_block_(false),
@@ -85,10 +86,14 @@ class MemoryBlock {
symbol_list_.clear(); symbol_list_.clear();
} }


void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) {
void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size,
int64_t stream_id) {
real_size_list_.emplace_back(real_size); real_size_list_.emplace_back(real_size);
no_align_size_list_.emplace_back(no_align_size); no_align_size_list_.emplace_back(no_align_size);
node_type_index_list_.emplace_back(node, type, out_index, false); node_type_index_list_.emplace_back(node, type, out_index, false);
if (stream_id != stream_id_) {
same_stream_ = false;
}
} }
size_t Size() const { return block_size_; } size_t Size() const { return block_size_; }


@@ -106,6 +111,12 @@ class MemoryBlock {
node_type_index_list_.emplace_back(node_type_index); node_type_index_list_.emplace_back(node_type_index);
real_size_list_.emplace_back(real_size); real_size_list_.emplace_back(real_size);
no_align_size_list_.emplace_back(no_align_size); no_align_size_list_.emplace_back(no_align_size);
if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) {
auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId();
if (stream_id != stream_id_) {
same_stream_ = false;
}
}
} }


void AddSymbol(const std::string &symbol) { void AddSymbol(const std::string &symbol) {
@@ -122,7 +133,7 @@ class MemoryBlock {


std::string String(); std::string String();


bool IsSameLabel(std::string &first_batch_label);
bool IsSameBatchLabel();


void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life);


@@ -142,6 +153,7 @@ class MemoryBlock {
int64_t stream_id_; int64_t stream_id_;
bool deleted_block_; bool deleted_block_;
bool reuse_mem_; bool reuse_mem_;
bool same_stream_;
uint32_t input_index_; uint32_t input_index_;
bool continuous_block_; bool continuous_block_;
bool first_continuous_block_; bool first_continuous_block_;
@@ -149,6 +161,7 @@ class MemoryBlock {
bool is_zero_copy_; bool is_zero_copy_;
std::map<int64_t, size_t> depend_stream_life_; std::map<int64_t, size_t> depend_stream_life_;
int64_t memory_type_; int64_t memory_type_;
std::string batch_label_;
private: private:
size_t block_size_; size_t block_size_;
std::vector<size_t> real_size_list_; std::vector<size_t> real_size_list_;
@@ -209,7 +222,7 @@ class BlockMemAssigner : public MemAssigner {


void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size); void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size);


void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory);
void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory, int64_t &total_size);


/// ///
/// @ingroup GE /// @ingroup GE
@@ -353,7 +366,7 @@ class BlockMemAssigner : public MemAssigner {
/// @return void /// @return void
/// @author /// @author
/// ///
void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory);
void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, bool same_stream = true);


/// ///
/// @ingroup GE /// @ingroup GE
@@ -379,11 +392,11 @@ class BlockMemAssigner : public MemAssigner {


/// ///
/// @ingroup GE /// @ingroup GE
/// @brief Merge memory blocks between different batchs
/// @brief Resize memory blocks for each batchs
/// @return merge or not /// @return merge or not
/// @author /// @author
/// ///
bool MergeDynamicBatchBlocks();
void ResizeDynamicBatchBlocks();


void AssignContinuousBlocks(); void AssignContinuousBlocks();


@@ -436,6 +449,17 @@ class BlockMemAssigner : public MemAssigner {


int64_t atomic_addr_clean_id_ = 0; int64_t atomic_addr_clean_id_ = 0;


size_t theory_min_memory_size_ = 0;

size_t theory_memory_size_ = 0;

std::string max_batch_label_;

///
/// @ [stream1][nodeid]
/// @[nodeid] [stream2][nodeid]
/// @ [stream2][nodeid]
///
DependStreamLife total_node_depend_stream_life_; DependStreamLife total_node_depend_stream_life_;
}; };
} // namespace ge } // namespace ge


+ 6
- 4
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -419,7 +419,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node,
GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1),
std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) +
" requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) +
" requires continuous output. There may be conflict between the two. This node is not supported now.";
" requires continuous output. There may be conflict between the two." +
"This node is not supported now.";
GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str());
return PARAM_INVALID;); return PARAM_INVALID;);


@@ -429,7 +430,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node,
GE_IF_BOOL_EXEC(is_peer_reference, GE_IF_BOOL_EXEC(is_peer_reference,
std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) +
" requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) +
" requires continuous output. There may be conflict between the two. This node is not supported now.";
" requires continuous output. There may be conflict between the two." +
"This node is not supported now.";
GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str());
return PARAM_INVALID;); return PARAM_INVALID;);


@@ -1646,9 +1648,9 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &node, const ve
} }
string atomic_mem_size_str = ss.str(); string atomic_mem_size_str = ss.str();


GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]",
GELOGI("[IMAS]SetAtomicCleanAttr : Set %s atomic_node name[%s] output[0] offset to [%s] streamid[%ld] size[%s]",
node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(),
atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId());
atomic_mem_start_str.c_str(), node->GetOpDesc()->GetStreamId(), atomic_mem_size_str.c_str());
} }
return SUCCESS; return SUCCESS;
} }


+ 4
- 4
ge/graph/build/model_builder.cc View File

@@ -282,7 +282,7 @@ Status ModelBuilder::SetInputOutputDesc() {
void ModelBuilder::AddNodeInputProperty() { void ModelBuilder::AddNodeInputProperty() {
for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) {
auto node_op_desc = node->GetOpDesc(); auto node_op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return );
GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return);
vector<string> src_name_list; vector<string> src_name_list;
vector<int64_t> src_index_list; vector<int64_t> src_index_list;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
@@ -309,10 +309,10 @@ void ModelBuilder::AddNodeInputProperty() {


for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) {
auto node_op_desc = node->GetOpDesc(); auto node_op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return );
GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return);
GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue);
auto out_control_anchor = node->GetOutControlAnchor(); auto out_control_anchor = node->GetOutControlAnchor();
GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return );
GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return);
vector<string> dst_name_list; vector<string> dst_name_list;
vector<int64_t> dst_index_list; vector<int64_t> dst_index_list;
string dst_name_temp; string dst_name_temp;
@@ -330,7 +330,7 @@ void ModelBuilder::AddNodeInputProperty() {
dst_name_temp = ""; dst_name_temp = "";
int64_t dst_index = kWrongIndex; // assign an impossible value to dst_index. int64_t dst_index = kWrongIndex; // assign an impossible value to dst_index.
for (const auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { for (const auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return );
GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return);
ge::NodePtr dst_node = in_data_anchor->GetOwnerNode(); ge::NodePtr dst_node = in_data_anchor->GetOwnerNode();
dst_name_temp = dst_name_temp.empty() ? dst_node->GetName() : dst_name_temp + ":" + dst_node->GetName(); dst_name_temp = dst_name_temp.empty() ? dst_node->GetName() : dst_name_temp + ":" + dst_node->GetName();
dst_index = in_data_anchor->GetIdx(); dst_index = in_data_anchor->GetIdx();


+ 2
- 1
ge/graph/build/stream_allocator.cc View File

@@ -49,7 +49,8 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &
} }


bool IsHcclOp(const string &op_type) { bool IsHcclOp(const string &op_type) {
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER,
ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
return hccl_op_types.find(op_type) != hccl_op_types.end(); return hccl_op_types.find(op_type) != hccl_op_types.end();
} }
} // namespace } // namespace


+ 1
- 1
ge/graph/build/stream_graph_optimizer.cc View File

@@ -38,7 +38,7 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap
continue; continue;
} }
for (ge::NodePtr &node : subgraph->GetDirectNode()) { for (ge::NodePtr &node : subgraph->GetDirectNode()) {
GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return );
GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return);
if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) {
node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); node->GetOpDesc()->SetId(static_cast<int64_t>(node_size));
node_size++; node_size++;


+ 3
- 17
ge/graph/build/task_generator.cc View File

@@ -49,8 +49,6 @@ const char *const kIsLastNode = "is_last_node";
const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsInputVar = "INPUT_IS_VAR";
const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR";
const char *const kProfilingMode = "PROFILING_MODE"; const char *const kProfilingMode = "PROFILING_MODE";
const char *const kProfilingFpPoint = "FP_POINT";
const char *const kProfilingBpPoint = "BP_POINT";
const uint32_t kProfilingArStep = 2; const uint32_t kProfilingArStep = 2;
const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingFpStartLogid = 1;
const uint64_t kProfilingBpEndLogid = 2; const uint64_t kProfilingBpEndLogid = 2;
@@ -810,35 +808,23 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint
vector<uint32_t> &all_reduce_nodes, std::string &fp_point_str, vector<uint32_t> &all_reduce_nodes, std::string &fp_point_str,
std::string &bp_point_str) const { std::string &bp_point_str) const {


if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_FPPONIT_OPTIONS, fp_point_str) == SUCCESS &&
ge::GetContext().GetOption(OPTION_EXEC_PROFILING_BPPONIT_OPTIONS, bp_point_str) == SUCCESS &&
!fp_point_str.empty() && !bp_point_str.empty()) {
return SUCCESS;
}
ProfilingManager::Instance().GetFpBpPoint(fp_point_str, bp_point_str);


Status ret = SUCCESS; Status ret = SUCCESS;
const char *fp_point = std::getenv(kProfilingFpPoint);
if (fp_point == nullptr) {
if (fp_point_str.empty()) {
ret = AutoFindFpOpIndex(graph, profiling_point); ret = AutoFindFpOpIndex(graph, profiling_point);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); GELOGW("First forward profiling op_index not set and FindFpOpIndex failed.");
return FAILED; return FAILED;
} }
} else {
fp_point_str = string(fp_point);
GELOGI("Get fp_point_str from env %s", fp_point_str.c_str());
} }


const char *bp_point = std::getenv(kProfilingBpPoint);
if (bp_point == nullptr) {
if (bp_point_str.empty()) {
ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed.");
return FAILED; return FAILED;
} }
} else {
bp_point_str = string(bp_point);
GELOGI("Get bp_point_str from env %s", bp_point_str.c_str());
} }


return SUCCESS; return SUCCESS;


+ 0
- 1
ge/graph/label/case_label_maker.h View File

@@ -86,7 +86,6 @@
| Node | | Node |
+------------+ +------------+
*******************************************************************************/ *******************************************************************************/

namespace ge { namespace ge {
class CaseOpLabelMaker : public LabelMaker { class CaseOpLabelMaker : public LabelMaker {
public: public:


+ 0
- 1
ge/graph/label/if_label_maker.h View File

@@ -70,7 +70,6 @@
| Node | | Node |
+------------+ +------------+
*******************************************************************************/ *******************************************************************************/

namespace ge { namespace ge {
class IfOpLabelMaker : public LabelMaker { class IfOpLabelMaker : public LabelMaker {
public: public:


+ 0
- 1
ge/graph/label/partitioned_call_label_maker.h View File

@@ -54,7 +54,6 @@
| c | | c |
+---------------+ +---------------+
*******************************************************************************/ *******************************************************************************/

namespace ge { namespace ge {
class PartitionedCallLabelMaker : public LabelMaker { class PartitionedCallLabelMaker : public LabelMaker {
public: public:


+ 0
- 1
ge/graph/label/while_label_maker.h View File

@@ -70,7 +70,6 @@
| Node | | Node |
+------------+ +------------+
*******************************************************************************/ *******************************************************************************/

namespace ge { namespace ge {
class WhileOpLabelMaker : public LabelMaker { class WhileOpLabelMaker : public LabelMaker {
public: public:


+ 2
- 1
ge/graph/load/graph_loader.cc View File

@@ -283,7 +283,8 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn
std::vector<GeTensorDesc> &output_desc) { std::vector<GeTensorDesc> &output_desc) {
auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc);
Status ret = model_manager->ExecuteModel(model_id, stream, async_mode,
input_data, input_desc, output_data, output_desc);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Execute model failed, model_id:%u.", model_id); GELOGE(ret, "Execute model failed, model_id:%u.", model_id);
return ret; return ret;


+ 2
- 2
ge/graph/load/new_model_manager/data_dumper.cc View File

@@ -919,11 +919,11 @@ Status DataDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> exceptio
ReplaceStringElem(op_name); ReplaceStringElem(op_name);
ReplaceStringElem(op_type); ReplaceStringElem(op_type);
string dump_file_path = string dump_file_path =
"./" + op_type + "." + op_name + "." + to_string(op_desc_info.task_id) + "." + to_string(now_time);
"./" + op_type + "." + op_name + "." + std::to_string(op_desc_info.task_id) + "." + std::to_string(now_time);
GELOGI("The exception dump file path is %s", dump_file_path.c_str()); GELOGI("The exception dump file path is %s", dump_file_path.c_str());


uint64_t proto_size = dump_data.ByteSizeLong(); uint64_t proto_size = dump_data.ByteSizeLong();
unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size);
if (!ret || proto_size == 0) { if (!ret || proto_size == 0) {
GELOGE(PARAM_INVALID, "Dump data proto serialize failed"); GELOGE(PARAM_INVALID, "Dump data proto serialize failed");


+ 108
- 133
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -16,7 +16,6 @@


#include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model.h"


#include <cce/dnn.h>
#include <graph/utils/node_utils.h> #include <graph/utils/node_utils.h>
#include <algorithm> #include <algorithm>
#include <map> #include <map>
@@ -84,7 +83,7 @@ const uint32_t kAddrLen = sizeof(void *);
const int kDecimal = 10; const int kDecimal = 10;
const int kBytes = 8; const int kBytes = 8;
const uint32_t kDataMemAlignSizeCompare = 64; const uint32_t kDataMemAlignSizeCompare = 64;
const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024;
const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024
const uint32_t kDumpFlagOfL1Fusion = 0; const uint32_t kDumpFlagOfL1Fusion = 0;
const char *const kDefaultBatchLable = "Batch_default"; const char *const kDefaultBatchLable = "Batch_default";
const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node";
@@ -331,8 +330,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size);
return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED;
} }
GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id,
mem_base_, data_size);
GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]",
runtime_param_.graph_id, mem_base_, data_size);


if (!is_inner_weight_base_) { if (!is_inner_weight_base_) {
weights_mem_base_ = mem_base_; weights_mem_base_ = mem_base_;
@@ -713,7 +712,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size
// collect profiling for ge // collect profiling for ge
auto &profiling_manager = ProfilingManager::Instance(); auto &profiling_manager = ProfilingManager::Instance();
if (profiling_manager.ProfilingModelLoadOn()) { if (profiling_manager.ProfilingModelLoadOn()) {
Status p_ret = ReportProfilingData(!profiling_manager.IsAclApiMode());
Status p_ret = ReportProfilingData();
if (p_ret != SUCCESS) { if (p_ret != SUCCESS) {
GELOGE(p_ret, "Report profiling data failed."); GELOGE(p_ret, "Report profiling data failed.");
return p_ret; return p_ret;
@@ -724,14 +723,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size
return ret; return ret;
} }


Status DavinciModel::ReportProfilingData(bool check_device) {
Status DavinciModel::ReportProfilingData() {
std::vector<ComputeGraphDescInfo> compute_graph_desc_info; std::vector<ComputeGraphDescInfo> compute_graph_desc_info;
Status ret = GetComputeGraphInfo(compute_graph_desc_info); Status ret = GetComputeGraphInfo(compute_graph_desc_info);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "GetComputeGraphInfo failed."); GELOGE(ret, "GetComputeGraphInfo failed.");
return ret; return ret;
} }
ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info, check_device);
ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info);
GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed.");
op_list_.clear(); op_list_.clear();


@@ -1544,7 +1543,8 @@ Status DavinciModel::LoadWithQueue() {
} }


if (output_queue_ids_.size() != new_output_data_info_.size()) { if (output_queue_ids_.size() != new_output_data_info_.size()) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu",
GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID,
"Output queue ids not match model: output_queue=%zu output_data=%zu",
output_queue_ids_.size(), new_output_data_info_.size()); output_queue_ids_.size(), new_output_data_info_.size());
return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID;
} }
@@ -2186,8 +2186,9 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data
const std::vector<DataBuffer> &blobs = input_data.blobs; const std::vector<DataBuffer> &blobs = input_data.blobs;
for (const auto &data : new_input_data_info_) { for (const auto &data : new_input_data_info_) {
if (data.first >= blobs.size()) { if (data.first >= blobs.size()) {
GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(),
new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first);
GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld, op_name(%s)", blobs.size(),
new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first,
data.second.GetOpName().c_str());
return FAILED; return FAILED;
} }


@@ -2198,13 +2199,14 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data
} }
uint64_t data_size = data.second.GetDataSize(); uint64_t data_size = data.second.GetDataSize();
GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID, GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID,
"input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length,
data_size);
"input data size(%lu) does not match model required size(%lu), op_name(%s) ret failed.",
data_buf.length, data_size, data.second.GetOpName().c_str());
void *mem_addr = data.second.GetBasicAddr(); void *mem_addr = data.second.GetBasicAddr();
void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data)); void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data));
uint64_t data_buf_length = data_buf.length; uint64_t data_buf_length = data_buf.length;
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]",
runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length);
GELOGI("CopyPlainData memcpy graph_%u type[F] input[%s] rank[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]",
runtime_param_.graph_id, data.second.GetOpName().c_str(), data.first, mem_addr, data_buf_addr, data_size,
data_buf_length);
GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind));
} }


@@ -2248,10 +2250,8 @@ inline int64_t SumSize(const vector<int64_t> &size_list) {


Status DavinciModel::SinkModelProfile() { Status DavinciModel::SinkModelProfile() {
// profiling plugin must be registered // profiling plugin must be registered
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS);

Msprof::Engine::ReporterData reporter_data{};
auto &prof_mgr = ProfilingManager::Instance();
ReporterData reporter_data{};
// report model data tag name // report model data tag name
std::string tag_name; std::string tag_name;
tag_name.append("model_load_info_").append(std::to_string(this->Id())); tag_name.append("model_load_info_").append(std::to_string(this->Id()));
@@ -2269,32 +2269,32 @@ Status DavinciModel::SinkModelProfile() {
reporter_data.deviceId = device_id_; reporter_data.deviceId = device_id_;
reporter_data.data = (unsigned char *)&name_len; reporter_data.data = (unsigned char *)&name_len;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


reporter_data.data = (unsigned char *)name.c_str(); reporter_data.data = (unsigned char *)name.c_str();
reporter_data.dataLen = name.size(); reporter_data.dataLen = name.size();
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


uint32_t model_id = this->Id(); uint32_t model_id = this->Id();
reporter_data.data = (unsigned char *)&model_id; reporter_data.data = (unsigned char *)&model_id;
reporter_data.dataLen = sizeof(uint32_t); reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


// Load Start/End Time // Load Start/End Time
int64_t start_time = this->GetLoadBeginTime(); int64_t start_time = this->GetLoadBeginTime();
reporter_data.data = (unsigned char *)&start_time; reporter_data.data = (unsigned char *)&start_time;
reporter_data.dataLen = sizeof(int64_t); reporter_data.dataLen = sizeof(int64_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


int64_t end_time = this->GetLoadEndTime(); int64_t end_time = this->GetLoadEndTime();
reporter_data.data = (unsigned char *)&end_time; reporter_data.data = (unsigned char *)&end_time;
reporter_data.dataLen = sizeof(int64_t); reporter_data.dataLen = sizeof(int64_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


int32_t task_num = task_list_.size(); int32_t task_num = task_list_.size();
std::multimap<uint32_t, uint32_t> op_id_map; std::multimap<uint32_t, uint32_t> op_id_map;
@@ -2308,6 +2308,7 @@ Status DavinciModel::SinkModelProfile() {
uint32_t op_num = fusion_op_info->original_op_names.size(); uint32_t op_num = fusion_op_info->original_op_names.size();
uint32_t task_id = task->GetTaskID(); uint32_t task_id = task->GetTaskID();
if (op_num > 0) { if (op_num > 0) {
GELOGI("task.id = %u, opNum = %u", task_id, op_num);
op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id)); op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id));
} }
} }
@@ -2350,39 +2351,39 @@ Status DavinciModel::SinkModelProfile() {
int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size();
reporter_data.data = (unsigned char *)&fusion_op_name_len; reporter_data.data = (unsigned char *)&fusion_op_name_len;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


reporter_data.data = (unsigned char *)fusion_op_name.c_str(); reporter_data.data = (unsigned char *)fusion_op_name.c_str();
reporter_data.dataLen = fusion_op_name_len; reporter_data.dataLen = fusion_op_name_len;
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


// original op name before fusion // original op name before fusion
reporter_data.data = (unsigned char *)&op_num; reporter_data.data = (unsigned char *)&op_num;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


for (uint32_t k = 0; k < op_num; k++) { for (uint32_t k = 0; k < op_num; k++) {
std::string op_name = fusion_op_info->original_op_names[k]; std::string op_name = fusion_op_info->original_op_names[k];
int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size();
reporter_data.data = (unsigned char *)&op_name_len; reporter_data.data = (unsigned char *)&op_name_len;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
reporter_data.data = (unsigned char *)op_name.c_str(); reporter_data.data = (unsigned char *)op_name.c_str();
reporter_data.dataLen = op_name_len; reporter_data.dataLen = op_name_len;
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
} }


// stream id info // stream id info
uint32_t streamId = task->GetStreamId(); uint32_t streamId = task->GetStreamId();
reporter_data.data = (unsigned char *)&streamId; reporter_data.data = (unsigned char *)&streamId;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


// memory info // memory info
struct memoryInfo memory_info; struct memoryInfo memory_info;
@@ -2398,22 +2399,22 @@ Status DavinciModel::SinkModelProfile() {
memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size; memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size;
reporter_data.data = (unsigned char *)&memory_info; reporter_data.data = (unsigned char *)&memory_info;
reporter_data.dataLen = sizeof(struct memoryInfo); reporter_data.dataLen = sizeof(struct memoryInfo);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


// task info // task info
reporter_data.data = (unsigned char *)&task_count; reporter_data.data = (unsigned char *)&task_count;
reporter_data.dataLen = sizeof(uint32_t); reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


Range task_range = op_id_map.equal_range(op_id); Range task_range = op_id_map.equal_range(op_id);
for (CIT idx = task_range.first; idx != task_range.second; ++idx) { for (CIT idx = task_range.first; idx != task_range.second; ++idx) {
uint32_t task_id = idx->second; uint32_t task_id = idx->second;
reporter_data.data = (unsigned char *)&task_id; reporter_data.data = (unsigned char *)&task_id;
reporter_data.dataLen = sizeof(uint32_t); reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
} }
} }
} }
@@ -2422,10 +2423,8 @@ Status DavinciModel::SinkModelProfile() {


Status DavinciModel::SinkTimeProfile(const InputData &current_data) { Status DavinciModel::SinkTimeProfile(const InputData &current_data) {
// profiling plugin must be registered // profiling plugin must be registered
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS);

Msprof::Engine::ReporterData reporter_data{};
auto &prof_mgr = ProfilingManager::Instance();
ReporterData reporter_data{};
// report model data tag name // report model data tag name
std::string tag_name; std::string tag_name;
tag_name.append("model_time_info_") tag_name.append("model_time_info_")
@@ -2448,33 +2447,33 @@ Status DavinciModel::SinkTimeProfile(const InputData &current_data) {
size_t name_len = name.size(); size_t name_len = name.size();
reporter_data.data = (unsigned char *)&name_len; reporter_data.data = (unsigned char *)&name_len;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


reporter_data.data = (unsigned char *)name.c_str(); reporter_data.data = (unsigned char *)name.c_str();
reporter_data.dataLen = name.size(); reporter_data.dataLen = name.size();
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.",
this->Id());
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());


// request id // request id
uint64_t request_id = current_data.request_id; uint64_t request_id = current_data.request_id;
reporter_data.data = (unsigned char *)&request_id; reporter_data.data = (unsigned char *)&request_id;
reporter_data.dataLen = sizeof(uint32_t); reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED,
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index);


// thread id // thread id
int32_t thread_id = GetDataInputTid(); int32_t thread_id = GetDataInputTid();
reporter_data.data = (unsigned char *)&thread_id; reporter_data.data = (unsigned char *)&thread_id;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED,
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index);


// time info // time info
time_info_.modelId = this->Id(); time_info_.modelId = this->Id();
reporter_data.data = (unsigned char *)&time_info_; reporter_data.data = (unsigned char *)&time_info_;
reporter_data.dataLen = sizeof(struct timeInfo); reporter_data.dataLen = sizeof(struct timeInfo);
GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED,
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index);


return SUCCESS; return SUCCESS;
@@ -2696,8 +2695,9 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b
is_getnext_sink_dynamic_ = true; is_getnext_sink_dynamic_ = true;
cur_dynamic_dims_.clear(); cur_dynamic_dims_.clear();
cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_);
GE_CHK_RT_RET(rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t),
netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST));
auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t),
netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST);
GE_CHK_RT_RET(ret);
} }
GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str()); GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str());
if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) {
@@ -2801,76 +2801,42 @@ void *DavinciModel::Run(DavinciModel *model) {
reinterpret_cast<int64_t *>(shape_data_buffer_data) + reinterpret_cast<int64_t *>(shape_data_buffer_data) +
shape_data_buffer_length / sizeof(int64_t)); shape_data_buffer_length / sizeof(int64_t));
GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str());
delete[] (int64_t *)current_data.blobs.back().data;
delete[] reinterpret_cast<int64_t *>(current_data.blobs.back().data);
current_data.blobs.pop_back(); current_data.blobs.pop_back();
} }
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END));
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START));
if (ProfilingManager::Instance().ProfilingOpTraceOn()) {
GELOGI("GetOpTraceIterNum:%d", ProfilingManager::Instance().GetOpTraceIterNum());
for (int32_t i = 0; i < ProfilingManager::Instance().GetOpTraceIterNum(); i++) {
if (!ProfilingManager::Instance().ProfilingLoadFlag()) {
vector<int32_t> prof_device_id_vec = ProfilingManager::Instance().GetProfilingDeviceId();
for (size_t j = 0; j < prof_device_id_vec.size(); ++j) {
// just profiling, no need to check value
(void)ProfilingManager::Instance().StartProfiling(i, prof_device_id_vec[j]);
}
}

GELOGI("rtModelExecute start.");
rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput());
continue); // [No need to check value]
GELOGI("rtModelExecute end");

GELOGI("rtStreamSynchronize start.");
rt_ret = rtStreamSynchronize(model->rt_model_stream_);
if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) {
GELOGI("The model with multiple datasets aborts normally.");
} else {
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, seq_end_flag, data_wrapper->GetOutput());
continue); // [No need to check value]
}

GELOGI("rtStreamSynchronize end.");
(void)ProfilingManager::Instance().StopProfiling(); // just profiling, no need to check value
}
GE_TIMESTAMP_START(rtModelExecute);
GELOGI("rtModelExecute start.");
rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput());
CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
continue);
GELOGI("rtModelExecute end");
GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute"));

GE_TIMESTAMP_START(rtStreamSynchronize);
GELOGI("rtStreamSynchronize start.");
rt_ret = rtStreamSynchronize(model->rt_model_stream_);
if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) {
seq_end_flag = true;
}
if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) {
GELOGI("The model with multiple datasets aborts normally.");
} else { } else {
GE_TIMESTAMP_START(rtModelExecute);
GELOGI("rtModelExecute start.");
rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput());
CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
continue);
GELOGI("rtModelExecute end");
GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute"));

GE_TIMESTAMP_START(rtStreamSynchronize);
GELOGI("rtStreamSynchronize start.");
rt_ret = rtStreamSynchronize(model->rt_model_stream_);
if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) {
seq_end_flag = true;
}
if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) {
GELOGI("The model with multiple datasets aborts normally.");
} else {
GE_IF_BOOL_EXEC(
rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag);
(void)model->ReturnResult(current_data.index, false, seq_end_flag,
data_wrapper->GetOutput()); // [No need to check value]
CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
continue);
}

GELOGI("rtStreamSynchronize end.");
GE_IF_BOOL_EXEC(model->is_first_execute_,
GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"));
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END));
GE_IF_BOOL_EXEC(
rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag);
(void)model->ReturnResult(current_data.index, false, seq_end_flag,
data_wrapper->GetOutput()); // [No need to check value]
CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
continue);
} }


GELOGI("rtStreamSynchronize end.");
GE_IF_BOOL_EXEC(model->is_first_execute_,
GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"));
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END));
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(),
model->SetProfileTime(MODEL_AFTER_PROC_START)); model->SetProfileTime(MODEL_AFTER_PROC_START));
GE_TIMESTAMP_START(ReturnResult3); GE_TIMESTAMP_START(ReturnResult3);
@@ -3170,21 +3136,29 @@ Status DavinciModel::DistributeTask() {


const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); const auto &model_task_def = ge_model_->GetModelTaskDefPtr();
for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) {
auto &task_def = model_task_def->task(task_index);
auto &task = task_list_.at(task_index); auto &task = task_list_.at(task_index);
GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index);
// for data dump // for data dump
auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(),
model_task_def->task(task_index).kernel_ex().op_index());
auto op_index = std::max(task_def.kernel().context().op_index(),
task_def.kernel_ex().op_index());
OpDescPtr op = GetOpByIndex(op_index); OpDescPtr op = GetOpByIndex(op_index);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);


SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId());
if (reinterpret_cast<void *>(task->GetDumpArgs()) != nullptr) { if (reinterpret_cast<void *>(task->GetDumpArgs()) != nullptr) {
bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo();
if (call_dump || is_op_debug_reg_) { if (call_dump || is_op_debug_reg_) {
SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs());
} }
} }

auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL)
&& (task_type != RT_MODEL_TASK_KERNEL_EX)
&& (task_type != RT_MODEL_TASK_HCCL);
GE_IF_BOOL_EXEC(no_need_profiling, continue);

SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId());
// Load task info for profiling // Load task info for profiling
TaskDescInfo task_desc_info; TaskDescInfo task_desc_info;
if (!om_name_.empty()) { if (!om_name_.empty()) {
@@ -3193,7 +3167,7 @@ Status DavinciModel::DistributeTask() {
task_desc_info.model_name = name_; task_desc_info.model_name = name_;
} }
task_desc_info.op_name = op->GetName(); task_desc_info.op_name = op->GetName();
task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim();
task_desc_info.block_dim = task_def.kernel().block_dim();
task_desc_info.task_id = task->GetTaskID(); task_desc_info.task_id = task->GetTaskID();
task_desc_info.stream_id = task->GetStreamId(); task_desc_info.stream_id = task->GetStreamId();
task_desc_info_.emplace_back(task_desc_info); task_desc_info_.emplace_back(task_desc_info);
@@ -3391,14 +3365,14 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64
/// ///
Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) {
if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) {
GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed.");
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update input data to model failed.");
return ACL_ERROR_GE_PARAM_INVALID;
} }


if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) !=
SUCCESS) { SUCCESS) {
GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed.");
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update output data to model failed.");
return ACL_ERROR_GE_PARAM_INVALID;
} }


for (ZeroCopyTask &task : zero_copy_tasks_) { for (ZeroCopyTask &task : zero_copy_tasks_) {
@@ -3444,7 +3418,7 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> &
} }


if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) {
GELOGE(FAILED, "Check input size and model size failed");
GELOGE(FAILED, "Check input size and model size failed, op[%s]", data.second.GetOpName().c_str());
return FAILED; return FAILED;
} }


@@ -3861,7 +3835,8 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa
if (!is_async_mode_) { if (!is_async_mode_) {
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START));
ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ACL_ERROR_GE_INTERNAL_ERROR,
"Copy Output data to user failed.");
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END));
} }


@@ -4061,7 +4036,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) {
data_dumper_.SetDeviceId(device_id); data_dumper_.SetDeviceId(device_id);


// set loop count addr // set loop count addr
auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void * {
auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void *{
if (op != nullptr) { if (op != nullptr) {
auto v_output_size = ModelUtils::GetOutputSize(op); auto v_output_size = ModelUtils::GetOutputSize(op);
auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op);


+ 1
- 1
ge/graph/load/new_model_manager/davinci_model.h View File

@@ -440,7 +440,7 @@ class DavinciModel {


Status SinkTimeProfile(const InputData &current_data); Status SinkTimeProfile(const InputData &current_data);


Status ReportProfilingData(bool check_device = true);
Status ReportProfilingData();


void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) {
data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id);


+ 24
- 56
ge/graph/load/new_model_manager/model_manager.cc View File

@@ -40,9 +40,7 @@ const int kCmdParSize = 2;
const int kDumpCmdPairSize = 2; const int kDumpCmdPairSize = 2;
const std::size_t kProfCmdParaMaxSize = 1000; const std::size_t kProfCmdParaMaxSize = 1000;
const std::size_t kProfStartCmdParaSize = 2; const std::size_t kProfStartCmdParaSize = 2;
const std::string kCmdTypeProfile = "profile";
const std::string kCmdTypeDump = "dump"; const std::string kCmdTypeDump = "dump";
const std::string kCmdTypeProfiling = "profiling";
const std::string kCmdTypeProfInit = "prof_init"; const std::string kCmdTypeProfInit = "prof_init";
const std::string kCmdTypeProfFinalize = "prof_finalize"; const std::string kCmdTypeProfFinalize = "prof_finalize";
const std::string kCmdTypeProfStart = "prof_start"; const std::string kCmdTypeProfStart = "prof_start";
@@ -51,6 +49,9 @@ const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe";
const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe";
const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; const char *const kBatchLoadBuf = "batchLoadsoFrombuf";
const char *const kDeleteCustOp = "deleteCustOp"; const char *const kDeleteCustOp = "deleteCustOp";
const int kTimeSpecNano = 1000000000;
const int kTimeSpecMiro = 1000000;
const int kSessionMaxBias = 100;
struct CustAicpuSoBuf { struct CustAicpuSoBuf {
uint64_t kernelSoBuf; uint64_t kernelSoBuf;
uint32_t kernelSoBufLen; uint32_t kernelSoBufLen;
@@ -224,7 +225,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) {


ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) {
GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id);
std::lock_guard<std::mutex> lock(sess_ids_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id);
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) {
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id);
@@ -237,7 +238,7 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_
} }


ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) {
std::lock_guard<std::mutex> lock(sess_ids_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
std::vector<uint64_t> v_aicpu_kernel; std::vector<uint64_t> v_aicpu_kernel;
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id);
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) {
@@ -345,7 +346,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge


GELOGI("Parse model %u success.", model_id); GELOGI("Parse model %u success.", model_id);


davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 +
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetProfileTime(MODEL_LOAD_END); davinci_model->SetProfileTime(MODEL_LOAD_END);
} while (0); } while (0);
@@ -629,8 +630,7 @@ Status ModelManager::Stop(uint32_t model_id) {
/// ///
Status ModelManager::HandleCommand(const Command &command) { Status ModelManager::HandleCommand(const Command &command) {
static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = {
{kCmdTypeProfile, HandleProfileCommand}, {kCmdTypeDump, HandleDumpCommand},
{kCmdTypeProfiling, HandleAclProfilingCommand}, {kCmdTypeProfInit, HandleProfInitCommand},
{kCmdTypeDump, HandleDumpCommand}, {kCmdTypeProfInit, HandleProfInitCommand},
{kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, {kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand},
{kCmdTypeProfStop, HandleProfStopCommand}, {kCmdTypeProfStop, HandleProfStopCommand},
{kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, {kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand},
@@ -645,21 +645,6 @@ Status ModelManager::HandleCommand(const Command &command) {
} }
} }


Status ModelManager::HandleAclProfilingCommand(const Command &command) {
if (command.cmd_params.size() < kCmdParSize) {
GELOGE(PARAM_INVALID, "When the cmd_type is 'profiling', the size of cmd_params must larger than 2.");
return PARAM_INVALID;
}

std::string map_key = command.cmd_params[0];
std::string value = command.cmd_params[1];
if (map_key == PROFILE_CONFIG) {
ProfilingManager::Instance().SetProfilingConfig(value);
}

return SUCCESS;
}

Status ModelManager::GetModelByCmd(const Command &command, Status ModelManager::GetModelByCmd(const Command &command,
std::shared_ptr<DavinciModel> &davinci_model) { std::shared_ptr<DavinciModel> &davinci_model) {
if (command.cmd_params.size() < kCmdParSize) { if (command.cmd_params.size() < kCmdParSize) {
@@ -806,29 +791,6 @@ Status ModelManager::HandleProfStopCommand(const Command &command) {
return SUCCESS; return SUCCESS;
} }


Status ModelManager::HandleProfileCommand(const Command &command) {
if (command.cmd_params.size() < kCmdParSize) {
GELOGE(PARAM_INVALID, "When the cmd_type is 'profile', the size of cmd_params must larger than 2.");
return PARAM_INVALID;
}

std::string map_key = command.cmd_params[0];
std::string value = command.cmd_params[1];

GELOGI("Profiling mode, Command key:%s , value:%s ", map_key.c_str(), value.c_str());

auto iter = PROFILE_COMPONENT_MAP.find(map_key);
if (iter != PROFILE_COMPONENT_MAP.end()) {
std::string property_value = (value == "on") ? "1" : "0";
PropertiesManager::Instance().SetPropertyValue(iter->second, property_value);
}

if ((map_key == PROFILER_JOBCTX || map_key == PROFILER_TARGET_PATH || map_key == RTS_PROFILE_PATH)) {
PropertiesManager::Instance().SetPropertyValue(map_key, value);
}
return SUCCESS;
}

static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) { static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) {
auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key); auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key);
if (iter != command.cmd_params.end()) { if (iter != command.cmd_params.end()) {
@@ -1072,12 +1034,12 @@ Status ModelManager::GenSessionId(uint64_t &session_id) {
GELOGE(INTERNAL_ERROR, "Failed to get current time."); GELOGE(INTERNAL_ERROR, "Failed to get current time.");
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
session_id = static_cast<uint64_t>(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us
session_id = static_cast<uint64_t>(tv.tv_sec * kTimeSpecMiro + tv.tv_usec); // 1000000us


session_id_bias_++; session_id_bias_++;
// max bais 100. // max bais 100.
session_id_bias_ = session_id_bias_ % 100;
session_id = session_id * 100 + session_id_bias_;
session_id_bias_ = session_id_bias_ % kSessionMaxBias;
session_id = session_id * kSessionMaxBias + session_id_bias_;


GELOGD("Generate new session id: %lu.", session_id); GELOGD("Generate new session id: %lu.", session_id);
return SUCCESS; return SUCCESS;
@@ -1086,8 +1048,7 @@ Status ModelManager::GenSessionId(uint64_t &session_id) {
Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener, Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener,
void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) {
GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK,
ACL_ERROR_GE_PARAM_INVALID,
"input key file path %s is invalid, %s", model.key.c_str(), strerror(errno));
ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno));
GenModelId(&model_id); GenModelId(&model_id);


shared_ptr<DavinciModel> davinci_model = nullptr; shared_ptr<DavinciModel> davinci_model = nullptr;
@@ -1148,7 +1109,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model


GELOGI("Parse model %u success.", model_id); GELOGI("Parse model %u success.", model_id);


davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 +
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetProfileTime(MODEL_LOAD_END); davinci_model->SetProfileTime(MODEL_LOAD_END);


@@ -1252,7 +1213,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy
} }


std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"Invalid model id %u, check weather model has been loaded or not.", model_id);


if (davinci_model->NeedDestroyAicpuKernel()) { if (davinci_model->NeedDestroyAicpuKernel()) {
GELOGI("Start to destroy specified aicpu kernel."); GELOGI("Start to destroy specified aicpu kernel.");
@@ -1289,8 +1251,8 @@ Status ModelManager::CreateAicpuSession(uint64_t session_id) {
return SUCCESS; return SUCCESS;
} }


Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name) {
GELOGI("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str());
Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded) {
GELOGD("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str());
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); std::lock_guard<std::mutex> lock(cust_aicpu_mutex_);
CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr());
if (aicpu_kernel == nullptr) { if (aicpu_kernel == nullptr) {
@@ -1313,18 +1275,24 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_
std::map<string, CustAICPUKernelPtr> new_so_name; std::map<string, CustAICPUKernelPtr> new_so_name;
new_so_name.insert({so_name, aicpu_kernel}); new_so_name.insert({so_name, aicpu_kernel});
cust_aicpu_so_[resource_id] = new_so_name; cust_aicpu_so_[resource_id] = new_so_name;
GELOGI("LoadCustAicpuSo new aicpu so resource id %lu", resource_id);
loaded = false;
GELOGD("LoadCustAicpuSo new aicpu so name %s, resource id %lu", so_name.c_str(), resource_id);
return SUCCESS; return SUCCESS;
} }
auto it_so_name = it->second.find(so_name); auto it_so_name = it->second.find(so_name);
if (it_so_name == it->second.end()) { if (it_so_name == it->second.end()) {
it->second.insert({so_name, aicpu_kernel}); it->second.insert({so_name, aicpu_kernel});
GELOGI("LoadCustAicpuSo add aicpu so resource id %lu", resource_id);
loaded = false;
GELOGD("LoadCustAicpuSo add aicpu so name %s, resource id %lu", so_name.c_str(), resource_id);
return SUCCESS;
} }
loaded = true;
GELOGD("LoadCustAicpuSo so name %s has been loaded.", so_name.c_str());
return SUCCESS; return SUCCESS;
} }


Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str());
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); std::lock_guard<std::mutex> lock(cust_aicpu_mutex_);
if (cust_aicpu_so_.size() == 0) return SUCCESS; if (cust_aicpu_so_.size() == 0) return SUCCESS;
// get current context // get current context


+ 1
- 3
ge/graph/load/new_model_manager/model_manager.h View File

@@ -169,8 +169,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
/// @brief comment handle function /// @brief comment handle function
/// ///
ge::Status HandleCommand(const Command &command); ge::Status HandleCommand(const Command &command);
static ge::Status HandleAclProfilingCommand(const Command &command);
static ge::Status HandleProfileCommand(const Command &command);
static ge::Status HandleDumpCommand(const Command &command); static ge::Status HandleDumpCommand(const Command &command);
static ge::Status HandleProfModelSubscribeCommand(const Command &command); static ge::Status HandleProfModelSubscribeCommand(const Command &command);
static ge::Status HandleProfModelUnsubscribeCommand(const Command &command); static ge::Status HandleProfModelUnsubscribeCommand(const Command &command);
@@ -289,7 +287,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {


ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); ge::Status DestroyAicpuSessionForInfer(uint32_t model_id);


ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name);
ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded);


ge::Status LaunchCustAicpuSo(); ge::Status LaunchCustAicpuSo();




+ 2
- 2
ge/graph/load/new_model_manager/model_utils.cc View File

@@ -61,7 +61,7 @@ vector<int64_t> ModelUtils::GetInputSize(ConstOpDescPtr op_desc) {
GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i);
continue); continue);


GELOGI("[IMAS]GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
GELOGI("GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
v_input_size.push_back(tensor_size); v_input_size.push_back(tensor_size);
} }


@@ -96,7 +96,7 @@ vector<int64_t> ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) {
GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i);
continue); continue);


GELOGI("[IMAS]GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
GELOGI("GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
v_output_size.push_back(tensor_size); v_output_size.push_back(tensor_size);
} }




+ 3
- 2
ge/graph/load/new_model_manager/task_info/hccl_task_info.cc View File

@@ -279,9 +279,10 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc,
output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i];
} }
kernel_hccl_infos[i].inputDataAddr = input_data_addr; kernel_hccl_infos[i].inputDataAddr = input_data_addr;
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER || hccl_type == HCOMREDUCE) {
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) {
kernel_hccl_infos[i].outputDataAddr = output_data_addr; kernel_hccl_infos[i].outputDataAddr = output_data_addr;
} else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) {
} else if (hccl_type == HCOMALLREDUCE ||
hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) {
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type),
"davinci_model: GetHcomOperationType fail!"); "davinci_model: GetHcomOperationType fail!");
kernel_hccl_infos[i].outputDataAddr = output_data_addr; kernel_hccl_infos[i].outputDataAddr = output_data_addr;


+ 30
- 21
ge/graph/load/new_model_manager/task_info/kernel_task_info.cc View File

@@ -43,6 +43,13 @@ const char *kIsLastNode = "is_last_node";
const char *kIsFirstNode = "is_first_node"; const char *kIsFirstNode = "is_first_node";
const int64_t kCloseSkt = 100; const int64_t kCloseSkt = 100;
const uint32_t kAddrLen = sizeof(void *); const uint32_t kAddrLen = sizeof(void *);
const int kBaseInt = 10;
const int kStrtolFail = 0;
const int kArgsInputDesc = 0;
const int kArgsInputAddr = 1;
const int kArgsOutputDesc = 2;
const int kArgsOutputAddr = 3;
const int kArgsAttrHandle = 4;
} // namespace } // namespace


namespace ge { namespace ge {
@@ -66,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
// get opcontext stored in model // get opcontext stored in model
const domi::KernelContext &context = kernel_def.context(); const domi::KernelContext &context = kernel_def.context();
// get kernel_type // get kernel_type
kernel_type_ = static_cast<cce::ccKernelType>(context.kernel_type());
kernel_type_ = static_cast<ccKernelType>(context.kernel_type());
// get opdesc // get opdesc
op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); op_desc_ = davinci_model_->GetOpByIndex(context.op_index());
GE_CHECK_NOTNULL(op_desc_); GE_CHECK_NOTNULL(op_desc_);
@@ -88,13 +95,13 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
// get bin_file_key // get bin_file_key
const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id);
// new aicpu kernel(rtCpuKernelLaunch) no need to check function // new aicpu kernel(rtCpuKernelLaunch) no need to check function
if (kernel_type_ == cce::ccKernelType::CCE_AI_CORE) {
if (kernel_type_ == ccKernelType::CCE_AI_CORE) {
rtError_t rt_ret; rtError_t rt_ret;
rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_); rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s",
kernel_def.stub_func().c_str()); kernel_def.stub_func().c_str());
return RT_ERROR_TO_GE_STATUS(rt_ret);); return RT_ERROR_TO_GE_STATUS(rt_ret););
} else if (kernel_type_ == cce::ccKernelType::TE) {
} else if (kernel_type_ == ccKernelType::TE) {
rtError_t rt_ret; rtError_t rt_ret;
rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE,
@@ -111,7 +118,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
ctx_.opIndex2[i] = context.origin_op_index(i); ctx_.opIndex2[i] = context.origin_op_index(i);
} }
ctx_.opCount = context.origin_op_index_size(); ctx_.opCount = context.origin_op_index_size();
if (kernel_type_ == cce::ccKernelType::TE) {
if (kernel_type_ == ccKernelType::TE) {
ctx_.opIndex = context.op_index(); ctx_.opIndex = context.op_index();
uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())); uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data()));
if (context.args_offset().size() / sizeof(uint16_t) < 1) { if (context.args_offset().size() / sizeof(uint16_t) < 1) {
@@ -120,9 +127,9 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
} }


ret = InitTVMTask(args_offset_tmp[0], kernel_def); ret = InitTVMTask(args_offset_tmp[0], kernel_def);
} else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) {
} else if (kernel_type_ == ccKernelType::CUSTOMIZED) {
ret = InitAICPUCustomTask(context.op_index(), kernel_def); ret = InitAICPUCustomTask(context.op_index(), kernel_def);
} else if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) {
} else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
ret = InitAicpuTask(context.op_index(), kernel_def); ret = InitAicpuTask(context.op_index(), kernel_def);
} else { } else {
if (kernel_def.args().empty() || args_size_ == 0) { if (kernel_def.args().empty() || args_size_ == 0) {
@@ -371,9 +378,9 @@ Status KernelTaskInfo::Distribute() {
rtError_t rt_ret = RT_ERROR_NONE; rtError_t rt_ret = RT_ERROR_NONE;
char skt_enable_env[MMPA_MAX_PATH] = { 0x00 }; char skt_enable_env[MMPA_MAX_PATH] = { 0x00 };
INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH); INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH);
int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, 10) : 0;
int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail;
bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_);
if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) {
if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_);
// blockDim is reserved parameter, set to 1 // blockDim is reserved parameter, set to 1
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()),
@@ -749,15 +756,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel
return FAILED; return FAILED;
} }
} }
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputDesc])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsInputAddr])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputDesc])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsOutputAddr])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) =
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[kArgsAttrHandle])) =
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 static_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4


rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);
@@ -874,8 +881,10 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) {
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_), "launch cust aicpu so failed");
if (kernel_type_ == ccKernelType::CUST_AI_CPU) {
bool loaded = false;
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_, loaded),
"launch cust aicpu so failed");
} }


// copy args to new host memory // copy args to new host memory
@@ -946,7 +955,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
GELOGI("Op debug is open in aicpu task info"); GELOGI("Op debug is open in aicpu task info");
dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead);
} }
if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) {
if (kernel_type_ == ccKernelType::CUST_AI_CPU) {
dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; dump_flag_ |= RT_KERNEL_CUSTOM_AICPU;
} }


@@ -1076,7 +1085,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector<void *> &input_d


Status KernelTaskInfo::SetContext(const domi::KernelDef &kernel_def) { Status KernelTaskInfo::SetContext(const domi::KernelDef &kernel_def) {
const domi::KernelContext &context = kernel_def.context(); const domi::KernelContext &context = kernel_def.context();
ctx_.kernelType = static_cast<cce::ccKernelType>(context.kernel_type());
ctx_.kernelType = static_cast<ccKernelType>(context.kernel_type());
ctx_.opId = context.op_id(); ctx_.opId = context.op_id();
ctx_.kernelFuncId = context.kernel_func_id(); ctx_.kernelFuncId = context.kernel_func_id();
ctx_.isFlowtable = context.is_flowtable(); ctx_.isFlowtable = context.is_flowtable();
@@ -1161,10 +1170,10 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u
GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", error); GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", error);
return FAILED; return FAILED;
} }
cce::ccStatus_t cc_ret;
ccStatus_t cc_ret;
std::string update_kernel_args = "ccUpdateKernelArgs"; std::string update_kernel_args = "ccUpdateKernelArgs";
auto cceUpdateKernelArgs = (cce::ccStatus_t(*)(cce::ccOpContext &, uint64_t, uint64_t, uint64_t, void *, uint64_t,
void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str()));
auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t,
uint64_t, void *, uint64_t, void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str()));
if (cceUpdateKernelArgs == nullptr) { if (cceUpdateKernelArgs == nullptr) {
GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs");
if (mmDlclose(handle) != 0) { if (mmDlclose(handle) != 0) {
@@ -1189,7 +1198,7 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u
GELOGW("Failed to close handle %s", error); GELOGW("Failed to close handle %s", error);
return FAILED; return FAILED;
} }
if (cc_ret != cce::CC_STATUS_SUCCESS) {
if (cc_ret != CC_STATUS_SUCCESS) {
GELOGE(CCE_FAILED, "Call cce api failed, ret: 0x%X", cc_ret); GELOGE(CCE_FAILED, "Call cce api failed, ret: 0x%X", cc_ret);
return CCE_FAILED; return CCE_FAILED;
} }


+ 4
- 4
ge/graph/load/new_model_manager/task_info/kernel_task_info.h View File

@@ -43,7 +43,7 @@ class KernelTaskInfo : public TaskInfo {
stream_id_(0), stream_id_(0),
so_name_(""), so_name_(""),
kernel_name_(""), kernel_name_(""),
kernel_type_(cce::ccKernelType::CCE_AI_CORE),
kernel_type_(ccKernelType::CCE_AI_CORE),
dump_flag_(RT_KERNEL_DEFAULT), dump_flag_(RT_KERNEL_DEFAULT),
dump_args_(nullptr), dump_args_(nullptr),
op_desc_(nullptr), op_desc_(nullptr),
@@ -75,7 +75,7 @@ class KernelTaskInfo : public TaskInfo {


Status Release() override; Status Release() override;


cce::ccOpContext *GetCtx() override { return &ctx_; }
ccOpContext *GetCtx() override { return &ctx_; }


FusionOpInfo *GetFusionOpInfo() override { return &fusion_op_info_; } FusionOpInfo *GetFusionOpInfo() override { return &fusion_op_info_; }


@@ -92,7 +92,7 @@ class KernelTaskInfo : public TaskInfo {


bool CallSaveDumpInfo() override { return call_save_dump_; }; bool CallSaveDumpInfo() override { return call_save_dump_; };


cce::ccOpContext ctx_;
ccOpContext ctx_;
FusionOpInfo fusion_op_info_; FusionOpInfo fusion_op_info_;


private: private:
@@ -153,7 +153,7 @@ class KernelTaskInfo : public TaskInfo {
uint32_t stream_id_; uint32_t stream_id_;
std::string so_name_; std::string so_name_;
std::string kernel_name_; std::string kernel_name_;
cce::ccKernelType kernel_type_;
ccKernelType kernel_type_;
uint32_t dump_flag_; uint32_t dump_flag_;
void *dump_args_; void *dump_args_;
OpDescPtr op_desc_; OpDescPtr op_desc_;


+ 2
- 2
ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h View File

@@ -41,7 +41,7 @@ class StreamSwitchTaskInfo : public TaskInfo {


Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;
private: private:
void SetInputAndValuePtr(DavinciModel *davinci_model, const vector<void *> &input_data_addrs);
void SetInputAndValuePtr(DavinciModel *davinci_model, const std::vector<void *> &input_data_addrs);
void *input_ptr_; void *input_ptr_;
rtCondition_t cond_; rtCondition_t cond_;
void *value_ptr_; void *value_ptr_;
@@ -49,7 +49,7 @@ class StreamSwitchTaskInfo : public TaskInfo {
uint32_t true_stream_id_; uint32_t true_stream_id_;
rtSwitchDataType_t data_type_; rtSwitchDataType_t data_type_;
static const uint32_t kInputNum = 2; static const uint32_t kInputNum = 2;
vector<int64_t> fixed_addr_offset_;
std::vector<int64_t> fixed_addr_offset_;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_

+ 5
- 4
ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc View File

@@ -25,10 +25,11 @@ Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) {
const void *args[] = {this->GetNavTablePtr(), const void *args[] = {this->GetNavTablePtr(),
reinterpret_cast<const void *>(static_cast<uintptr_t>(this->GetNavTableSize()))}; reinterpret_cast<const void *>(static_cast<uintptr_t>(this->GetNavTableSize()))};


rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); return
RT_ERROR_TO_GE_STATUS(rt_ret);)
rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE);
rtError_t rt_ret = rtMalloc(reinterpret_cast<void **>(&device_args_addr_), sizeof(args), RT_MEMORY_HBM);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);)
rt_ret = rtMemcpy(reinterpret_cast<void *>(device_args_addr_), sizeof(args), reinterpret_cast<void *>(args),
sizeof(args), RT_MEMCPY_HOST_TO_DEVICE);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);) return RT_ERROR_TO_GE_STATUS(rt_ret);)
rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream,


+ 14
- 12
ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc View File

@@ -19,6 +19,8 @@


namespace ge { namespace ge {
namespace skt { namespace skt {
const size_t kFusedKernelMinimumSize = 2;
const size_t kFusedKernelSizeUnit = 2;
SuperKernelFactory &SuperKernelFactory::GetInstance() { SuperKernelFactory &SuperKernelFactory::GetInstance() {
static SuperKernelFactory factory; static SuperKernelFactory factory;
return factory; return factory;
@@ -79,17 +81,17 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list
return FAILED; return FAILED;
} }


if (super_kernel_size < 2) {
if (super_kernel_size < kFusedKernelMinimumSize) {
GELOGW( GELOGW(
"SKT: the number of kernels being fused must be greater than or " "SKT: the number of kernels being fused must be greater than or "
"equal to 2"); "equal to 2");
return FAILED; return FAILED;
} }
GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size()); GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size());
const size_t nav_table_len = 2 * stub_func_list.size();
const size_t nav_table_len = kFusedKernelSizeUnit * stub_func_list.size();
std::unique_ptr<uint64_t[]> nav_table(new(std::nothrow) uint64_t[nav_table_len]); std::unique_ptr<uint64_t[]> nav_table(new(std::nothrow) uint64_t[nav_table_len]);
GE_CHECK_NOTNULL(nav_table); GE_CHECK_NOTNULL(nav_table);
uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t);
uint64_t nav_table_size = kFusedKernelSizeUnit * stub_func_list.size() * sizeof(int64_t);


rtError_t rt_ret; rtError_t rt_ret;
void *hbm_nav_table_addr = nullptr; void *hbm_nav_table_addr = nullptr;
@@ -101,21 +103,21 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list
GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func);
// store two uint64_t address // store two uint64_t address
// address divided by 4 because of 32bits encoding, call offset will *4 when calculating // address divided by 4 because of 32bits encoding, call offset will *4 when calculating
nav_table[i * 2] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4;
GELOGD("SKT: CALL offet %lu", nav_table[i * 2]);
nav_table[i * 2 + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i]));
GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]);
nav_table[i * kFusedKernelSizeUnit] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4;
GELOGD("SKT: CALL offet %lu", nav_table[i * kFusedKernelSizeUnit]);
nav_table[i * kFusedKernelSizeUnit + 1] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i]));
GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * kFusedKernelSizeUnit + 1]);
} }
rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM);
rt_ret = rtMalloc(reinterpret_cast<void **>(&hbm_nav_table_addr), nav_table_size, RT_MEMORY_HBM);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);) return RT_ERROR_TO_GE_STATUS(rt_ret);)
rt_ret =
rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table.get(), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE);
rt_ret = rtMemcpy(reinterpret_cast<void *>(hbm_nav_table_addr), nav_table_size,
reinterpret_cast<void *>(nav_table.get()), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret);
GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);)
// Create the necessary metadata for the super kernel // Create the necessary metadata for the super kernel
h = std::unique_ptr<skt::SuperKernel>(
new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim));
h =
std::unique_ptr<skt::SuperKernel>(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim));
return SUCCESS; return SUCCESS;
} }
} // namespace skt } // namespace skt


+ 4
- 4
ge/graph/load/new_model_manager/task_info/task_info.h View File

@@ -20,7 +20,7 @@
#include <vector> #include <vector>


#include "cce/customize.h" #include "cce/customize.h"
#include "cce/taskdown_common.hpp"
#include "framework/common/taskdown_common.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "graph/load/new_model_manager/ts_mem_mall.h" #include "graph/load/new_model_manager/ts_mem_mall.h"
#include "graph/load/new_model_manager/task_info/task_info_factory.h" #include "graph/load/new_model_manager/task_info/task_info_factory.h"
@@ -63,8 +63,8 @@ struct RuntimeParam {
}; };


typedef struct FusionOpInfo { typedef struct FusionOpInfo {
vector<string> original_op_names;
string op_name;
std::vector<std::string> original_op_names;
std::string op_name;
uint32_t op_index; uint32_t op_index;
uint32_t stream_id; uint32_t stream_id;
} FusionOpInfo; } FusionOpInfo;
@@ -87,7 +87,7 @@ class TaskInfo {


virtual Status Release() { return SUCCESS; } virtual Status Release() { return SUCCESS; }


virtual cce::ccOpContext *GetCtx() { return nullptr; }
virtual ccOpContext *GetCtx() { return nullptr; }


virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } virtual uint32_t GetTaskID() { return 0xFFFFFFFF; }




+ 1
- 1
ge/graph/load/new_model_manager/ts_mem_mall.h View File

@@ -25,7 +25,7 @@
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"


namespace { namespace {
constexpr uint32_t kMaxTsMemBlock = 2 * 1024 * 1024; // Max block 2M
constexpr uint32_t kMaxTsMemBlock = 2097152; // Max block 2M 2 * 1024 * 1024
constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align
constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1; constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1;
} }


+ 2
- 0
ge/graph/load/new_model_manager/zero_copy_offset.cc View File

@@ -35,6 +35,7 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr
GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p", GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p",
op_desc->GetName().c_str(), output_size, virtual_addr); op_desc->GetName().c_str(), output_size, virtual_addr);
basic_addr_ = virtual_addr; basic_addr_ = virtual_addr;
op_name_ = op_desc->GetName();
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_);
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_);
GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID,
@@ -82,6 +83,7 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector<int64_t> &input_size_list
GELOGD("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); GELOGD("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size);


basic_addr_ = virtual_addr_list[idx]; basic_addr_ = virtual_addr_list[idx];
op_name_ = op_desc->GetName();
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_);
(void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_);
GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID,


+ 4
- 1
ge/graph/load/new_model_manager/zero_copy_offset.h View File

@@ -66,9 +66,12 @@ class ZeroCopyOffset {
int64_t GetDataSize() const { return data_size_; } int64_t GetDataSize() const { return data_size_; }
// value of *outside_addrs_ from davinci_model // value of *outside_addrs_ from davinci_model
std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; }
// name of op
std::string GetOpName() const { return op_name_; }


private: private:
void *basic_addr_ = nullptr; void *basic_addr_ = nullptr;
std::string op_name_;
uint32_t data_count_ = 0; uint32_t data_count_ = 0;
std::vector<std::pair<int64_t, void *>> data_info_; std::vector<std::pair<int64_t, void *>> data_info_;
vector<int64_t> relative_offset_; vector<int64_t> relative_offset_;
@@ -80,4 +83,4 @@ class ZeroCopyOffset {
std::vector<int64_t> zero_copy_relative_offset_; std::vector<int64_t> zero_copy_relative_offset_;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_

+ 1
- 1
ge/graph/load/new_model_manager/zero_copy_task.cc View File

@@ -131,7 +131,7 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma
auto dst_addr = static_cast<uint8_t *>(buffer_addr); auto dst_addr = static_cast<uint8_t *>(buffer_addr);
GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p",
name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr);
*(uintptr_t *)(args_info + offset) = reinterpret_cast<uintptr_t>(dst_addr);
*reinterpret_cast<uintptr_t *>(args_info + offset)= reinterpret_cast<uintptr_t>(dst_addr);
is_updated_ = true; is_updated_ = true;
} }
} }


+ 6
- 6
ge/graph/manager/graph_caching_allocator.cc View File

@@ -25,13 +25,13 @@


namespace ge { namespace ge {
const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize,
8 * kMByteSize,
32 * kMByteSize,
128 * kMByteSize,
kBinSizeUnit8 * kMByteSize,
kBinSizeUnit32 * kMByteSize,
kBinSizeUnit128 * kMByteSize,
kGByteSize, kGByteSize,
4 * kGByteSize,
16 * kGByteSize,
26 * kGByteSize};
kBinSizeUnit4 * kGByteSize,
kBinSizeUnit16 * kGByteSize,
kBinSizeUnit26 * kGByteSize};


static bool BlockComparator(const Block *left, const Block *right) { static bool BlockComparator(const Block *left, const Block *right) {
if (left->size != right->size) { if (left->size != right->size) {


+ 9
- 2
ge/graph/manager/graph_caching_allocator.h View File

@@ -34,10 +34,17 @@


namespace ge { namespace ge {
constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes
constexpr size_t kBinSizeUnit4 = 4;
constexpr size_t kBinSizeUnit8 = 8;
constexpr size_t kBinSizeUnit16 = 16;
constexpr size_t kBinSizeUnit26 = 26;
constexpr size_t kBinSizeUnit32 = 32;
constexpr size_t kBinSizeUnit128 = 128;

constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold
constexpr size_t kKByteSize = 1024; constexpr size_t kKByteSize = 1024;
constexpr size_t kMByteSize = 1024 * 1024;
constexpr size_t kGByteSize = 1024 * 1024 * 1024;
constexpr size_t kMByteSize = 1048576; // 1024 * 1024
constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024


static const uint32_t kNumBins = 8; static const uint32_t kNumBins = 8;




+ 15
- 64
ge/graph/manager/graph_manager.cc View File

@@ -533,9 +533,8 @@ Status GraphManager::CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_gr
return SUCCESS; return SUCCESS;
} }


Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph,
Graph2SubGraphInfoList &sub_graph_map,
uint64_t session_id) {
Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph,
Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) {
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);
// use default 16 multi thread // use default 16 multi thread
const uint32_t thread_num = 16; const uint32_t thread_num = 16;
@@ -550,14 +549,14 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
} }
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext());
compute_graph->GetGraphID(), subgraph, compute_graph, session_id,
GetThreadLocalContext());
if (!f.valid()) { if (!f.valid()) {
GELOGE(FAILED, "Future is invalid"); GELOGE(FAILED, "Future is invalid");
return FAILED; return FAILED;
} }
vector_future.emplace_back(std::move(f)); vector_future.emplace_back(std::move(f));
} }

for (auto &function_graph : compute_graph->GetAllSubgraphs()) { for (auto &function_graph : compute_graph->GetAllSubgraphs()) {
auto subgraph_list = sub_graph_map[function_graph]; auto subgraph_list = sub_graph_map[function_graph];
for (const auto &subgraph : subgraph_list) { for (const auto &subgraph : subgraph_list) {
@@ -651,62 +650,13 @@ Status GraphManager::ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_
Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) {
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);
auto sub_graph_map = partitioner.GetSubGraphMap(); auto sub_graph_map = partitioner.GetSubGraphMap();
std::string buffer_optimize;
graphStatus graph_status = ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize);
bool need_lx_fusion = (graph_status == GRAPH_SUCCESS) && (buffer_optimize != kOffOptimize);
if (options_.build_mode.empty() && need_lx_fusion) {
GELOGI("Enter normal mode with buffer_optimize:%s.", buffer_optimize.c_str());
/// 1. Copy subgraph for buffer optimize while lx fusion failed.
/// 2. Set graph with attr "lx_fusion" for fusion optimize.
std::unordered_map<std::string, ComputeGraphPtr> copy_graphs;
GE_TIMESTAMP_START(CopySubGraphAndMarkFusion);
Status ret = CopySubGraphAndMarkFusion(compute_graph, sub_graph_map, copy_graphs);
GE_TIMESTAMP_EVENT_END(CopySubGraphAndMarkFusion, "SetSubgraph:CopySubGraphAndMarkFusion");
if (ret != SUCCESS) {
GELOGE(ret, "CopySubGraphAndMarkFusion failed.");
return ret;
}

// Multiply optimize subgraph with lx fusion
ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id);
if (ret != SUCCESS) {
GELOGE(ret, "Multiply optimize subgraph with lx fusion failed.");
return ret;
}

// Check whether all subgraph lx fusion success
GE_TIMESTAMP_START(CheckAllFusionOptimizeSuccess);
if (CheckAllFusionOptimizeSuccess(compute_graph, sub_graph_map)) {
GE_TIMESTAMP_EVENT_END(CheckAllFusionOptimizeSuccess, "SetSubgraph:CheckAllFusionOptimizeSuccess");
return SUCCESS;
}

// Replace subgraph with original graph for lx buffer
ret = ReplaceSubgraphWithOriGraph(compute_graph, sub_graph_map, copy_graphs);
if (ret != SUCCESS) {
GELOGE(ret, "Replace subgraph with original graph failed.");
return ret;
}

// Multiply optimize subgraph with lx buffer
ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id);
if (ret != SUCCESS) {
GELOGE(ret, "Multiply optimize subgraph with lx buffer failed.");
return ret;
}
} else {
/// Multiply optimize subgraph:
/// 1. run lx buffer while build_mode is normal and buffer_optimize is empty or "off_optimize";
/// 2. run lx fusion or buffer according build_mode and build_step in fe.
GELOGD("Directly optimize subgraph with build mode:%s, and step:%s, buffer_optimize:%s.",
options_.build_mode.c_str(),
options_.build_step.c_str(),
buffer_optimize.c_str());
Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id);
if (ret != SUCCESS) {
GELOGE(ret, "Multiply optimize subgraph with lx buffer");
return ret;
}
GELOGD("Directly optimize subgraph with build mode:%s, and step:%s.",
options_.build_mode.c_str(),
options_.build_step.c_str());
Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id);
if (ret != SUCCESS) {
GELOGE(ret, "Multiply optimize subgraph failed");
return ret;
} }
return SUCCESS; return SUCCESS;
} }
@@ -2515,7 +2465,6 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
GetContext().SetSessionId(session_id); GetContext().SetSessionId(session_id);
GetThreadLocalContext() = ge_context; GetThreadLocalContext() = ge_context;
graph_manager->UpdateLocalOmgContext(root_graph_id); graph_manager->UpdateLocalOmgContext(root_graph_id);

ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph();
const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); const std::string &engine_name = sub_graph_info_ptr->GetEngineName();
GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu",
@@ -2523,6 +2472,10 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
pthread_self()); pthread_self());
GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore");
GE_CHECK_NOTNULL(compute_graph_tmp); GE_CHECK_NOTNULL(compute_graph_tmp);
if (!AttrUtils::SetInt(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_ID, root_graph_id)) {
GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id);
return FAILED;
}
compute_graph_tmp->SetSessionID(session_id); compute_graph_tmp->SetSessionID(session_id);
Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp,
compute_graph, compute_graph,
@@ -2688,9 +2641,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
} }


// it will not execute graph preprocess, optimize, parition, build if the graph has built successful. // it will not execute graph preprocess, optimize, parition, build if the graph has built successful.

GELOGI("Start for run graph async."); GELOGI("Start for run graph async.");

GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
if (graph_manager->IsGraphNeedBuild(graph_node)) { if (graph_manager->IsGraphNeedBuild(graph_node)) {
if (graph_node->GetBuildFlag()) { if (graph_node->GetBuildFlag()) {


+ 2
- 2
ge/graph/manager/graph_var_manager.cc View File

@@ -280,9 +280,9 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin
return PARAM_INVALID; return PARAM_INVALID;
} }
uint64_t free_size = total_size_ - var_mem_size_; uint64_t free_size = total_size_ - var_mem_size_;
if (free_size < (size + kSessionMemAlignSize * 2)) {
if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
size + kSessionMemAlignSize * 2 + var_mem_size_, total_size_);
size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
return PARAM_INVALID; return PARAM_INVALID;
} }




+ 1
- 0
ge/graph/manager/graph_var_manager.h View File

@@ -42,6 +42,7 @@ const size_t kGraphMemoryBuffer = 4UL * 1024UL * 1024UL * 1024UL;
const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL;
const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY";
const uint64_t kSessionMemAlignSize = 512; const uint64_t kSessionMemAlignSize = 512;
const size_t kSessionMemAlignUnit = 2;


enum MemStatus { enum MemStatus {
NORMAL = 0, NORMAL = 0,


+ 1
- 1
ge/graph/manager/host_mem_manager.cc View File

@@ -106,7 +106,7 @@ Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_add
GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address));
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address));
data_size = var_memory_base_map_[op_name].mem_size; data_size = var_memory_base_map_[op_name].mem_size;
return SUCCESS; return SUCCESS;
} }


+ 2
- 1
ge/graph/manager/util/debug.cc View File

@@ -32,7 +32,8 @@ Debug::~Debug() = default;


void Debug::DumpProto(const Message &proto, const char *file) { void Debug::DumpProto(const Message &proto, const char *file) {
std::string file_path = RealPath(file); std::string file_path = RealPath(file);
int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD | M_UMASK_OTHREAD);
int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD |
M_UMASK_OTHREAD);
if (fd == -1) { if (fd == -1) {
GELOGW("Write %s failed", file_path.c_str()); GELOGW("Write %s failed", file_path.c_str());
return; return;


+ 2
- 1
ge/graph/manager/util/hcom_util.cc View File

@@ -263,7 +263,8 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro
Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc,
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) {
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) {
if (op_desc->GetType() == HCOMBROADCAST ||
op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) {
GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str());
int64_t root_id = 0; int64_t root_id = 0;
Status dmrt = GetHcclRootId(op_desc, root_id); Status dmrt = GetHcclRootId(op_desc, root_id);


+ 14
- 7
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -26,6 +26,13 @@
namespace { namespace {
using namespace ge; using namespace ge;
const int kIdentityAnchorIndex = 0; const int kIdentityAnchorIndex = 0;
const size_t kSerialStringVecSize = 4;

const int kCaseReadOnly = 0;
const int kCaseScopeWriteable = 2;
const int kCaseWriteable = 3;
const int kCaseInvalidRWType = 5;

// rw type of input. // rw type of input.
enum class InputRWType { enum class InputRWType {
kReadOnly, // Normal op input only read kReadOnly, // Normal op input only read
@@ -55,7 +62,7 @@ thread_local map<string, NodeInputOutputRWType> node_rwtype_map_;
/// @return rw_type_name /// @return rw_type_name
/// ///
static std::string InputRWTypeToSerialString(InputRWType rw_type) { static std::string InputRWTypeToSerialString(InputRWType rw_type) {
const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
return names[static_cast<int>(rw_type)]; return names[static_cast<int>(rw_type)];
} }


@@ -65,7 +72,7 @@ static std::string InputRWTypeToSerialString(InputRWType rw_type) {
/// @return rw_type_name /// @return rw_type_name
/// ///
static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { static std::string OutputRWTypeToSerialString(OutputRWType rw_type) {
const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
return names[static_cast<int>(rw_type)]; return names[static_cast<int>(rw_type)];
} }


@@ -118,13 +125,13 @@ InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
} }


switch (total_rw_type) { switch (total_rw_type) {
case 0:
case kCaseReadOnly:
return InputRWType::kReadOnly; // all input rw type is readonly return InputRWType::kReadOnly; // all input rw type is readonly
case 2:
case kCaseScopeWriteable:
return InputRWType::kScopeWriteable; // readonly 2 scope_writeable return InputRWType::kScopeWriteable; // readonly 2 scope_writeable
case 3:
case kCaseWriteable:
return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable
case 5:
case kCaseInvalidRWType:
return InputRWType::kInvalidRWType; // writeable 2 scope_writeable return InputRWType::kInvalidRWType; // writeable 2 scope_writeable
default: default:
return InputRWType::kInvalidRWType; return InputRWType::kInvalidRWType;
@@ -643,7 +650,7 @@ Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
pre_node->GetName().c_str(), node->GetName().c_str());
} }
} }
} }


+ 26
- 26
ge/graph/partition/graph_partition.cc View File

@@ -614,32 +614,32 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr
} }
// flush parent node of subgraph // flush parent node of subgraph
sub_graph->SetParentNode(compute_graph->GetParentNode()); sub_graph->SetParentNode(compute_graph->GetParentNode());
(void) AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName());
auto sgi = MakeShared<SubGraphInfo>();
if (sgi == nullptr) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed.");
return FAILED;
}
// set engine name
sgi->SetEngineName(engine_name);
// set stream label
string sub_graph_stream;
if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) {
sgi->SetStreamLabel(sub_graph_stream);
}
/// for now inputFlag is the same before and after partition. It should
/// be changed according to the real partition
std::vector<bool> sub_graph_input(graph_info_.input_size_, true);
std::vector<bool> sub_graph_output(graph_info_.output_size_, true);
sgi->SetSubGraph(sub_graph);
sgi->SetOutputFlag(sub_graph_output);
sgi->SetInputFlag(sub_graph_input);
sgi->SetOutputContext(graph_info_.output_name_);
AddEndPldInformationToSubGraphInfo(sgi);
GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s",
engine_name.c_str(),
sub_graph->GetName().c_str(),
sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str());
(void)AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName());
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(),
compute_graph->GetName().c_str());
auto sgi = MakeShared<SubGraphInfo>();
if (sgi == nullptr) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed.");
return FAILED;
}
// set engine name
sgi->SetEngineName(engine_name);
// set stream label
string sub_graph_stream;
if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) {
sgi->SetStreamLabel(sub_graph_stream);
}
/// for now inputFlag is the same before and after partition. It should
/// be changed according to the real partition
std::vector<bool> sub_graph_input(graph_info_.input_size_, true);
std::vector<bool> sub_graph_output(graph_info_.output_size_, true);
sgi->SetSubGraph(sub_graph);
sgi->SetOutputFlag(sub_graph_output);
sgi->SetInputFlag(sub_graph_input);
sgi->SetOutputContext(graph_info_.output_name_);
AddEndPldInformationToSubGraphInfo(sgi);
GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s", engine_name.c_str(),
sub_graph->GetName().c_str(), sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str());
if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo
output_subgraphs.push_back(sgi); output_subgraphs.push_back(sgi);
} else { } else {


+ 80
- 24
ge/graph/passes/atomic_addr_clean_pass.cc View File

@@ -74,10 +74,88 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) {
return SUCCESS; return SUCCESS;
} }


// just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input
bool AtomicAddrCleanPass::CheckAtomicFromOpsKernel(const NodePtr &node) {
// 1.Check if isAtomic attrs exist for HCOM
std::shared_ptr<GELib> instance_ptr = GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGW("GELib not initialized, atomic from ops kernel judge false, node_name: %s", node->GetName().c_str());
return false;
}

OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(node->GetType());
for (const auto &op_info : op_info_vec) {
if (op_info.isAtomic) {
// check peer input is DATA
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor->GetPeerOutAnchor() != nullptr &&
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) {
auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode();
if (peer_in_node->GetType() == DATA) {
GELOGI("Recognized atomic op %s from %s engine and input is DATA.", node->GetName().c_str(),
op_info.engine.c_str());
return false;
}
}
}
GELOGI("Recognized atomic op %s from %s engine.", node->GetName().c_str(), op_info.engine.c_str());
hcom_node_vec_.push_back(node);
return true;
}
}
return false;
}

bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index) {
auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index);
if (out_data_anchor == nullptr) {
return false;
}

for (auto input_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto output_node = input_anchor->GetOwnerNode();
// just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input
// hccl's attr ATOMIC_ATTR_INPUT_INDEX mark on CalcOpRunningParam, can't be get here
if (CheckAtomicFromOpsKernel(output_node)) {
return true;
}
}
return false;
}

bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) {
OpDescPtr op_desc = node->GetOpDesc();
std::map<string, std::map<int, int>> node_workspace_offset;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) {
std::vector<int64_t> atomic_output_index;
(void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index);
bool is_all_output_peer_also_atomic = true;
for (const auto &output_index : atomic_output_index) {
if (!IsOutputIndexPeerInputAtomic(node, output_index)) {
is_all_output_peer_also_atomic = false;
break;
}
}
if (is_all_output_peer_also_atomic) {
GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str());
return true;
}
}
return false;
}

Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) { Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) {
// Loop graph , insert clean node follow atomic node // Loop graph , insert clean node follow atomic node
int index = 0; int index = 0;
for (const auto &node : atomic_node_vec) { for (const auto &node : atomic_node_vec) {
if (CheckSkipInsertInLoopGraph(node)) {
continue;
}

// Insert atomic clean op // Insert atomic clean op
NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph);
if (clean_addr_node == nullptr) { if (clean_addr_node == nullptr) {
@@ -249,32 +327,10 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) {
return false; return false;
} }
// 1.Check if isAtomic attrs exist for HCOM // 1.Check if isAtomic attrs exist for HCOM
std::shared_ptr<GELib> instance_ptr = GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGW("GELib not initialized");
return false;
if (CheckAtomicFromOpsKernel(node)) {
return true;
} }


OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
for (const auto &op_info : op_info_vec) {
if (op_info.isAtomic) {
GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str());
// check peer input is DATA
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor->GetPeerOutAnchor() != nullptr &&
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) {
auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode();
if (peer_in_node->GetType() == DATA) {
GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str());
return false;
}
}
}
hcom_node_vec_.push_back(node);
return true;
}
}
// 2.Check atomic attr in node // 2.Check atomic attr in node
std::map<string, std::map<int, int>> node_workspace_offset; std::map<string, std::map<int, int>> node_workspace_offset;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);


+ 5
- 0
ge/graph/passes/atomic_addr_clean_pass.h View File

@@ -84,6 +84,11 @@ class AtomicAddrCleanPass : public GraphPass {
Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec, Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec,
std::vector<NodePtr> &common_atomic_nodes); std::vector<NodePtr> &common_atomic_nodes);


bool CheckAtomicFromOpsKernel(const NodePtr &node);

bool IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index);

bool CheckSkipInsertInLoopGraph(const NodePtr &node);


vector<NodePtr> hcom_node_vec_; vector<NodePtr> hcom_node_vec_;
bool is_loop_graph_ = false; bool is_loop_graph_ = false;


+ 23
- 31
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {


FindNodes(graph); FindNodes(graph);
for (const auto &node : need_label_nodes_) { for (const auto &node : need_label_nodes_) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str());
}
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str());
} }
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed.");


@@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() {
/// ///
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
for (const NodePtr &node : graph->GetDirectNode()) { for (const NodePtr &node : graph->GetDirectNode()) {
const std::string &type = node->GetType();
if (type == STREAMSWITCH) {
const auto &op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
const std::string &type = op_desc->GetType();
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) {
stream_switch_nodes_.emplace_back(node); stream_switch_nodes_.emplace_back(node);
} else if (type == STREAMMERGE) {
if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
need_label_nodes_.emplace_back(node);
}
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
need_label_nodes_.emplace_back(node);
} else if ((type == ENTER) || (type == REFENTER)) { } else if ((type == ENTER) || (type == REFENTER)) {
enter_nodes_.emplace_back(node); enter_nodes_.emplace_back(node);
} }
@@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
/// ///
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
std::string stream_label; std::string stream_label;
if (AttachFlag(node, stream_label) != SUCCESS) {
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str());
return FAILED;
}

std::unordered_set<NodePtr> branch_nodes; std::unordered_set<NodePtr> branch_nodes;
std::unordered_set<NodePtr> visited; std::unordered_set<NodePtr> visited;
std::stack<NodePtr> nodes; std::stack<NodePtr> nodes;
nodes.push(node); nodes.push(node);

static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE};
while (!nodes.empty()) { while (!nodes.empty()) {
NodePtr cur_node = nodes.top(); NodePtr cur_node = nodes.top();
@@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
if (visited.count(cur_node) > 0) { if (visited.count(cur_node) > 0) {
continue; continue;
} }
if (AttachFlag(cur_node, stream_label) != SUCCESS) {
GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str());
return FAILED;
}


const std::string &type = cur_node->GetType(); const std::string &type = cur_node->GetType();
for (const auto &out_node : cur_node->GetOutAllNodes()) { for (const auto &out_node : cur_node->GetOutAllNodes()) {
@@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
visited.insert(cur_node); visited.insert(cur_node);
} }


if (node->GetType() == STREAMSWITCH) {
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed.");
}

for (const NodePtr &tmp_node : branch_nodes) { for (const NodePtr &tmp_node : branch_nodes) {
GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str());
GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed.");
@@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea
GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED,
"StreamSwitch get attr TRUE_BRANCH_STREAM failed."); "StreamSwitch get attr TRUE_BRANCH_STREAM failed.");
stream_label += (value ? "_t" : "_f"); stream_label += (value ? "_t" : "_f");
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed.");
} else if (type == STREAMMERGE) { } else if (type == STREAMMERGE) {
stream_label = node->GetName(); stream_label = node->GetName();
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed.");
} else if ((type == EXIT) || (type == REFEXIT)) {
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed.");
} }


return SUCCESS; return SUCCESS;
@@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map;
for (const auto &enter_node : enter_nodes_) { for (const auto &enter_node : enter_nodes_) {
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) {
if (out_ctrl_node->GetType() == STREAMACTIVE) {
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) {
enter_active_map[out_ctrl_node] = {enter_node};
} else {
enter_active_map[out_ctrl_node].emplace_back(enter_node);
}
if (out_ctrl_node->GetType() != STREAMACTIVE) {
continue;
}
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) {
enter_active_map[out_ctrl_node] = {enter_node};
} else {
enter_active_map[out_ctrl_node].emplace_back(enter_node);
} }
} }
} }
@@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
std::string stream_label; std::string stream_label;
GE_CHECK_NOTNULL(active_node); GE_CHECK_NOTNULL(active_node);
(void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label);

if (stream_label.empty()) { if (stream_label.empty()) {
GELOGW("stream_label of enter_active & enter_nodes is empty.");
GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }


@@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
} }
} }
GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed.");
return SUCCESS; return SUCCESS;
} }




+ 8
- 2
ge/graph/passes/cond_remove_pass.cc View File

@@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) {
OutDataAnchorPtr cond_out_anchor = nullptr; OutDataAnchorPtr cond_out_anchor = nullptr;
InDataAnchorPtr cond_in_anchor = nullptr; InDataAnchorPtr cond_in_anchor = nullptr;
Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor);
if (ret == NOT_CHANGED) {
return SUCCESS;
} else if (ret != SUCCESS) {
GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str());
return FAILED;
}
int32_t cond_index = 0; int32_t cond_index = 0;
GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str());
bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index);
@@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph,
std::string type = node->GetType(); std::string type = node->GetType();
if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) {
if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) {
GELOGE(FAILED, "Get cond_info for if node failed.");
GELOGE(FAILED, "Get cond_info for if/case node failed.");
return FAILED; return FAILED;
} }
} else { } else {
GELOGD("no need cond_pass for node %s.", node->GetName().c_str());
GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str());
return NOT_CHANGED; return NOT_CHANGED;
} }




+ 0
- 1
ge/graph/passes/ctrl_edge_transfer_pass.cc View File

@@ -38,7 +38,6 @@ namespace ge {
* \ / * \ /
* B * B
*/ */

Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) {
GELOGD("CtrlEdgeTransferPass start running"); GELOGD("CtrlEdgeTransferPass start running");
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);


+ 2
- 1
ge/graph/passes/data_pass.cc View File

@@ -21,6 +21,7 @@


namespace ge { namespace ge {
namespace { namespace {
const int kDataIndexOffset = 2;
Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) { Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function<int(int data_index)> &input) {
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
if (node->GetType() != DATA) { if (node->GetType() != DATA) {
@@ -111,7 +112,7 @@ Status ParseSubgraphPostFnWhile(const string &subgraph_name, const ComputeGraphP


Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) {
return MappingSubgraphIndex(graph, return MappingSubgraphIndex(graph,
[](int data_index) { return (data_index == 0) ? 0 : data_index + 2; },
[](int data_index) { return (data_index == 0) ? 0 : data_index + kDataIndexOffset; },
[](int retval_index) { return retval_index; }); [](int retval_index) { return retval_index; });
} }




+ 9
- 16
ge/graph/passes/enter_pass.cc View File

@@ -16,6 +16,7 @@


#include "graph/passes/enter_pass.h" #include "graph/passes/enter_pass.h"


#include "graph/debug/ge_attr_define.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
@@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) {
} }


Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
auto out_nodes_of_in_node = in_node->GetOutAllNodes();
if (out_nodes_of_in_node.size() != kOutNodesNum) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS; return SUCCESS;
} }

if (!node->GetOutControlNodes().empty()) {
bool is_constant_flag = true;
(void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag);
if (!is_constant_flag) {
return SUCCESS; return SUCCESS;
} }


for (const auto &out_node : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(out_node);
if (out_node->GetType() == MERGE) {
return SUCCESS;
}
}

GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
auto out_data_anchor = node->GetOutDataAnchor(0);
const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor); GE_CHECK_NOTNULL(out_data_anchor);
for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
} }

auto graph = node->GetOwnerComputeGraph();
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node); AddRePassNodesWithInOut(in_node);


return SUCCESS; return SUCCESS;


+ 9
- 21
ge/graph/passes/for_pass.cc View File

@@ -37,6 +37,7 @@ namespace {
const uint32_t kSubgraphLoopVarInputIndex = 0; const uint32_t kSubgraphLoopVarInputIndex = 0;
const uint32_t kSubgraphInputIndex = 1; const uint32_t kSubgraphInputIndex = 1;
const uint32_t kWhileOutputIndex = 5; const uint32_t kWhileOutputIndex = 5;
const size_t kIDiffValue = 2;
const std::string kAbs = "Abs"; const std::string kAbs = "Abs";
} }


@@ -137,7 +138,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n
for_info.ctrl_inputs = std::move(ctrl_inputs); for_info.ctrl_inputs = std::move(ctrl_inputs);
for_info.ctrl_outputs = std::move(ctrl_outputs); for_info.ctrl_outputs = std::move(ctrl_outputs);


GELOGI("Build for_info for node %s succ.", node->GetName().c_str());
GELOGI("Build for_info for node %s success.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }


@@ -159,13 +160,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index
return nullptr; return nullptr;
} }


OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index);
return nullptr;
}

return peer_out_anchor;
return in_data_anchor->GetPeerOutAnchor();
} }


/// ///
@@ -186,20 +181,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc
uint32_t input_data_num = node->GetAllInDataAnchorsSize(); uint32_t input_data_num = node->GetAllInDataAnchorsSize();
for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
if (in_data_anchor == nullptr) {
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
return FAILED;
}
GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr,
GELOGW("Get null input by index %d from node %s ",
in_data_anchor->GetIdx(), node->GetName().c_str());
continue);
GE_CHECK_NOTNULL(in_data_anchor);
data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
} }


for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
peer_in_data_anchors.emplace_back(peer_in_data_anchor); peer_in_data_anchors.emplace_back(peer_in_data_anchor);
} }
data_outputs.emplace_back(peer_in_data_anchors); data_outputs.emplace_back(peer_in_data_anchors);
@@ -207,13 +195,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc


InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor); GE_CHECK_NOTNULL(in_ctrl_anchor);
for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
ctrl_inputs.emplace_back(peer_out_ctrl_anchor); ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
} }


OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor); GE_CHECK_NOTNULL(out_ctrl_anchor);
for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
ctrl_outputs.emplace_back(peer_in_ctrl_anchor); ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
} }


@@ -707,7 +695,7 @@ Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
} else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
continue; continue;
} else { } else {
input_mapping[i] = i - 2;
input_mapping[i] = i - kIDiffValue;
} }
} }
for_body->UpdateInputMapping(input_mapping); for_body->UpdateInputMapping(input_mapping);


+ 3
- 1
ge/graph/passes/mark_agnostic_pass.cc View File

@@ -19,6 +19,8 @@
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"


namespace ge { namespace ge {
const size_t kTwoInputNodesSize = 2;

Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
auto node_type = NodeUtils::GetNodeType(*node); auto node_type = NodeUtils::GetNodeType(*node);
@@ -52,7 +54,7 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
/// Enter-----------+ /// Enter-----------+
/// +-> Merge /// +-> Merge
/// NextIteration---+ /// NextIteration---+
if (input_nodes.size() == 2) {
if (input_nodes.size() == kTwoInputNodesSize) {
if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) {
continue; continue;
} }


+ 6
- 9
ge/graph/passes/merge_pass.cc View File

@@ -21,18 +21,16 @@
#include <vector> #include <vector>


#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/passes/pass_utils.h" #include "graph/passes/pass_utils.h"


using domi::PARAM_INVALID;
using domi::SUCCESS;

namespace ge { namespace ge {
const int kValueIndexOutputIndex = 1; const int kValueIndexOutputIndex = 1;
const size_t kCaseNoInput = 0;
const size_t kCaseOneInput = 1;


Status MergePass::Run(NodePtr &node) { Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running"); GELOGD("MergePass running");
@@ -47,15 +45,14 @@ Status MergePass::Run(NodePtr &node) {
return SUCCESS; return SUCCESS;
} }


auto out_data_anchors = node->GetAllOutDataAnchors();
if (out_data_anchors.empty()) {
if (node->GetAllOutDataAnchors().empty()) {
GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }


auto in_data_nodes = node->GetInDataNodes();
const auto &in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) { switch (in_data_nodes.size()) {
case 0: {
case kCaseNoInput: {
/// Case A: input_count = 0, the output of merge node is inactive as well /// Case A: input_count = 0, the output of merge node is inactive as well
/// In which case the output branch can be removed /// In which case the output branch can be removed
/// until another merge node is met /// until another merge node is met
@@ -70,7 +67,7 @@ Status MergePass::Run(NodePtr &node) {
} }
return ret; return ret;
} }
case 1: { // Case B: input_count = 1, the merge node can be optimized out
case kCaseOneInput: { // Case B: input_count = 1, the merge node can be optimized out
std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1}; std::vector<int> merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1};
if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) { if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) {
int index = merge_io_map[0]; int index = merge_io_map[0];


+ 32
- 56
ge/graph/passes/multi_batch_pass.cc View File

@@ -22,9 +22,6 @@
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


using std::string;
using std::vector;

namespace ge { namespace ge {
Status MultiBatchPass::Run(ComputeGraphPtr graph) { Status MultiBatchPass::Run(ComputeGraphPtr graph) {
GELOGD("MultiBatchPass Enter"); GELOGD("MultiBatchPass Enter");
@@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) {
return FAILED; return FAILED;
} }
std::vector<std::vector<int64_t>> batch_shape; std::vector<std::vector<int64_t>> batch_shape;
vector<vector<int64_t>> combined_batch;
std::vector<std::vector<int64_t>> combined_batch;
if (!CheckSwitchN(batch_shape, combined_batch)) { if (!CheckSwitchN(batch_shape, combined_batch)) {
GELOGE(FAILED, "CheckSwitchN failed."); GELOGE(FAILED, "CheckSwitchN failed.");
return FAILED; return FAILED;
@@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() {
/// ///
Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
const auto &func_desc = case_node->GetOpDesc(); const auto &func_desc = case_node->GetOpDesc();
GE_CHECK_NOTNULL(func_desc);
if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
return SUCCESS; return SUCCESS;
@@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr
const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
GE_CHECK_NOTNULL(subgraph); GE_CHECK_NOTNULL(subgraph);


const string batch_label = "Batch_" + std::to_string(i);
const std::string batch_label = "Batch_" + std::to_string(i);
for (const auto &node : subgraph->GetDirectNode()) { for (const auto &node : subgraph->GetDirectNode()) {
(void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
} }
@@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
continue; continue;
} }


InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
if (in_data_anchor == nullptr) { if (in_data_anchor == nullptr) {
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
return FAILED; return FAILED;
} }
OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor();
const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
if (pred_input == nullptr) { if (pred_input == nullptr) {
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
return FAILED; return FAILED;
@@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
/// @return Status /// @return Status
/// ///
Status MultiBatchPass::GetDynamicType() { Status MultiBatchPass::GetDynamicType() {
for (const auto &switchn : switch_n_nodes_) {
auto switchn_desc = switchn->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
for (const auto &switch_n : switch_n_nodes_) {
int32_t dynamic_type = static_cast<int32_t>(FIXED); int32_t dynamic_type = static_cast<int32_t>(FIXED);
if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str());
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str());
return FAILED; return FAILED;
} }
if (dynamic_type == static_cast<int32_t>(FIXED)) { if (dynamic_type == static_cast<int32_t>(FIXED)) {
@@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() {
return FAILED; return FAILED;
} }
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.",
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.",
dynamic_type, dynamic_type_); dynamic_type, dynamic_type_);
return FAILED; return FAILED;
} }
@@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() {
Status MultiBatchPass::GetUserDesignateShape() { Status MultiBatchPass::GetUserDesignateShape() {
data_name_order_.clear(); data_name_order_.clear();
bool first_check = true; bool first_check = true;
for (const auto &switchn : switch_n_nodes_) {
auto switchn_desc = switchn->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
vector<string> cur_switchn_data_name_order;
if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) {
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str());
for (const auto &switch_n : switch_n_nodes_) {
std::vector<std::string> cur_data_name_order;
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str());
return FAILED; return FAILED;
} }
if (first_check) { if (first_check) {
data_name_order_ = cur_switchn_data_name_order;
data_name_order_ = cur_data_name_order;
first_check = false; first_check = false;
} else { } else {
if (data_name_order_ != cur_switchn_data_name_order) {
if (data_name_order_ != cur_data_name_order) {
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
switchn->GetName().c_str());
switch_n->GetName().c_str());
return FAILED; return FAILED;
} }
} }
@@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() {
/// @param [out] combined_batch /// @param [out] combined_batch
/// @return bool /// @return bool
/// ///
bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) {
bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
std::vector<std::vector<int64_t>> &combined_batch) {
// Check if output_num of different SwitchN is same // Check if output_num of different SwitchN is same
uint32_t batch_num = 0; uint32_t batch_num = 0;
for (const NodePtr &node : switch_n_nodes_) { for (const NodePtr &node : switch_n_nodes_) {
@@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
} }
size_t tmp_combined_dim_num = combined_batch[i].size(); size_t tmp_combined_dim_num = combined_batch[i].size();
if (combined_dim_num != tmp_combined_dim_num) { if (combined_dim_num != tmp_combined_dim_num) {
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
combined_dim_num, i, tmp_combined_dim_num);
return false; return false;
} }
} }
@@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
/// @param [out] combined_batch /// @param [out] combined_batch
/// @return bool /// @return bool
/// ///
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape,
vector<vector<int64_t>> &combined_batch) {
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
std::vector<std::vector<int64_t>> &combined_batch) {
// Check if output_shape of different SwitchN is same // Check if output_shape of different SwitchN is same
vector<vector<int64_t>> idx_batch_shape;
vector<vector<int64_t>> idx_combined_batch;
std::vector<std::vector<int64_t>> idx_batch_shape;
std::vector<std::vector<int64_t>> idx_combined_batch;
for (uint32_t i = 0; i < batch_num; i++) { for (uint32_t i = 0; i < batch_num; i++) {
idx_batch_shape.clear(); idx_batch_shape.clear();
idx_combined_batch.clear(); idx_combined_batch.clear();
@@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
return false; return false;
} }
vector<int64_t> output_dims;
std::vector<int64_t> output_dims;
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
return false; return false;
@@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
/// @return Status /// @return Status
/// ///
Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
const vector<vector<int64_t>> &batch_shape,
const vector<vector<int64_t>> &combined_batch) {
const std::vector<std::vector<int64_t>> &batch_shape,
const std::vector<std::vector<int64_t>> &combined_batch) {
NodePtr pred_value_node = pred_value->GetOwnerNode(); NodePtr pred_value_node = pred_value->GetOwnerNode();
// Create SwitchCase node // Create SwitchCase node
const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
@@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
return false; return false;
} }


size_t num = output_shape.size();
size_t dim_num = output_shape[0].size();
for (size_t i = 1; i < num; i++) {
size_t tmp_dim_num = output_shape[i].size();
if (dim_num != tmp_dim_num) {
GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num);
for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
if (output_shape[0] != *iter) {
return false; return false;
} }
} }

if (dim_num == 0) {
return true;
}

for (size_t i = 0; i < dim_num; i++) {
int64_t dim_value = output_shape[0][i];
for (size_t j = 1; j < num; j++) {
int64_t tmp_dim_value = output_shape[j][i];
if (dim_value != tmp_dim_value) {
GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i,
dim_value, j, tmp_dim_value);
return false;
}
}
}
return true; return true;
} }


@@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
/// ///
NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
const OutDataAnchorPtr &pred_value, const OutDataAnchorPtr &pred_value,
const vector<vector<int64_t>> &batch_shape,
const vector<vector<int64_t>> &combined_batch) {
const std::vector<std::vector<int64_t>> &batch_shape,
const std::vector<std::vector<int64_t>> &combined_batch) {
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
if (op_desc == nullptr) { if (op_desc == nullptr) {
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
@@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
return nullptr; return nullptr;
} }
const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
return nullptr; return nullptr;


+ 0
- 4
ge/graph/passes/pass_utils.cc View File

@@ -37,10 +37,6 @@
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge { namespace ge {
namespace {
const uint32_t kShapeDimSize = 1;
const uint32_t DIM_SIZE_TWO = 2;
} // namespace


Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data, Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data,
std::vector<GeTensorPtr> &v_output, const bool scalar_output) { std::vector<GeTensorPtr> &v_output, const bool scalar_output) {


+ 4
- 4
ge/graph/passes/subgraph_pass.cc View File

@@ -149,10 +149,10 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node
// 5. While->NetOutput in known subgraph // 5. While->NetOutput in known subgraph
std::string op_type; std::string op_type;
bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) ||
IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) ||
((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) ||
(!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) &&
(kWhileOpTypes.count(in_node->GetType()) != 0));
IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) ||
((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) ||
(!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) &&
(kWhileOpTypes.count(in_node->GetType()) != 0));
if (insert_flag) { if (insert_flag) {
GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str());
std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy";


+ 37
- 39
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra
std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map;
for (const NodePtr &node : graph->GetDirectNode()) { for (const NodePtr &node : graph->GetDirectNode()) {
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.");
if ((type == SWITCH) || (type == REFSWITCH)) {
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
GE_CHECK_NOTNULL(in_cond_anchor);
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str());
return FAILED;
}
if ((type != SWITCH) && (type != REFSWITCH)) {
continue;
}
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
GE_CHECK_NOTNULL(in_cond_anchor);
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str());
return FAILED;
}


NodePtr cond_node = peer_out_anchor->GetOwnerNode();
auto iter = cond_switch_map.find(cond_node);
if (iter == cond_switch_map.end()) {
cond_switch_map[cond_node] = { node };
} else {
iter->second.emplace_back(node);
}
switch_nodes_.emplace_back(node);
NodePtr cond_node = peer_out_anchor->GetOwnerNode();
auto iter = cond_switch_map.find(cond_node);
if (iter == cond_switch_map.end()) {
cond_switch_map[cond_node] = { node };
} else {
iter->second.emplace_back(node);
} }
switch_nodes_.emplace_back(node);
} }


MarkCycleDependence(cond_switch_map); MarkCycleDependence(cond_switch_map);
@@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou
if (idx == SWITCH_DATA_INPUT) { if (idx == SWITCH_DATA_INPUT) {
peer_data_anchor = peer_out_anchor; peer_data_anchor = peer_out_anchor;
} else { } else {
if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str());
return FAILED;
}
peer_cond_anchor = peer_out_anchor; peer_cond_anchor = peer_out_anchor;
} }
} }
@@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou


/// ///
/// @brief Find Switch cond input /// @brief Find Switch cond input
/// @param [in] pass_switch_flag
/// @param [out] peer_cond_anchor /// @param [out] peer_cond_anchor
/// @return Status /// @return Status
/// ///
Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) {
Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) {
NodePtr tmp_node = nullptr; NodePtr tmp_node = nullptr;
string type;
bool need_pass_type = true;
while (need_pass_type) {
std::string type;
bool pass_flag = true;
while (pass_flag) {
if (tmp_node == nullptr) { if (tmp_node == nullptr) {
tmp_node = peer_cond_anchor->GetOwnerNode(); tmp_node = peer_cond_anchor->GetOwnerNode();
} else { } else {
@@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD
} }


GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed.");
need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH)));
pass_flag = ((type == SWITCH) || (type == REFSWITCH));
} }


return SUCCESS; return SUCCESS;
@@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_
} }
} else { } else {
int64_t switch_group_id = GetGroupId(stream_switch); int64_t switch_group_id = GetGroupId(stream_switch);
map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
std::list<NodePtr> false_node_list; std::list<NodePtr> false_node_list;
std::list<NodePtr> true_node_list; std::list<NodePtr> true_node_list;
std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
@@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_
/// @return group_id /// @return group_id
/// ///
int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
string tailing_optimization_option;
std::string tailing_optimization_option;
bool is_tailing_optimization = false; bool is_tailing_optimization = false;
if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) {
// "1" means it's True from frontend option // "1" means it's True from frontend option
@@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
return 0; return 0;
} }


string hccl_group_id;
std::string hccl_group_id;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); GELOGI("Node %s can not find hccl group id.", node->GetName().c_str());
return 0; return 0;
@@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());


OutDataAnchorPtr peer_cond_anchor = iter->first; OutDataAnchorPtr peer_cond_anchor = iter->first;
GE_CHECK_NOTNULL(peer_cond_anchor);
NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); NodePtr cond_node = peer_cond_anchor->GetOwnerNode();
GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str());


@@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con


NodePtr cast_node = graph->AddNode(cast_desc); NodePtr cast_node = graph->AddNode(cast_desc);
GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed.");
// Cast node has and only has one input
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed.");


return cast_node; return cast_node;
@@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
"Remove ctl edge failed."); "Remove ctl edge failed.");
GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
"Add ctl edge failed."); "Add ctl edge failed.");
}); });


GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue);
if (same_cond_switch.count(in_ctl_node) > 0) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue);
if (same_cond_switch.count(in_ctrl_node) > 0) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
"Remove ctl edge failed."); "Remove ctl edge failed.");
continue; continue;
} }


auto find_res1 = switch_node_map_.find(in_ctl_node);
auto find_res1 = switch_node_map_.find(in_ctrl_node);
GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), {
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str());
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
}); });
auto find_res2 = find_res1->second.find(orig_switch_name); auto find_res2 = find_res1->second.find(orig_switch_name);


+ 3
- 4
ge/graph/passes/switch_to_stream_switch_pass.h View File

@@ -42,9 +42,9 @@ namespace ge {
+-----------+ +-----------+ +-----------+ +-----------+
| Const | | VariableV2| | Const | | VariableV2|
+-----------+ +-----------+ +-----------+ +-----------+
*/


/* Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input

Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input


+-----------+ +-----------+
/ | task2 | \ / | task2 | \
@@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass {


/// ///
/// @brief Find Switch cond input /// @brief Find Switch cond input
/// @param [in] pass_switch_flag
/// @param [out] peer_cond_anchor /// @param [out] peer_cond_anchor
/// @return Status /// @return Status
/// ///
Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor);
Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor);


/// ///
/// @brief Create StreamSwitch Node /// @brief Create StreamSwitch Node


+ 3
- 1
ge/graph/passes/transop_breadth_fusion_pass.cc View File

@@ -70,8 +70,10 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No
trans_data_type = true; trans_data_type = true;
trans_format = true; trans_format = true;
trans_shape = true; trans_shape = true;
} else if (node->GetType() == RESHAPE) {
} else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) {
trans_shape = true; trans_shape = true;
} else if (node->GetType() == REFORMAT) {
trans_format = true;
} }


id << node->GetType() << '-' << anchor_index; id << node->GetType() << '-' << anchor_index;


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save