Merge pull request !1880 from mindspore_ding/code_sync_0626tags/v1.3.0
| @@ -1,8 +1,8 @@ | |||||
| [submodule "parser"] | [submodule "parser"] | ||||
| path = parser | path = parser | ||||
| url = https://gitee.com/ascend/parser.git | url = https://gitee.com/ascend/parser.git | ||||
| branch = master | |||||
| branch = r1.5.0 | |||||
| [submodule "metadef"] | [submodule "metadef"] | ||||
| path = metadef | path = metadef | ||||
| url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
| branch = master | |||||
| branch = r1.5.0 | |||||
| @@ -95,6 +95,7 @@ else () | |||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
| else() | else() | ||||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | find_module(slog libalog.so ${ASCEND_ATC_DIR}) | ||||
| find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | ||||
| if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
| @@ -355,13 +355,13 @@ generate_package() | |||||
| if [ "x${PLATFORM}" = "xtrain" ] | if [ "x${PLATFORM}" = "xtrain" ] | ||||
| then | then | ||||
| tar -cf graphengine_lib.tar fwkacllib | |||||
| tar -zcf graphengine_lib.tar fwkacllib | |||||
| elif [ "x${PLATFORM}" = "xinference" ] | elif [ "x${PLATFORM}" = "xinference" ] | ||||
| then | then | ||||
| tar -cf graphengine_lib.tar acllib atc | |||||
| tar -zcf graphengine_lib.tar acllib atc | |||||
| elif [ "x${PLATFORM}" = "xall" ] | elif [ "x${PLATFORM}" = "xall" ] | ||||
| then | then | ||||
| tar -cf graphengine_lib.tar fwkacllib acllib atc | |||||
| tar -zcf graphengine_lib.tar fwkacllib acllib atc | |||||
| fi | fi | ||||
| } | } | ||||
| @@ -371,6 +371,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||||
| then | then | ||||
| cd "${OUTPUT_PATH}" | cd "${OUTPUT_PATH}" | ||||
| find ./ -name graphengine_lib.tar -exec rm {} \; | find ./ -name graphengine_lib.tar -exec rm {} \; | ||||
| tar -cf graphengine_lib.tar lib | |||||
| tar -zcf graphengine_lib.tar lib | |||||
| fi | fi | ||||
| echo "---------------- GraphEngine package archive generated ----------------" | echo "---------------- GraphEngine package archive generated ----------------" | ||||
| @@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (GE_PB_PKG) | if (GE_PB_PKG) | ||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||||
| else() | else() | ||||
| if (ENABLE_GITEE) | if (ENABLE_GITEE) | ||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| @@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/ | |||||
| set(INSTALL_BASE_DIR "") | set(INSTALL_BASE_DIR "") | ||||
| set(INSTALL_LIBRARY_DIR lib) | set(INSTALL_LIBRARY_DIR lib) | ||||
| install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL | |||||
| install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | DESTINATION ${INSTALL_LIBRARY_DIR}) | ||||
| install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL | install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL | ||||
| DESTINATION ${INSTALL_LIBRARY_DIR}) | DESTINATION ${INSTALL_LIBRARY_DIR}) | ||||
| @@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| endif() | endif() | ||||
| if(GE_PB_PKG) | if(GE_PB_PKG) | ||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||||
| else() | else() | ||||
| if (ENABLE_GITEE) | if (ENABLE_GITEE) | ||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| @@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||||
| set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | ||||
| ExternalProject_Add(protobuf_static_build | ExternalProject_Add(protobuf_static_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | |||||
| TLS_VERIFY OFF | TLS_VERIFY OFF | ||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||||
| @@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| endif() | endif() | ||||
| if(GE_PB_PKG) | if(GE_PB_PKG) | ||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") | |||||
| else() | else() | ||||
| if (ENABLE_GITEE) | if (ENABLE_GITEE) | ||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | |||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | |||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| @@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(protoc_build | ExternalProject_Add(protoc_build | ||||
| URL ${REQ_URL} | URL ${REQ_URL} | ||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||||
| #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | |||||
| TLS_VERIFY OFF | TLS_VERIFY OFF | ||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | ||||
| BUILD_COMMAND $(MAKE) | BUILD_COMMAND $(MAKE) | ||||
| @@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST | |||||
| "graph/load/model_manager/task_info/model_exit_task_info.cc" | "graph/load/model_manager/task_info/model_exit_task_info.cc" | ||||
| "graph/load/model_manager/task_info/event_record_task_info.cc" | "graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
| "graph/load/model_manager/task_info/event_wait_task_info.cc" | "graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
| "graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
| "graph/load/model_manager/task_info/fusion_start_task_info.cc" | "graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
| "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
| "graph/load/model_manager/task_info/hccl_task_info.cc" | "graph/load/model_manager/task_info/hccl_task_info.cc" | ||||
| @@ -433,6 +434,7 @@ set(TRAIN_SRC_LIST | |||||
| "graph/build/memory/max_block_mem_assigner.cc" | "graph/build/memory/max_block_mem_assigner.cc" | ||||
| "graph/build/memory/var_mem_assign_util.cc" | "graph/build/memory/var_mem_assign_util.cc" | ||||
| "graph/build/memory/buffer_pool_mem_assigner.cc" | "graph/build/memory/buffer_pool_mem_assigner.cc" | ||||
| "ge_opt_info/ge_opt_info.cc" | |||||
| ) | ) | ||||
| set(INFER_SRC_LIST | set(INFER_SRC_LIST | ||||
| @@ -662,6 +664,7 @@ set(INFER_SRC_LIST | |||||
| "graph/load/model_manager/task_info/task_info.cc" | "graph/load/model_manager/task_info/task_info.cc" | ||||
| "graph/load/model_manager/task_info/event_record_task_info.cc" | "graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
| "graph/load/model_manager/task_info/event_wait_task_info.cc" | "graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
| "graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
| "graph/load/model_manager/task_info/fusion_start_task_info.cc" | "graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
| "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
| "graph/load/model_manager/task_info/kernel_ex_task_info.cc" | "graph/load/model_manager/task_info/kernel_ex_task_info.cc" | ||||
| @@ -709,6 +712,7 @@ set(INFER_SRC_LIST | |||||
| "graph/build/memory/max_block_mem_assigner.cc" | "graph/build/memory/max_block_mem_assigner.cc" | ||||
| "graph/build/memory/var_mem_assign_util.cc" | "graph/build/memory/var_mem_assign_util.cc" | ||||
| "graph/build/memory/buffer_pool_mem_assigner.cc" | "graph/build/memory/buffer_pool_mem_assigner.cc" | ||||
| "ge_opt_info/ge_opt_info.cc" | |||||
| ) | ) | ||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | ||||
| @@ -770,11 +774,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE | |||||
| ${GE_CODE_DIR}/../inc/cce | ${GE_CODE_DIR}/../inc/cce | ||||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ||||
| ${GE_CODE_DIR}/../abl/adump/external | ${GE_CODE_DIR}/../abl/adump/external | ||||
| ${GE_CODE_DIR}/../abl/licctrl | |||||
| #### blue zone | #### blue zone | ||||
| ${ASCEND_DIR}/driver/include | ${ASCEND_DIR}/driver/include | ||||
| ${ASCEND_DIR}/fwkacllib/include | ${ASCEND_DIR}/fwkacllib/include | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | ${GE_CODE_DIR}/third_party/fwkacllib/inc | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||||
| ) | ) | ||||
| target_link_options(ge_runner PRIVATE | target_link_options(ge_runner PRIVATE | ||||
| @@ -797,6 +803,7 @@ target_link_libraries(ge_runner PRIVATE | |||||
| runtime | runtime | ||||
| error_manager | error_manager | ||||
| ascend_hal_stub | ascend_hal_stub | ||||
| opt_feature | |||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| json | json | ||||
| -lrt | -lrt | ||||
| @@ -851,11 +858,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE | |||||
| ${GE_CODE_DIR}/../inc/cce | ${GE_CODE_DIR}/../inc/cce | ||||
| ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ||||
| ${GE_CODE_DIR}/../abl/adump/external | ${GE_CODE_DIR}/../abl/adump/external | ||||
| ${GE_CODE_DIR}/../abl/licctrl | |||||
| #### blue zone #### | #### blue zone #### | ||||
| ${ASCEND_DIR}/driver/include | ${ASCEND_DIR}/driver/include | ||||
| ${ASCEND_DIR}/fwkacllib/include | ${ASCEND_DIR}/fwkacllib/include | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | ${GE_CODE_DIR}/third_party/fwkacllib/inc | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||||
| ) | ) | ||||
| target_link_options(ge_compiler PRIVATE | target_link_options(ge_compiler PRIVATE | ||||
| @@ -875,6 +884,7 @@ target_link_libraries(ge_compiler PRIVATE | |||||
| error_manager | error_manager | ||||
| slog | slog | ||||
| runtime_compile | runtime_compile | ||||
| opt_feature | |||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| json | json | ||||
| -lrt | -lrt | ||||
| @@ -1 +0,0 @@ | |||||
| ../../proto/ge_api.proto | |||||
| @@ -1,193 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| DT_VARIANT = 26; // variant type | |||||
| DT_BF16 = 27; // bf16 type | |||||
| DT_INT4 = 28; // int4 type | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -1,140 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| float padding_value = 72; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -1,396 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -106,6 +106,7 @@ target_link_libraries(ge_common PRIVATE | |||||
| c_sec | c_sec | ||||
| error_manager | error_manager | ||||
| slog | slog | ||||
| opt_feature | |||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| json | json | ||||
| $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> | $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> | ||||
| @@ -33,7 +33,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetIn | |||||
| bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { | bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { | ||||
| if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { | if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { | ||||
| dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||||
| dump_properties_map_[kInferSessionId] = dump_properties; | |||||
| GELOGI("Dump does not open"); | GELOGI("Dump does not open"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump | |||||
| if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && | if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && | ||||
| dump_config.dump_debug == kDumpoff) { | dump_config.dump_debug == kDumpoff) { | ||||
| dump_properties.ClearDumpPropertyValue(); | dump_properties.ClearDumpPropertyValue(); | ||||
| dump_properties_map_.emplace(kInferSessionId, dump_properties); | |||||
| dump_properties_map_[kInferSessionId] = dump_properties; | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { | if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { | ||||
| @@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff | |||||
| } | } | ||||
| } | } | ||||
| void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||||
| void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, | |||||
| string &caffe_parser_path, int recursive_depth) { | |||||
| static const int kMaxRecursiveDepth = 20; // For recursive depth protection | |||||
| if (recursive_depth >= kMaxRecursiveDepth) { | |||||
| GELOGW("Recursive depth is become %d, Please check input!", recursive_depth); | |||||
| return; | |||||
| } | |||||
| // Path, change to absolute path | // Path, change to absolute path | ||||
| string real_path = RealPath(path.c_str()); | string real_path = RealPath(path.c_str()); | ||||
| // Plugin path does not exist | // Plugin path does not exist | ||||
| @@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_lis | |||||
| ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | ||||
| aicpu_host_so_suff); | aicpu_host_so_suff); | ||||
| } else { | } else { | ||||
| FindParserSo(full_name, file_list, caffe_parser_path); | |||||
| FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1); | |||||
| } | } | ||||
| } | } | ||||
| mmScandirFree(entries, ret); | mmScandirFree(entries, ret); | ||||
| @@ -57,7 +57,8 @@ class TBEPluginManager { | |||||
| static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | ||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | const string &caffe_parser_so_suff, const string &aicpu_so_suff, | ||||
| const string &aicpu_host_so_suff); | const string &aicpu_host_so_suff); | ||||
| static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||||
| static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path, | |||||
| int recursive_depth = 0); | |||||
| static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | ||||
| static void GetCustomOpPath(std::string &customop_path); | static void GetCustomOpPath(std::string &customop_path); | ||||
| void LoadCustomOpLib(); | void LoadCustomOpLib(); | ||||
| @@ -1,193 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| DT_VARIANT = 26; // variant type | |||||
| DT_BF16 = 27; // bf16 type | |||||
| DT_INT4 = 28; // int4 type | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -1,140 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| float padding_value = 72; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -1,396 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -1,75 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package toolkit.aicpu.dump; | |||||
| message Shape { | |||||
| repeated uint64 dim = 1; | |||||
| } | |||||
| message Output { | |||||
| int32 data_type = 1; | |||||
| int32 format = 2; | |||||
| Shape shape = 3; | |||||
| uint64 address = 4; | |||||
| string original_name = 5; | |||||
| int32 original_output_index = 6; | |||||
| int32 original_output_data_type = 7; | |||||
| int32 original_output_format = 8; | |||||
| uint64 size = 9; | |||||
| Shape origin_shape = 10; | |||||
| } | |||||
| message Input { | |||||
| int32 data_type =1; | |||||
| int32 format = 2; | |||||
| Shape shape = 3; | |||||
| uint64 address = 4; | |||||
| uint64 size = 5; | |||||
| Shape origin_shape = 6; | |||||
| } | |||||
| enum BufferType { | |||||
| L1 = 0; | |||||
| } | |||||
| message OpBuffer { | |||||
| BufferType buffer_type = 1; | |||||
| uint64 address = 2; | |||||
| uint64 size = 3; | |||||
| } | |||||
| message Op { | |||||
| string op_name = 1; | |||||
| string op_type = 2; | |||||
| } | |||||
| message Task { | |||||
| uint32 task_id = 1; | |||||
| uint32 stream_id = 2; | |||||
| Op op = 3; | |||||
| repeated Output output = 4; | |||||
| bool end_graph = 5; | |||||
| repeated Input input = 6; | |||||
| repeated OpBuffer buffer = 7; | |||||
| } | |||||
| message OpMappingInfo { | |||||
| string dump_path = 1; | |||||
| oneof model_name_param { | |||||
| string model_name = 2; | |||||
| } | |||||
| oneof model_id_param { | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| oneof step_id { | |||||
| uint64 step_id_addr = 4; | |||||
| } | |||||
| oneof iterations_per_loop { | |||||
| uint64 iterations_per_loop_addr = 5; | |||||
| } | |||||
| oneof loop_cond { | |||||
| uint64 loop_cond_addr = 6; | |||||
| } | |||||
| uint32 flag = 7; // 0x01 load, 0x00 unload | |||||
| repeated Task task = 8; | |||||
| string dump_step = 9; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -1,70 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "AttrValueProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "tensor.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing the value for an attr used to configure an Op. | |||||
| // Comment indicates the corresponding attr type. Only the field matching the | |||||
| // attr type may be filled. | |||||
| message AttrValue { | |||||
| // LINT.IfChange | |||||
| message ListValue { | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated DataType type = 6 [packed = true]; // "list(type)" | |||||
| repeated TensorShapeProto shape = 7; // "list(shape)" | |||||
| repeated TensorProto tensor = 8; // "list(tensor)" | |||||
| repeated NameAttrList func = 9; // "list(attr)" | |||||
| } | |||||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||||
| oneof value { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| DataType type = 6; // "type" | |||||
| TensorShapeProto shape = 7; // "shape" | |||||
| TensorProto tensor = 8; // "tensor" | |||||
| ListValue list = 1; // any "list(...)" | |||||
| // "func" represents a function. func.name is a function's name or | |||||
| // a primitive op's name. func.attr.first is the name of an attr | |||||
| // defined for that function. func.attr.second is the value for | |||||
| // that attr in the instantiation. | |||||
| NameAttrList func = 10; | |||||
| // This is a placeholder only used in nodes defined inside a | |||||
| // function. It indicates the attr value will be supplied when | |||||
| // the function is instantiated. For example, let us suppose a | |||||
| // node "N" in function "FN". "N" has an attr "A" with value | |||||
| // placeholder = "foo". When FN is instantiated with attr "foo" | |||||
| // set to "bar", the instantiated node N's attr A will have been | |||||
| // given the value "bar". | |||||
| string placeholder = 9; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NameAttrList { | |||||
| string name = 1; | |||||
| map<string, AttrValue> attr = 2; | |||||
| } | |||||
| @@ -1,108 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "FunctionProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "node_def.proto"; | |||||
| import "op_def.proto"; | |||||
| // A library is a set of named functions. | |||||
| message FunctionDefLibrary { | |||||
| repeated FunctionDef function = 1; | |||||
| repeated GradientDef gradient = 2; | |||||
| } | |||||
| // A function can be instantiated when the runtime can bind every attr | |||||
| // with a value. When a GraphDef has a call to a function, it must | |||||
| // have binding for every attr defined in the signature. | |||||
| // * device spec, etc. | |||||
| message FunctionDef { | |||||
| // The definition of the function's name, arguments, return values, | |||||
| // attrs etc. | |||||
| OpDef signature = 1; | |||||
| // Attributes specific to this function definition. | |||||
| map<string, AttrValue> attr = 5; | |||||
| // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. | |||||
| reserved 2; | |||||
| // In both of the following fields, there is the need to specify an | |||||
| // output that is used as either the input to another node (in | |||||
| // `node_def`) or as a return value of the function (in `ret`). | |||||
| // Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||||
| // list in some cases (instead of just single outputs). Also, we | |||||
| // need to be able to deal with lists of unknown length (so the | |||||
| // output index may not be known at function definition time). So | |||||
| // we use the following format instead: | |||||
| // * "fun_in" where "fun_in" is the name of a function input arg in | |||||
| // the `signature` field above. This represents that input, whether | |||||
| // it is a single tensor or a list. | |||||
| // * "fun_in:0" gives the first element of a function input arg (a | |||||
| // non-list input is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // * "node:out" where "node" is the name of a node in `node_def` and | |||||
| // "out" is the name one of its op's output arguments (the name | |||||
| // comes from the OpDef of the node's op). This represents that | |||||
| // node's output, whether it is a single tensor or a list. | |||||
| // Note: We enforce that an op's output arguments are never | |||||
| // renamed in the backwards-compatibility test. | |||||
| // * "node:out:0" gives the first element of a node output arg (a | |||||
| // non-list output is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // | |||||
| // NOT CURRENTLY SUPPORTED (but may be in the future): | |||||
| // * "node:out:-1" gives last element in a node output list | |||||
| // * "node:out:1:" gives a list with all but the first element in a | |||||
| // node output list | |||||
| // * "node:out::-1" gives a list with all but the last element in a | |||||
| // node output list | |||||
| // The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||||
| // may have values of type `placeholder` and the `input` field uses | |||||
| // the "output" format above. | |||||
| // By convention, "op" in node_def is resolved by consulting with a | |||||
| // user-defined library first. If not resolved, "func" is assumed to | |||||
| // be a builtin op. | |||||
| repeated NodeDef node_def = 3; | |||||
| // A mapping from the output arg names from `signature` to the | |||||
| // outputs from `node_def` that should be returned by the function. | |||||
| map<string, string> ret = 4; | |||||
| } | |||||
| // GradientDef defines the gradient function of a function defined in | |||||
| // a function library. | |||||
| // | |||||
| // A gradient function g (specified by gradient_func) for a function f | |||||
| // (specified by function_name) must follow the following: | |||||
| // | |||||
| // The function 'f' must be a numerical function which takes N inputs | |||||
| // and produces M outputs. Its gradient function 'g', which is a | |||||
| // function taking N + M inputs and produces N outputs. | |||||
| // | |||||
| // I.e. if we have | |||||
| // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||||
| // then, g is | |||||
| // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||||
| // dL/dy1, dL/dy2, ..., dL/dy_M), | |||||
| // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||||
| // loss function). dL/dx_i is the partial derivative of L with respect | |||||
| // to x_i. | |||||
| message GradientDef { | |||||
| string function_name = 1; // The function name. | |||||
| string gradient_func = 2; // The gradient function's name. | |||||
| } | |||||
| @@ -1,64 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "GraphProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "node_def.proto"; | |||||
| import "function.proto"; | |||||
| import "versions.proto"; | |||||
| // Represents the graph of operations | |||||
| message GraphDef { | |||||
| repeated NodeDef node = 1; | |||||
| // Compatibility versions of the graph. See core/public/version.h for version | |||||
| // history. The GraphDef version is distinct from the TensorFlow version, and | |||||
| // each release of TensorFlow will support a range of GraphDef versions. | |||||
| VersionDef versions = 4; | |||||
| // Deprecated single version field; use versions above instead. Since all | |||||
| // GraphDef changes before "versions" was introduced were forward | |||||
| // compatible, this field is entirely ignored. | |||||
| int32 version = 3 [deprecated = true]; | |||||
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||||
| // | |||||
| // "library" provides user-defined functions. | |||||
| // | |||||
| // Naming: | |||||
| // * library.function.name are in a flat namespace. | |||||
| // NOTE: We may need to change it to be hierarchical to support | |||||
| // different orgs. E.g., | |||||
| // { "/google/nn", { ... }}, | |||||
| // { "/google/vision", { ... }} | |||||
| // { "/org_foo/module_bar", { ... }} | |||||
| // map<string, FunctionDefLib> named_lib; | |||||
| // * If node[i].op is the name of one function in "library", | |||||
| // node[i] is deemed as a function call. Otherwise, node[i].op | |||||
| // must be a primitive operation supported by the runtime. | |||||
| // | |||||
| // | |||||
| // Function call semantics: | |||||
| // | |||||
| // * The callee may start execution as soon as some of its inputs | |||||
| // are ready. The caller may want to use Tuple() mechanism to | |||||
| // ensure all inputs are ready in the same time. | |||||
| // | |||||
| // * The consumer of return values may start executing as soon as | |||||
| // the return values the consumer depends on are ready. The | |||||
| // consumer may want to use Tuple() mechanism to ensure the | |||||
| // consumer does not start until all return values of the callee | |||||
| // function are ready. | |||||
| FunctionDefLibrary library = 2; | |||||
| }; | |||||
| @@ -1,22 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| import "graph.proto"; | |||||
| message GeGraphDef { | |||||
| string name = 1; | |||||
| GraphDef graph = 2; | |||||
| } | |||||
| message GraphDefLibrary { | |||||
| repeated GeGraphDef graph_def = 1; | |||||
| }; | |||||
| @@ -1,71 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "NodeProto"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| message NodeDef { | |||||
| // The name given to this operator. Used for naming inputs, | |||||
| // logging, visualization, etc. Unique within a single GraphDef. | |||||
| // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||||
| string name = 1; | |||||
| // The operation name. There may be custom parameters in attrs. | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| string op = 2; | |||||
| // Each input is "node:src_output" with "node" being a string name and | |||||
| // "src_output" indicating which output tensor to use from "node". If | |||||
| // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||||
| // may optionally be followed by control inputs that have the format | |||||
| // "^node". | |||||
| repeated string input = 3; | |||||
| // A (possibly partial) specification for the device on which this | |||||
| // node should be placed. | |||||
| // The expected syntax for this string is as follows: | |||||
| // | |||||
| // DEVICE_SPEC ::= PARTIAL_SPEC | |||||
| // | |||||
| // PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||||
| // CONSTRAINT ::= ("job:" JOB_NAME) | |||||
| // | ("replica:" [1-9][0-9]*) | |||||
| // | ("task:" [1-9][0-9]*) | |||||
| // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) | |||||
| // | |||||
| // Valid values for this string include: | |||||
| // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) | |||||
| // * "/job:worker/device:GPU:3" (partial specification) | |||||
| // * "" (no specification) | |||||
| // | |||||
| // If the constraints do not resolve to a single device (or if this | |||||
| // field is empty or not present), the runtime will attempt to | |||||
| // choose a device automatically. | |||||
| string device = 4; | |||||
| // Operation-specific graph-construction-time configuration. | |||||
| // Note that this should include all attrs defined in the | |||||
| // corresponding OpDef, including those with a value matching | |||||
| // the default -- this allows the default to change and makes | |||||
| // NodeDefs easier to interpret on their own. However, if | |||||
| // an attr with a default is not specified in this list, the | |||||
| // default will be used. | |||||
| // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||||
| // one of the names from the corresponding OpDef's attr field). | |||||
| // The values must have a type matching the corresponding OpDef | |||||
| // attr's type field. | |||||
| // Add some examples here showing best practices. | |||||
| map<string, AttrValue> attr = 5; | |||||
| }; | |||||
| @@ -1,172 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "OpDefProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "types.proto"; | |||||
| // Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||||
| // using the "op" field which should match the name of a OpDef. | |||||
| // LINT.IfChange | |||||
| message OpDef { | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||||
| string name = 1; | |||||
| // For describing inputs and outputs. | |||||
| message ArgDef { | |||||
| // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||||
| string name = 1; | |||||
| // Human readable description. | |||||
| string description = 2; | |||||
| // Describes the type of one or more tensors that are accepted/produced | |||||
| // by this input/output arg. The only legal combinations are: | |||||
| // * For a single tensor: either the "type" field is set or the | |||||
| // "type_attr" field is set to the name of an attr with type "type". | |||||
| // * For a sequence of tensors with the same type: the "number_attr" | |||||
| // field will be set to the name of an attr with type "int", and | |||||
| // either the "type" or "type_attr" field will be set as for | |||||
| // single tensors. | |||||
| // * For a sequence of tensors, the "type_list_attr" field will be set | |||||
| // to the name of an attr with type "list(type)". | |||||
| DataType type = 3; | |||||
| string type_attr = 4; // if specified, attr must have type "type" | |||||
| string number_attr = 5; // if specified, attr must have type "int" | |||||
| // If specified, attr must have type "list(type)", and none of | |||||
| // type, type_attr, and number_attr may be specified. | |||||
| string type_list_attr = 6; | |||||
| // For inputs: if true, the inputs are required to be refs. | |||||
| // By default, inputs can be either refs or non-refs. | |||||
| // For outputs: if true, outputs are refs, otherwise they are not. | |||||
| bool is_ref = 16; | |||||
| }; | |||||
| // Description of the input(s). | |||||
| repeated ArgDef input_arg = 2; | |||||
| // Description of the output(s). | |||||
| repeated ArgDef output_arg = 3; | |||||
| // Description of the graph-construction-time configuration of this | |||||
| // Op. That is to say, this describes the attr fields that will | |||||
| // be specified in the NodeDef. | |||||
| message AttrDef { | |||||
| // A descriptive name for the argument. May be used, e.g. by the | |||||
| // Python client, as a keyword argument name, and so should match | |||||
| // the regexp "[a-z][a-z0-9_]+". | |||||
| string name = 1; | |||||
| // One of the type names from attr_value.proto ("string", "list(string)", | |||||
| // "int", etc.). | |||||
| string type = 2; | |||||
| // A reasonable default for this attribute if the user does not supply | |||||
| // a value. If not specified, the user must supply a value. | |||||
| AttrValue default_value = 3; | |||||
| // Human-readable description. | |||||
| string description = 4; | |||||
| // --- Constraints --- | |||||
| // These constraints are only in effect if specified. Default is no | |||||
| // constraints. | |||||
| // For type == "int", this is a minimum value. For "list(___)" | |||||
| // types, this is the minimum length. | |||||
| bool has_minimum = 5; | |||||
| int64 minimum = 6; | |||||
| // The set of allowed values. Has type that is the "list" version | |||||
| // of the "type" field above (uses the "list" field of AttrValue). | |||||
| // If type == "type" or "list(type)" above, then the "type" field | |||||
| // of "allowed_values.list" has the set of allowed DataTypes. | |||||
| // If type == "string" or "list(string)", then the "s" field of | |||||
| // "allowed_values.list" has the set of allowed strings. | |||||
| AttrValue allowed_values = 7; | |||||
| } | |||||
| repeated AttrDef attr = 4; | |||||
| // Optional deprecation based on GraphDef versions. | |||||
| OpDeprecation deprecation = 8; | |||||
| // One-line human-readable description of what the Op does. | |||||
| string summary = 5; | |||||
| // Additional, longer human-readable description of what the Op does. | |||||
| string description = 6; | |||||
| // ------------------------------------------------------------------------- | |||||
| // Which optimizations this operation can participate in. | |||||
| // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||||
| bool is_commutative = 18; | |||||
| // If is_aggregate is true, then this operation accepts N >= 2 | |||||
| // inputs and produces 1 output all of the same type. Should be | |||||
| // associative and commutative, and produce output with the same | |||||
| // shape as the input. The optimizer may replace an aggregate op | |||||
| // taking input from multiple devices with a tree of aggregate ops | |||||
| // that aggregate locally within each device (and possibly within | |||||
| // groups of nearby devices) before communicating. | |||||
| bool is_aggregate = 16; // for things like add | |||||
| // Other optimizations go here, like | |||||
| // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||||
| // ------------------------------------------------------------------------- | |||||
| // Optimization constraints. | |||||
| // Ops are marked as stateful if their behavior depends on some state beyond | |||||
| // their input tensors (e.g. variable reading op) or if they have | |||||
| // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops | |||||
| // must always produce the same output for the same input and have | |||||
| // no side-effects. | |||||
| // | |||||
| // By default Ops may be moved between devices. Stateful ops should | |||||
| // either not be moved, or should only be moved if that state can also | |||||
| // be moved (e.g. via some sort of save / restore). | |||||
| // Stateful ops are guaranteed to never be optimized away by Common | |||||
| // Subexpression Elimination (CSE). | |||||
| bool is_stateful = 17; // for things like variables, queue | |||||
| // ------------------------------------------------------------------------- | |||||
| // Non-standard options. | |||||
| // By default, all inputs to an Op must be initialized Tensors. Ops | |||||
| // that may initialize tensors for the first time should set this | |||||
| // field to true, to allow the Op to take an uninitialized Tensor as | |||||
| // input. | |||||
| bool allows_uninitialized_input = 19; // for Assign, etc. | |||||
| }; | |||||
| // LINT.ThenChange( | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) | |||||
| // Information about version-dependent deprecation of an op | |||||
| message OpDeprecation { | |||||
| // First GraphDef version at which the op is disallowed. | |||||
| int32 version = 1; | |||||
| // Explanation of why it was deprecated and what to use instead. | |||||
| string explanation = 2; | |||||
| }; | |||||
| // A collection of OpDefs | |||||
| message OpList { | |||||
| repeated OpDef op = 1; | |||||
| }; | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "ResourceHandle"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Protocol buffer representing a handle to a tensorflow resource. Handles are | |||||
| // not valid across executions, but can be serialized back and forth from within | |||||
| // a single run. | |||||
| message ResourceHandleProto { | |||||
| // Unique name for the device containing the resource. | |||||
| string device = 1; | |||||
| // Container in which this resource is placed. | |||||
| string container = 2; | |||||
| // Unique name of this resource. | |||||
| string name = 3; | |||||
| // Hash code for the type of the resource. Is only valid in the same device | |||||
| // and in the same execution. | |||||
| uint64 hash_code = 4; | |||||
| // For debug-only, the name of the type pointed to by this handle, if | |||||
| // available. | |||||
| string maybe_type_name = 5; | |||||
| }; | |||||
| @@ -1,102 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "resource_handle.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing a tensor. | |||||
| message TensorProto { | |||||
| DataType dtype = 1; | |||||
| // Shape of the tensor. | |||||
| TensorShapeProto tensor_shape = 2; | |||||
| // Only one of the representations below is set, one of "tensor_contents" and | |||||
| // the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||||
| // contain repeated fields it would require another extra set of messages. | |||||
| // Version number. | |||||
| // | |||||
| // In version 0, if the "repeated xxx" representations contain only one | |||||
| // element, that element is repeated to fill the shape. This makes it easy | |||||
| // to represent a constant Tensor with a single value. | |||||
| int32 version_number = 3; | |||||
| // Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||||
| // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||||
| // can be used for all tensor types. The purpose of this representation is to | |||||
| // reduce serialization overhead during RPC call by avoiding serialization of | |||||
| // many repeated small items. | |||||
| bytes tensor_content = 4; | |||||
| // Type specific representations that make it easy to create tensor protos in | |||||
| // all languages. Only the representation corresponding to "dtype" can | |||||
| // be set. The values hold the flattened representation of the tensor in | |||||
| // row major order. | |||||
| // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll | |||||
| // have some pointless zero padding for each value here. | |||||
| repeated int32 half_val = 13 [packed = true]; | |||||
| // DT_FLOAT. | |||||
| repeated float float_val = 5 [packed = true]; | |||||
| // DT_DOUBLE. | |||||
| repeated double double_val = 6 [packed = true]; | |||||
| // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||||
| repeated int32 int_val = 7 [packed = true]; | |||||
| // DT_STRING | |||||
| repeated bytes string_val = 8; | |||||
| // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th single precision complex. | |||||
| repeated float scomplex_val = 9 [packed = true]; | |||||
| // DT_INT64 | |||||
| repeated int64 int64_val = 10 [packed = true]; | |||||
| // DT_BOOL | |||||
| repeated bool bool_val = 11 [packed = true]; | |||||
| // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th double precision complex. | |||||
| repeated double dcomplex_val = 12 [packed = true]; | |||||
| // DT_RESOURCE | |||||
| repeated ResourceHandleProto resource_handle_val = 14; | |||||
| // DT_VARIANT | |||||
| repeated VariantTensorDataProto variant_val = 15; | |||||
| // DT_UINT32 | |||||
| repeated uint32 uint32_val = 16 [packed = true]; | |||||
| // DT_UINT64 | |||||
| repeated uint64 uint64_val = 17 [packed = true]; | |||||
| }; | |||||
| // Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||||
| message VariantTensorDataProto { | |||||
| // Name of the type of objects being serialized. | |||||
| string type_name = 1; | |||||
| // Portions of the object that are not Tensors. | |||||
| bytes metadata = 2; | |||||
| // Tensors contained within objects being serialized. | |||||
| repeated TensorProto tensors = 3; | |||||
| } | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| // Protocol buffer representing the shape of tensors. | |||||
| syntax = "proto3"; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorShapeProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| package domi.tensorflow; | |||||
| // Dimensions of a tensor. | |||||
| message TensorShapeProto { | |||||
| // One dimension of the tensor. | |||||
| message Dim { | |||||
| // Size of the tensor in that dimension. | |||||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||||
| // shapes (values of -1 mean "unknown" dimension). Certain wrappers | |||||
| // that work with TensorShapeProto may fail at runtime when deserializing | |||||
| // a TensorShapeProto containing a dim value of -1. | |||||
| int64 size = 1; | |||||
| // Optional name of the tensor dimension. | |||||
| string name = 2; | |||||
| }; | |||||
| // Dimensions of the tensor, such as {"input", 30}, {"output", 40} | |||||
| // for a 30 x 40 2D tensor. If an entry has size -1, this | |||||
| // corresponds to a dimension of unknown size. The names are | |||||
| // optional. | |||||
| // | |||||
| // The order of entries in "dim" matters: It indicates the layout of the | |||||
| // values in the tensor in-memory representation. | |||||
| // | |||||
| // The first entry in "dim" is the outermost dimension used to layout the | |||||
| // values, the last entry is the innermost dimension. This matches the | |||||
| // in-memory layout of RowMajor Eigen tensors. | |||||
| // | |||||
| // If "dim.size()" > 0, "unknown_rank" must be false. | |||||
| repeated Dim dim = 2; | |||||
| // If true, the number of dimensions in the shape is unknown. | |||||
| // | |||||
| // If true, "dim.size()" must be 0. | |||||
| bool unknown_rank = 3; | |||||
| }; | |||||
| @@ -1,82 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TypesProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // LINT.IfChange | |||||
| enum DataType { | |||||
| // Not a legal value for DataType. Used to indicate a DataType field | |||||
| // has not been set. | |||||
| DT_INVALID = 0; | |||||
| // Data types that all computation devices are expected to be | |||||
| // capable to support. | |||||
| DT_FLOAT = 1; | |||||
| DT_DOUBLE = 2; | |||||
| DT_INT32 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_INT8 = 6; | |||||
| DT_STRING = 7; | |||||
| DT_COMPLEX64 = 8; // Single-precision complex | |||||
| DT_INT64 = 9; | |||||
| DT_BOOL = 10; | |||||
| DT_QINT8 = 11; // Quantized int8 | |||||
| DT_QUINT8 = 12; // Quantized uint8 | |||||
| DT_QINT32 = 13; // Quantized int32 | |||||
| DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. | |||||
| DT_QINT16 = 15; // Quantized int16 | |||||
| DT_QUINT16 = 16; // Quantized uint16 | |||||
| DT_UINT16 = 17; | |||||
| DT_COMPLEX128 = 18; // Double-precision complex | |||||
| DT_HALF = 19; | |||||
| DT_RESOURCE = 20; | |||||
| DT_VARIANT = 21; // Arbitrary C++ data types | |||||
| DT_UINT32 = 22; | |||||
| DT_UINT64 = 23; | |||||
| // Do not use! These are only for parameters. Every enum above | |||||
| // should have a corresponding value below (verified by types_test). | |||||
| DT_FLOAT_REF = 101; | |||||
| DT_DOUBLE_REF = 102; | |||||
| DT_INT32_REF = 103; | |||||
| DT_UINT8_REF = 104; | |||||
| DT_INT16_REF = 105; | |||||
| DT_INT8_REF = 106; | |||||
| DT_STRING_REF = 107; | |||||
| DT_COMPLEX64_REF = 108; | |||||
| DT_INT64_REF = 109; | |||||
| DT_BOOL_REF = 110; | |||||
| DT_QINT8_REF = 111; | |||||
| DT_QUINT8_REF = 112; | |||||
| DT_QINT32_REF = 113; | |||||
| DT_BFLOAT16_REF = 114; | |||||
| DT_QINT16_REF = 115; | |||||
| DT_QUINT16_REF = 116; | |||||
| DT_UINT16_REF = 117; | |||||
| DT_COMPLEX128_REF = 118; | |||||
| DT_HALF_REF = 119; | |||||
| DT_RESOURCE_REF = 120; | |||||
| DT_VARIANT_REF = 121; | |||||
| DT_UINT32_REF = 122; | |||||
| DT_UINT64_REF = 123; | |||||
| } | |||||
| // LINT.ThenChange( | |||||
| // https://www.tensorflow.org/code/tensorflow/c/c_api.h, | |||||
| // https://www.tensorflow.org/code/tensorflow/go/tensor.go, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/types.h, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, | |||||
| // https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, | |||||
| // https://www.tensorflow.org/code/tensorflow/python/framework/function.py) | |||||
| @@ -1,39 +0,0 @@ | |||||
| /** | |||||
| * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow | |||||
| * | |||||
| * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. | |||||
| * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). | |||||
| * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "VersionsProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Version information for a piece of serialized data | |||||
| // | |||||
| // There are different types of versions for each type of data | |||||
| // (GraphDef, etc.), but they all have the same common shape | |||||
| // described here. | |||||
| // | |||||
| // Each consumer has "consumer" and "min_producer" versions (specified | |||||
| // elsewhere). A consumer is allowed to consume this data if | |||||
| // | |||||
| // producer >= min_producer | |||||
| // consumer >= min_consumer | |||||
| // consumer not in bad_consumers | |||||
| // | |||||
| message VersionDef { | |||||
| // The version of the code that produced this data. | |||||
| int32 producer = 1; | |||||
| // Any consumer below this version is not allowed to consume this data. | |||||
| int32 min_consumer = 2; | |||||
| // Specific consumer versions which are disallowed (e.g. due to bugs). | |||||
| repeated int32 bad_consumers = 3; | |||||
| }; | |||||
| @@ -340,15 +340,24 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||||
| return res; | return res; | ||||
| } | } | ||||
| void PathValidErrReport(const std::string &file_path, const std::string &atc_param, const std::string &reason) { | |||||
| if (!atc_param.empty()) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({atc_param, file_path, reason})); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] invalid, reason:%s", file_path.c_str(), reason.c_str()); | |||||
| } | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, | ||||
| const std::string &atc_param) { | const std::string &atc_param) { | ||||
| // The specified path is empty | // The specified path is empty | ||||
| std::map<std::string, std::string> args_map; | std::map<std::string, std::string> args_map; | ||||
| if (file_path.empty()) { | if (file_path.empty()) { | ||||
| if (atc_param != "") { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | |||||
| if (!atc_param.empty()) { | |||||
| REPORT_INPUT_ERROR("E10004", std::vector<std::string>({"parameter"}), std::vector<std::string>({atc_param})); | |||||
| } else { | } else { | ||||
| REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid"); | |||||
| REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid."); | |||||
| } | } | ||||
| GELOGW("Input parameter %s is empty.", file_path.c_str()); | GELOGW("Input parameter %s is empty.", file_path.c_str()); | ||||
| return false; | return false; | ||||
| @@ -356,13 +365,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
| std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
| // Unable to get absolute path (does not exist or does not have permission to access) | // Unable to get absolute path (does not exist or does not have permission to access) | ||||
| if (real_path.empty()) { | if (real_path.empty()) { | ||||
| if (atc_param != "") { | |||||
| std::string reason = "realpath error, errmsg:" + std::string(strerror(errno)); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, reason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||||
| } | |||||
| std::string reason = "realpath error, errmsg:" + std::string(strerror(errno)); | |||||
| PathValidErrReport(file_path, atc_param, reason); | |||||
| GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); | GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -378,23 +382,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| !ValidateStr(real_path, mode), | !ValidateStr(real_path, mode), | ||||
| if (atc_param != "") { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, real_path, kPathValidReason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason); | |||||
| } | |||||
| PathValidErrReport(file_path, atc_param, kPathValidReason); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | ||||
| // The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
| if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { | if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { | ||||
| if (atc_param != "") { | |||||
| std::string reason = "cat not access, errmsg:" + std::string(strerror(errno)); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, reason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno)); | |||||
| } | |||||
| PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno))); | |||||
| GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); | GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -406,10 +399,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| const std::string &atc_param) { | const std::string &atc_param) { | ||||
| // The specified path is empty | // The specified path is empty | ||||
| if (file_path.empty()) { | if (file_path.empty()) { | ||||
| if (atc_param != "") { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | |||||
| if (!atc_param.empty()) { | |||||
| REPORT_INPUT_ERROR("E10004", std::vector<std::string>({"parameter"}), std::vector<std::string>({atc_param})); | |||||
| } else { | } else { | ||||
| REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid"); | |||||
| REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid."); | |||||
| } | } | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | ||||
| GELOGW("Input parameter's value is empty."); | GELOGW("Input parameter's value is empty."); | ||||
| @@ -417,17 +410,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| } | } | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, | ||||
| if (atc_param != "") { | |||||
| std::string reason = "len is too long, it must be less than " + | |||||
| std::to_string(MMPA_MAX_PATH); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, reason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] len is too long, it must be less than %d", | |||||
| file_path.c_str(), MMPA_MAX_PATH); | |||||
| } | |||||
| return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), | |||||
| std::string reason = "len is too long, it must be less than " + | |||||
| std::to_string(MMPA_MAX_PATH); | |||||
| PathValidErrReport(file_path, atc_param, reason); | |||||
| return false, "Path[%s] len is too long, it must be less than %d", file_path.c_str(), | |||||
| MMPA_MAX_PATH); | MMPA_MAX_PATH); | ||||
| // A regular matching expression to verify the validity of the input file path | // A regular matching expression to verify the validity of the input file path | ||||
| @@ -441,12 +427,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| !ValidateStr(file_path, mode), | !ValidateStr(file_path, mode), | ||||
| if (atc_param != "") { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, kPathValidReason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason); | |||||
| } | |||||
| PathValidErrReport(file_path, atc_param, kPathValidReason); | |||||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); | return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); | ||||
| std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
| @@ -454,13 +435,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| if (!real_path.empty()) { | if (!real_path.empty()) { | ||||
| // File is not readable or writable | // File is not readable or writable | ||||
| if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) { | if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) { | ||||
| if (atc_param != "") { | |||||
| std::string reason = "cat not access, errmsg:" + std::string(strerror(errno)); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, reason}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno)); | |||||
| } | |||||
| PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno))); | |||||
| GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); | GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -479,12 +454,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
| std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | ||||
| // Determine whether the specified path is valid by creating the path | // Determine whether the specified path is valid by creating the path | ||||
| if (CreateDirectory(prefix_path) != 0) { | if (CreateDirectory(prefix_path) != 0) { | ||||
| if (atc_param != "") { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {atc_param, file_path, "Can not create directory"}); | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Path[%s] Can not create directory", file_path.c_str()); | |||||
| } | |||||
| PathValidErrReport(file_path, atc_param, "Can not create directory"); | |||||
| GELOGW("Can not create directory[%s].", file_path.c_str()); | GELOGW("Can not create directory[%s].", file_path.c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -37,6 +37,7 @@ set(SRC_LIST | |||||
| "../graph/load/model_manager/task_info/task_info.cc" | "../graph/load/model_manager/task_info/task_info.cc" | ||||
| "../graph/load/model_manager/task_info/event_record_task_info.cc" | "../graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
| "../graph/load/model_manager/task_info/event_wait_task_info.cc" | "../graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
| "../graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
| "../graph/load/model_manager/task_info/fusion_start_task_info.cc" | "../graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
| "../graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "../graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
| "../graph/load/model_manager/task_info/kernel_ex_task_info.cc" | "../graph/load/model_manager/task_info/kernel_ex_task_info.cc" | ||||
| @@ -1,113 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package toolkit.dump; | |||||
| enum OutputDataType { | |||||
| DT_UNDEFINED = 0; | |||||
| DT_FLOAT = 1; | |||||
| DT_FLOAT16 = 2; | |||||
| DT_INT8 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_UINT16 = 6; | |||||
| DT_INT32 = 7; | |||||
| DT_INT64 = 8; | |||||
| DT_UINT32 = 9; | |||||
| DT_UINT64 = 10; | |||||
| DT_BOOL = 11; | |||||
| DT_DOUBLE = 12; | |||||
| DT_STRING = 13; | |||||
| DT_DUAL_SUB_INT8 = 14; | |||||
| DT_DUAL_SUB_UINT8 = 15; | |||||
| DT_COMPLEX64 = 16; | |||||
| DT_COMPLEX128 = 17; | |||||
| DT_QINT8 = 18; | |||||
| DT_QINT16 = 19; | |||||
| DT_QINT32 = 20; | |||||
| DT_QUINT8 = 21; | |||||
| DT_QUINT16 = 22; | |||||
| DT_RESOURCE = 23; | |||||
| DT_STRING_REF = 24; | |||||
| DT_DUAL = 25; | |||||
| DT_VARIANT = 26; | |||||
| } | |||||
| enum OutputFormat { | |||||
| FORMAT_NCHW = 0; | |||||
| FORMAT_NHWC = 1; | |||||
| FORMAT_ND = 2; | |||||
| FORMAT_NC1HWC0 = 3; | |||||
| FORMAT_FRACTAL_Z = 4; | |||||
| FORMAT_NC1C0HWPAD = 5; | |||||
| FORMAT_NHWC1C0 = 6; | |||||
| FORMAT_FSR_NCHW = 7; | |||||
| FORMAT_FRACTAL_DECONV = 8; | |||||
| FORMAT_C1HWNC0 = 9; | |||||
| FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; | |||||
| FORMAT_NC1HWC0_C04 = 12; | |||||
| FORMAT_FRACTAL_Z_C04 = 13; | |||||
| FORMAT_CHWN = 14; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; | |||||
| FORMAT_HWCN = 16; | |||||
| FORMAT_NC1KHKWHWC0 = 17; | |||||
| FORMAT_BN_WEIGHT = 18; | |||||
| FORMAT_FILTER_HWCK = 19; | |||||
| FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; | |||||
| FORMAT_HASHTABLE_LOOKUP_KEYS = 21; | |||||
| FORMAT_HASHTABLE_LOOKUP_VALUE = 22; | |||||
| FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; | |||||
| FORMAT_HASHTABLE_LOOKUP_HITS=24; | |||||
| FORMAT_C1HWNCoC0 = 25; | |||||
| FORMAT_MD = 26; | |||||
| FORMAT_NDHWC = 27; | |||||
| FORMAT_FRACTAL_ZZ = 28; | |||||
| FORMAT_FRACTAL_NZ = 29; | |||||
| FORMAT_RESERVED = 30; | |||||
| } | |||||
| message OriginalOp { | |||||
| string name = 1; | |||||
| uint32 output_index = 2; | |||||
| OutputDataType data_type = 3; | |||||
| OutputFormat format = 4; | |||||
| } | |||||
| message Shape { | |||||
| repeated uint64 dim = 1; | |||||
| } | |||||
| message OpOutput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| OriginalOp original_op = 4; // the original op corresponding to the output | |||||
| bytes data = 5; | |||||
| uint64 size = 6; | |||||
| } | |||||
| message OpInput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| bytes data = 4; | |||||
| uint64 size = 5; | |||||
| } | |||||
| enum BufferType { | |||||
| L1 = 0; | |||||
| } | |||||
| message OpBuffer { | |||||
| BufferType buffer_type = 1; | |||||
| bytes data = 2; | |||||
| uint64 size = 3; | |||||
| } | |||||
| message DumpData{ | |||||
| string version = 1; | |||||
| uint64 dump_time = 2; | |||||
| repeated OpOutput output = 3; | |||||
| repeated OpInput input = 4; | |||||
| repeated OpBuffer buffer = 5; | |||||
| string op_name = 6; | |||||
| } | |||||
| @@ -1,193 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| DT_VARIANT = 26; // variant type | |||||
| DT_BF16 = 27; // bf16 type | |||||
| DT_INT4 = 28; // int4 type | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -1,140 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| float padding_value = 72; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -1,396 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -1,75 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package toolkit.aicpu.dump; | |||||
| message Shape { | |||||
| repeated uint64 dim = 1; | |||||
| } | |||||
| message Output { | |||||
| int32 data_type = 1; | |||||
| int32 format = 2; | |||||
| Shape shape = 3; | |||||
| uint64 address = 4; | |||||
| string original_name = 5; | |||||
| int32 original_output_index = 6; | |||||
| int32 original_output_data_type = 7; | |||||
| int32 original_output_format = 8; | |||||
| uint64 size = 9; | |||||
| Shape origin_shape = 10; | |||||
| } | |||||
| message Input { | |||||
| int32 data_type =1; | |||||
| int32 format = 2; | |||||
| Shape shape = 3; | |||||
| uint64 address = 4; | |||||
| uint64 size = 5; | |||||
| Shape origin_shape = 6; | |||||
| } | |||||
| enum BufferType { | |||||
| L1 = 0; | |||||
| } | |||||
| message OpBuffer { | |||||
| BufferType buffer_type = 1; | |||||
| uint64 address = 2; | |||||
| uint64 size = 3; | |||||
| } | |||||
| message Op { | |||||
| string op_name = 1; | |||||
| string op_type = 2; | |||||
| } | |||||
| message Task { | |||||
| uint32 task_id = 1; | |||||
| uint32 stream_id = 2; | |||||
| Op op = 3; | |||||
| repeated Output output = 4; | |||||
| bool end_graph = 5; | |||||
| repeated Input input = 6; | |||||
| repeated OpBuffer buffer = 7; | |||||
| } | |||||
| message OpMappingInfo { | |||||
| string dump_path = 1; | |||||
| oneof model_name_param { | |||||
| string model_name = 2; | |||||
| } | |||||
| oneof model_id_param { | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| oneof step_id { | |||||
| uint64 step_id_addr = 4; | |||||
| } | |||||
| oneof iterations_per_loop { | |||||
| uint64 iterations_per_loop_addr = 5; | |||||
| } | |||||
| oneof loop_cond { | |||||
| uint64 loop_cond_addr = 6; | |||||
| } | |||||
| uint32 flag = 7; // 0x01 load, 0x00 unload | |||||
| repeated Task task = 8; | |||||
| string dump_step = 9; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_opt_info/ge_opt_info.h" | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include "graph/ge_local_context.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "common/debug/ge_log.h" | |||||
| #include "opt_info.h" | |||||
| namespace ge { | |||||
| Status GeOptInfo::SetOptInfo() { | |||||
| std::string soc_ver; | |||||
| graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Get soc version failed."); | |||||
| GELOGE(FAILED, "[Get][SocVersion]Get soc version failed."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Soc version:%s.", soc_ver.c_str()); | |||||
| std::map<std::string, std::string> opt_info; | |||||
| // the first arg does not work at present. | |||||
| if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s", | |||||
| gelc::kOffline, soc_ver.c_str()); | |||||
| GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s", | |||||
| gelc::kOffline, soc_ver.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // do nothing if get empty information | |||||
| if (opt_info.empty()) { | |||||
| GELOGI("Optional information is empty."); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
| for (const auto &itr : opt_info) { | |||||
| graph_options.emplace(itr.first, itr.second); | |||||
| GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str()); | |||||
| } | |||||
| GetThreadLocalContext().SetGraphOption(graph_options); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_OPT_INFO_GE_OPT_INFO_H_ | |||||
| #define GE_OPT_INFO_GE_OPT_INFO_H_ | |||||
| #include "ge/ge_api_error_codes.h" | |||||
| #include "register/register_types.h" | |||||
| namespace ge { | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo { | |||||
| public: | |||||
| GeOptInfo() = default; | |||||
| static Status SetOptInfo(); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_OPT_INFO_GE_OPT_INFO_H_ | |||||
| @@ -674,6 +674,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| GELOGD("Current ctx is null."); | GELOGD("Current ctx is null."); | ||||
| ctx = nullptr; | ctx = nullptr; | ||||
| } | } | ||||
| std::function<void()> callback = [&]() { | |||||
| if (ctx != nullptr) { | |||||
| (void)rtCtxSetCurrent(ctx); | |||||
| } | |||||
| }; | |||||
| GE_MAKE_GUARD(restore, callback); | |||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
| @@ -712,11 +718,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (ctx != nullptr) { | |||||
| (void)rtCtxSetCurrent(ctx); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
| GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str()); | |||||
| return true; | |||||
| } | |||||
| ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | ||||
| if (owner_graph == nullptr) { | if (owner_graph == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", | REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", | ||||
| @@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr | |||||
| for (ge::NodePtr &node : graph->GetDirectNode()) { | for (ge::NodePtr &node : graph->GetDirectNode()) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) { | |||||
| op_desc->SetStreamId(kInvalidStream); | |||||
| GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str()); | |||||
| continue; | |||||
| } | |||||
| int64_t stream_id = op_desc->GetStreamId(); | int64_t stream_id = op_desc->GetStreamId(); | ||||
| if (ops_without_label.find(op_desc) != ops_without_label.end()) { | if (ops_without_label.find(op_desc) != ops_without_label.end()) { | ||||
| if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { | if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { | ||||
| @@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { | |||||
| // Insert the send/recv event id to the graph | // Insert the send/recv event id to the graph | ||||
| Status StreamAllocator::InsertSyncEvents() { | Status StreamAllocator::InsertSyncEvents() { | ||||
| for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||||
| auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||||
| return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||||
| }; | |||||
| for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||||
| // Take the adjacent points, then judge whether need to insert the event | // Take the adjacent points, then judge whether need to insert the event | ||||
| for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { | for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { | ||||
| for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | ||||
| @@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const | |||||
| Status StreamAllocator::InsertEventsForSubgraph() { | Status StreamAllocator::InsertEventsForSubgraph() { | ||||
| for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { | for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { | ||||
| GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
| const auto parent_node = subgraph->GetParentNode(); | |||||
| if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
| GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| for (const auto &node : subgraph->GetDirectNode()) { | for (const auto &node : subgraph->GetDirectNode()) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
| }; | }; | ||||
| GE_MAKE_GUARD(release, callback); | GE_MAKE_GUARD(release, callback); | ||||
| for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||||
| auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||||
| return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||||
| }; | |||||
| for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| node_index++; | node_index++; | ||||
| @@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
| GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (op_kernel_lib_name.empty()) { | |||||
| GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
| continue; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, | |||||
| "Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
| auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | ||||
| if (kernel_info_store == nullptr) { | if (kernel_info_store == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", | REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", | ||||
| @@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), | GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), | ||||
| type.c_str()); | type.c_str()); | ||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
| GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed", | |||||
| name.c_str(), type.c_str()); | |||||
| } | |||||
| // Profiling task | // Profiling task | ||||
| size_t task_list_size_before = task_def_list.size(); | size_t task_list_size_before = task_def_list.size(); | ||||
| GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | ||||
| @@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) { | |||||
| GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str()); | |||||
| if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||||
| for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||||
| auto sub_graph = NodeUtils::GetSubgraph(*node, i); | |||||
| GE_CHECK_NOTNULL(sub_graph); | |||||
| GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str()); | |||||
| for (auto &ffts_node : sub_graph->GetDirectNode()) { | |||||
| GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", | |||||
| ffts_node->GetName().c_str(), ffts_node->GetType().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { | Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { | ||||
| GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str()); | |||||
| if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { | if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", | REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", | ||||
| node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
| @@ -80,6 +80,7 @@ class TaskGenerator { | |||||
| Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | ||||
| std::vector<uint32_t> &all_reduce_nodes); | std::vector<uint32_t> &all_reduce_nodes); | ||||
| private: | private: | ||||
| Status UpdateAnchorStatusForFfts(const NodePtr &node); | |||||
| Status UpdateAnchorStatus(const NodePtr &node); | Status UpdateAnchorStatus(const NodePtr &node); | ||||
| Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | ||||
| @@ -274,21 +274,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| /// | |||||
| /// @brief Set Op _force_unknown_shape flag | |||||
| /// @param [in] node | |||||
| /// @param [in] force_unknown, set attribute if true | |||||
| /// @param [in] group_index, condition group index of node. | |||||
| /// @return | |||||
| /// | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||||
| if (!force_unknown) { | |||||
| return; | |||||
| } | |||||
| SetControlFlowGroup(node, group_index); | |||||
| } | |||||
| /// | /// | ||||
| /// @brief Set Op _control_flow_group flag | /// @brief Set Op _control_flow_group flag | ||||
| /// @param [in] node | /// @param [in] node | ||||
| @@ -125,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||||
| /// | /// | ||||
| bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | ||||
| /// | |||||
| /// @brief Set Op _force_unknown_shape flag | |||||
| /// @param [in] node | |||||
| /// @param [in] force_unknown, set attribute if true | |||||
| /// @param [in] group_index, condition group index of node. | |||||
| /// @return | |||||
| /// | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||||
| /// | /// | ||||
| /// @brief Set Op _control_flow_group flag | /// @brief Set Op _control_flow_group flag | ||||
| /// @param [in] node | /// @param [in] node | ||||
| @@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005; | |||||
| const int32_t kModelAbortNormal = 0x0704000e; | const int32_t kModelAbortNormal = 0x0704000e; | ||||
| const int32_t kModelAbortNormalNew = 507024; | const int32_t kModelAbortNormalNew = 507024; | ||||
| const uint32_t kInteval = 2; | const uint32_t kInteval = 2; | ||||
| const uint32_t kFftsTbeHandleElementSize = 2; | |||||
| const uint32_t kNonTailBlock = 0; | |||||
| const uint32_t kTailBlock = 1; | |||||
| const char *const kModelName = "model_name"; | const char *const kModelName = "model_name"; | ||||
| const char *const kModeleId = "model_id"; | const char *const kModeleId = "model_id"; | ||||
| const char *const kLoadStartTime = "load_start_time"; | const char *const kLoadStartTime = "load_start_time"; | ||||
| @@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size"; | |||||
| const char *const kTotalSize = "total_size"; | const char *const kTotalSize = "total_size"; | ||||
| const char *const kTaskCount = "task_count"; | const char *const kTaskCount = "task_count"; | ||||
| const char *const kTaskId = "task_id"; | const char *const kTaskId = "task_id"; | ||||
| const char* const kRequestId = "request_id"; | |||||
| const char* const kThreadId = "thread_id"; | |||||
| const char* const kInputBeginTime = "input_begin_time"; | |||||
| const char* const kInputEndTime = "input_end_time"; | |||||
| const char* const kInferBeginTime = "infer_begin_time"; | |||||
| const char* const kInferEndTime = "infer_end_time"; | |||||
| const char* const kOutputBeginTime = "output_start_time"; | |||||
| const char* const kOutputEndTime = "output_end_time"; | |||||
| const char *const kRequestId = "request_id"; | |||||
| const char *const kThreadId = "thread_id"; | |||||
| const char *const kInputBeginTime = "input_begin_time"; | |||||
| const char *const kInputEndTime = "input_end_time"; | |||||
| const char *const kInferBeginTime = "infer_begin_time"; | |||||
| const char *const kInferEndTime = "infer_end_time"; | |||||
| const char *const kOutputBeginTime = "output_start_time"; | |||||
| const char *const kOutputEndTime = "output_end_time"; | |||||
| const char *const kStubFuncName = "_register_stub_func"; | |||||
| const uint32_t kStringHeadElems = 2; | const uint32_t kStringHeadElems = 2; | ||||
| const uint32_t kPlacementHostData = 0; | const uint32_t kPlacementHostData = 0; | ||||
| const size_t kAlignment = 64; | const size_t kAlignment = 64; | ||||
| @@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
| SetLabelForDynamic(node); | SetLabelForDynamic(node); | ||||
| auto it = op_desc_handle.find(op_desc->GetType()); | auto it = op_desc_handle.find(op_desc->GetType()); | ||||
| if (it != op_desc_handle.end()) { | if (it != op_desc_handle.end()) { | ||||
| if ((this->*it->second)(op_desc) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID, | |||||
| "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
| GE_TIMESTAMP_RESTART(InitTbeHandle); | GE_TIMESTAMP_RESTART(InitTbeHandle); | ||||
| if (IsTbeTask(op_desc)) { | if (IsTbeTask(op_desc)) { | ||||
| Status status = InitTbeHandle(op_desc); | |||||
| Status status = | |||||
| op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); | GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); | ||||
| return status; | return status; | ||||
| @@ -3463,11 +3466,11 @@ bool DavinciModel::CheckUserAndModelSize(const int64_t &size, const int64_t &op_ | |||||
| } | } | ||||
| // The input and model input size can not be exactly equal because user input is not definite. | // The input and model input size can not be exactly equal because user input is not definite. | ||||
| if ((size + kDataMemAlignSizeCompare) < op_size) { | if ((size + kDataMemAlignSizeCompare) < op_size) { | ||||
| REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u, " | |||||
| REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u, " | |||||
| "check invalid", | "check invalid", | ||||
| input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); | input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); | ||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | GELOGE(ACL_ERROR_GE_PARAM_INVALID, | ||||
| "[Check][Param] %s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u", | |||||
| "[Check][Param] %s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u", | |||||
| input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); | input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -3700,6 +3703,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | ||||
| string bin_file = op_desc->GetName(); | |||||
| auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); | auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); | ||||
| auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | ||||
| if (tbe_kernel == nullptr) { | if (tbe_kernel == nullptr) { | ||||
| @@ -3708,12 +3712,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed", | |||||
| bin_file.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::string session_graph_model_id; | |||||
| GetUniqueId(op_desc, session_graph_model_id); | |||||
| const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid. | |||||
| TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||||
| Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) { | |||||
| std::vector<OpKernelBinPtr> tbe_kernel; | |||||
| tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); | |||||
| GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size()); | |||||
| if (tbe_kernel.size() != kFftsTbeHandleElementSize) { | |||||
| REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts", | |||||
| op_desc->GetName().c_str(), tbe_kernel.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| vector<string> bin_file_keys; | |||||
| (void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); | |||||
| if (bin_file_keys.size() != kFftsTbeHandleElementSize) { | |||||
| REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts", | |||||
| op_desc->GetName().c_str(), bin_file_keys.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true, | |||||
| kNonTailBlock), | |||||
| "Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str()); | |||||
| GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock), | |||||
| "Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, | |||||
| bool is_ffts, size_t thread_index) { | |||||
| if (thread_index > 1) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const char *bin_file_key; | |||||
| if (is_ffts) { | |||||
| bin_file_key = GetRegisterStub(bin_file, ""); | |||||
| GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key); | |||||
| } else { | |||||
| std::string session_graph_model_id; | |||||
| GetUniqueId(op_desc, session_graph_model_id); | |||||
| bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid. | |||||
| } | |||||
| TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||||
| std::lock_guard<std::mutex> lock(tvm_bin_mutex_); | std::lock_guard<std::mutex> lock(tvm_bin_mutex_); | ||||
| if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { | if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { | ||||
| void *bin_handle = nullptr; | void *bin_handle = nullptr; | ||||
| @@ -3721,59 +3774,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||||
| GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); | GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); | ||||
| rtDevBinary_t binary; | rtDevBinary_t binary; | ||||
| std::string json_string; | |||||
| GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string), | |||||
| GELOGD("Get original type of session_graph_id.")); | |||||
| if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") { | |||||
| binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU; | |||||
| } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { | |||||
| binary.magic = RT_DEV_BINARY_MAGIC_ELF; | |||||
| } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { | |||||
| binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; | |||||
| } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { | |||||
| binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
| TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
| TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.", | |||||
| op_desc->GetName().c_str()); | |||||
| binary.version = 0; | binary.version = 0; | ||||
| binary.data = tbe_kernel->GetBinData(); | binary.data = tbe_kernel->GetBinData(); | ||||
| binary.length = tbe_kernel->GetBinDataSize(); | binary.length = tbe_kernel->GetBinDataSize(); | ||||
| GELOGD("TBE: binary.length: %lu", binary.length); | GELOGD("TBE: binary.length: %lu", binary.length); | ||||
| GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); | GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); | ||||
| std::string meta_data; | |||||
| GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data), | |||||
| GELOGI("Get original type of json_string")); | |||||
| GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||||
| GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()))); | |||||
| GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.", | |||||
| op_desc->GetName().c_str()); | |||||
| kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); | kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); | ||||
| } else { | } else { | ||||
| GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); | GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); | ||||
| kernel_store.ReferTBEHandle(bin_file_key); | kernel_store.ReferTBEHandle(bin_file_key); | ||||
| } | } | ||||
| std::string kernel_name; | std::string kernel_name; | ||||
| GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name), | |||||
| GELOGD("Get original type of kernel_name")); | |||||
| GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.", | |||||
| op_desc->GetName().c_str()); | |||||
| GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); | GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); | ||||
| used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. | used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // Kernel registed, Increase used num in store. | // Kernel registed, Increase used num in store. | ||||
| StoreTbeHandle(bin_file_key); | StoreTbeHandle(bin_file_key); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, | |||||
| rtDevBinary_t &binary) { | |||||
| string json_string; | |||||
| const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC; | |||||
| const static std::map<std::string, uint32_t> binary_magics = { | |||||
| {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}, | |||||
| {"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, | |||||
| {"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC}, | |||||
| {"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE} | |||||
| }; | |||||
| if (is_ffts) { | |||||
| vector<string> json_list; | |||||
| (void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list); | |||||
| if (json_list.size() != kFftsTbeHandleElementSize) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.", | |||||
| tvm_magic.c_str(), thread_index, json_list.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| json_string = json_list[thread_index]; | |||||
| } else { | |||||
| (void)AttrUtils::GetStr(op_desc, tvm_magic, json_string); | |||||
| } | |||||
| auto iter = binary_magics.find(json_string); | |||||
| if (iter == binary_magics.end()) { | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
| tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), model_id_); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
| TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| binary.magic = iter->second; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) { | |||||
| string meta_data; | |||||
| const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA; | |||||
| if (is_ffts) { | |||||
| vector<string> meta_data_list; | |||||
| (void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list); | |||||
| if (meta_data_list.size() != kFftsTbeHandleElementSize) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.", | |||||
| tvm_metadata.c_str(), thread_index, meta_data_list.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| meta_data = meta_data_list[thread_index]; | |||||
| } else { | |||||
| (void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data); | |||||
| } | |||||
| GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||||
| if (!meta_data.empty()) { | |||||
| GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) { | |||||
| if (is_ffts) { | |||||
| // delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile | |||||
| vector<string> kernel_name_list; | |||||
| auto pos = op_desc->GetName().find("/"); | |||||
| if (pos == std::string::npos) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname"; | |||||
| (void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list); | |||||
| if (kernel_name_list.size() != kFftsTbeHandleElementSize) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.", | |||||
| attr_kernel_name.c_str(), thread_index, kernel_name_list.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| kernel_name = kernel_name_list[thread_index]; | |||||
| } else { | |||||
| string attr_kernel_name = op_desc->GetName() + "_kernelname"; | |||||
| (void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void DavinciModel::StoreTbeHandle(const std::string &handle_key) { | void DavinciModel::StoreTbeHandle(const std::string &handle_key) { | ||||
| // Online mode FE may call rtFunctionRegister. | // Online mode FE may call rtFunctionRegister. | ||||
| TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | ||||
| @@ -771,6 +771,12 @@ class DavinciModel { | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status InitTbeHandle(const OpDescPtr &op_desc); | Status InitTbeHandle(const OpDescPtr &op_desc); | ||||
| Status InitTbeHandleWithFfts(const OpDescPtr &op_desc); | |||||
| Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts, | |||||
| size_t thread_index = 0); | |||||
| Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary); | |||||
| Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle); | |||||
| Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name); | |||||
| void StoreTbeHandle(const string &handle_key); | void StoreTbeHandle(const string &handle_key); | ||||
| void CleanTbeHandle(); | void CleanTbeHandle(); | ||||
| @@ -0,0 +1,393 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/load/model_manager/task_info/ffts_task_info.h" | |||||
| #include <vector> | |||||
| #include "graph/load/model_manager/davinci_model.h" | |||||
| namespace { | |||||
| constexpr uint32_t kAddrLen = sizeof(void *); | |||||
| } | |||||
| namespace ge { | |||||
| FftsTaskInfo::~FftsTaskInfo() { | |||||
| GE_FREE_RT_LOG(args_); | |||||
| } | |||||
| Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
| GELOGI("FftsTaskInfo Init Start."); | |||||
| GE_CHECK_NOTNULL(davinci_model); | |||||
| davinci_model_ = davinci_model; | |||||
| GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList())); | |||||
| const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task(); | |||||
| OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index()); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) || | |||||
| (ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d", | |||||
| op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| args_size_ = kAddrLen * ffts_task_def.addr_size(); | |||||
| GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); | |||||
| InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc); | |||||
| sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type()); | |||||
| sub_task_info_.subTaskNum = ffts_task_def.sub_task_size(); | |||||
| for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx])); | |||||
| } | |||||
| sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size(); | |||||
| for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx])); | |||||
| } | |||||
| size_t data_size = kAddrLen * io_addrs_.size(); | |||||
| GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
| GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size()); | |||||
| return SUCCESS; | |||||
| } | |||||
| void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) { | |||||
| ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm()); | |||||
| ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di()); | |||||
| ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw()); | |||||
| ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df()); | |||||
| ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit()); | |||||
| ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num()); | |||||
| ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num()); | |||||
| ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper()); | |||||
| ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower()); | |||||
| ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper()); | |||||
| ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower()); | |||||
| } | |||||
| Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) { | |||||
| if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) || | |||||
| (sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d", | |||||
| sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size()); | |||||
| return FAILED; | |||||
| } | |||||
| if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d", | |||||
| sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv()); | |||||
| return FAILED; | |||||
| } | |||||
| thread_dim_ = sub_task_def.thread_dim(); | |||||
| GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_); | |||||
| sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type()); | |||||
| sub_task_desc.threadDim = sub_task_def.thread_dim(); | |||||
| sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap(); | |||||
| sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap(); | |||||
| sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap(); | |||||
| for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) { | |||||
| sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx); | |||||
| } | |||||
| for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) { | |||||
| sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx); | |||||
| } | |||||
| if (sub_task_def.has_auto_thread_aic_aiv()) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv)); | |||||
| } | |||||
| if (sub_task_def.has_manual_thread_aic_aiv()) { | |||||
| GE_CHK_STATUS_RET_NOLOG( | |||||
| InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv)); | |||||
| } | |||||
| if (sub_task_def.has_manual_thread_nop()) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) { | |||||
| if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d", | |||||
| ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache()); | |||||
| return FAILED; | |||||
| } | |||||
| ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option()); | |||||
| ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window(); | |||||
| if (ticket_cache_def.has_auto_thread_cache()) { | |||||
| InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache); | |||||
| } | |||||
| if (ticket_cache_def.has_manual_thread_cache()) { | |||||
| GE_CHK_STATUS_RET_NOLOG( | |||||
| InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // task_addr = {0,200,700,1000,2000, 3500} | |||||
| // task_addr_offset = {20,40,2,100,200} | |||||
| template <typename T> | |||||
| Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, | |||||
| uint32_t addr_count) { | |||||
| for (uint32_t i = 0; i < addr_count; ++i) { | |||||
| uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i); | |||||
| uint8_t *io_addr = nullptr; | |||||
| if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p", | |||||
| aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr); | |||||
| io_addrs_.emplace_back(io_addr); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) { | |||||
| if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size()); | |||||
| return FAILED; | |||||
| } | |||||
| aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||||
| GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||||
| const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||||
| for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||||
| static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||||
| int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||||
| for (int k = 0; k < last_thread_workspace_size; ++k) { | |||||
| uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||||
| uint8_t *io_addr = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||||
| GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr); | |||||
| io_addrs_.emplace_back(io_addr); | |||||
| } | |||||
| aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||||
| GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset); | |||||
| aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||||
| aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||||
| aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||||
| aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); | |||||
| aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); | |||||
| aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim(); | |||||
| aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim(); | |||||
| aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), ""); | |||||
| aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), ""); | |||||
| GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub); | |||||
| for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) { | |||||
| InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) { | |||||
| cache.dataAddr = cache_def.data_addr(); | |||||
| cache.dataAddrOffset = cache_def.data_addr_offset(); | |||||
| cache.nonTailDataLen = cache_def.non_tail_data_len(); | |||||
| cache.tailDataLen = cache_def.tail_data_len(); | |||||
| cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt(); | |||||
| } | |||||
| void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) { | |||||
| prefetch.dataAddr = prefetch_def.data_addr(); | |||||
| prefetch.dataAddrOffset = prefetch_def.data_addr_offset(); | |||||
| prefetch.nonTailDataLen = prefetch_def.non_tail_data_len(); | |||||
| prefetch.tailDataLen = prefetch_def.tail_data_len(); | |||||
| } | |||||
| Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, | |||||
| rtManualThreadAicAivInfo_t &aic_aiv) { | |||||
| if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
| (aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
| (aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
| (aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, " | |||||
| "thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d", | |||||
| aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(), | |||||
| aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size()); | |||||
| return FAILED; | |||||
| } | |||||
| aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||||
| GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||||
| const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||||
| for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||||
| static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||||
| int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||||
| for (int k = 0; k < last_thread_workspace_size; ++k) { | |||||
| uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||||
| uint8_t *io_addr = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||||
| io_addrs_.emplace_back(io_addr); | |||||
| } | |||||
| aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||||
| aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||||
| aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||||
| aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||||
| aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0 | |||||
| aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0 | |||||
| aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num(); | |||||
| for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) { | |||||
| aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx); | |||||
| } | |||||
| for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) { | |||||
| aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx); | |||||
| } | |||||
| for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) { | |||||
| aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str(); | |||||
| } | |||||
| InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList); | |||||
| for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx])); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, | |||||
| rtManualThreadCacheInfo_t &cache_info) { | |||||
| if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
| (cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d", | |||||
| cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size()); | |||||
| return FAILED; | |||||
| } | |||||
| InitManualDmuInfo(cache_def, cache_info.dmuList); | |||||
| for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) { | |||||
| cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx); | |||||
| } | |||||
| for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) { | |||||
| cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def, | |||||
| rtManualThreadDependency_t &dependency) { | |||||
| if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size()); | |||||
| return FAILED; | |||||
| } | |||||
| for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) { | |||||
| dependency.dependency[idx] = dependency_def.dependency(idx); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) { | |||||
| if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||||
| GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size()); | |||||
| return FAILED; | |||||
| } | |||||
| for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx])); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) { | |||||
| if (aic_aiv_def.prefetch_list().empty()) { | |||||
| return; | |||||
| } | |||||
| std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size()); | |||||
| dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||||
| for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) { | |||||
| InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]); | |||||
| } | |||||
| } | |||||
| void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) { | |||||
| if (cache_def.dmu_list().empty()) { | |||||
| return; | |||||
| } | |||||
| std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size()); | |||||
| dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||||
| for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) { | |||||
| InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]); | |||||
| } | |||||
| } | |||||
| void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) { | |||||
| dmu.dataAddr = dmu_def.data_addr(); | |||||
| dmu.numOuter = dmu_def.num_outer(); | |||||
| dmu.numInner = dmu_def.num_inner(); | |||||
| dmu.strideOuter = dmu_def.stride_outer(); | |||||
| dmu.lenInner = dmu_def.len_inner(); | |||||
| dmu.strideInner = dmu_def.stride_inner(); | |||||
| } | |||||
| Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::UpdateArgs() { | |||||
| GE_CHECK_NOTNULL(davinci_model_); | |||||
| std::vector<void *> io_addrs = io_addrs_; | |||||
| davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); | |||||
| auto addr_size = kAddrLen * io_addrs.size(); | |||||
| GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FftsTaskInfo::Distribute() { | |||||
| GELOGI("FftsTaskInfo Distribute Start."); | |||||
| rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| GELOGI("FftsTaskInfo Distribute Success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||||
| #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||||
| #include "graph/load/model_manager/task_info/task_info.h" | |||||
| #include "graph/op_desc.h" | |||||
| namespace ge { | |||||
| class FftsTaskInfo : public TaskInfo { | |||||
| public: | |||||
| FftsTaskInfo() = default; | |||||
| ~FftsTaskInfo() override; | |||||
| Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
| Status Distribute() override; | |||||
| Status UpdateArgs() override; | |||||
| Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
| private: | |||||
| void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc); | |||||
| Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task); | |||||
| Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache); | |||||
| Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv); | |||||
| void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache); | |||||
| void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch); | |||||
| Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv); | |||||
| Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache); | |||||
| Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend); | |||||
| Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop); | |||||
| void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu); | |||||
| void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu); | |||||
| void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu); | |||||
| template<typename T> | |||||
| Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count); | |||||
| DavinciModel *davinci_model_{nullptr}; | |||||
| rtFftsTaskInfo_t sub_task_info_; | |||||
| std::vector<void *> io_addrs_; | |||||
| uint32_t thread_dim_{0}; | |||||
| void *args_{nullptr}; // runtime args memory | |||||
| uint32_t args_size_{0}; // runtime args memory length | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||||
| @@ -27,6 +27,7 @@ | |||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "common/thread_pool.h" | #include "common/thread_pool.h" | ||||
| #include "common/dump/dump_manager.h" | #include "common/dump/dump_manager.h" | ||||
| #include "ge_opt_info/ge_opt_info.h" | |||||
| #include "analyzer/analyzer.h" | #include "analyzer/analyzer.h" | ||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| @@ -949,7 +950,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint | |||||
| rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); | rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d", | |||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d", | |||||
| session_id, graph_id, mode); | session_id, graph_id, mode); | ||||
| GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); | GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -1001,6 +1002,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = GeOptInfo::SetOptInfo(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Set][OptInfo] Set optional information failed."); | |||||
| return ret; | |||||
| } | |||||
| /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | ||||
| /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | ||||
| /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | ||||
| @@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { | |||||
| GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); | GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| #ifndef ONLY_COMPILE_OPEN_SRC | |||||
| GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); | GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); | ||||
| ret = (iter->second)->OptimizeAfterStage1(*compute_graph); | ret = (iter->second)->OptimizeAfterStage1(*compute_graph); | ||||
| #endif | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " | REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " | ||||
| "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); | "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); | ||||
| @@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
| } | } | ||||
| void DynamicShapePartitioner::MergeClustersControlFlow() { | void DynamicShapePartitioner::MergeClustersControlFlow() { | ||||
| std::unordered_set<ClusterPtr> all_merged_clusters; | |||||
| for (const auto &item : control_clusters_) { | for (const auto &item : control_clusters_) { | ||||
| const auto &control_cluster = item.second; | const auto &control_cluster = item.second; | ||||
| auto rit = control_cluster.rbegin(); | auto rit = control_cluster.rbegin(); | ||||
| @@ -373,17 +374,32 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
| } | } | ||||
| const auto &cluster = *rit; | const auto &cluster = *rit; | ||||
| if (all_merged_clusters.count(cluster) > 0) { | |||||
| continue; | |||||
| } | |||||
| bool is_unknown_cluster = cluster->IsUnknownShape(); | |||||
| for (++rit; rit != control_cluster.rend(); ++rit) { | for (++rit; rit != control_cluster.rend(); ++rit) { | ||||
| const auto &cluster_from = *rit; | const auto &cluster_from = *rit; | ||||
| if (all_merged_clusters.count(cluster_from) > 0) { | |||||
| continue; | |||||
| } | |||||
| auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | ||||
| GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | ||||
| ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
| for (const auto &merged_cluster : merged_clusters) { | for (const auto &merged_cluster : merged_clusters) { | ||||
| all_merged_clusters.emplace(merged_cluster); | |||||
| for (const auto &node : merged_cluster->Nodes()) { | for (const auto &node : merged_cluster->Nodes()) { | ||||
| node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||||
| GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); | |||||
| ordered_cluster_.push_back(cluster); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -703,7 +719,12 @@ void Cluster::Merge(ClusterPtr other) { | |||||
| if (other->min_ < min_) { | if (other->min_ < min_) { | ||||
| min_ = other->min_; | min_ = other->min_; | ||||
| } | } | ||||
| }; | |||||
| if (!IsUnknownShape() && other->IsUnknownShape()) { | |||||
| type_ = UNKNOWN_SHAPE; | |||||
| } | |||||
| } | |||||
| bool Cluster::TryMerge(ClusterPtr other) { | bool Cluster::TryMerge(ClusterPtr other) { | ||||
| std::queue<ClusterPtr> forward_reached; | std::queue<ClusterPtr> forward_reached; | ||||
| forward_reached.push(other); | forward_reached.push(other); | ||||
| @@ -161,7 +161,7 @@ class DynamicShapePartitioner { | |||||
| ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
| std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
| // V1 control flow cluster, need merge to one Graph. | // V1 control flow cluster, need merge to one Graph. | ||||
| std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
| std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
| // topological sorted clusters, this field will change with the splitting. | // topological sorted clusters, this field will change with the splitting. | ||||
| // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | ||||
| // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | ||||
| @@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||||
| GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(original_compute_graph); | GE_CHECK_NOTNULL(original_compute_graph); | ||||
| output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||||
| // partition sub graph | // partition sub graph | ||||
| for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { | for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { | ||||
| ComputeGraphPtr merged_sub_graph = nullptr; | ComputeGraphPtr merged_sub_graph = nullptr; | ||||
| @@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||||
| GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | ||||
| continue; | continue; | ||||
| } | } | ||||
| // this means subgraph added in optimize subgraph and without partitions, so just add to root graph | |||||
| if (merged_sub_graph == sub_graph) { | |||||
| GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(), | |||||
| sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str()); | |||||
| sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph()); | |||||
| GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS, | |||||
| return FAILED;) | |||||
| continue; | |||||
| } | |||||
| // add sub graph | // add sub graph | ||||
| output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||||
| merged_sub_graph->SetName(sub_graph->GetName()); | merged_sub_graph->SetName(sub_graph->GetName()); | ||||
| merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); | merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); | ||||
| merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); | merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); | ||||
| @@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co | |||||
| } | } | ||||
| if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || | if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || | ||||
| (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { | (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { | ||||
| REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||||
| original_compute_graph->GetName().c_str()); | |||||
| GELOGE(GE_GRAPH_NULL_INPUT, | |||||
| "[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||||
| original_compute_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| GELOGW("[GraphPartition]: compute_graph has not found, just return original."); | |||||
| output_merged_compute_graph = original_compute_graph; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; | GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; | ||||
| const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; | const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; | ||||
| @@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
| } | } | ||||
| auto &engine_name = graph_info_.partitions_.at(sub_graph); | auto &engine_name = graph_info_.partitions_.at(sub_graph); | ||||
| (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | ||||
| (void)sub_graph->SetExtAttr("part_src_graph", compute_graph); | |||||
| GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | ||||
| compute_graph->GetName().c_str()); | compute_graph->GetName().c_str()); | ||||
| GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | ||||
| @@ -132,39 +132,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
| /// @return | /// @return | ||||
| /// | /// | ||||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | ||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||||
| }; | |||||
| for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||||
| const auto &op_node1 = it1->first; | |||||
| const auto &op_desc1 = op_node1->GetOpDesc(); | |||||
| if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
| for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||||
| const auto &op_node = it->first; | |||||
| const auto &op_desc = op_node->GetOpDesc(); | |||||
| if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||||
| int64_t group_index = op_desc1->GetId(); | |||||
| GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
| MarkForceUnknownShape(op_node1, true, group_index); | |||||
| for (const auto &n : it1->second) { | |||||
| MarkForceUnknownShape(n, true, group_index); | |||||
| } | |||||
| for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
| const auto &op_node2 = it2->first; | |||||
| const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
| if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
| continue; | |||||
| } | |||||
| if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
| MarkForceUnknownShape(op_node2, true, group_index); | |||||
| for (const auto &n : it2->second) { | |||||
| MarkForceUnknownShape(n, true, group_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| int64_t group_index = op_desc->GetId(); | |||||
| SetControlFlowGroup(op_node, group_index); | |||||
| for (const auto &n : it->second) { | |||||
| SetControlFlowGroup(n, group_index); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||||
| } | } | ||||
| } | } | ||||
| const auto &node = graph->GetParentNode(); | |||||
| if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { | |||||
| GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||||
| "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); | |||||
| } | |||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | ||||
| (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | ||||
| @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | ||||
| return FAILED, "[Check][Param] Param of pre node is nullptr."); | return FAILED, "[Check][Param] Param of pre node is nullptr."); | ||||
| int64_t group_index = -1; | int64_t group_index = -1; | ||||
| bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| MarkForceUnknownShape(node, force_unknown, group_index); | |||||
| (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
| @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MarkForceUnknownShape(active_node, force_unknown, group_index); | |||||
| SetControlFlowGroup(active_node, group_index); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| /// @return void | /// @return void | ||||
| /// | /// | ||||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | ||||
| std::string node_type; | |||||
| for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
| SetControlFlowGroup(switch_node, group_index); | SetControlFlowGroup(switch_node, group_index); | ||||
| for (const auto &node : switch_node->GetOutDataNodes()) { | for (const auto &node : switch_node->GetOutDataNodes()) { | ||||
| std::string node_type; | |||||
| (void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
| if (kExitOpTypes.count(node_type) > 0) { | if (kExitOpTypes.count(node_type) > 0) { | ||||
| SetControlFlowGroup(node, group_index); | SetControlFlowGroup(node, group_index); | ||||
| } else { | |||||
| // For: Switch -> Cast -> Exit | |||||
| for (const auto &n : node->GetOutDataNodes()) { | |||||
| (void)GetOriginalType(n, node_type); | |||||
| if (kExitOpTypes.count(node_type) > 0) { | |||||
| SetControlFlowGroup(n, group_index); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,7 +21,23 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/node_utils.h" | |||||
| namespace { | |||||
| const std::unordered_set<std::string> kControlFlowOps = { | |||||
| ge::SWITCH, | |||||
| ge::REFSWITCH, | |||||
| ge::MERGE, | |||||
| ge::REFMERGE, | |||||
| ge::ENTER, | |||||
| ge::REFENTER, | |||||
| ge::NEXTITERATION, | |||||
| ge::REFNEXTITERATION, | |||||
| ge::EXIT, | |||||
| ge::REFEXIT, | |||||
| ge::LOOPCOND | |||||
| }; | |||||
| } | |||||
| namespace ge { | namespace ge { | ||||
| Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | ||||
| GELOGD("ReplaceWithEmptyConstPass in."); | GELOGD("ReplaceWithEmptyConstPass in."); | ||||
| @@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | |||||
| GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); | GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) { | |||||
| GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Node like no op, it has no output | // Node like no op, it has no output | ||||
| if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { | if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { | ||||
| GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); | GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); | ||||
| @@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
| peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
| int64_t group_index = -1; | int64_t group_index = -1; | ||||
| bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||||
| (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| SetControlFlowGroup(stream_switch, group_index); | |||||
| return stream_switch; | return stream_switch; | ||||
| } | } | ||||
| @@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||||
| Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | ||||
| for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | ||||
| for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | ||||
| std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||||
| std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||||
| const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||||
| const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||||
| std::set<NodePtr> same_cond_switch; | std::set<NodePtr> same_cond_switch; | ||||
| same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | ||||
| same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | ||||
| @@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | ||||
| return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | ||||
| }; | }; | ||||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||||
| MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||||
| (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||||
| SetControlFlowGroup(active_node, group_index); | |||||
| const std::string &cond_group = cond_node->GetName(); | const std::string &cond_group = cond_node->GetName(); | ||||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
| bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | ||||
| std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||||
| const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||||
| GE_IF_BOOL_EXEC(switch_list.empty(), continue); | GE_IF_BOOL_EXEC(switch_list.empty(), continue); | ||||
| // select first stream_switch | // select first stream_switch | ||||
| @@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| "[Add][Edge] between %s and %s failed.", | "[Add][Edge] between %s and %s failed.", | ||||
| cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
| MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||||
| SetControlFlowGroup(stream_switch, group_index); | |||||
| for (const NodePtr &node : switch_list) { | for (const NodePtr &node : switch_list) { | ||||
| GE_IF_BOOL_EXEC(node != stream_switch, { | GE_IF_BOOL_EXEC(node != stream_switch, { | ||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | ||||
| @@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { | |||||
| Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) { | |||||
| auto format = desc.GetFormat(); | auto format = desc.GetFormat(); | ||||
| auto origin_format = desc.GetOriginFormat(); | auto origin_format = desc.GetOriginFormat(); | ||||
| auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||||
| bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); | bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); | ||||
| if (need_check_internal_format) { | if (need_check_internal_format) { | ||||
| bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | ||||
| @@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) { | |||||
| auto data_type = desc.GetDataType(); | |||||
| uint32_t length = 1; | |||||
| bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | |||||
| if (!type_ret) { | |||||
| std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + | |||||
| std::to_string(index) + " input tensor is not support"; | |||||
| REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| int64_t desc_shape = desc.GetShape().GetShapeSize(); | |||||
| FMK_INT64_UINT32_MULCHECK(desc_shape, length); | |||||
| int64_t shape_size = desc_shape * length; | |||||
| GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length)); | |||||
| int64_t size = 0; | |||||
| GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, | |||||
| REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED); | |||||
| bool size_check = (size != 0 && shape_size != size); | |||||
| if (size_check) { | |||||
| std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + | |||||
| "] != shape_size[" + std::to_string(size) + "], check invalid"; | |||||
| REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); | |||||
| return FAILED; | |||||
| } | |||||
| ge::TensorUtils::SetSize(desc, shape_size); | |||||
| auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||||
| if (!tune_flag) { | |||||
| graphStatus graph_ret = op->UpdateInputDesc(0, desc); | |||||
| if (graph_ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| return graph_ret; | |||||
| } | |||||
| // Size will be recalculated in the build stage | |||||
| ge::TensorUtils::SetSize(desc, 0); | |||||
| graph_ret = op->UpdateOutputDesc(0, desc); | |||||
| if (graph_ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| return graph_ret; | |||||
| } | |||||
| } else { | |||||
| GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | ||||
| const std::map<string, string> &graph_option) { | const std::map<string, string> &graph_option) { | ||||
| // Get shape range of input in dynamic_execute mode | // Get shape range of input in dynamic_execute mode | ||||
| @@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | |||||
| } | } | ||||
| GeTensorDesc desc(user_input[index].GetTensorDesc()); | GeTensorDesc desc(user_input[index].GetTensorDesc()); | ||||
| // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | ||||
| auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | |||||
| ret = CheckInternalFormat(input_node, desc, tune_flag); | |||||
| ret = CheckInternalFormat(input_node, desc); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| auto data_type = desc.GetDataType(); | |||||
| uint32_t length = 1; | |||||
| bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | |||||
| if (!type_ret) { | |||||
| std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + | |||||
| std::to_string(index) + " input tensor is not support"; | |||||
| REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| int64_t desc_shape = desc.GetShape().GetShapeSize(); | |||||
| FMK_INT64_UINT32_MULCHECK(desc_shape, length); | |||||
| int64_t shape_size = desc_shape * length; | |||||
| GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length)); | |||||
| int64_t size = 0; | |||||
| GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, | |||||
| REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); | |||||
| return FAILED); | |||||
| bool size_check = (size != 0 && shape_size != size); | |||||
| if (size_check) { | |||||
| std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + | |||||
| "] != shape_size[" + std::to_string(size) + "], check invalid"; | |||||
| REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason})); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); | |||||
| return FAILED; | |||||
| } | |||||
| ge::TensorUtils::SetSize(desc, shape_size); | |||||
| if (!tune_flag) { | |||||
| graphStatus graph_ret = op->UpdateInputDesc(0, desc); | |||||
| if (graph_ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| return graph_ret; | |||||
| } | |||||
| // Size will be recalculated in the build stage | |||||
| ge::TensorUtils::SetSize(desc, 0); | |||||
| graph_ret = op->UpdateOutputDesc(0, desc); | |||||
| if (graph_ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| return graph_ret; | |||||
| } | |||||
| } else { | |||||
| GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); | |||||
| ret = UpdateDataInputOutputDesc(index, op, desc); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str()); | |||||
| return ret; | |||||
| } | } | ||||
| if (!dynamic_shape_range_vec.empty()) { | if (!dynamic_shape_range_vec.empty()) { | ||||
| ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); | ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); | ||||
| GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); | GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); | ||||
| @@ -63,7 +63,8 @@ class GraphPrepare { | |||||
| Status CheckRefOp(); | Status CheckRefOp(); | ||||
| Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); | Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); | ||||
| Status AdjustDataOpOutput(const NodePtr &node); | Status AdjustDataOpOutput(const NodePtr &node); | ||||
| Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); | |||||
| Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc); | |||||
| Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc); | |||||
| Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | ||||
| Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); | ||||
| Status CheckConstOp(); | Status CheckConstOp(); | ||||
| @@ -114,7 +114,7 @@ Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &bat | |||||
| std::vector<std::string>({ | std::vector<std::string>({ | ||||
| data_node->GetName() + " format", | data_node->GetName() + " format", | ||||
| TypeUtils::FormatToSerialString(format), | TypeUtils::FormatToSerialString(format), | ||||
| "only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and "+ | |||||
| "only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and " + | |||||
| TypeUtils::FormatToSerialString(FORMAT_NHWC) + | TypeUtils::FormatToSerialString(FORMAT_NHWC) + | ||||
| " supported which dynamic aipp is linked"})); | " supported which dynamic aipp is linked"})); | ||||
| GELOGE(PARAM_INVALID, "[Check][Param] Not support data format:%s, node:%s", | GELOGE(PARAM_INVALID, "[Check][Param] Not support data format:%s, node:%s", | ||||
| @@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| GELOGD("Start to init HybridGraphEngine."); | GELOGD("Start to init HybridGraphEngine."); | ||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
| GE_CHECK_NOTNULL(root_graph_executor_); | |||||
| GELOGD("HybridGraphEngine initialized successfully."); | GELOGD("HybridGraphEngine initialized successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | ||||
| sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | ||||
| } | } | ||||
| SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||||
| auto ret = ExecuteGraphInternal(executor, args); | |||||
| auto ret = ExecuteGraphInternal(args); | |||||
| Cleanup(); | Cleanup(); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | ||||
| GELOGD("Model executed successfully."); | GELOGD("Model executed successfully."); | ||||
| @@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| context_.profiler->Dump(std::cout); | context_.profiler->Dump(std::cout); | ||||
| context_.profiler->Reset(); | context_.profiler->Reset(); | ||||
| } | } | ||||
| root_graph_executor_->ReleaseContext(); | |||||
| context_.iteration += 1; | context_.iteration += 1; | ||||
| if (ret == END_OF_SEQUENCE) { | if (ret == END_OF_SEQUENCE) { | ||||
| @@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| HybridModelExecutor::ExecuteArgs &args) { | |||||
| Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | ||||
| GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | ||||
| @@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | ||||
| } | } | ||||
| HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| "Failed to execute partitioned call."); | "Failed to execute partitioned call."); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | ||||
| @@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| if (!model_->IsSingleOp()) { | if (!model_->IsSingleOp()) { | ||||
| Status ret = executor.Synchronize(); | |||||
| Status ret = root_graph_executor_->Synchronize(); | |||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| @@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| args.outputs.clear(); | args.outputs.clear(); | ||||
| HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -48,7 +48,7 @@ class HybridModelExecutor { | |||||
| Status Execute(ExecuteArgs &args); | Status Execute(ExecuteArgs &args); | ||||
| private: | private: | ||||
| Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||||
| Status ExecuteGraphInternal(ExecuteArgs &args); | |||||
| Status Cleanup(); | Status Cleanup(); | ||||
| Status InitExecutionContext(); | Status InitExecutionContext(); | ||||
| static Status ResetExecutionContext(GraphExecutionContext &context); | static Status ResetExecutionContext(GraphExecutionContext &context); | ||||
| @@ -58,6 +58,7 @@ class HybridModelExecutor { | |||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| GraphExecutionContext context_; | GraphExecutionContext context_; | ||||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -19,8 +19,9 @@ | |||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "hybrid_execution_context.h" | |||||
| #include "subgraph_context.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/task_context.h" | |||||
| #define INC_ITERATION_COUNT(iteration) \ | #define INC_ITERATION_COUNT(iteration) \ | ||||
| do { \ | do { \ | ||||
| @@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex | |||||
| this->op_desc_ = node_item.node->GetOpDesc(); | this->op_desc_ = node_item.node->GetOpDesc(); | ||||
| } | } | ||||
| Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) { | |||||
| GE_CHECK_NOTNULL(frame_state); | |||||
| group_ = group; | |||||
| frame_state_ = frame_state; | |||||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | ||||
| if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
| GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); | GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); | ||||
| @@ -314,15 +325,54 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
| return task_context_; | return task_context_; | ||||
| } | } | ||||
| void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||||
| if (node_item_->root_data_.count(input_idx) > 0) { | |||||
| GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
| root_tensor_values_[input_idx] = tensor; | |||||
| } | |||||
| if (node_item_->enter_data_.count(input_idx) > 0) { | |||||
| GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||||
| root_tensor_values_[input_idx] = tensor; | |||||
| } | |||||
| } | |||||
| void NodeState::UpdatePersistTensor(int input_idx) { | |||||
| const auto it = root_tensor_values_.find(input_idx); | |||||
| if (it == root_tensor_values_.end()) { | |||||
| GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||||
| return; | |||||
| } | |||||
| auto tensor = task_context_->MutableInput(input_idx); | |||||
| if (tensor == nullptr) { | |||||
| GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); | |||||
| return; | |||||
| } | |||||
| *tensor = it->second; | |||||
| GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); | |||||
| } | |||||
| void NodeState::ResetContext(uint64_t iteration) { | void NodeState::ResetContext(uint64_t iteration) { | ||||
| switch_index_ = -1; | switch_index_ = -1; | ||||
| subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
| if (iteration == 0) { | |||||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
| } else { | |||||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||||
| GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); | |||||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
| for (auto item : node_item_->root_data_) { | |||||
| UpdatePersistTensor(item.first); | |||||
| } | |||||
| if (iteration > 0) { | |||||
| data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||||
| ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||||
| for (auto item : node_item_->enter_data_) { | |||||
| UpdatePersistTensor(item.first); | |||||
| } | |||||
| } | } | ||||
| iteration_count_ = iteration; | iteration_count_ = iteration; | ||||
| @@ -100,6 +100,8 @@ struct NodeState { | |||||
| NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | ||||
| ~NodeState() = default; | ~NodeState() = default; | ||||
| Status Init(int group, const shared_ptr<FrameState> &frame_state); | |||||
| OpDesc *GetOpDesc() const { | OpDesc *GetOpDesc() const { | ||||
| return op_desc_.get(); | return op_desc_.get(); | ||||
| } | } | ||||
| @@ -129,6 +131,8 @@ struct NodeState { | |||||
| void RunStreamActive(); | void RunStreamActive(); | ||||
| void RunNextIteration(); | void RunNextIteration(); | ||||
| void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
| void SetScheduleFuture(std::future<Status> &&future); | void SetScheduleFuture(std::future<Status> &&future); | ||||
| @@ -150,18 +154,10 @@ struct NodeState { | |||||
| return merge_index_; | return merge_index_; | ||||
| } | } | ||||
| void SetGroup(int group) { | |||||
| group_ = group; | |||||
| } | |||||
| int GetGroup() const { | int GetGroup() const { | ||||
| return group_; | return group_; | ||||
| } | } | ||||
| void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||||
| frame_state_ = frame_state; | |||||
| } | |||||
| const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
| return kernel_task_; | return kernel_task_; | ||||
| } | } | ||||
| @@ -181,12 +177,17 @@ struct NodeState { | |||||
| void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | ||||
| std::shared_ptr<TaskContext> GetTaskContext(); | std::shared_ptr<TaskContext> GetTaskContext(); | ||||
| void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } | |||||
| bool MaySkipShapeInference() const { return skip_infershape_; } | |||||
| private: | private: | ||||
| bool IsScheduleReady() const; | bool IsScheduleReady() const; | ||||
| void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
| void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
| void ResetContext(uint64_t iteration); | void ResetContext(uint64_t iteration); | ||||
| void ScheduleContext(const NodeState &node_state); | void ScheduleContext(const NodeState &node_state); | ||||
| void UpdatePersistTensor(int input_idx); | |||||
| const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
| @@ -199,6 +200,7 @@ struct NodeState { | |||||
| std::future<Status> schedule_future_; | std::future<Status> schedule_future_; | ||||
| std::shared_ptr<FrameState> frame_state_; | std::shared_ptr<FrameState> frame_state_; | ||||
| std::map<int, TensorValue> root_tensor_values_; | |||||
| uint64_t active_count_ = 0; | uint64_t active_count_ = 0; | ||||
| uint64_t iteration_count_ = 0; | uint64_t iteration_count_ = 0; | ||||
| uint32_t ctrl_scheduled_ = 0; | uint32_t ctrl_scheduled_ = 0; | ||||
| @@ -206,6 +208,7 @@ struct NodeState { | |||||
| int merge_index_ = -1; // Use for Execute (Reset after Executed). | int merge_index_ = -1; // Use for Execute (Reset after Executed). | ||||
| int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | ||||
| int group_ = -1; | int group_ = -1; | ||||
| bool skip_infershape_ = false; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -19,7 +19,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) | |||||
| SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) | |||||
| : graph_item_(graph_item), execution_context_(execution_context) { | : graph_item_(graph_item), execution_context_(execution_context) { | ||||
| } | } | ||||
| @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return CreateNodeState(node_item); | |||||
| } | |||||
| NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { | |||||
| GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | ||||
| if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | ||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
| if (node_state == nullptr) { | |||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
| node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||||
| node_state->SetGroup(group_); | |||||
| (void)guard; | |||||
| } | |||||
| do { | |||||
| if (node_state == nullptr) { | |||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
| if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); | |||||
| break; | |||||
| } | |||||
| (void)guard; | |||||
| } | |||||
| } while (0); | |||||
| GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | ||||
| if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | ||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | ||||
| @@ -30,7 +30,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| class SubgraphContext { | class SubgraphContext { | ||||
| public: | public: | ||||
| explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | |||||
| explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); | |||||
| ~SubgraphContext(); | ~SubgraphContext(); | ||||
| Status Init(); | Status Init(); | ||||
| @@ -51,10 +51,11 @@ class SubgraphContext { | |||||
| void NodeDone(const NodePtr &node); | void NodeDone(const NodePtr &node); | ||||
| private: | private: | ||||
| NodeStatePtr CreateNodeState(const NodeItem *node_item); | |||||
| FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | ||||
| friend class TaskContext; | friend class TaskContext; | ||||
| const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||
| const GraphExecutionContext *execution_context_; | |||||
| GraphExecutionContext *execution_context_; | |||||
| mmRWLock_t rw_lock_; | mmRWLock_t rw_lock_; | ||||
| std::vector<TensorValue> all_inputs_; | std::vector<TensorValue> all_inputs_; | ||||
| std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
| @@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | ||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | ||||
| auto op_desc = input_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| output_desc->SetShape(tensor_desc->GetShape()); | |||||
| output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | |||||
| output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
| node_state->SetSkipInferShape(true); | |||||
| } | } | ||||
| } | } | ||||
| @@ -175,16 +183,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->SetKernelTask(node_item->kernel_task); | node_state->SetKernelTask(node_item->kernel_task); | ||||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(known_shape_task_context_); | |||||
| node_state->SetTaskContext(known_shape_task_context_); | |||||
| std::function<void()> callback; | std::function<void()> callback; | ||||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | ||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), | |||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), | |||||
| "[%s] Failed to execute node [%s] for known subgraph.", | "[%s] Failed to execute node [%s] for known subgraph.", | ||||
| graph_item_->GetName().c_str(), | graph_item_->GetName().c_str(), | ||||
| known_shape_task_context_->GetNodeName()); | |||||
| node_state->GetName().c_str()); | |||||
| GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -271,16 +275,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
| } else { | } else { | ||||
| node_state->SetKernelTask(node_item.kernel_task); | node_state->SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state->GetKernelTask(); | const auto &task = node_state->GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | ||||
| REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | ||||
| return AfterPrepared(p_node_state); | return AfterPrepared(p_node_state); | ||||
| } | } | ||||
| @@ -480,19 +480,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state.GetKernelTask(); | const auto &task = node_state.GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | ||||
| REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state.SetTaskContext(shared_task_context); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | ||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | ||||
| GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws | |||||
| GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -41,6 +41,8 @@ class SubgraphExecutor { | |||||
| Status PartialExecuteAsync(int task_group); | Status PartialExecuteAsync(int task_group); | ||||
| void ReleaseContext() { subgraph_context_.reset(nullptr); } | |||||
| /** | /** | ||||
| * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | ||||
| * valid after this method returned | * valid after this method returned | ||||
| @@ -125,7 +127,6 @@ class SubgraphExecutor { | |||||
| ThreadPool pre_run_pool_; | ThreadPool pre_run_pool_; | ||||
| BlockingQueue<NodeState *> ready_queue_; | BlockingQueue<NodeState *> ready_queue_; | ||||
| std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | ||||
| std::shared_ptr<TaskContext> known_shape_task_context_; | |||||
| std::mutex mu_; // Guard for prepare_queues_. | std::mutex mu_; // Guard for prepare_queues_. | ||||
| std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | ||||
| @@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| } | } | ||||
| // Do shape inference | // Do shape inference | ||||
| // Skipping infer shape of input node. | |||||
| GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | ||||
| { | |||||
| if (!node_state.MaySkipShapeInference()) { | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
| "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | ||||
| @@ -1227,6 +1227,28 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||||
| hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model); | hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model); | ||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(InitHcclExecutorOnDemand(ge_model)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::InitHcclExecutorOnDemand(const GeModelPtr &ge_model) { | |||||
| if (NodeExecutorManager::GetInstance().IsExecutorInitialized(NodeExecutorManager::ExecutorType::HCCL)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // HCCL tasks in known-shaped subgraph which resides in a dynamic root graph | |||||
| // still depends on the initialization of the HcclExecutor | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||||
| for (int i = 0; i < tasks.size(); ++i) { | |||||
| const domi::TaskDef &task_def = tasks[i]; | |||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | |||||
| if (task_type == RT_MODEL_TASK_HCCL) { | |||||
| const NodeExecutor *unused = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance() | |||||
| .GetOrCreateExecutor(NodeExecutorManager::ExecutorType::HCCL, &unused)); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -57,6 +57,7 @@ class HybridModelBuilder { | |||||
| Status ValidateParams(); | Status ValidateParams(); | ||||
| Status LoadGraph(); | Status LoadGraph(); | ||||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
| static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | |||||
| Status LoadTask(NodeItem &node_item); | Status LoadTask(NodeItem &node_item); | ||||
| Status LoadTasks(); | Status LoadTasks(); | ||||
| Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | ||||
| @@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
| data_send_.emplace(node_item); | data_send_.emplace(node_item); | ||||
| node_item->data_recv_[this] = anchor_index; | node_item->data_recv_[this] = anchor_index; | ||||
| if (is_root_node_) { | if (is_root_node_) { | ||||
| node_item->root_data_.emplace(this); | |||||
| node_item->root_data_[anchor_index] = this; | |||||
| } | } | ||||
| // If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | ||||
| node_item->enter_data_.emplace(this); | |||||
| node_item->enter_inside_.emplace(anchor_index); | |||||
| node_item->enter_data_[anchor_index] = this; | |||||
| } | } | ||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
| } | } | ||||
| @@ -148,15 +148,14 @@ struct NodeItem { | |||||
| int64_t frame_index_ = -1; | int64_t frame_index_ = -1; | ||||
| int64_t parent_frame_ = -1; | int64_t parent_frame_ = -1; | ||||
| std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
| std::set<const NodeItem *> root_data_; // Recv data from root node | |||||
| std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||||
| std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | ||||
| std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||||
| std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||||
| std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
| std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
| std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
| std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | ||||
| std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||||
| std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
| std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
| @@ -64,10 +64,6 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ | |||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), | GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), | ||||
| "[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id); | "[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id); | ||||
| bool execute_mode = !aicpu_ext_handle_.IsNeedRefreshIOAddr() && !node_item_->is_dynamic; | |||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(execute_mode), | |||||
| "[Update][ExecuteMode] failed, node:%s.", node_name_.c_str()); | |||||
| // copy task args buf | // copy task args buf | ||||
| GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), | GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), | ||||
| "[Invoke][AllocTensorBuffer]Node[%s] alloc kernel_ext_info buf failed, size=%zu", | "[Invoke][AllocTensorBuffer]Node[%s] alloc kernel_ext_info buf failed, size=%zu", | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "runtime/event.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -325,7 +326,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| rtEvent_t evt = nullptr; | rtEvent_t evt = nullptr; | ||||
| if (context.GetExecutionContext()->hccl_stream != nullptr) { | if (context.GetExecutionContext()->hccl_stream != nullptr) { | ||||
| GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, 0x01)); | |||||
| GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, RT_EVENT_WITH_FLAG)); | |||||
| GE_CHK_RT_RET(rtStreamWaitEvent(context.GetExecutionContext()->hccl_stream, evt)); | GE_CHK_RT_RET(rtStreamWaitEvent(context.GetExecutionContext()->hccl_stream, evt)); | ||||
| } | } | ||||
| TaskContext *p_ctx = &context; | TaskContext *p_ctx = &context; | ||||
| @@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, | |||||
| } | } | ||||
| Status NodeExecutorManager::EnsureInitialized() { | Status NodeExecutorManager::EnsureInitialized() { | ||||
| GE_CHK_STATUS_RET(InitializeExecutors()); | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| ++ref_count_; | |||||
| if (initialized_) { | if (initialized_) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -115,17 +115,14 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node | |||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { | |||||
| Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) { | |||||
| auto executor_type = ResolveExecutorType(node); | auto executor_type = ResolveExecutorType(node); | ||||
| GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type)); | |||||
| const auto it = executors_.find(executor_type); | const auto it = executors_.find(executor_type); | ||||
| if (it == executors_.end()) { | if (it == executors_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast<int>(executor_type)); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.", | |||||
| static_cast<int>(executor_type)); | |||||
| return INTERNAL_ERROR; | |||||
| return GetOrCreateExecutor(executor_type, executor); | |||||
| } | } | ||||
| GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type)); | |||||
| *executor = it->second.get(); | *executor = it->second.get(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -178,51 +175,55 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { | |||||
| return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); | return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); | ||||
| } | } | ||||
| Status NodeExecutorManager::InitializeExecutors() { | |||||
| bool NodeExecutorManager::IsExecutorInitialized(NodeExecutorManager::ExecutorType executor_type) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| return executors_.find(executor_type) != executors_.end(); | |||||
| } | |||||
| Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| if (executor_initialized_) { | |||||
| ++ref_count_; | |||||
| GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); | |||||
| const auto executor_it = executors_.find(executor_type); | |||||
| if (executor_it != executors_.end()) { | |||||
| *out_executor = executor_it->second.get(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGI("Start to Initialize NodeExecutors"); | |||||
| for (auto &it : builders_) { | |||||
| auto engine_type = it.first; | |||||
| auto build_fn = it.second; | |||||
| GE_CHECK_NOTNULL(build_fn); | |||||
| auto executor = std::unique_ptr<NodeExecutor>(build_fn()); | |||||
| if (executor == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d", | |||||
| static_cast<int>(engine_type)); | |||||
| GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(engine_type)); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast<int>(executor_type)); | |||||
| auto it = builders_.find(executor_type); | |||||
| if (it == builders_.end()) { | |||||
| REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", | |||||
| static_cast<int>(executor_type)); | |||||
| GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast<int>(executor_type)); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(engine_type)); | |||||
| auto ret = executor->Initialize(); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(engine_type)); | |||||
| GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(engine_type)); | |||||
| for (auto &executor_it : executors_) { | |||||
| executor_it.second->Finalize(); | |||||
| } | |||||
| executors_.clear(); | |||||
| return ret; | |||||
| } | |||||
| auto build_fn = it->second; | |||||
| GE_CHECK_NOTNULL(build_fn); | |||||
| auto executor = std::unique_ptr<NodeExecutor>(build_fn()); | |||||
| if (executor == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", | |||||
| static_cast<int>(executor_type)); | |||||
| GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(executor_type)); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| executors_.emplace(engine_type, std::move(executor)); | |||||
| GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(executor_type)); | |||||
| auto ret = executor->Initialize(); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(executor_type)); | |||||
| GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(executor_type)); | |||||
| return ret; | |||||
| } | } | ||||
| ++ref_count_; | |||||
| executor_initialized_ = true; | |||||
| GELOGI("Initializing NodeExecutors successfully."); | |||||
| *out_executor = executor.get(); | |||||
| executors_.emplace(executor_type, std::move(executor)); | |||||
| GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast<int>(executor_type)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void NodeExecutorManager::FinalizeExecutors() { | void NodeExecutorManager::FinalizeExecutors() { | ||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| if (!executor_initialized_) { | |||||
| if (ref_count_ <= 0) { | |||||
| GELOGD("No need for finalizing for not initialized."); | GELOGD("No need for finalizing for not initialized."); | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -237,7 +238,6 @@ void NodeExecutorManager::FinalizeExecutors() { | |||||
| it.second->Finalize(); | it.second->Finalize(); | ||||
| } | } | ||||
| executors_.clear(); | executors_.clear(); | ||||
| executor_initialized_ = false; | |||||
| GELOGD("Done invoking Finalize successfully."); | GELOGD("Done invoking Finalize successfully."); | ||||
| } | } | ||||
| @@ -179,8 +179,6 @@ class NodeExecutorManager { | |||||
| */ | */ | ||||
| Status EnsureInitialized(); | Status EnsureInitialized(); | ||||
| Status InitializeExecutors(); | |||||
| void FinalizeExecutors(); | void FinalizeExecutors(); | ||||
| /** | /** | ||||
| @@ -196,7 +194,7 @@ class NodeExecutorManager { | |||||
| * @param executor executor | * @param executor executor | ||||
| * @return SUCCESS on success, error code otherwise | * @return SUCCESS on success, error code otherwise | ||||
| */ | */ | ||||
| Status GetExecutor(Node &node, const NodeExecutor **executor) const; | |||||
| Status GetExecutor(Node &node, const NodeExecutor **executor); | |||||
| /** | /** | ||||
| * Resolve executor type by node | * Resolve executor type by node | ||||
| @@ -205,13 +203,16 @@ class NodeExecutorManager { | |||||
| */ | */ | ||||
| ExecutorType ResolveExecutorType(Node &node) const; | ExecutorType ResolveExecutorType(Node &node) const; | ||||
| Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor); | |||||
| bool IsExecutorInitialized(ExecutorType executor_type); | |||||
| private: | private: | ||||
| std::map<ExecutorType, std::unique_ptr<NodeExecutor>> executors_; | std::map<ExecutorType, std::unique_ptr<NodeExecutor>> executors_; | ||||
| std::map<ExecutorType, std::function<NodeExecutor *()>> builders_; | std::map<ExecutorType, std::function<NodeExecutor *()>> builders_; | ||||
| std::map<std::string, NodeExecutorManager::ExecutorType> engine_mapping_; | std::map<std::string, NodeExecutorManager::ExecutorType> engine_mapping_; | ||||
| std::mutex mu_; | std::mutex mu_; | ||||
| bool initialized_ = false; | bool initialized_ = false; | ||||
| bool executor_initialized_ = false; | |||||
| int ref_count_ = 0; | int ref_count_ = 0; | ||||
| }; | }; | ||||
| @@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() { | |||||
| } | } | ||||
| } | } | ||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | |||||
| SubgraphContext *subgraph_context) { | |||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) { | |||||
| const NodeItem &node_item = *node_state->GetNodeItem(); | const NodeItem &node_item = *node_state->GetNodeItem(); | ||||
| GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
| } | } | ||||
| auto task_context = std::unique_ptr<TaskContext>( | auto task_context = std::unique_ptr<TaskContext>( | ||||
| new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); | |||||
| new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context)); | |||||
| if (task_context == nullptr) { | if (task_context == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); | ||||
| GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); | GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); | ||||
| @@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
| task_context->node_item_ = &node_item; | task_context->node_item_ = &node_item; | ||||
| task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; | task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; | ||||
| task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; | task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; | ||||
| task_context->iteration_ = execution_context->iteration; | |||||
| task_context->iteration_ = subgraph_context->execution_context_->iteration; | |||||
| return task_context; | return task_context; | ||||
| } | } | ||||
| @@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() { | |||||
| subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
| node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
| } | } | ||||
| auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
| GE_CHECK_NOTNULL(dst_node_state); | |||||
| dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
| } | } | ||||
| } | } | ||||
| (void)guard; | (void)guard; | ||||
| @@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
| } | } | ||||
| void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
| if (node_item_->enter_inside_.count(index) > 0) { | |||||
| GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
| return; | |||||
| } | |||||
| auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
| if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
| input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
| @@ -36,9 +36,7 @@ class SubgraphContext; | |||||
| class TaskContext { | class TaskContext { | ||||
| public: | public: | ||||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | |||||
| SubgraphContext *subgraph_context); | |||||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context); | |||||
| ~TaskContext(); | ~TaskContext(); | ||||
| @@ -263,6 +263,7 @@ class Impl { | |||||
| omg_context_.user_attr_index_valid = false; | omg_context_.user_attr_index_valid = false; | ||||
| }; | }; | ||||
| ~Impl() { (void)generator_.Finalize(); }; | ~Impl() { (void)generator_.Finalize(); }; | ||||
| graphStatus CheckBuildModeAndBuildStep(); | |||||
| graphStatus GetSupportedOptions(const std::map<std::string, std::string> &in, | graphStatus GetSupportedOptions(const std::map<std::string, std::string> &in, | ||||
| std::map<std::string, std::string> &out); | std::map<std::string, std::string> &out); | ||||
| graphStatus CheckOptions(const std::map<std::string, std::string> &options); | graphStatus CheckOptions(const std::map<std::string, std::string> &options); | ||||
| @@ -451,6 +452,37 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus Impl::CheckBuildModeAndBuildStep() { | |||||
| std::string build_mode; | |||||
| auto it = options_.find(BUILD_MODE); | |||||
| if (it != options_.end() && !(it->second.empty())) { | |||||
| if (build_mode_options.find(it->second) == build_mode_options.end()) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({BUILD_MODE, it->second, "value is unsupported. Please check!"})); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| build_mode = it->second; | |||||
| } | |||||
| it = options_.find(BUILD_STEP); | |||||
| if (it != options_.end() && !(it->second.empty())) { | |||||
| if (build_step_options.find(it->second) == build_step_options.end()) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({BUILD_STEP, it->second, "value is unsupported. Please check!"})); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| } else { | |||||
| if (build_mode == BUILD_MODE_TUNING) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({BUILD_MODE, it->second, "tuning must specify build step. Please check!"})); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Impl::GetSupportedOptions(const std::map<std::string, std::string> &in, | graphStatus Impl::GetSupportedOptions(const std::map<std::string, std::string> &in, | ||||
| std::map<std::string, std::string> &out) { | std::map<std::string, std::string> &out) { | ||||
| for (auto &ele : in) { | for (auto &ele : in) { | ||||
| @@ -475,29 +507,12 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options | |||||
| } | } | ||||
| // Check options build_mode and build_step. | // Check options build_mode and build_step. | ||||
| std::string build_mode; | |||||
| auto it = options_.find(BUILD_MODE); | |||||
| if (it != options_.end() && !(it->second.empty())) { | |||||
| if (build_mode_options.find(it->second) == build_mode_options.end()) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| build_mode = it->second; | |||||
| } | |||||
| it = options_.find(BUILD_STEP); | |||||
| if (it != options_.end() && !(it->second.empty())) { | |||||
| if (build_step_options.find(it->second) == build_step_options.end()) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| } else { | |||||
| if (build_mode == BUILD_MODE_TUNING) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| ret = CheckBuildModeAndBuildStep(); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | } | ||||
| // Check option EXEC_DISABLE_REUSED_MEMORY | // Check option EXEC_DISABLE_REUSED_MEMORY | ||||
| it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY); | |||||
| auto it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY); | |||||
| if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { | if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { | ||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -505,6 +520,18 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options | |||||
| if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) { | if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) { | ||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| // Check option OP_PRECISION_MODE | |||||
| it = options_.find(ge::ir_option::OP_PRECISION_MODE); | |||||
| if (it != options_.end() && !it->second.empty() && !ge::CheckInputPathValid(it->second)) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({ge::ir_option::OP_PRECISION_MODE, it->second, "path is not found"})); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "[Check][OP_PRECISION_MODE] %s not found", it->second.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (it != options_.end()) { | |||||
| GELOGI("Option set successfully, option_key=%s, option_value=%s", | |||||
| ge::ir_option::OP_PRECISION_MODE, it->second.c_str()); | |||||
| } | |||||
| // Check Input Format | // Check Input Format | ||||
| if (options_.find(kInputFormat) != options_.end()) { | if (options_.find(kInputFormat) != options_.end()) { | ||||
| return CheckInputFormat(options_[kInputFormat]); | return CheckInputFormat(options_[kInputFormat]); | ||||
| @@ -559,8 +586,8 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); | std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); | ||||
| GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, | ||||
| return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); | ||||
| // check insert_op_conf | |||||
| // check insert_op_conf | |||||
| std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); | std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); | ||||
| GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, | ||||
| return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); | ||||
| @@ -204,7 +204,7 @@ bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map | |||||
| if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { | if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { | ||||
| GELOGE(ge::PARAM_INVALID, | GELOGE(ge::PARAM_INVALID, | ||||
| "[Check][DynamicImagesizeInputShape] input_format [%s] invalid, can not support now.", input_format.c_str()); | "[Check][DynamicImagesizeInputShape] input_format [%s] invalid, can not support now.", input_format.c_str()); | ||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter","value","reason"}), | |||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({"input_format", input_format, "this format is not support"})); | std::vector<std::string>({"input_format", input_format, "this format is not support"})); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -106,10 +106,14 @@ DEFINE_string(out_nodes, "", | |||||
| "Optional; output nodes designated by users." | "Optional; output nodes designated by users." | ||||
| "Format: \"node_name1:0;node_name1:1;node_name2:0\""); | "Format: \"node_name1:0;node_name1:1;node_name2:0\""); | ||||
| DEFINE_string(op_precision_mode, "", "Optional; operator precision mode configuration file path"); | |||||
| DEFINE_string(precision_mode, "force_fp16", | DEFINE_string(precision_mode, "force_fp16", | ||||
| "Optional; precision mode." | "Optional; precision mode." | ||||
| "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); | "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); | ||||
| DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path"); | |||||
| DEFINE_string(keep_dtype, "", | DEFINE_string(keep_dtype, "", | ||||
| "Optional; config file to specify the precision used by the operator during compilation."); | "Optional; config file to specify the precision used by the operator during compilation."); | ||||
| @@ -192,8 +196,11 @@ DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, war | |||||
| DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); | DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); | ||||
| DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;" | |||||
| "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); | |||||
| DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug; " | |||||
| "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler; " | |||||
| "3: disable debug, and keep generating kernel file (.o and .json); 4: disable debug, " | |||||
| "keep generation kernel file (.o and .json) and generate the operator CCE file (.cce) " | |||||
| "and the UB fusion computing description file (.json)"); | |||||
| DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," | DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," | ||||
| "multiple names can be set and separated by ','."); | "multiple names can be set and separated by ','."); | ||||
| DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation"); | DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation"); | ||||
| @@ -210,8 +217,6 @@ DEFINE_string(display_model_info, "0", "Optional; display model info"); | |||||
| DEFINE_string(device_id, "0", "Optional; user device id"); | DEFINE_string(device_id, "0", "Optional; user device id"); | ||||
| DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path"); | |||||
| class GFlagUtils { | class GFlagUtils { | ||||
| public: | public: | ||||
| /** | /** | ||||
| @@ -298,8 +303,10 @@ class GFlagUtils { | |||||
| "\"l1_optimize\", \"off_optimize\"\n" | "\"l1_optimize\", \"off_optimize\"\n" | ||||
| " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" | " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" | ||||
| "\n[Operator Tuning]\n" | "\n[Operator Tuning]\n" | ||||
| " --op_precision_mode Set the path of operator precision mode configuration file (.ini)\n" | |||||
| " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, " | " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, " | ||||
| "allow_fp32_to_fp16, must_keep_origin_dtype.\n" | "allow_fp32_to_fp16, must_keep_origin_dtype.\n" | ||||
| " --modify_mixlist Set the path of operator mixed precision configuration file.\n" | |||||
| " --keep_dtype Retains the precision of certain operators in inference " | " --keep_dtype Retains the precision of certain operators in inference " | ||||
| "scenarios by using a configuration file.\n" | "scenarios by using a configuration file.\n" | ||||
| " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" | " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" | ||||
| @@ -315,7 +322,8 @@ class GFlagUtils { | |||||
| " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " | " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " | ||||
| "(.json), and enable the CCE compiler -O0-g.\n" | "(.json), and enable the CCE compiler -O0-g.\n" | ||||
| " 3: Disable debug, and keep generating kernel file (.o and .json)\n" | " 3: Disable debug, and keep generating kernel file (.o and .json)\n" | ||||
| " --modify_mixlist Set the path of operator mixed precision configuration file.\n" | |||||
| " 4: Disable debug, keep generation kernel file (.o and .json) and generate the " | |||||
| "operator CCE file (.cce) and the UB fusion computing description file (.json)" | |||||
| "\n[Debug]\n" | "\n[Debug]\n" | ||||
| " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | ||||
| " --log Generate log with level. Support debug, info, warning, error, null\n" | " --log Generate log with level. Support debug, info, warning, error, null\n" | ||||
| @@ -365,6 +373,14 @@ class GFlagUtils { | |||||
| FLAGS_op_select_implmode) != ge::SUCCESS, | FLAGS_op_select_implmode) != ge::SUCCESS, | ||||
| ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!"); | ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!"); | ||||
| if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {"op_precision_mode", FLAGS_op_precision_mode.c_str(), | |||||
| "path is not found"}); | |||||
| GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str()); | |||||
| ret = ge::FAILED; | |||||
| } | |||||
| if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { | if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ||||
| {"modify_mixlist", FLAGS_modify_mixlist.c_str(), | {"modify_mixlist", FLAGS_modify_mixlist.c_str(), | ||||
| @@ -847,6 +863,7 @@ domi::Status GenerateInfershapeJson() { | |||||
| ge::Graph graph; | ge::Graph graph; | ||||
| std::map<string, string> atc_params; | std::map<string, string> atc_params; | ||||
| atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); | atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); | ||||
| atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report)); | |||||
| ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, | ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, | ||||
| "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); | "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); | ||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| @@ -953,8 +970,7 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
| ge::Model load_model = ge::Model("loadmodel", "version2"); | ge::Model load_model = ge::Model("loadmodel", "version2"); | ||||
| auto ret1 = load_model.LoadFromFile(FLAGS_model); | auto ret1 = load_model.LoadFromFile(FLAGS_model); | ||||
| if (ret1 != ge::GRAPH_SUCCESS) { | if (ret1 != ge::GRAPH_SUCCESS) { | ||||
| REPORT_INPUT_ERROR("E10041", std::vector<std::string>({"file"}), std::vector<std::string>({FLAGS_model})); | |||||
| REPORT_CALL_ERROR("E19999", "load from model file:%s failed", FLAGS_model.c_str()); | |||||
| REPORT_INPUT_ERROR("E10041", std::vector<std::string>({"parameter"}), std::vector<std::string>({FLAGS_model})); | |||||
| DOMI_LOGE("Load model from %s failed, please check model file or " | DOMI_LOGE("Load model from %s failed, please check model file or " | ||||
| "input parameter[--framework] is correct", FLAGS_model.c_str()); | "input parameter[--framework] is correct", FLAGS_model.c_str()); | ||||
| (void)ge_generator.Finalize(); | (void)ge_generator.Finalize(); | ||||
| @@ -1050,6 +1066,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) { | |||||
| options.emplace(ge::RUN_FLAG, flag_off); | options.emplace(ge::RUN_FLAG, flag_off); | ||||
| options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); | options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); | ||||
| options.emplace(ge::SINGLE_OP_FLAG, flag_on); | options.emplace(ge::SINGLE_OP_FLAG, flag_on); | ||||
| options.emplace(ge::OP_PRECISION_MODE, FLAGS_op_precision_mode); | |||||
| options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); | options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); | ||||
| options.emplace(ge::SOC_VERSION, FLAGS_soc_version); | options.emplace(ge::SOC_VERSION, FLAGS_soc_version); | ||||
| options.emplace(ge::CORE_TYPE, FLAGS_core_type); | options.emplace(ge::CORE_TYPE, FLAGS_core_type); | ||||
| @@ -1077,6 +1094,14 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { | |||||
| ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, | ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, | ||||
| return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode."); | return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode."); | ||||
| if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||||
| {"op_precision_mode", FLAGS_op_precision_mode.c_str(), | |||||
| "path is not found"}); | |||||
| GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { | if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ||||
| {"modify_mixlist", FLAGS_modify_mixlist.c_str(), | {"modify_mixlist", FLAGS_modify_mixlist.c_str(), | ||||
| @@ -1160,6 +1185,7 @@ domi::Status GenerateOmModel() { | |||||
| options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); | options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); | ||||
| options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); | options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); | ||||
| options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); | options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); | ||||
| options.insert(std::pair<string, string>(string(ge::OP_PRECISION_MODE), FLAGS_op_precision_mode)); | |||||
| options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); | options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); | ||||
| options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); | options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); | ||||
| @@ -1,193 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| DT_VARIANT = 26; // variant type | |||||
| DT_BF16 = 27; // bf16 type | |||||
| DT_INT4 = 28; // int4 type | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -1,140 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| float padding_value = 72; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -1,396 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -1,179 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -1,113 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package toolkit.dump; | |||||
| enum OutputDataType { | |||||
| DT_UNDEFINED = 0; | |||||
| DT_FLOAT = 1; | |||||
| DT_FLOAT16 = 2; | |||||
| DT_INT8 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_UINT16 = 6; | |||||
| DT_INT32 = 7; | |||||
| DT_INT64 = 8; | |||||
| DT_UINT32 = 9; | |||||
| DT_UINT64 = 10; | |||||
| DT_BOOL = 11; | |||||
| DT_DOUBLE = 12; | |||||
| DT_STRING = 13; | |||||
| DT_DUAL_SUB_INT8 = 14; | |||||
| DT_DUAL_SUB_UINT8 = 15; | |||||
| DT_COMPLEX64 = 16; | |||||
| DT_COMPLEX128 = 17; | |||||
| DT_QINT8 = 18; | |||||
| DT_QINT16 = 19; | |||||
| DT_QINT32 = 20; | |||||
| DT_QUINT8 = 21; | |||||
| DT_QUINT16 = 22; | |||||
| DT_RESOURCE = 23; | |||||
| DT_STRING_REF = 24; | |||||
| DT_DUAL = 25; | |||||
| DT_VARIANT = 26; | |||||
| } | |||||
| enum OutputFormat { | |||||
| FORMAT_NCHW = 0; | |||||
| FORMAT_NHWC = 1; | |||||
| FORMAT_ND = 2; | |||||
| FORMAT_NC1HWC0 = 3; | |||||
| FORMAT_FRACTAL_Z = 4; | |||||
| FORMAT_NC1C0HWPAD = 5; | |||||
| FORMAT_NHWC1C0 = 6; | |||||
| FORMAT_FSR_NCHW = 7; | |||||
| FORMAT_FRACTAL_DECONV = 8; | |||||
| FORMAT_C1HWNC0 = 9; | |||||
| FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; | |||||
| FORMAT_NC1HWC0_C04 = 12; | |||||
| FORMAT_FRACTAL_Z_C04 = 13; | |||||
| FORMAT_CHWN = 14; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; | |||||
| FORMAT_HWCN = 16; | |||||
| FORMAT_NC1KHKWHWC0 = 17; | |||||
| FORMAT_BN_WEIGHT = 18; | |||||
| FORMAT_FILTER_HWCK = 19; | |||||
| FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; | |||||
| FORMAT_HASHTABLE_LOOKUP_KEYS = 21; | |||||
| FORMAT_HASHTABLE_LOOKUP_VALUE = 22; | |||||
| FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; | |||||
| FORMAT_HASHTABLE_LOOKUP_HITS=24; | |||||
| FORMAT_C1HWNCoC0 = 25; | |||||
| FORMAT_MD = 26; | |||||
| FORMAT_NDHWC = 27; | |||||
| FORMAT_FRACTAL_ZZ = 28; | |||||
| FORMAT_FRACTAL_NZ = 29; | |||||
| FORMAT_RESERVED = 30; | |||||
| } | |||||
| message OriginalOp { | |||||
| string name = 1; | |||||
| uint32 output_index = 2; | |||||
| OutputDataType data_type = 3; | |||||
| OutputFormat format = 4; | |||||
| } | |||||
| message Shape { | |||||
| repeated uint64 dim = 1; | |||||
| } | |||||
| message OpOutput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| OriginalOp original_op = 4; // the original op corresponding to the output | |||||
| bytes data = 5; | |||||
| uint64 size = 6; | |||||
| } | |||||
| message OpInput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| bytes data = 4; | |||||
| uint64 size = 5; | |||||
| } | |||||
| enum BufferType { | |||||
| L1 = 0; | |||||
| } | |||||
| message OpBuffer { | |||||
| BufferType buffer_type = 1; | |||||
| bytes data = 2; | |||||
| uint64 size = 3; | |||||
| } | |||||
| message DumpData{ | |||||
| string version = 1; | |||||
| uint64 dump_time = 2; | |||||
| repeated OpOutput output = 3; | |||||
| repeated OpInput input = 4; | |||||
| repeated OpBuffer buffer = 5; | |||||
| string op_name = 6; | |||||
| } | |||||
| @@ -1,21 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| import "om.proto"; | |||||
| package domi; | |||||
| message FusionModelDef { | |||||
| string version = 1; | |||||
| repeated OpDef fusion_op = 2; | |||||
| } | |||||
| @@ -1,37 +0,0 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package aicpu.FWKAdapter; | |||||
| option cc_enable_arenas = true; | |||||
| // Defines an struct for input and output. | |||||
| message TensorDataInfo { | |||||
| // value DataType | |||||
| uint32 dtype = 1; | |||||
| // shape dim | |||||
| repeated int64 dim = 2; | |||||
| // data point addr | |||||
| int64 data_addr = 3; | |||||
| } | |||||
| message KernelRunParam { | |||||
| // input | |||||
| repeated TensorDataInfo input = 1; | |||||
| // output | |||||
| repeated TensorDataInfo output = 2; | |||||
| } | |||||
| @@ -1,88 +0,0 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.api_pb; | |||||
| import "ge_ir.proto"; | |||||
| // GE initialize | |||||
| message GEInitialize { | |||||
| map<string, string> options = 1; | |||||
| }; | |||||
| // initialize response | |||||
| message GEInitializeResponse { | |||||
| uint32 status = 1; | |||||
| uint32 clientId = 2; | |||||
| }; | |||||
| // GE finalize | |||||
| message GEFinalize { | |||||
| bool final = 1; | |||||
| uint32 clientId = 2; | |||||
| }; | |||||
| message GEFinalizeResponse { | |||||
| uint32 status = 1; | |||||
| }; | |||||
| // GE Session | |||||
| message CreateSession{ | |||||
| map<string, string> options = 1; | |||||
| }; | |||||
| message CreateSessionResponse { | |||||
| uint32 status = 1; | |||||
| uint64 sessionId = 2; | |||||
| }; | |||||
| //GE AddGraph | |||||
| //model serialize :: serializegraph | |||||
| message SessionAddGraph{ | |||||
| uint32 graphId = 1; | |||||
| uint64 sessionId = 2; | |||||
| ge.proto.GraphDef graph = 3; | |||||
| }; | |||||
| message SessionAddGraphResponse { | |||||
| uint32 status = 1; | |||||
| }; | |||||
| //GE SessionRemoveGraph | |||||
| message SessionRemoveGraph{ | |||||
| uint32 graphId = 1; | |||||
| uint64 sessionId = 2; | |||||
| }; | |||||
| message SessionRemoveGraphResponse { | |||||
| uint32 status = 1; | |||||
| }; | |||||
| message SessionRunGraph{ | |||||
| uint32 graphId = 1; | |||||
| uint64 sessionId = 2; | |||||
| repeated ge.proto.TensorDef tensor = 3; | |||||
| }; | |||||
| message SessionBuildGraph{ | |||||
| uint32 graphId = 1; | |||||
| uint64 sessionId = 2; | |||||
| repeated ge.proto.TensorDef tensor = 3; | |||||
| string savePath = 4; | |||||
| }; | |||||
| message SessionRunGraphResponse { | |||||
| uint32 status = 1; | |||||
| repeated ge.proto.TensorDef tensor = 2; | |||||
| }; | |||||
| message SessionBuildGraphResponse { | |||||
| uint32 status = 1; | |||||
| }; | |||||
| message DestroySession{ | |||||
| bool final = 1; | |||||
| uint64 sessionId = 2; | |||||
| }; | |||||
| message DestroySessionResponse { | |||||
| uint32 status = 1; | |||||
| }; | |||||