| @@ -8,6 +8,19 @@ if (NOT BUILD_PATH) | |||
| set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | |||
| endif() | |||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||
| else() | |||
| set(ASCEND_DIR /usr/local/Ascend) | |||
| endif() | |||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||
| set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||
| set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | |||
| if (ENABLE_OPEN_SRC) | |||
| @@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC) | |||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||
| endif() | |||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
| find_module(slog libslog.so ${GE_LIB_PATH}) | |||
| find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | |||
| find_module(msprof libmsprof.so ${GE_LIB_PATH}) | |||
| @@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC) | |||
| find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | |||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
| else() | |||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
| set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||
| else() | |||
| set(ASCEND_DIR /usr/local/Ascend) | |||
| endif() | |||
| set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||
| set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||
| set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||
| set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||
| set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||
| set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||
| set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||
| find_module(slog libslog.so ${ASCEND_ATC_DIR}) | |||
| find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | |||
| if(PLATFORM STREQUAL "train") | |||
| @@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC) | |||
| add_subdirectory(parser) | |||
| #add_subdirectory(metadef/graph) | |||
| #add_subdirectory(metadef/register) | |||
| elseif (ENABLE_D OR ENABLE_ACL) | |||
| # compiling with MindSpore | |||
| include(cmake/external_libs/protobuf_static.cmake) | |||
| include(cmake/external_libs/protoc.cmake) | |||
| include(cmake/external_libs/securec.cmake) | |||
| include(cmake/external_libs/json.cmake) | |||
| include(cmake/FindModule.cmake) | |||
| include(cmake/intf_pub_linux.cmake) | |||
| # common modules | |||
| find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
| find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
| if (ENABLE_D) | |||
| # training | |||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
| find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||
| find_module(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||
| find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||
| elseif(ENABLE_ACL) | |||
| # inference | |||
| find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
| find_module(resource libresource.so ${ASCEND_ATC_DIR}) | |||
| find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||
| endif () | |||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
| add_subdirectory(metadef) | |||
| else() | |||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||
| @@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES | |||
| add_library(ascend_protobuf_static INTERFACE) | |||
| target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | |||
| target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | |||
| if (ENABLE_D OR ENABLE_ACL) | |||
| include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) | |||
| endif () | |||
| add_dependencies(ascend_protobuf_static protobuf_static_build) | |||
| @@ -1,10 +1,15 @@ | |||
| add_subdirectory(common) | |||
| add_subdirectory(plugin/engine) | |||
| add_subdirectory(graph/build/memory) | |||
| add_subdirectory(ge_local_engine) | |||
| add_subdirectory(host_cpu_engine) | |||
| add_subdirectory(executor) | |||
| add_subdirectory(offline) | |||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
| add_subdirectory(common) | |||
| add_subdirectory(plugin/engine) | |||
| add_subdirectory(graph/build/memory) | |||
| add_subdirectory(ge_local_engine) | |||
| add_subdirectory(host_cpu_engine) | |||
| add_subdirectory(executor) | |||
| add_subdirectory(offline) | |||
| else() | |||
| add_subdirectory(common) | |||
| add_subdirectory(ge_runtime) | |||
| endif () | |||
| set(PROTO_LIST | |||
| "${METADEF_DIR}/proto/fusion_model.proto" | |||
| @@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | |||
| protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | |||
| ############ libge_runner.so ############ | |||
| set(TRAIN_SRC_LIST | |||
| "common/formats/format_transfers/datatype_transfer.cc" | |||
| "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | |||
| @@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST | |||
| "ir_build/atc_ir_common.cc" | |||
| ) | |||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||
| target_compile_definitions(ge_runner PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| DAVINCI_SUPPORT_PROFILING | |||
| REUSE_MEMORY=1 | |||
| FMK_SUPPORT_DUMP | |||
| DAVINCI_CLOUD | |||
| google=ascend_private | |||
| ) | |||
| target_compile_options(ge_runner PRIVATE | |||
| -O2 | |||
| ) | |||
| target_include_directories(ge_runner PRIVATE | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/ge/analyzer | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${GE_CODE_DIR}/inc/framework/common | |||
| ${METADEF_DIR} | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| #### yellow zone #### | |||
| ${GE_CODE_DIR}/../inc | |||
| ${GE_CODE_DIR}/../inc/external | |||
| ${GE_CODE_DIR}/../inc/cce | |||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
| #### blue zone | |||
| ${ASCEND_DIR}/driver/include | |||
| ${ASCEND_DIR}/fwkacllib/include | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
| ) | |||
| target_link_libraries(ge_runner | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| ge_memory | |||
| adump_server | |||
| msprofiler | |||
| -Wl,--no-as-needed | |||
| graph | |||
| ge_common | |||
| ascend_protobuf | |||
| register | |||
| c_sec | |||
| slog | |||
| mmpa | |||
| msprof | |||
| runtime | |||
| resource | |||
| error_manager | |||
| ascend_hal_stub | |||
| -Wl,--as-needed | |||
| json | |||
| -lrt | |||
| -ldl | |||
| ) | |||
| ############ libge_compiler.so ############ | |||
| set(INFER_SRC_LIST | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "omm/csa_interact.cc" | |||
| @@ -662,6 +600,74 @@ set(INFER_SRC_LIST | |||
| "analyzer/analyzer.cc" | |||
| ) | |||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
| ############ libge_runner.so ############ | |||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||
| target_compile_definitions(ge_runner PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| DAVINCI_SUPPORT_PROFILING | |||
| REUSE_MEMORY=1 | |||
| FMK_SUPPORT_DUMP | |||
| DAVINCI_CLOUD | |||
| google=ascend_private | |||
| ) | |||
| target_compile_options(ge_runner PRIVATE | |||
| -O2 | |||
| ) | |||
| target_include_directories(ge_runner PRIVATE | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/ge/analyzer | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${GE_CODE_DIR}/inc/framework/common | |||
| ${METADEF_DIR} | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| #### yellow zone #### | |||
| ${GE_CODE_DIR}/../inc | |||
| ${GE_CODE_DIR}/../inc/external | |||
| ${GE_CODE_DIR}/../inc/cce | |||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
| #### blue zone | |||
| ${ASCEND_DIR}/driver/include | |||
| ${ASCEND_DIR}/fwkacllib/include | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
| ) | |||
| target_link_libraries(ge_runner | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| ge_memory | |||
| adump_server | |||
| msprofiler | |||
| -Wl,--no-as-needed | |||
| graph | |||
| ge_common | |||
| ascend_protobuf | |||
| register | |||
| c_sec | |||
| slog | |||
| mmpa | |||
| msprof | |||
| runtime | |||
| resource | |||
| error_manager | |||
| ascend_hal_stub | |||
| -Wl,--as-needed | |||
| json | |||
| -lrt | |||
| -ldl | |||
| ) | |||
| ############ libge_compiler.so ############ | |||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||
| target_compile_definitions(ge_compiler PRIVATE | |||
| @@ -919,3 +925,70 @@ install(FILES | |||
| ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | |||
| DESTINATION ${INSTALL_LIBRARY_DIR} | |||
| ) | |||
| elseif (ENABLE_ACL) | |||
| ############ libge_compiler.so w/static protobuf ############ | |||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||
| target_compile_definitions(ge_compiler PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| REUSE_MEMORY=1 | |||
| FMK_SUPPORT_DUMP | |||
| FMK_HOST_INFER | |||
| COMPILE_OMG_PACKAGE | |||
| google=ascend_private | |||
| ) | |||
| target_compile_options(ge_compiler PRIVATE | |||
| -O2 | |||
| ) | |||
| target_include_directories(ge_compiler PRIVATE | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/ge/analyzer | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${GE_CODE_DIR}/inc/framework/common | |||
| ${METADEF_DIR} | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| ${ASCEND_DIR}/driver/include | |||
| ${ASCEND_DIR}/fwkacllib/include | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
| ) | |||
| target_link_libraries(ge_compiler | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| ge_memory | |||
| -Wl,--no-as-needed | |||
| graph | |||
| ge_common | |||
| static_ascend_protobuf | |||
| register | |||
| c_sec | |||
| error_manager | |||
| slog | |||
| mmpa | |||
| runtime_compile | |||
| resource | |||
| -Wl,--as-needed | |||
| json | |||
| -lrt | |||
| -ldl | |||
| ) | |||
| ############ install libge_compiler for MindSpore############ | |||
| set(INSTALL_BASE_DIR "") | |||
| set(INSTALL_LIBRARY_DIR lib) | |||
| install(TARGETS ge_compiler OPTIONAL | |||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||
| ) | |||
| endif() | |||
| @@ -63,6 +63,7 @@ set(SRC_LIST | |||
| "ge/tbe_plugin_manager.cc" | |||
| ) | |||
| if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
| ############ libge_common.so ############ | |||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||
| target_compile_definitions(ge_common PRIVATE | |||
| @@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE | |||
| -ldl | |||
| ) | |||
| else () | |||
| ############ libge_common.so w/static protobuf ############ | |||
| add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||
| target_compile_definitions(ge_common PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| HOST_VISIBILITY | |||
| FMK_SUPPORT_DUMP | |||
| OS_CENTOS | |||
| google=ascend_private | |||
| ) | |||
| target_compile_options(ge_common PRIVATE | |||
| -fvisibility=hidden | |||
| -O2 | |||
| -Werror | |||
| ) | |||
| target_include_directories(ge_common PRIVATE | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/ge/common | |||
| ${GE_CODE_DIR}/ge/common/op | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
| ) | |||
| target_link_libraries(ge_common PRIVATE | |||
| $<BUILD_INTERFACE:intf_pub> | |||
| ascend_protobuf_static | |||
| -Wl,--no-as-needed | |||
| graph | |||
| register | |||
| c_sec | |||
| error_manager | |||
| slog | |||
| mmpa | |||
| -Wl,--as-needed | |||
| json | |||
| -lrt | |||
| -ldl | |||
| ) | |||
| endif () | |||
| ############ install ############ | |||
| set(INSTALL_BASE_DIR "") | |||
| set(INSTALL_LIBRARY_DIR lib) | |||
| @@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE | |||
| ) | |||
| target_include_directories(ge_runtime PRIVATE | |||
| ${TOP_DIR} | |||
| ${TOP_DIR}/inc | |||
| ${TOP_DIR}/inc/graph | |||
| ${TOP_DIR}/inc/external | |||
| ${TOP_DIR}/inc/framework | |||
| ${TOP_DIR}/inc/framework/common | |||
| ${TOP_DIR}/inc/framework/ge_runtime | |||
| ${TOP_DIR}/inc/cce | |||
| ${CMAKE_CURRENT_LIST_DIR} | |||
| ${GE_CODE_DIR} | |||
| ${GE_CODE_DIR}/ge | |||
| ${GE_CODE_DIR}/inc | |||
| ${GE_CODE_DIR}/inc/graph | |||
| ${GE_CODE_DIR}/inc/external | |||
| ${GE_CODE_DIR}/inc/framework | |||
| ${GE_CODE_DIR}/inc/framework/common | |||
| ${GE_CODE_DIR}/inc/framework/ge_runtime | |||
| ${GE_CODE_DIR}/inc/cce | |||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
| ${METADEF_DIR} | |||
| ${METADEF_DIR}/inc | |||
| ${METADEF_DIR}/inc/external/graph | |||
| ${METADEF_DIR}/inc/external | |||
| ${METADEF_DIR}/inc/graph | |||
| ${CMAKE_BINARY_DIR} | |||
| ${CMAKE_BINARY_DIR}/proto/ge | |||
| ) | |||
| @@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE | |||
| slog | |||
| runtime | |||
| c_sec | |||
| graph | |||
| -Wl,--as-needed | |||
| -lrt | |||
| -ldl | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -27,8 +27,13 @@ class ModelContext { | |||
| ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | |||
| rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | |||
| const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | |||
| : device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle), | |||
| rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list), | |||
| : device_id_(device_id), | |||
| session_id_(session_id), | |||
| priority_(priority), | |||
| rt_model_handle_(rt_model_handle), | |||
| rt_model_stream_(rt_model_stream), | |||
| stream_list_(stream_list), | |||
| label_list_(label_list), | |||
| event_list_(event_list) {} | |||
| ~ModelContext() {} | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -24,6 +24,7 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | |||
| using DavinciModelPtr = std::shared_ptr<DavinciModel>; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat | |||
| bool support_mem_share) { | |||
| return true; | |||
| } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -24,6 +24,7 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class Output { | |||
| public: | |||
| Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | |||
| @@ -32,8 +33,7 @@ class Output { | |||
| bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | |||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||
| bool support_mem_share); | |||
| bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||
| // Copy assignment operator and copy constructor are deleted | |||
| Output &operator=(const Output &output) = delete; | |||
| @@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | |||
| rtStream_t stream = nullptr; | |||
| uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | |||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
| : (RT_STREAM_PERSISTENT); | |||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
| : (RT_STREAM_PERSISTENT); | |||
| rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| @@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||
| GELOGI("batch number:%u.", batch_num); | |||
| for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||
| rtLabel_t rt_lLabel = nullptr; | |||
| rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
| return false; | |||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
| label_list_.resize(davinci_model->GetBatchNum()); | |||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
| if (task_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||
| continue; | |||
| } | |||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
| continue; | |||
| } | |||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
| if (rt_lLabel == nullptr) { | |||
| GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
| return false; | |||
| } | |||
| label_list_.emplace_back(rt_lLabel); | |||
| rtLabel_t rt_label = nullptr; | |||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| return false; | |||
| } | |||
| if (!InitLabel(davinci_model->GetBatchNum())) { | |||
| if (!InitLabel(davinci_model)) { | |||
| return false; | |||
| } | |||
| @@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() { | |||
| GELOGE(FAILED, "DistributeTask failed"); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -293,10 +303,14 @@ bool RuntimeModel::Run() { | |||
| return false; | |||
| } | |||
| GELOGI("Run rtModelExecute success"); | |||
| GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||
| ret = rtStreamSynchronize(rt_model_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||
| GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||
| return true; | |||
| } | |||
| GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | |||
| return false; | |||
| } | |||
| @@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept { | |||
| void RuntimeModel::RtLabelDestory() noexcept { | |||
| for (size_t i = 0; i < label_list_.size(); i++) { | |||
| if (label_list_[i] == nullptr) { | |||
| continue; | |||
| } | |||
| if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | |||
| return; | |||
| @@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | |||
| /// and that of unknown shape is zero too. | |||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | |||
| int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||
| if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||
| elem_num = 1; | |||
| } | |||
| int64_t elem_num = | |||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | |||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | |||
| return false; | |||
| @@ -40,13 +40,11 @@ class RuntimeModel { | |||
| const std::vector<uint32_t> &GetTaskIdList() const; | |||
| const std::vector<uint32_t> &GetStreamIdList() const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||
| const rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
| rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
| bool Run(); | |||
| bool CopyInputData(const InputData &input_data); | |||
| bool GetInputOutputDescInfo(bool zero_copy, | |||
| std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, | |||
| std::vector<uint32_t> *input_format, | |||
| bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
| std::vector<uint32_t> *output_format); | |||
| private: | |||
| @@ -55,7 +53,7 @@ class RuntimeModel { | |||
| bool LoadTask(); | |||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitEvent(uint32_t event_num); | |||
| bool InitLabel(uint32_t batch_num); | |||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| @@ -87,6 +85,7 @@ class RuntimeModel { | |||
| std::vector<uint32_t> stream_id_list_{}; | |||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| args_(nullptr), | |||
| ext_info_(nullptr), | |||
| input_output_addr_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| @@ -41,7 +42,10 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||
| } | |||
| } | |||
| AicpuTask::~AicpuTask() { ReleaseRtMem(&args_); } | |||
| AicpuTask::~AicpuTask() { | |||
| ReleaseRtMem(&args_); | |||
| ReleaseRtMem(&ext_info_); | |||
| } | |||
| bool AicpuTask::Distribute() { | |||
| GELOGI("InitAicpuTask start."); | |||
| @@ -51,10 +55,37 @@ bool AicpuTask::Distribute() { | |||
| auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | |||
| auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | |||
| constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | |||
| uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; | |||
| uint32_t args_size = | |||
| sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size()); | |||
| aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; | |||
| uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||
| uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||
| uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||
| static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||
| aicpu::AicpuParamHead aicpu_param_head; | |||
| aicpu_param_head.length = args_size; | |||
| aicpu_param_head.ioAddrNum = io_addrs_num; | |||
| auto ext_info = task_info_->ext_info(); | |||
| uint32_t ext_size = ext_info.size(); | |||
| if (ext_info.empty()) { | |||
| aicpu_param_head.extInfoLength = 0; | |||
| aicpu_param_head.extInfoAddr = 0; | |||
| } else { | |||
| rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||
| if (flag != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||
| return false; | |||
| } | |||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (flag != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||
| return false; | |||
| } | |||
| GELOGI("ext info size:", ext_size); | |||
| aicpu_param_head.extInfoLength = ext_size; | |||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||
| } | |||
| // Malloc device memory for args | |||
| rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
| @@ -80,6 +111,17 @@ bool AicpuTask::Distribute() { | |||
| return false; | |||
| } | |||
| } | |||
| // Memcpy node def | |||
| auto size = task_info_->node_def().size(); | |||
| rt_ret = | |||
| rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||
| reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| // Memcpy node def | |||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | |||
| task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | |||
| @@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||
| std::shared_ptr<AicpuTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *args_; | |||
| void *ext_info_; | |||
| void *input_output_addr_; | |||
| }; | |||
| } // namespace model_runner | |||
| @@ -103,9 +103,9 @@ bool CceTask::Distribute() { | |||
| // Modify flowtable addr in args | |||
| auto args = const_cast<uint8_t *>(task_info_->args().data()); | |||
| auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | |||
| if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | |||
| GELOGE(FAILED, | |||
| "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||
| GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||
| static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | |||
| return false; | |||
| } | |||
| @@ -136,8 +136,7 @@ bool CceTask::Distribute() { | |||
| return false; | |||
| } | |||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), | |||
| task_info_->sm_desc().data(), | |||
| rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||
| task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| @@ -146,12 +145,8 @@ bool CceTask::Distribute() { | |||
| } | |||
| // Kernel launch | |||
| rt_ret = rtKernelLaunch(stub_func_, | |||
| task_info_->block_dim(), | |||
| args_, | |||
| task_info_->args_size(), | |||
| static_cast<rtSmDesc_t *>(sm_desc_), | |||
| stream_); | |||
| rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||
| static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| @@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||
| private: | |||
| std::shared_ptr<EventRecordTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||
| private: | |||
| std::shared_ptr<EventWaitTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { | |||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| (void)rtStreamDestroy(stream); | |||
| return false; | |||
| } | |||
| @@ -129,8 +128,6 @@ bool HcclTask::Distribute() { | |||
| ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | |||
| ge_task.stream = stream_; | |||
| GETaskKernelHcclInfo kernel_hccl_info; | |||
| ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info); | |||
| ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||
| ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||
| ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_goto_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() {} | |||
| bool LabelGotoTask::Distribute() { | |||
| GELOGI("LabelGotoTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| public: | |||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
| ~LabelGotoTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_set_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelSetTask::~LabelSetTask() {} | |||
| bool LabelSetTask::Distribute() { | |||
| GELOGI("LabelSetTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
| public: | |||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
| ~LabelSetTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| @@ -0,0 +1,131 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_switch_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| all_label_resource_(), | |||
| label_info_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| all_label_resource_ = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() { | |||
| if (label_info_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| label_info_ = nullptr; | |||
| } | |||
| } | |||
| bool LabelSwitchTask::Distribute() { | |||
| GELOGI("LabelSwitchTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
| uint32_t label_index = label_index_list[i]; | |||
| if (label_index >= all_label_resource_.size()) { | |||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
| all_label_resource_.size()); | |||
| return false; | |||
| } | |||
| label_list[i] = all_label_resource_[label_index]; | |||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelSwitchTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (task_info_->label_list().empty()) { | |||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
| task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| public: | |||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
| ~LabelSwitchTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<void *> all_label_resource_; | |||
| void *label_info_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| @@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||
| void *stream_; | |||
| std::vector<rtStream_t> stream_list_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| @@ -42,7 +42,7 @@ class Task { | |||
| template <class T> | |||
| class TaskRepeater : public Task { | |||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); /*lint !e30*/ | |||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||
| public: | |||
| TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | |||
| @@ -81,6 +81,7 @@ class TaskFactory { | |||
| std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | |||
| return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | |||
| }); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| @@ -27,10 +27,10 @@ namespace ge { | |||
| namespace model_runner { | |||
| class DavinciModel { | |||
| public: | |||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, /*lint !e151*/ | |||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, /*lint !e151*/ | |||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, /*lint !e1049*/ | |||
| const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||
| const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||
| const std::vector<model_runner::OpInfoPtr> &variable_info_list, | |||
| const std::vector<uint32_t> &wait_active_stream_list, | |||
| const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||
| @@ -68,12 +68,12 @@ class DavinciModel { | |||
| uint32_t GetBatchNum() const { return batch_num_; } | |||
| uint32_t GetEventNum() const { return event_num_; } | |||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/ | |||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/ | |||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||
| int32_t GetPriority() const { return priority_; } | |||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/ | |||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | |||
| const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | |||
| @@ -81,7 +81,7 @@ class DavinciModel { | |||
| private: | |||
| std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; /*lint !e151*/ | |||
| std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> output_info_list_; | |||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | |||
| std::vector<model_runner::OpInfoPtr> variable_info_list_; | |||
| @@ -52,11 +52,8 @@ class ModelRunner { | |||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||
| bool GetInputOutputDescInfo(uint32_t model_id, | |||
| bool zero_copy, | |||
| std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, | |||
| std::vector<uint32_t> *input_format, | |||
| bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
| std::vector<uint32_t> *output_format); | |||
| private: | |||
| @@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { | |||
| class AicpuTaskInfo : public TaskInfo { | |||
| public: | |||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||
| const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||
| const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
| so_name_(so_name), | |||
| kernel_name_(kernel_name), | |||
| node_def_(node_def), | |||
| ext_info_(ext_info), | |||
| input_data_addrs_(input_data_addrs), | |||
| output_data_addrs_(output_data_addrs) {} | |||
| ~AicpuTaskInfo() override {} | |||
| @@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { | |||
| const std::string &node_def() const { return node_def_; } | |||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
| const std::string &ext_info() const { return ext_info_; } | |||
| private: | |||
| std::string so_name_; | |||
| std::string kernel_name_; | |||
| std::string node_def_; | |||
| std::string ext_info_; | |||
| std::vector<void *> input_data_addrs_; | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| @@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo { | |||
| hcom_distribute_task_(hcom_distribute_task) {} | |||
| ~HcclTaskInfo() override {} | |||
| const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/ | |||
| const std::string &hccl_type() const { return hccl_type_; } | |||
| void *input_data_addr() const { return input_data_addr_; } | |||
| void *output_data_addr() const { return output_data_addr_; } | |||
| void *workspace_addr() const { return workspace_addr_; } | |||
| int64_t workspace_size() const { return workspace_size_; } | |||
| int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||
| const std::vector<uint8_t> &private_def() const { return private_def_; } /*lint !e1413*/ | |||
| const std::vector<uint8_t> &private_def() const { return private_def_; } | |||
| void *ops_kernel_store() const { return ops_kernel_store_; } | |||
| int32_t count() const { return count_; } | |||
| int64_t root_id() const { return root_id_; } | |||
| int64_t op_type() const { return op_type_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| const std::string group() const { return group_; } | |||
| const std::string &group() const { return group_; } | |||
| std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | |||
| std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | |||