From: @ljl0711 Reviewed-by: @youui,@liujunzhu Signed-off-by: @liujunzhutags/v1.1.0
| @@ -8,6 +8,19 @@ if (NOT BUILD_PATH) | |||||
| set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | ||||
| endif() | endif() | ||||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||||
| else() | |||||
| set(ASCEND_DIR /usr/local/Ascend) | |||||
| endif() | |||||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||||
| set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||||
| set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
| if (ENABLE_OPEN_SRC) | if (ENABLE_OPEN_SRC) | ||||
| @@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC) | |||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | ||||
| endif() | endif() | ||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | ||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
| find_module(slog libslog.so ${GE_LIB_PATH}) | find_module(slog libslog.so ${GE_LIB_PATH}) | ||||
| find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | ||||
| find_module(msprof libmsprof.so ${GE_LIB_PATH}) | find_module(msprof libmsprof.so ${GE_LIB_PATH}) | ||||
| @@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | ||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
| else() | else() | ||||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||||
| else() | |||||
| set(ASCEND_DIR /usr/local/Ascend) | |||||
| endif() | |||||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||||
| set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||||
| set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||||
| find_module(slog libslog.so ${ASCEND_ATC_DIR}) | find_module(slog libslog.so ${ASCEND_ATC_DIR}) | ||||
| find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | ||||
| if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
| @@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC) | |||||
| add_subdirectory(parser) | add_subdirectory(parser) | ||||
| #add_subdirectory(metadef/graph) | #add_subdirectory(metadef/graph) | ||||
| #add_subdirectory(metadef/register) | #add_subdirectory(metadef/register) | ||||
| elseif (ENABLE_D OR ENABLE_ACL) | |||||
| # compiling with MindSpore | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common modules | |||||
| find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| if (ENABLE_D) | |||||
| # training | |||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||||
| elseif(ENABLE_ACL) | |||||
| # inference | |||||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(resource libresource.so ${ASCEND_ATC_DIR}) | |||||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||||
| endif () | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| else() | else() | ||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | ||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | ||||
| @@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES | |||||
| add_library(ascend_protobuf_static INTERFACE) | add_library(ascend_protobuf_static INTERFACE) | ||||
| target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | ||||
| target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | ||||
| if (ENABLE_D OR ENABLE_ACL) | |||||
| include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) | |||||
| endif () | |||||
| add_dependencies(ascend_protobuf_static protobuf_static_build) | add_dependencies(ascend_protobuf_static protobuf_static_build) | ||||
| @@ -1,10 +1,15 @@ | |||||
| add_subdirectory(common) | |||||
| add_subdirectory(plugin/engine) | |||||
| add_subdirectory(graph/build/memory) | |||||
| add_subdirectory(ge_local_engine) | |||||
| add_subdirectory(host_cpu_engine) | |||||
| add_subdirectory(executor) | |||||
| add_subdirectory(offline) | |||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
| add_subdirectory(common) | |||||
| add_subdirectory(plugin/engine) | |||||
| add_subdirectory(graph/build/memory) | |||||
| add_subdirectory(ge_local_engine) | |||||
| add_subdirectory(host_cpu_engine) | |||||
| add_subdirectory(executor) | |||||
| add_subdirectory(offline) | |||||
| else() | |||||
| add_subdirectory(common) | |||||
| add_subdirectory(ge_runtime) | |||||
| endif () | |||||
| set(PROTO_LIST | set(PROTO_LIST | ||||
| "${METADEF_DIR}/proto/fusion_model.proto" | "${METADEF_DIR}/proto/fusion_model.proto" | ||||
| @@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | ||||
| protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | ||||
| ############ libge_runner.so ############ | |||||
| set(TRAIN_SRC_LIST | set(TRAIN_SRC_LIST | ||||
| "common/formats/format_transfers/datatype_transfer.cc" | "common/formats/format_transfers/datatype_transfer.cc" | ||||
| "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | ||||
| @@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST | |||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| ) | ) | ||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
| target_compile_definitions(ge_runner PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| DAVINCI_SUPPORT_PROFILING | |||||
| REUSE_MEMORY=1 | |||||
| FMK_SUPPORT_DUMP | |||||
| DAVINCI_CLOUD | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(ge_runner PRIVATE | |||||
| -O2 | |||||
| ) | |||||
| target_include_directories(ge_runner PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/analyzer | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${GE_CODE_DIR}/inc/framework/common | |||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| ${GE_CODE_DIR}/../inc/external | |||||
| ${GE_CODE_DIR}/../inc/cce | |||||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||||
| #### blue zone | |||||
| ${ASCEND_DIR}/driver/include | |||||
| ${ASCEND_DIR}/fwkacllib/include | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_libraries(ge_runner | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ge_memory | |||||
| adump_server | |||||
| msprofiler | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| ge_common | |||||
| ascend_protobuf | |||||
| register | |||||
| c_sec | |||||
| slog | |||||
| mmpa | |||||
| msprof | |||||
| runtime | |||||
| resource | |||||
| error_manager | |||||
| ascend_hal_stub | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| ############ libge_compiler.so ############ | |||||
| set(INFER_SRC_LIST | set(INFER_SRC_LIST | ||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
| @@ -662,6 +600,74 @@ set(INFER_SRC_LIST | |||||
| "analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
| ) | ) | ||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
| ############ libge_runner.so ############ | |||||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
| target_compile_definitions(ge_runner PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| DAVINCI_SUPPORT_PROFILING | |||||
| REUSE_MEMORY=1 | |||||
| FMK_SUPPORT_DUMP | |||||
| DAVINCI_CLOUD | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(ge_runner PRIVATE | |||||
| -O2 | |||||
| ) | |||||
| target_include_directories(ge_runner PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/analyzer | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${GE_CODE_DIR}/inc/framework/common | |||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| ${GE_CODE_DIR}/../inc/external | |||||
| ${GE_CODE_DIR}/../inc/cce | |||||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||||
| #### blue zone | |||||
| ${ASCEND_DIR}/driver/include | |||||
| ${ASCEND_DIR}/fwkacllib/include | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_libraries(ge_runner | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ge_memory | |||||
| adump_server | |||||
| msprofiler | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| ge_common | |||||
| ascend_protobuf | |||||
| register | |||||
| c_sec | |||||
| slog | |||||
| mmpa | |||||
| msprof | |||||
| runtime | |||||
| resource | |||||
| error_manager | |||||
| ascend_hal_stub | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| ############ libge_compiler.so ############ | |||||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | ||||
| target_compile_definitions(ge_compiler PRIVATE | target_compile_definitions(ge_compiler PRIVATE | ||||
| @@ -919,3 +925,70 @@ install(FILES | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | ||||
| DESTINATION ${INSTALL_LIBRARY_DIR} | DESTINATION ${INSTALL_LIBRARY_DIR} | ||||
| ) | ) | ||||
| elseif (ENABLE_ACL) | |||||
| ############ libge_compiler.so w/static protobuf ############ | |||||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||||
| target_compile_definitions(ge_compiler PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| REUSE_MEMORY=1 | |||||
| FMK_SUPPORT_DUMP | |||||
| FMK_HOST_INFER | |||||
| COMPILE_OMG_PACKAGE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(ge_compiler PRIVATE | |||||
| -O2 | |||||
| ) | |||||
| target_include_directories(ge_compiler PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/analyzer | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${GE_CODE_DIR}/inc/framework/common | |||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${ASCEND_DIR}/driver/include | |||||
| ${ASCEND_DIR}/fwkacllib/include | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_libraries(ge_compiler | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ge_memory | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| ge_common | |||||
| static_ascend_protobuf | |||||
| register | |||||
| c_sec | |||||
| error_manager | |||||
| slog | |||||
| mmpa | |||||
| runtime_compile | |||||
| resource | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| ############ install libge_compiler for MindSpore############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS ge_compiler OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| endif() | |||||
| @@ -63,6 +63,7 @@ set(SRC_LIST | |||||
| "ge/tbe_plugin_manager.cc" | "ge/tbe_plugin_manager.cc" | ||||
| ) | ) | ||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
| ############ libge_common.so ############ | ############ libge_common.so ############ | ||||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | ||||
| target_compile_definitions(ge_common PRIVATE | target_compile_definitions(ge_common PRIVATE | ||||
| @@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE | |||||
| -ldl | -ldl | ||||
| ) | ) | ||||
| else () | |||||
| ############ libge_common.so w/static protobuf ############ | |||||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_common PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| HOST_VISIBILITY | |||||
| FMK_SUPPORT_DUMP | |||||
| OS_CENTOS | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(ge_common PRIVATE | |||||
| -fvisibility=hidden | |||||
| -O2 | |||||
| -Werror | |||||
| ) | |||||
| target_include_directories(ge_common PRIVATE | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/ge/common | |||||
| ${GE_CODE_DIR}/ge/common/op | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
| ) | |||||
| target_link_libraries(ge_common PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ascend_protobuf_static | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| register | |||||
| c_sec | |||||
| error_manager | |||||
| slog | |||||
| mmpa | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| endif () | |||||
| ############ install ############ | ############ install ############ | ||||
| set(INSTALL_BASE_DIR "") | set(INSTALL_BASE_DIR "") | ||||
| set(INSTALL_LIBRARY_DIR lib) | set(INSTALL_LIBRARY_DIR lib) | ||||
| @@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE | |||||
| ) | ) | ||||
| target_include_directories(ge_runtime PRIVATE | target_include_directories(ge_runtime PRIVATE | ||||
| ${TOP_DIR} | |||||
| ${TOP_DIR}/inc | |||||
| ${TOP_DIR}/inc/graph | |||||
| ${TOP_DIR}/inc/external | |||||
| ${TOP_DIR}/inc/framework | |||||
| ${TOP_DIR}/inc/framework/common | |||||
| ${TOP_DIR}/inc/framework/ge_runtime | |||||
| ${TOP_DIR}/inc/cce | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/graph | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${GE_CODE_DIR}/inc/framework/common | |||||
| ${GE_CODE_DIR}/inc/framework/ge_runtime | |||||
| ${GE_CODE_DIR}/inc/cce | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/graph | |||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
| ) | ) | ||||
| @@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE | |||||
| slog | slog | ||||
| runtime | runtime | ||||
| c_sec | c_sec | ||||
| graph | |||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| -lrt | -lrt | ||||
| -ldl | -ldl | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -27,8 +27,13 @@ class ModelContext { | |||||
| ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | ||||
| rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | ||||
| const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | ||||
| : device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle), | |||||
| rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list), | |||||
| : device_id_(device_id), | |||||
| session_id_(session_id), | |||||
| priority_(priority), | |||||
| rt_model_handle_(rt_model_handle), | |||||
| rt_model_stream_(rt_model_stream), | |||||
| stream_list_(stream_list), | |||||
| label_list_(label_list), | |||||
| event_list_(event_list) {} | event_list_(event_list) {} | ||||
| ~ModelContext() {} | ~ModelContext() {} | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -24,6 +24,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace model_runner { | namespace model_runner { | ||||
| using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | ||||
| using DavinciModelPtr = std::shared_ptr<DavinciModel>; | using DavinciModelPtr = std::shared_ptr<DavinciModel>; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat | |||||
| bool support_mem_share) { | bool support_mem_share) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -24,6 +24,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace model_runner { | namespace model_runner { | ||||
| class Output { | class Output { | ||||
| public: | public: | ||||
| Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | ||||
| @@ -32,8 +33,7 @@ class Output { | |||||
| bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | ||||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||||
| bool support_mem_share); | |||||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||||
| // Copy assignment operator and copy constructor are deleted | // Copy assignment operator and copy constructor are deleted | ||||
| Output &operator=(const Output &output) = delete; | Output &operator=(const Output &output) = delete; | ||||
| @@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | ||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | ||||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||||
| : (RT_STREAM_PERSISTENT); | |||||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||||
| : (RT_STREAM_PERSISTENT); | |||||
| rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||||
| GELOGI("batch number:%u.", batch_num); | |||||
| for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||||
| rtLabel_t rt_lLabel = nullptr; | |||||
| rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
| return false; | |||||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
| label_list_.resize(davinci_model->GetBatchNum()); | |||||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
| if (task_info == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||||
| continue; | |||||
| } | |||||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
| continue; | |||||
| } | } | ||||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
| if (rt_lLabel == nullptr) { | |||||
| GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| label_list_.emplace_back(rt_lLabel); | |||||
| rtLabel_t rt_label = nullptr; | |||||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (!InitLabel(davinci_model->GetBatchNum())) { | |||||
| if (!InitLabel(davinci_model)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() { | |||||
| GELOGE(FAILED, "DistributeTask failed"); | GELOGE(FAILED, "DistributeTask failed"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -293,10 +303,14 @@ bool RuntimeModel::Run() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("Run rtModelExecute success"); | |||||
| GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||||
| ret = rtStreamSynchronize(rt_model_stream_); | ret = rtStreamSynchronize(rt_model_stream_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||||
| GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||||
| return true; | |||||
| } | |||||
| GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept { | |||||
| void RuntimeModel::RtLabelDestory() noexcept { | void RuntimeModel::RtLabelDestory() noexcept { | ||||
| for (size_t i = 0; i < label_list_.size(); i++) { | for (size_t i = 0; i < label_list_.size(); i++) { | ||||
| if (label_list_[i] == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | ||||
| return; | return; | ||||
| @@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
| /// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
| int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
| if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
| elem_num = 1; | |||||
| } | |||||
| int64_t elem_num = | |||||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
| return false; | return false; | ||||
| @@ -40,13 +40,11 @@ class RuntimeModel { | |||||
| const std::vector<uint32_t> &GetTaskIdList() const; | const std::vector<uint32_t> &GetTaskIdList() const; | ||||
| const std::vector<uint32_t> &GetStreamIdList() const; | const std::vector<uint32_t> &GetStreamIdList() const; | ||||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | ||||
| const rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
| rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
| bool Run(); | bool Run(); | ||||
| bool CopyInputData(const InputData &input_data); | bool CopyInputData(const InputData &input_data); | ||||
| bool GetInputOutputDescInfo(bool zero_copy, | |||||
| std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, | |||||
| std::vector<uint32_t> *input_format, | |||||
| bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
| std::vector<uint32_t> *output_format); | std::vector<uint32_t> *output_format); | ||||
| private: | private: | ||||
| @@ -55,7 +53,7 @@ class RuntimeModel { | |||||
| bool LoadTask(); | bool LoadTask(); | ||||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitEvent(uint32_t event_num); | bool InitEvent(uint32_t event_num); | ||||
| bool InitLabel(uint32_t batch_num); | |||||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
| @@ -87,6 +85,7 @@ class RuntimeModel { | |||||
| std::vector<uint32_t> stream_id_list_{}; | std::vector<uint32_t> stream_id_list_{}; | ||||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | ||||
| }; | }; | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||||
| task_info_(task_info), | task_info_(task_info), | ||||
| stream_(nullptr), | stream_(nullptr), | ||||
| args_(nullptr), | args_(nullptr), | ||||
| ext_info_(nullptr), | |||||
| input_output_addr_(nullptr) { | input_output_addr_(nullptr) { | ||||
| if (task_info_ == nullptr) { | if (task_info_ == nullptr) { | ||||
| GELOGW("task_info_ is null!"); | GELOGW("task_info_ is null!"); | ||||
| @@ -41,7 +42,10 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||||
| } | } | ||||
| } | } | ||||
| AicpuTask::~AicpuTask() { ReleaseRtMem(&args_); } | |||||
| AicpuTask::~AicpuTask() { | |||||
| ReleaseRtMem(&args_); | |||||
| ReleaseRtMem(&ext_info_); | |||||
| } | |||||
| bool AicpuTask::Distribute() { | bool AicpuTask::Distribute() { | ||||
| GELOGI("InitAicpuTask start."); | GELOGI("InitAicpuTask start."); | ||||
| @@ -51,10 +55,37 @@ bool AicpuTask::Distribute() { | |||||
| auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | ||||
| auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | ||||
| constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | ||||
| uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; | |||||
| uint32_t args_size = | |||||
| sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size()); | |||||
| aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; | |||||
| uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||||
| uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||||
| uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||||
| static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||||
| aicpu::AicpuParamHead aicpu_param_head; | |||||
| aicpu_param_head.length = args_size; | |||||
| aicpu_param_head.ioAddrNum = io_addrs_num; | |||||
| auto ext_info = task_info_->ext_info(); | |||||
| uint32_t ext_size = ext_info.size(); | |||||
| if (ext_info.empty()) { | |||||
| aicpu_param_head.extInfoLength = 0; | |||||
| aicpu_param_head.extInfoAddr = 0; | |||||
| } else { | |||||
| rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||||
| if (flag != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||||
| return false; | |||||
| } | |||||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (flag != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||||
| return false; | |||||
| } | |||||
| GELOGI("ext info size:", ext_size); | |||||
| aicpu_param_head.extInfoLength = ext_size; | |||||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||||
| } | |||||
| // Malloc device memory for args | // Malloc device memory for args | ||||
| rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | ||||
| @@ -80,6 +111,17 @@ bool AicpuTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| // Memcpy node def | |||||
| auto size = task_info_->node_def().size(); | |||||
| rt_ret = | |||||
| rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||||
| reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
| return false; | |||||
| } | |||||
| // Memcpy node def | // Memcpy node def | ||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | ||||
| task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | ||||
| @@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||||
| std::shared_ptr<AicpuTaskInfo> task_info_; | std::shared_ptr<AicpuTaskInfo> task_info_; | ||||
| void *stream_; | void *stream_; | ||||
| void *args_; | void *args_; | ||||
| void *ext_info_; | |||||
| void *input_output_addr_; | void *input_output_addr_; | ||||
| }; | }; | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| @@ -103,9 +103,9 @@ bool CceTask::Distribute() { | |||||
| // Modify flowtable addr in args | // Modify flowtable addr in args | ||||
| auto args = const_cast<uint8_t *>(task_info_->args().data()); | auto args = const_cast<uint8_t *>(task_info_->args().data()); | ||||
| auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | ||||
| if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | ||||
| GELOGE(FAILED, | |||||
| "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||||
| GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||||
| static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -136,8 +136,7 @@ bool CceTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), | |||||
| task_info_->sm_desc().data(), | |||||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||||
| task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| @@ -146,12 +145,8 @@ bool CceTask::Distribute() { | |||||
| } | } | ||||
| // Kernel launch | // Kernel launch | ||||
| rt_ret = rtKernelLaunch(stub_func_, | |||||
| task_info_->block_dim(), | |||||
| args_, | |||||
| task_info_->args_size(), | |||||
| static_cast<rtSmDesc_t *>(sm_desc_), | |||||
| stream_); | |||||
| rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||||
| static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| return false; | return false; | ||||
| @@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||||
| private: | private: | ||||
| std::shared_ptr<EventRecordTaskInfo> task_info_; | std::shared_ptr<EventRecordTaskInfo> task_info_; | ||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| rtEvent_t event_; | |||||
| rtEvent_t event_; | |||||
| }; | }; | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||||
| private: | private: | ||||
| std::shared_ptr<EventWaitTaskInfo> task_info_; | std::shared_ptr<EventWaitTaskInfo> task_info_; | ||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| rtEvent_t event_; | |||||
| rtEvent_t event_; | |||||
| }; | }; | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { | |||||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| (void)rtStreamDestroy(stream); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -129,8 +128,6 @@ bool HcclTask::Distribute() { | |||||
| ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ||||
| ge_task.stream = stream_; | ge_task.stream = stream_; | ||||
| GETaskKernelHcclInfo kernel_hccl_info; | |||||
| ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info); | |||||
| ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | ||||
| ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | ||||
| ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | ||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_goto_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| label_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelGotoTask::~LabelGotoTask() {} | |||||
| bool LabelGotoTask::Distribute() { | |||||
| GELOGI("LabelGotoTask Distribute start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
| public: | |||||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
| ~LabelGotoTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *label_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_set_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| label_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| auto stream_list = model_context.stream_list(); | |||||
| auto label_list = model_context.label_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| uint32_t label_id = task_info->label_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
| GELOGW("Stream/Label id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| label_ = label_list[label_id]; | |||||
| } | |||||
| LabelSetTask::~LabelSetTask() {} | |||||
| bool LabelSetTask::Distribute() { | |||||
| GELOGI("LabelSetTask Distribute start."); | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (label_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label is null!"); | |||||
| return false; | |||||
| } | |||||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
| public: | |||||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
| ~LabelSetTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| void *label_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
| @@ -0,0 +1,131 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_runtime/task/label_switch_task.h" | |||||
| #include "ge_runtime/task/task_factory.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
| task_info_(task_info), | |||||
| stream_(nullptr), | |||||
| all_label_resource_(), | |||||
| label_info_(nullptr) { | |||||
| if (task_info_ == nullptr) { | |||||
| GELOGW("task_info_ is null!"); | |||||
| return; | |||||
| } | |||||
| all_label_resource_ = model_context.label_list(); | |||||
| auto stream_list = model_context.stream_list(); | |||||
| uint32_t stream_id = task_info->stream_id(); | |||||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
| if (stream_id >= stream_list.size()) { | |||||
| GELOGW("Stream id invalid."); | |||||
| return; | |||||
| } | |||||
| stream_ = stream_list[stream_id]; | |||||
| } | |||||
| LabelSwitchTask::~LabelSwitchTask() { | |||||
| if (label_info_ != nullptr) { | |||||
| rtError_t rt_ret = rtFree(label_info_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
| } | |||||
| label_info_ = nullptr; | |||||
| } | |||||
| } | |||||
| bool LabelSwitchTask::Distribute() { | |||||
| GELOGI("LabelSwitchTask Distribute start."); | |||||
| if (!CheckParamValid()) { | |||||
| return false; | |||||
| } | |||||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
| uint32_t label_index = label_index_list[i]; | |||||
| if (label_index >= all_label_resource_.size()) { | |||||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
| all_label_resource_.size()); | |||||
| return false; | |||||
| } | |||||
| label_list[i] = all_label_resource_[label_index]; | |||||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
| } | |||||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return false; | |||||
| } | |||||
| GELOGI("DistributeTask end."); | |||||
| return true; | |||||
| } | |||||
| bool LabelSwitchTask::CheckParamValid() { | |||||
| if (stream_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_list().empty()) { | |||||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
| task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
| return false; | |||||
| } | |||||
| if (label_info_ != nullptr) { | |||||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| #include <memory> | |||||
| #include "ge_runtime/task/task.h" | |||||
| namespace ge { | |||||
| namespace model_runner { | |||||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
| public: | |||||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
| ~LabelSwitchTask() override; | |||||
| bool Distribute() override; | |||||
| private: | |||||
| bool CheckParamValid(); | |||||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
| void *stream_; | |||||
| std::vector<void *> all_label_resource_; | |||||
| void *label_info_; | |||||
| }; | |||||
| } // namespace model_runner | |||||
| } // namespace ge | |||||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
| @@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||||
| void *stream_; | void *stream_; | ||||
| std::vector<rtStream_t> stream_list_; | std::vector<rtStream_t> stream_list_; | ||||
| }; | }; | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | ||||
| @@ -42,7 +42,7 @@ class Task { | |||||
| template <class T> | template <class T> | ||||
| class TaskRepeater : public Task { | class TaskRepeater : public Task { | ||||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); /*lint !e30*/ | |||||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||||
| public: | public: | ||||
| TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | ||||
| @@ -81,6 +81,7 @@ class TaskFactory { | |||||
| std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | ||||
| return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | ||||
| }); | }); | ||||
| } // namespace model_runner | } // namespace model_runner | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | ||||
| @@ -27,10 +27,10 @@ namespace ge { | |||||
| namespace model_runner { | namespace model_runner { | ||||
| class DavinciModel { | class DavinciModel { | ||||
| public: | public: | ||||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, /*lint !e151*/ | |||||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||||
| const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | ||||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, /*lint !e151*/ | |||||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, /*lint !e1049*/ | |||||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||||
| const std::vector<model_runner::OpInfoPtr> &variable_info_list, | const std::vector<model_runner::OpInfoPtr> &variable_info_list, | ||||
| const std::vector<uint32_t> &wait_active_stream_list, | const std::vector<uint32_t> &wait_active_stream_list, | ||||
| const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | ||||
| @@ -68,12 +68,12 @@ class DavinciModel { | |||||
| uint32_t GetBatchNum() const { return batch_num_; } | uint32_t GetBatchNum() const { return batch_num_; } | ||||
| uint32_t GetEventNum() const { return event_num_; } | uint32_t GetEventNum() const { return event_num_; } | ||||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/ | |||||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/ | |||||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||||
| int32_t GetPriority() const { return priority_; } | int32_t GetPriority() const { return priority_; } | ||||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/ | |||||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||||
| const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | ||||
| const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | ||||
| const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | ||||
| @@ -81,7 +81,7 @@ class DavinciModel { | |||||
| private: | private: | ||||
| std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | ||||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; /*lint !e151*/ | |||||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_; | std::vector<std::shared_ptr<OpInfo>> output_info_list_; | ||||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | ||||
| std::vector<model_runner::OpInfoPtr> variable_info_list_; | std::vector<model_runner::OpInfoPtr> variable_info_list_; | ||||
| @@ -52,11 +52,8 @@ class ModelRunner { | |||||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | ||||
| bool GetInputOutputDescInfo(uint32_t model_id, | |||||
| bool zero_copy, | |||||
| std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, | |||||
| std::vector<uint32_t> *input_format, | |||||
| bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
| std::vector<uint32_t> *output_format); | std::vector<uint32_t> *output_format); | ||||
| private: | private: | ||||
| @@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { | |||||
| class AicpuTaskInfo : public TaskInfo { | class AicpuTaskInfo : public TaskInfo { | ||||
| public: | public: | ||||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | ||||
| const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||||
| const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | const std::vector<void *> &output_data_addrs, bool dump_flag) | ||||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | ||||
| so_name_(so_name), | so_name_(so_name), | ||||
| kernel_name_(kernel_name), | kernel_name_(kernel_name), | ||||
| node_def_(node_def), | node_def_(node_def), | ||||
| ext_info_(ext_info), | |||||
| input_data_addrs_(input_data_addrs), | input_data_addrs_(input_data_addrs), | ||||
| output_data_addrs_(output_data_addrs) {} | output_data_addrs_(output_data_addrs) {} | ||||
| ~AicpuTaskInfo() override {} | ~AicpuTaskInfo() override {} | ||||
| @@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { | |||||
| const std::string &node_def() const { return node_def_; } | const std::string &node_def() const { return node_def_; } | ||||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | ||||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | ||||
| const std::string &ext_info() const { return ext_info_; } | |||||
| private: | private: | ||||
| std::string so_name_; | std::string so_name_; | ||||
| std::string kernel_name_; | std::string kernel_name_; | ||||
| std::string node_def_; | std::string node_def_; | ||||
| std::string ext_info_; | |||||
| std::vector<void *> input_data_addrs_; | std::vector<void *> input_data_addrs_; | ||||
| std::vector<void *> output_data_addrs_; | std::vector<void *> output_data_addrs_; | ||||
| }; | }; | ||||
| @@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo { | |||||
| hcom_distribute_task_(hcom_distribute_task) {} | hcom_distribute_task_(hcom_distribute_task) {} | ||||
| ~HcclTaskInfo() override {} | ~HcclTaskInfo() override {} | ||||
| const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/ | |||||
| const std::string &hccl_type() const { return hccl_type_; } | |||||
| void *input_data_addr() const { return input_data_addr_; } | void *input_data_addr() const { return input_data_addr_; } | ||||
| void *output_data_addr() const { return output_data_addr_; } | void *output_data_addr() const { return output_data_addr_; } | ||||
| void *workspace_addr() const { return workspace_addr_; } | void *workspace_addr() const { return workspace_addr_; } | ||||
| int64_t workspace_size() const { return workspace_size_; } | int64_t workspace_size() const { return workspace_size_; } | ||||
| int64_t hccl_stream_num() const { return hccl_stream_num_; } | int64_t hccl_stream_num() const { return hccl_stream_num_; } | ||||
| const std::vector<uint8_t> &private_def() const { return private_def_; } /*lint !e1413*/ | |||||
| const std::vector<uint8_t> &private_def() const { return private_def_; } | |||||
| void *ops_kernel_store() const { return ops_kernel_store_; } | void *ops_kernel_store() const { return ops_kernel_store_; } | ||||
| int32_t count() const { return count_; } | int32_t count() const { return count_; } | ||||
| int64_t root_id() const { return root_id_; } | int64_t root_id() const { return root_id_; } | ||||
| int64_t op_type() const { return op_type_; } | int64_t op_type() const { return op_type_; } | ||||
| int64_t data_type() const { return data_type_; } | int64_t data_type() const { return data_type_; } | ||||
| const std::string group() const { return group_; } | |||||
| const std::string &group() const { return group_; } | |||||
| std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | ||||
| std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | ||||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | ||||