modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.ccpull/1696/head
| @@ -39,7 +39,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} | |||||
| option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
| if (ENABLE_OPEN_SRC) | |||||
| if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||||
| set(HI_PYTHON python3) | set(HI_PYTHON python3) | ||||
| include(cmake/external_libs/protobuf_shared.cmake) | include(cmake/external_libs/protobuf_shared.cmake) | ||||
| @@ -51,118 +51,132 @@ if (ENABLE_OPEN_SRC) | |||||
| include(cmake/external_libs/json.cmake) | include(cmake/external_libs/json.cmake) | ||||
| include(cmake/FindModule.cmake) | include(cmake/FindModule.cmake) | ||||
| include(cmake/intf_pub_linux.cmake) | include(cmake/intf_pub_linux.cmake) | ||||
| # if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
| if(DEFINED ENV{D_LINK_PATH}) | |||||
| # D_LINK_PATH is set | |||||
| set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
| set(GE_SYS_ARCH "") | |||||
| if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
| # x86 ubuntu | |||||
| set(GE_SYS_ARCH "x86_64") | |||||
| elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
| # arm euleros | |||||
| set(GE_SYS_ARCH "aarch64") | |||||
| add_subdirectory(tests) | |||||
| else () | |||||
| if (ENABLE_OPEN_SRC) | |||||
| set(HI_PYTHON python3) | |||||
| include(cmake/external_libs/protobuf_shared.cmake) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/gflags.cmake) | |||||
| include(cmake/external_libs/gtest.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
| if(DEFINED ENV{D_LINK_PATH}) | |||||
| # D_LINK_PATH is set | |||||
| set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
| set(GE_SYS_ARCH "") | |||||
| if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
| # x86 ubuntu | |||||
| set(GE_SYS_ARCH "x86_64") | |||||
| elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
| # arm euleros | |||||
| set(GE_SYS_ARCH "aarch64") | |||||
| else() | |||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
| endif() | |||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
| find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||||
| find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||||
| find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||||
| find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||||
| find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
| find_module(resource libresource.so ${GE_LIB_PATH}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
| else() | else() | ||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
| endif() | |||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
| set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
| find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||||
| find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||||
| find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||||
| find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||||
| find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
| find_module(resource libresource.so ${GE_LIB_PATH}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
| elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
| add_subdirectory(tests) | |||||
| else() | |||||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||||
| if(PLATFORM STREQUAL "train") | |||||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||||
| if(PLATFORM STREQUAL "train") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| if(PRODUCT STREQUAL "flr3") | |||||
| message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||||
| endif() | |||||
| elseif(PLATFORM STREQUAL "inference") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| if(PRODUCT STREQUAL "flr3") | |||||
| elseif(PRODUCT STREQUAL "flr1") | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| elseif(PRODUCT STREQUAL "flr2") | |||||
| # flr2 ascend_hal_stub limsprof ? | |||||
| else() | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| endif() | |||||
| elseif(PLATFORM STREQUAL "all") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| if(PRODUCT STREQUAL "flr3") | |||||
| message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||||
| endif() | |||||
| elseif(PLATFORM STREQUAL "inference") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
| if(PRODUCT STREQUAL "flr3") | |||||
| elseif(PRODUCT STREQUAL "flr1") | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
| elseif(PRODUCT STREQUAL "flr2") | |||||
| # flr2 ascend_hal_stub limsprof ? | |||||
| else() | else() | ||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| endif() | endif() | ||||
| elseif(PLATFORM STREQUAL "all") | |||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
| else() | |||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
| endif() | endif() | ||||
| endif() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| add_subdirectory(metadef) | |||||
| add_subdirectory(parser) | |||||
| #add_subdirectory(metadef/graph) | |||||
| #add_subdirectory(metadef/register) | |||||
| elseif (ENABLE_D OR ENABLE_ACL) | |||||
| # compiling with MindSpore | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| if (ENABLE_D) | |||||
| # training | |||||
| find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| endif () | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| elseif(ENABLE_MS_TESTCASES) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| add_subdirectory(metadef) | |||||
| add_subdirectory(parser) | |||||
| #add_subdirectory(metadef/graph) | |||||
| #add_subdirectory(metadef/register) | |||||
| elseif (ENABLE_D OR ENABLE_ACL) | |||||
| # compiling with MindSpore | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/external_libs/json.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| if (ENABLE_D) | |||||
| # training | |||||
| find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| endif () | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| elseif(ENABLE_MS_TESTCASES) | |||||
| include(cmake/external_libs/protobuf_static.cmake) | |||||
| include(cmake/external_libs/protoc.cmake) | |||||
| include(cmake/external_libs/securec.cmake) | |||||
| include(cmake/FindModule.cmake) | |||||
| include(cmake/intf_pub_linux.cmake) | |||||
| # common libraries | |||||
| find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| else() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| endif() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
| add_subdirectory(metadef) | |||||
| else() | |||||
| set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||||
| set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||||
| set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
| endif() | |||||
| add_subdirectory(ge) | |||||
| add_subdirectory(ge) | |||||
| endif () | |||||
| @@ -177,6 +177,9 @@ build_graphengine() | |||||
| elif [ "X$ENABLE_GE_UT" = "Xon" ] | elif [ "X$ENABLE_GE_UT" = "Xon" ] | ||||
| then | then | ||||
| TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | ||||
| elif [ "X$ENABLE_GE_ST" = "Xon" ] | |||||
| then | |||||
| TARGET="graph_engine_test" | |||||
| elif [ "X$MINDSPORE_MODE" = "Xon" ] | elif [ "X$MINDSPORE_MODE" = "Xon" ] | ||||
| then | then | ||||
| TARGET="ge_common graph" | TARGET="ge_common graph" | ||||
| @@ -234,6 +237,27 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| genhtml coverage.info | genhtml coverage.info | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
| #prepare engine & opskernel so | |||||
| mkdir -p ${OUTPUT_PATH}/plugin/nnengine | |||||
| mkdir -p ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||||
| mkdir -p ${OUTPUT_PATH}/plugin/opskernel | |||||
| cp ${BUILD_PATH}/tests/st/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine | |||||
| cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||||
| cp ${BUILD_PATH}/tests/st/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel | |||||
| #prepare st execution bin | |||||
| cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} | |||||
| #execute st testcase | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} | |||||
| if [[ "$?" -ne 0 ]]; then | |||||
| echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
| echo -e "\033[31m${RUN_TEST_CASE}\033[0m" | |||||
| exit 1; | |||||
| fi | |||||
| # remove plugin | |||||
| rm -rf ${OUTPUT_PATH}/plugin | |||||
| fi | |||||
| # generate output package in tar form, including ut/st libraries/executables | # generate output package in tar form, including ut/st libraries/executables | ||||
| generate_package() | generate_package() | ||||
| { | { | ||||
| @@ -337,7 +361,7 @@ generate_package() | |||||
| fi | fi | ||||
| } | } | ||||
| if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
| if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$ENABLE_GE_ST" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
| generate_package | generate_package | ||||
| elif [ "X$MINDSPORE_MODE" = "Xon" ] | elif [ "X$MINDSPORE_MODE" = "Xon" ] | ||||
| then | then | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "framework/common/op/op_parser_util.h" | #include "framework/common/op/op_parser_util.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "task/task_factory.h" | #include "task/task_factory.h" | ||||
| #include "ge/common/math/math_util.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace model_runner { | namespace model_runner { | ||||
| @@ -500,7 +501,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
| } | } | ||||
| uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | ||||
| uint32_t head_len = kOffsetUnit * kStringHeadElems; | uint32_t head_len = kOffsetUnit * kStringHeadElems; | ||||
| if (ge::CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||||
| if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||||
| GELOGE(FAILED, "Shape size is invalid"); | GELOGE(FAILED, "Shape size is invalid"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -83,7 +83,7 @@ bool AicpuTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("ext info size:", ext_size); | |||||
| GELOGI("ext info size: %u", ext_size); | |||||
| aicpu_param_head.extInfoLength = ext_size; | aicpu_param_head.extInfoLength = ext_size; | ||||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | ||||
| } | } | ||||
| @@ -130,7 +130,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
| Status ret; | Status ret; | ||||
| std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | ||||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | ||||
| GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id); | |||||
| GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); | |||||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ||||
| if (!ret) { | if (!ret) { | ||||
| GELOGE(RT_FAILED, "Create hccl stream failed."); | GELOGE(RT_FAILED, "Create hccl stream failed."); | ||||
| @@ -189,7 +189,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
| } | } | ||||
| GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | ||||
| } else { | } else { | ||||
| GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(), | |||||
| GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), | |||||
| master_stream_id); | master_stream_id); | ||||
| ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ||||
| if (!ret) { | if (!ret) { | ||||
| @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | ||||
| return false; | return false; | ||||
| @@ -69,7 +69,7 @@ bool LabelSwitchTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| label_list[i] = all_label_resource_[label_index]; | label_list[i] = all_label_resource_[label_index]; | ||||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
| GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); | |||||
| } | } | ||||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | ||||
| @@ -0,0 +1,6 @@ | |||||
| project(graphengine_st) | |||||
| include(cmake/graphengine.cmake) | |||||
| add_subdirectory(framework) | |||||
| add_subdirectory(testcase) | |||||
| @@ -0,0 +1,249 @@ | |||||
| # ---- Test coverage ---- | |||||
| if (ENABLE_GE_COV) | |||||
| set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -fPIC -O0 -ftest-coverage") | |||||
| set(CMAKE_CXX_FLAGS "${COVERAGE_COMPILER_FLAGS}") | |||||
| endif() | |||||
| # ---- Proto generate ---- | |||||
| file(GLOB_RECURSE PROTO_FILES CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/proto/*.proto") | |||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_FILES}) | |||||
| # ---- File glob by group ---- | |||||
| file(GLOB_RECURSE METADEF_SRCS CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/metadef/graph/*.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/*.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/*.cpp" | |||||
| "${GE_CODE_DIR}/metadef/ops/*.cc" | |||||
| "${GE_CODE_DIR}/metadef/third_party/transformer/src/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/metadef/register/*.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/*.cpp" | |||||
| ) | |||||
| file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/parser/parser/common/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE LOCAL_ENGINE_SRC CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/ge/ge_local_engine/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE HOST_ENGINE_SRC CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/ge/host_cpu_engine/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE NN_ENGINE_SRC CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/ge/plugin/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE OFFLINE_SRC CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/ge/offline/*.cc" | |||||
| ) | |||||
| file(GLOB_RECURSE GE_SRCS CONFIGURE_DEPENDS | |||||
| "${GE_CODE_DIR}/ge/*.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM GE_SRCS ${LOCAL_ENGINE_SRC} ${HOST_ENGINE_SRC} ${NN_ENGINE_SRC} ${OFFLINE_SRC}) | |||||
| list(APPEND INCLUDE_DIRECTORIES | |||||
| "${CMAKE_CURRENT_SOURCE_DIR}" | |||||
| "${GE_CODE_DIR}" | |||||
| "${GE_CODE_DIR}/inc" | |||||
| "${GE_CODE_DIR}/metadef/inc" | |||||
| "${GE_CODE_DIR}/ge" | |||||
| "${GE_CODE_DIR}/ge/inc" | |||||
| "${GE_CODE_DIR}/ge/ir_build" | |||||
| "${GE_CODE_DIR}/metadef" | |||||
| "${GE_CODE_DIR}/metadef/graph" | |||||
| "${GE_CODE_DIR}/inc/external" | |||||
| "${GE_CODE_DIR}/inc/framework/common" | |||||
| "${GE_CODE_DIR}/metadef/inc/external" | |||||
| "${GE_CODE_DIR}/metadef/inc/external/graph" | |||||
| "${GE_CODE_DIR}/metadef/inc/graph" | |||||
| "${GE_CODE_DIR}/inc/framework" | |||||
| "${GE_CODE_DIR}/metadef/inc/common" | |||||
| "${GE_CODE_DIR}/metadef/third_party" | |||||
| "${GE_CODE_DIR}/metadef/third_party/transformer/inc" | |||||
| "${GE_CODE_DIR}/parser" | |||||
| "${GE_CODE_DIR}/parser/parser" | |||||
| "${GE_CODE_DIR}/third_party/fwkacllib/inc" | |||||
| "${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | |||||
| "${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | |||||
| "${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | |||||
| "${GE_CODE_DIR}/tests/ut/ge" | |||||
| "${GE_CODE_DIR}/tests/ut/common" | |||||
| "${CMAKE_BINARY_DIR}" | |||||
| "${CMAKE_BINARY_DIR}/proto/ge" | |||||
| "${CMAKE_BINARY_DIR}/proto/ge/proto" | |||||
| ) | |||||
| list(APPEND STUB_LIBS | |||||
| c_sec | |||||
| slog_stub | |||||
| cce_ge_stub | |||||
| runtime_stub | |||||
| profiler_stub | |||||
| #mmpa_stub | |||||
| hccl_stub | |||||
| error_manager_stub | |||||
| ascend_protobuf | |||||
| json | |||||
| ) | |||||
| # ---- Target : Local engine ---- | |||||
| add_library(localengine STATIC ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) | |||||
| target_include_directories(localengine | |||||
| PUBLIC | |||||
| "${INCLUDE_DIRECTORIES}" | |||||
| "${GE_CODE_DIR}/ge/ge_local_engine" | |||||
| ) | |||||
| target_compile_definitions(localengine PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(localengine PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(localengine PUBLIC | |||||
| $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
| ) | |||||
| set_target_properties(localengine PROPERTIES CXX_STANDARD 11) | |||||
| # ---- Target : metadef graph ---- | |||||
| add_library(metadef_graph STATIC ${METADEF_SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| target_include_directories(metadef_graph | |||||
| PUBLIC | |||||
| "${INCLUDE_DIRECTORIES}" | |||||
| ) | |||||
| target_compile_definitions(metadef_graph PRIVATE | |||||
| google=ascend_private | |||||
| FMK_SUPPORT_DUMP | |||||
| ) | |||||
| target_compile_options(metadef_graph PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(metadef_graph PUBLIC | |||||
| $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
| ) | |||||
| set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) | |||||
| # ---- Target : Host engine ---- | |||||
| add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC} ${PROTO_HDRS}) | |||||
| target_include_directories(host_cpu_engine | |||||
| PUBLIC | |||||
| "${INCLUDE_DIRECTORIES}" | |||||
| "${GE_CODE_DIR}/ge/host_cpu_engine" | |||||
| ) | |||||
| target_compile_definitions(host_cpu_engine PRIVATE | |||||
| google=ascend_private | |||||
| FMK_SUPPORT_DUMP | |||||
| ) | |||||
| target_compile_options(host_cpu_engine PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(host_cpu_engine PUBLIC | |||||
| $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lmmpa -L/home/hugo/Code/ge/graphengine/build/tests/depends/mmpa -lrt -ldl -lpthread -lgcov | |||||
| ) | |||||
| set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) | |||||
| # ---- Target : engine plugin---- | |||||
| # | |||||
| add_library(nnengine SHARED ${NN_ENGINE_SRC}) | |||||
| target_include_directories(nnengine | |||||
| PUBLIC | |||||
| "${INCLUDE_DIRECTORIES}" | |||||
| "${GE_CODE_DIR}/ge/plugin/engine" | |||||
| ) | |||||
| target_compile_definitions(nnengine PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(nnengine PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(nnengine PUBLIC | |||||
| $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
| ) | |||||
| set_target_properties(nnengine PROPERTIES CXX_STANDARD 11) | |||||
| # Targe: engine_conf | |||||
| add_custom_target( | |||||
| engine_conf.json ALL | |||||
| DEPENDS ${CMAKE_BINARY_DIR}/engine_conf.json | |||||
| ) | |||||
| add_custom_command( | |||||
| OUTPUT ${CMAKE_BINARY_DIR}/engine_conf.json | |||||
| COMMAND cp ${GE_CODE_DIR}/ge/engine_manager/engine_conf.json ${CMAKE_BINARY_DIR}/ | |||||
| ) | |||||
| # Targe: optimizer priority | |||||
| add_custom_target( | |||||
| optimizer_priority.pbtxt ALL | |||||
| DEPENDS ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||||
| ) | |||||
| add_custom_command( | |||||
| OUTPUT ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||||
| COMMAND cp ${GE_CODE_DIR}/ge/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_BINARY_DIR}/ | |||||
| ) | |||||
| # ---- Target : Graph engine ---- | |||||
| add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS} ${PROTO_HDRS}) | |||||
| target_include_directories(graphengine | |||||
| PUBLIC | |||||
| "${INCLUDE_DIRECTORIES}" | |||||
| "${GE_CODE_DIR}/ge/host_cpu_engine" | |||||
| ) | |||||
| target_compile_definitions(graphengine PRIVATE | |||||
| google=ascend_private | |||||
| FMK_SUPPORT_DUMP | |||||
| ) | |||||
| target_compile_options(graphengine PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(graphengine PUBLIC | |||||
| $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} | |||||
| metadef_graph | |||||
| localengine | |||||
| host_cpu_engine | |||||
| nnengine | |||||
| mmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov | |||||
| ) | |||||
| set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) | |||||
| add_dependencies(graphengine engine_conf.json optimizer_priority.pbtxt) | |||||
| @@ -0,0 +1,16 @@ | |||||
| file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
| #todo | |||||
| file(GLOB_RECURSE stub_engine CONFIGURE_DEPENDS | |||||
| "stub_engine/*.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM SOURCES ${stub_engine}) | |||||
| add_library(framework STATIC ${SOURCES}) | |||||
| target_include_directories(framework | |||||
| PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ) | |||||
| set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||||
| target_link_libraries(framework PUBLIC graphengine) | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdlib.h> | |||||
| #include "framework.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| Status Framework::SetUp() { | |||||
| } | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||||
| #define GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||||
| #include <string> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| class Framework { | |||||
| public: | |||||
| explicit Framework() {}; | |||||
| Status SetUp(); | |||||
| Status TearDown(); | |||||
| }; | |||||
| } // namespace st | |||||
| }// namespace ge | |||||
| #endif // GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||||
| @@ -0,0 +1,259 @@ | |||||
| set(PROTO_LIST | |||||
| "${METADEF_DIR}/proto/task.proto" | |||||
| ) | |||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | |||||
| "engine/stub_engine.cc" | |||||
| "ops_kernel_store/host_cpu_ops_kernel_info.cc" | |||||
| "ops_kernel_store/op/op_factory.cc" | |||||
| "ops_kernel_store/op/host_op.cc" | |||||
| ) | |||||
| set(CPU_OPS_KERNEL_LIST | |||||
| "ops_kernel_store/host_cpu_ops_kernel_builder.cc" | |||||
| ) | |||||
| ############ libfe.so ############ | |||||
| add_library(fe SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
| target_compile_options(fe PRIVATE | |||||
| -Werror | |||||
| -fno-common | |||||
| -fvisibility=hidden | |||||
| ) | |||||
| target_compile_definitions(fe PRIVATE | |||||
| google=ascend_private | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_include_directories(fe PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| #### blue zone #### | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ) | |||||
| target_link_options(fe PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(fe PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| ascend_protobuf | |||||
| c_sec | |||||
| graph | |||||
| slog | |||||
| -Wl,--as-needed | |||||
| ) | |||||
| ############ atcstub/libfe.so ############ | |||||
| add_library(atc_fe SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) | |||||
| target_compile_options(atc_fe PRIVATE | |||||
| -Werror | |||||
| -fno-common | |||||
| -fvisibility=hidden | |||||
| ) | |||||
| target_compile_definitions(atc_fe PRIVATE | |||||
| google=ascend_private | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_include_directories(atc_fe PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_atcstub | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| #### blue zone #### | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ) | |||||
| target_link_options(atc_fe PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(atc_fe PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| ascend_protobuf | |||||
| c_sec | |||||
| graph | |||||
| slog | |||||
| -Wl,--as-needed | |||||
| ) | |||||
| set_target_properties(atc_fe PROPERTIES | |||||
| OUTPUT_NAME fe | |||||
| LIBRARY_OUTPUT_DIRECTORY atclib | |||||
| ) | |||||
| ############ libhost_cpu_opskernel_builder.so ############ | |||||
| add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||||
| target_compile_options(host_cpu_opskernel_builder PRIVATE | |||||
| -Werror | |||||
| -fno-common | |||||
| -fvisibility=hidden | |||||
| ) | |||||
| target_compile_definitions(host_cpu_opskernel_builder PRIVATE | |||||
| google=ascend_private | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_include_directories(host_cpu_opskernel_builder PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| #### blue zone #### | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ) | |||||
| target_link_options(host_cpu_opskernel_builder PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(host_cpu_opskernel_builder PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| ascend_protobuf | |||||
| c_sec | |||||
| slog | |||||
| graph | |||||
| register | |||||
| -Wl,--as-needed | |||||
| ) | |||||
| ############ atclib/libhost_cpu_opskernel_builder.so ############ | |||||
| add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||||
| target_compile_options(atc_host_cpu_opskernel_builder PRIVATE | |||||
| -Werror | |||||
| -fno-common | |||||
| -fvisibility=hidden | |||||
| ) | |||||
| target_compile_definitions(atc_host_cpu_opskernel_builder PRIVATE | |||||
| google=ascend_private | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_include_directories(atc_host_cpu_opskernel_builder PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| #### blue zone #### | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ) | |||||
| target_link_options(atc_host_cpu_opskernel_builder PRIVATE | |||||
| -Wl,-Bsymbolic | |||||
| ) | |||||
| target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| ascend_protobuf | |||||
| c_sec | |||||
| slog | |||||
| graph | |||||
| register | |||||
| -Wl,--as-needed | |||||
| ) | |||||
| set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES | |||||
| OUTPUT_NAME host_cpu_opskernel_builder | |||||
| LIBRARY_OUTPUT_DIRECTORY atclib | |||||
| ) | |||||
| ############ libhost_cpu_opskernel_builder.a ############ | |||||
| add_library(host_cpu_opskernel_builder_static STATIC ${CPU_OPS_KERNEL_LIST}) | |||||
| target_compile_options(host_cpu_opskernel_builder_static PRIVATE | |||||
| -Werror | |||||
| -fno-common | |||||
| -fvisibility=hidden | |||||
| ) | |||||
| target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE | |||||
| google=ascend_private | |||||
| LOG_CPP | |||||
| FUNC_VISIBILITY | |||||
| ) | |||||
| target_include_directories(host_cpu_opskernel_builder_static PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${GE_CODE_DIR}/ge | |||||
| ${GE_CODE_DIR}/inc | |||||
| ${GE_CODE_DIR}/inc/external | |||||
| ${GE_CODE_DIR}/inc/framework | |||||
| ${METADEF_DIR}/inc | |||||
| ${METADEF_DIR}/inc/external | |||||
| ${METADEF_DIR}/inc/external/graph | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| #### yellow zone #### | |||||
| ${GE_CODE_DIR}/../inc | |||||
| #### blue zone #### | |||||
| ${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
| ) | |||||
| target_link_libraries(host_cpu_opskernel_builder_static PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ascend_protobuf | |||||
| c_sec | |||||
| ) | |||||
| ############ install ############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS fe host_cpu_opskernel_builder OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| install(TARGETS atc_fe atc_host_cpu_opskernel_builder OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib | |||||
| ) | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||||
| #define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||||
| #include <string> | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| // engine name | |||||
| const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU"; | |||||
| const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE"; | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "stub_engine.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <securec.h> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "host_cpu_engine/common/constant/constant.h" | |||||
| #include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||||
| namespace fe { | |||||
| AICEngine &AICEngine::Instance() { | |||||
| static AICEngine instance; | |||||
| return instance; | |||||
| } | |||||
| Status AICEngine::Initialize(const std::map<string, string> &options) { | |||||
| if (ops_kernel_store_ == nullptr) { | |||||
| ops_kernel_store_ = MakeShared<HostCpuOpsKernelInfoStore>(); | |||||
| if (ops_kernel_store_ == nullptr) { | |||||
| GELOGE(FAILED, "[Create][AICEngine]Make HostCpuOpsKernelInfoStore failed."); | |||||
| REPORT_INNER_ERROR("E19999", "AICEngine::Initialize failed for new AICEngine."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void AICEngine::GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
| if (ops_kernel_store_ != nullptr) { | |||||
| // add buildin opsKernel to opsKernelInfoMap | |||||
| ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_; | |||||
| } | |||||
| } | |||||
| void AICEngine::GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &) { | |||||
| // no optimizer for host cpu engine | |||||
| } | |||||
| Status AICEngine::Finalize() { | |||||
| ops_kernel_store_ = nullptr; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace fe | |||||
| ge::Status Initialize(const std::map<string, string> &options) { | |||||
| return fe::AICEngine::Instance().Initialize(options); | |||||
| } | |||||
| void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
| fe::AICEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); | |||||
| } | |||||
| void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers) { | |||||
| fe::AICEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); | |||||
| } | |||||
| ge::Status Finalize() { return fe::AICEngine::Instance().Finalize(); } | |||||
| @@ -0,0 +1,126 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||||
| #define GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||||
| #if defined(_MSC_VER) | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #else | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #endif | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/opskernel/ops_kernel_info_store.h" | |||||
| #include "common/optimizer/graph_optimizer.h" | |||||
| using OpsKernelInfoStorePtr = std::shared_ptr<ge::OpsKernelInfoStore>; | |||||
| using GraphOptimizerPtr = std::shared_ptr<ge::GraphOptimizer>; | |||||
| namespace ge { | |||||
| namespace { | |||||
| std::vector<string> extern_engine_name_vec = {"fe","rts_engine","aicpu_ascend_engine","aicpu_tf_engine",} | |||||
| } // namespace | |||||
| /** | |||||
| * host cpu engine. | |||||
| * Used for the ops which executes on host. | |||||
| */ | |||||
| class GE_FUNC_VISIBILITY StubEngine { | |||||
| public: | |||||
| /** | |||||
| * get HostCpuEngine instance. | |||||
| * @return HostCpuEngine instance. | |||||
| */ | |||||
| static StubEngine &Instance(); | |||||
| virtual ~StubEngine() = default; | |||||
| /** | |||||
| * When Ge start, GE will invoke this interface | |||||
| * @return The status whether initialize successfully | |||||
| */ | |||||
| Status Initialize(const std::map<string, string> &options); | |||||
| /** | |||||
| * After the initialize, GE will invoke this interface | |||||
| * to get the Ops kernel Store. | |||||
| * @param ops_kernel_map The host cpu's ops kernel info | |||||
| */ | |||||
| void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
| /** | |||||
| * After the initialize, GE will invoke this interface | |||||
| * to get the Graph Optimizer. | |||||
| * @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
| */ | |||||
| void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
| /** | |||||
| * When the graph finished, GE will invoke this interface | |||||
| * @return The status whether initialize successfully | |||||
| */ | |||||
| Status Finalize(); | |||||
| StubEngine(const StubEngine &StubEngine) = delete; | |||||
| StubEngine(const StubEngine &&StubEngine) = delete; | |||||
| StubEngine &operator=(const StubEngine &StubEngine) = delete; | |||||
| StubEngine &operator=(StubEngine &&StubEngine) = delete; | |||||
| private: | |||||
| StubEngine() = default; | |||||
| OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; | |||||
| }; | |||||
| } // namespace ge | |||||
| extern "C" { | |||||
| /** | |||||
| * When Ge start, GE will invoke this interface | |||||
| * @return The status whether initialize successfully | |||||
| */ | |||||
| GE_FUNC_VISIBILITY ge::Status Initialize(const map<string, string> &options); | |||||
| /** | |||||
| * After the initialize, GE will invoke this interface to get the Ops kernel Store | |||||
| * @param ops_kernel_map The host cpu's ops kernel info | |||||
| */ | |||||
| GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
| /** | |||||
| * After the initialize, GE will invoke this interface to get the Graph Optimizer | |||||
| * @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
| */ | |||||
| GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
| /** | |||||
| * When the graph finished, GE will invoke this interface | |||||
| * @return The status whether initialize successfully | |||||
| */ | |||||
| GE_FUNC_VISIBILITY ge::Status Finalize(); | |||||
| } | |||||
| #endif // GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||||
| @@ -0,0 +1,114 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "host_cpu_ops_kernel_builder.h" | |||||
| #include <memory> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include <securec.h> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "host_cpu_engine/common/constant/constant.h" | |||||
| #include "register/ops_kernel_builder_registry.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| REGISTER_OPS_KERNEL_BUILDER(kHostCpuOpKernelLibName, HostCpuOpsKernelBuilder); | |||||
| Status HostCpuOpsKernelBuilder::Finalize() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HostCpuOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HostCpuOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
| OpDescPtr op_desc = ge_node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); | |||||
| REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); | |||||
| return FAILED; | |||||
| } | |||||
| bool is_shape_unknown = false; | |||||
| if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
| if (is_shape_unknown) { | |||||
| GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| const string name = ge_node.GetName(); | |||||
| const string type = ge_node.GetType(); | |||||
| GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize()); | |||||
| for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||||
| GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast<uint32_t>(i)); | |||||
| Format format = output_tensor.GetFormat(); | |||||
| DataType data_type = output_tensor.GetDataType(); | |||||
| int64_t mem_size = 0; | |||||
| // If mem size has been set, no need reset. | |||||
| if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) { | |||||
| GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.", | |||||
| name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size); | |||||
| continue; | |||||
| } | |||||
| int64_t output_mem_size = 0; | |||||
| GeShape output_shape = output_tensor.GetShape(); | |||||
| if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) || | |||||
| (output_mem_size < 0)) { | |||||
| GELOGE(FAILED, | |||||
| "[Calc][TensorMemSize] fail for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
| name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", | |||||
| "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
| name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", | |||||
| name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| TensorUtils::SetSize(output_tensor, output_mem_size); | |||||
| if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(i), output_tensor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, | |||||
| "[Update][OutputDesc] fail for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||||
| name.c_str(), type.c_str(), i, | |||||
| TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "UpdateOutputDesc failed for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||||
| name.c_str(), type.c_str(), i, | |||||
| TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HostCpuOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) { | |||||
| // no need to generate device task | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
| #define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
| #if defined(_MSC_VER) | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #else | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #endif | |||||
| #include "common/opskernel/ops_kernel_builder.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| class GE_FUNC_VISIBILITY HostCpuOpsKernelBuilder : public OpsKernelBuilder { | |||||
| public: | |||||
| Status Initialize(const map<std::string, std::string> &options) override; | |||||
| Status Finalize() override; | |||||
| Status CalcOpRunningParam(Node &node) override; | |||||
| Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override; | |||||
| }; | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||||
| #include <memory> | |||||
| #include "common/constant/constant.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "op/op_factory.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| using domi::TaskDef; | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| Status HostCpuOpsKernelInfoStore::Initialize(const map<string, string> &options) { | |||||
| GELOGI("HostCpuOpsKernelInfoStore init start."); | |||||
| OpInfo default_op_info = {.engine = kHostCpuEngineName, | |||||
| .opKernelLib = kHostCpuOpKernelLibName, | |||||
| .computeCost = 0, | |||||
| .flagPartial = false, | |||||
| .flagAsync = false, | |||||
| .isAtomic = false}; | |||||
| // Init op_info_map_ | |||||
| auto all_ops = OpFactory::Instance().GetAllOps(); | |||||
| for (auto &op : all_ops) { | |||||
| op_info_map_[op] = default_op_info; | |||||
| } | |||||
| GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HostCpuOpsKernelInfoStore::Finalize() { | |||||
| op_info_map_.clear(); | |||||
| return SUCCESS; | |||||
| } | |||||
| void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | |||||
| bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { | |||||
| if (op_desc == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return op_info_map_.count(op_desc->GetType()) > 0; | |||||
| } | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
| #define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
| #if defined(_MSC_VER) | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #else | |||||
| #ifdef FUNC_VISIBILITY | |||||
| #define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_VISIBILITY | |||||
| #endif | |||||
| #endif | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/opskernel/ops_kernel_info_store.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| class GE_FUNC_VISIBILITY HostCpuOpsKernelInfoStore : public OpsKernelInfoStore { | |||||
| public: | |||||
| HostCpuOpsKernelInfoStore() {} | |||||
| ~HostCpuOpsKernelInfoStore() override = default; | |||||
| /** | |||||
| * Initialize related resources of the host cpu kernelinfo store | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| Status Initialize(const std::map<std::string, std::string> &options) override; | |||||
| /** | |||||
| * Release related resources of the host cpu kernel info store | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| Status Finalize() override; | |||||
| /** | |||||
| * Check to see if an operator is fully supported or partially supported. | |||||
| * @param op_desc OpDesc information | |||||
| * @param reason unsupported reason | |||||
| * @return bool value indicate whether the operator is fully supported | |||||
| */ | |||||
| bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; | |||||
| /** | |||||
| * Returns the full operator information. | |||||
| * @param infos reference of a map, | |||||
| * contain operator's name and detailed information | |||||
| */ | |||||
| void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override; | |||||
| HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
| HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
| HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
| HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
| private: | |||||
| // store op name and OpInfo key-value pair | |||||
| std::map<std::string, ge::OpInfo> op_info_map_; | |||||
| }; | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "host_cpu_engine/ops_kernel_store/op/host_op.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| Status HostOp::Run() { | |||||
| // no need to generate device task | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_CREATOR(NoOp, HostOp); | |||||
| REGISTER_OP_CREATOR(Variable, HostOp); | |||||
| REGISTER_OP_CREATOR(Constant, HostOp); | |||||
| REGISTER_OP_CREATOR(Assign, HostOp); | |||||
| REGISTER_OP_CREATOR(RandomUniform, HostOp); | |||||
| REGISTER_OP_CREATOR(Add, HostOp); | |||||
| REGISTER_OP_CREATOR(Mul, HostOp); | |||||
| REGISTER_OP_CREATOR(ConcatV2, HostOp); | |||||
| REGISTER_OP_CREATOR(Data, HostOp); | |||||
| REGISTER_OP_CREATOR(Fill, HostOp); | |||||
| REGISTER_OP_CREATOR(NetOutput, HostOp); | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
| #define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
| #include "host_cpu_engine/ops_kernel_store/op/op.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| class GE_FUNC_VISIBILITY HostOp : public Op { | |||||
| public: | |||||
| HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} | |||||
| ~HostOp() override = default; | |||||
| HostOp &operator=(const HostOp &op) = delete; | |||||
| HostOp(const HostOp &op) = delete; | |||||
| Status Run() override; | |||||
| }; | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
| #define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
| #include <climits> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/opskernel/ops_kernel_info_types.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| /** | |||||
| * The base class for all op. | |||||
| */ | |||||
| class GE_FUNC_VISIBILITY Op { | |||||
| public: | |||||
| Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} | |||||
| virtual ~Op() = default; | |||||
| virtual Status Run() = 0; | |||||
| protected: | |||||
| const RunContext &run_context_; | |||||
| const Node &node_; | |||||
| }; | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "graph/op_desc.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| OpFactory &OpFactory::Instance() { | |||||
| static OpFactory instance; | |||||
| return instance; | |||||
| } | |||||
| std::shared_ptr<Op> OpFactory::CreateOp(const Node &node, RunContext &run_context) { | |||||
| auto iter = op_creator_map_.find(node.GetType()); | |||||
| if (iter != op_creator_map_.end()) { | |||||
| return iter->second(node, run_context); | |||||
| } | |||||
| GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) { | |||||
| if (func == nullptr) { | |||||
| GELOGW("Func is NULL."); | |||||
| return; | |||||
| } | |||||
| auto iter = op_creator_map_.find(type); | |||||
| if (iter != op_creator_map_.end()) { | |||||
| GELOGW("%s creator already exist", type.c_str()); | |||||
| return; | |||||
| } | |||||
| op_creator_map_[type] = func; | |||||
| all_ops_.emplace_back(type); | |||||
| } | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
| #define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "host_cpu_engine/ops_kernel_store/op/op.h" | |||||
| namespace ge { | |||||
| namespace host_cpu { | |||||
| using OP_CREATOR_FUNC = std::function<std::shared_ptr<Op>(const Node &, RunContext &)>; | |||||
| /** | |||||
| * manage all the op, support create op. | |||||
| */ | |||||
| class GE_FUNC_VISIBILITY OpFactory { | |||||
| public: | |||||
| static OpFactory &Instance(); | |||||
| /** | |||||
| * @brief create Op. | |||||
| * @param [in] node share ptr of node | |||||
| * @param [in] run_context run context | |||||
| * @return not nullptr success | |||||
| * @return nullptr fail | |||||
| */ | |||||
| std::shared_ptr<Op> CreateOp(const Node &node, RunContext &run_context); | |||||
| /** | |||||
| * @brief Register Op create function. | |||||
| * @param [in] type Op type | |||||
| * @param [in] func Op create func | |||||
| */ | |||||
| void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func); | |||||
| const std::vector<std::string> &GetAllOps() const { return all_ops_; } | |||||
| bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); } | |||||
| OpFactory(const OpFactory &) = delete; | |||||
| OpFactory &operator=(const OpFactory &) = delete; | |||||
| OpFactory(OpFactory &&) = delete; | |||||
| OpFactory &operator=(OpFactory &&) = delete; | |||||
| private: | |||||
| OpFactory() = default; | |||||
| ~OpFactory() = default; | |||||
| // the op creator function map | |||||
| std::map<std::string, OP_CREATOR_FUNC> op_creator_map_; | |||||
| std::vector<std::string> all_ops_; | |||||
| }; | |||||
| class GE_FUNC_VISIBILITY OpRegistrar { | |||||
| public: | |||||
| OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) { | |||||
| OpFactory::Instance().RegisterCreator(type, func); | |||||
| } | |||||
| ~OpRegistrar() = default; | |||||
| OpRegistrar(const OpRegistrar &) = delete; | |||||
| OpRegistrar &operator=(const OpRegistrar &) = delete; | |||||
| OpRegistrar(OpRegistrar &&) = delete; | |||||
| OpRegistrar &operator=(OpRegistrar &&) = delete; | |||||
| }; | |||||
| #define REGISTER_OP_CREATOR(type, clazz) \ | |||||
| std::shared_ptr<Op> Creator_##type##Op(const Node &node, RunContext &run_context) { \ | |||||
| return MakeShared<clazz>(node, run_context); \ | |||||
| } \ | |||||
| OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) | |||||
| } // namespace host_cpu | |||||
| } // namespace ge | |||||
| #endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
| @@ -0,0 +1,179 @@ | |||||
| /* Copyright 2021. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| KernelDefWithHandle kernel_with_handle = 40; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelDefWithHandle { | |||||
| KernelContext context = 1; | |||||
| uint64 handle = 10; | |||||
| string dev_func = 11; | |||||
| uint32 block_dim = 12; | |||||
| uint32 args_size = 13; | |||||
| bytes args = 14; | |||||
| bytes sm_desc = 15; | |||||
| string original_kernel_key = 16; | |||||
| string node_info = 17; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -0,0 +1,711 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file array_ops.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/operator.h" | |||||
| namespace ge { | |||||
| /** | |||||
| *@brief Finds unique elements in a 1D tensor. \n | |||||
| *@par Inputs: | |||||
| *x: 1D tensor. | |||||
| *Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" | |||||
| are 0D scalars. \n | |||||
| *@par Attributes: | |||||
| *out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n | |||||
| *@par Outputs: | |||||
| *@li y: "x" in the unique output "y". | |||||
| *@li idx: A tensor the same size as "x". The index of each value of "x". \n | |||||
| *@attention Constraints: | |||||
| *Unique runs on the Ascend AI CPU, which delivers poor performance. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Unique. | |||||
| */ | |||||
| REG_OP(Unique) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||||
| DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||||
| DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||||
| .OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) | |||||
| .ATTR(out_idx, Type, DT_INT32) | |||||
| .OP_END_FACTORY_REG(Unique) | |||||
| /** | |||||
| *@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. | |||||
| Operator Const has the same definition as operator Constant. \n | |||||
| *@par Attributes: | |||||
| *value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||||
| *@par Outputs: | |||||
| *y: A constant tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Const. | |||||
| */ | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| /** | |||||
| *@brief Creates a constant tensor for training. \n | |||||
| *@par Attributes: | |||||
| *value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||||
| *@par Outputs: | |||||
| *y: The constant tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Const. | |||||
| */ | |||||
| REG_OP(Constant) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Constant) | |||||
| /** | |||||
| *@brief Returns a copy of the input tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Snapshot. | |||||
| */ | |||||
| REG_OP(Snapshot) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Snapshot) | |||||
| /** | |||||
| *@brief Gives a guarantee to the runtime that the input tensor is a constant. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: The input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator GuaranteeConst. | |||||
| */ | |||||
| REG_OP(GuaranteeConst) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(GuaranteeConst) | |||||
| /** | |||||
| *@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n | |||||
| *@par Inputs: | |||||
| *@li x1: A tensor of type int32 or int64. A shape. | |||||
| *@li x2: A tensor of the same type as "x1". The other shape. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. The broadcasted shape. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator BroadcastArgs. | |||||
| */ | |||||
| REG_OP(BroadcastArgs) | |||||
| .INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||||
| .INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(BroadcastArgs) | |||||
| /** | |||||
| *@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *message: Will be printed in the error at the attempt to request a gradient. \n | |||||
| *@par Outputs: | |||||
| *y: The input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator PreventGradient. | |||||
| */ | |||||
| REG_OP(PreventGradient) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(message, String, "") | |||||
| .OP_END_FACTORY_REG(PreventGradient) | |||||
| /** | |||||
| *@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n | |||||
| *@par Inputs: | |||||
| *@li x1: A tensor of type int32 or int64. | |||||
| *@li x2: A tensor of type int32 or int64. | |||||
| "x2" has the same type as "x1". \n | |||||
| *@par Outputs: | |||||
| *@li y1: A tensor. Reduction indices of "x1". | |||||
| *@li y2: A tensor. Reduction indices of "x2". \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator BroadcastGradientArgs. | |||||
| */ | |||||
| REG_OP(BroadcastGradientArgs) | |||||
| .INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||||
| .INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y1, TensorType({DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y2, TensorType({DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(BroadcastGradientArgs) | |||||
| /** | |||||
| *@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped. | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: The input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator StopGradient. | |||||
| */ | |||||
| REG_OP(StopGradient) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(StopGradient) | |||||
| /** | |||||
| *@brief Return a tensor with the same shape and contents as input. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Identity. | |||||
| */ | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| /** | |||||
| *@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n | |||||
| *@par Inputs: | |||||
| *x: A list of input tensors. It's a dynamic input \n | |||||
| *@par Outputs: | |||||
| *y: A list of Tensor objects, with the same length as the input tensor list. | |||||
| It's a dynamic output. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator IdentityN. | |||||
| */ | |||||
| REG_OP(IdentityN) | |||||
| .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(IdentityN) | |||||
| /** | |||||
| *@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||||
| *@par Inputs: | |||||
| *@li x: A tensor. | |||||
| *@li axis: The dimension index at which to expand. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator ExpandDims. | |||||
| */ | |||||
| REG_OP(ExpandDims) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
| DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .INPUT(axis, TensorType({DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
| DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(ExpandDims) | |||||
| /** | |||||
| *@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||||
| *@par Inputs: | |||||
| *@li x: Original tensor. | |||||
| *@li axis: List of ints. \n | |||||
| *@par Outputs: | |||||
| *y: Reshape tensor with same data as input. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the Onnx operator Unsqueeze. | |||||
| */ | |||||
| REG_OP(Unsqueeze) | |||||
| .INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||||
| .ATTR(axes, ListInt, {}) | |||||
| .OP_END_FACTORY_REG(Unsqueeze) | |||||
| /** | |||||
| *@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n | |||||
| *@par Inputs: | |||||
| *@li x: A tensor. | |||||
| *@li shape: A tensor. Defines the shape of the output tensor. \n | |||||
| *@par Attributes: | |||||
| *@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0". | |||||
| *@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Attention: | |||||
| *This operator cannot be directly called by the acllopExecute API. \n | |||||
| *@par Third-party framework compatibility | |||||
| *@li Compatible with the TensorFlow operator Reshape. | |||||
| *@li Compatible with the Caffe operator Reshape. | |||||
| */ | |||||
| REG_OP(Reshape) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
| DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .INPUT(shape, TensorType({DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
| DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(axis, Int, 0) | |||||
| .ATTR(num_axes, Int, -1) | |||||
| .OP_END_FACTORY_REG(Reshape) | |||||
| /** | |||||
| *@brief Removes dimensions of size 1 from the shape of a tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Squeeze. | |||||
| */ | |||||
| REG_OP(Squeeze) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(axis, ListInt, {}) | |||||
| .OP_END_FACTORY_REG(Squeeze) | |||||
| /** | |||||
| *@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. The rank of input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Rank. | |||||
| */ | |||||
| REG_OP(Rank) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_INT32})) | |||||
| .OP_END_FACTORY_REG(Rank) | |||||
| /** | |||||
| *@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. The size of the input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Size. | |||||
| */ | |||||
| REG_OP(Size) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_INT32,DT_INT64})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .OP_END_FACTORY_REG(Size) | |||||
| /** | |||||
| *@brief Input data for other operators. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *index: Index of the input tensor.The data type must be int32 or int64. | |||||
| Assume that net has three data nodes, one should be set 0, another should | |||||
| be set 1, and the left should be set 2. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the Caffe operator Data. | |||||
| */ | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| /** | |||||
| *@brief Inserts a placeholder for a tensor that will be always fed. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *@li peerIndex: An integer type. The index of the corresponding "end" node connected to. | |||||
| *@li parentId: A string, used to check if the nodes are from the saved parent node. | |||||
| *@li parentOpType: A string. Op type of the original node. | |||||
| *@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n | |||||
| *@par Outputs: | |||||
| *y: The created placeholder tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator PlaceHolder. | |||||
| */ | |||||
| REG_OP(PlaceHolder) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to | |||||
| .ATTR(parentId, String, "") // check if these node are from save parent node | |||||
| .ATTR(parentOpType, String, "") // op type of original node | |||||
| .ATTR(anchorIndex, Int, 0) // check if these node are from save anchor | |||||
| .OP_END_FACTORY_REG(PlaceHolder) | |||||
| /** | |||||
| *@brief Inserts a placeholder with default value for a tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *@li dtype: data type of tensor. | |||||
| *@li shape: tensor shape. \n | |||||
| *@par Outputs: | |||||
| *y: The created placeholder tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator PlaceholderWithDefault. | |||||
| */ | |||||
| REG_OP(PlaceholderWithDefault) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .REQUIRED_ATTR(shape, ListInt) | |||||
| .OP_END_FACTORY_REG(PlaceholderWithDefault) | |||||
| /** | |||||
| *@brief Reads and returns the value of the input variable tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator ReadVariableOp. | |||||
| */ | |||||
| REG_OP(ReadVariableOp) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .OP_END_FACTORY_REG(ReadVariableOp) | |||||
| /** | |||||
| *@brief Mark outputs of one sub graph which partitioned by engine type. | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Attributes: | |||||
| *@li peerIndex: The index of the corresponding 'placeholder' node it's connected to. | |||||
| *@li parentOpType: Op type of original node. | |||||
| *@par Restrictions: | |||||
| *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
| */ | |||||
| REG_OP(End) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(peerIndex, Int, 0) | |||||
| .ATTR(parentOpType, String, "") | |||||
| .OP_END_FACTORY_REG(End) | |||||
| /** | |||||
| *@brief Operations for writing summary data, for use in analysis and visualization. | |||||
| *@par Inputs: | |||||
| * One input: | |||||
| *x: Collections of summary data. | |||||
| *@par Restrictions: | |||||
| *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
| */ | |||||
| REG_OP(Summary) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OP_END_FACTORY_REG(Summary) | |||||
| /** | |||||
| *@brief Returns the shape of a tensor. \n | |||||
| *@par Inputs: | |||||
| *x: A tensor. \n | |||||
| *@par Attributes: | |||||
| *dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. The shape of the input tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Size. | |||||
| */ | |||||
| REG_OP(Shape) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .OP_END_FACTORY_REG(Shape) | |||||
| /** | |||||
| *@brief Returns shape of tensors. \n | |||||
| *@par Inputs: | |||||
| *x: A list of input tensors. It's a dynamic input. \n | |||||
| *@par Attributes: | |||||
| *dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||||
| *@par Outputs: | |||||
| *y: A list of tensors with the same length as the input list of tensors. | |||||
| It's a dynamic output. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator ShapeN. | |||||
| */ | |||||
| REG_OP(ShapeN) | |||||
| .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .OP_END_FACTORY_REG(ShapeN) | |||||
| /** | |||||
| *@brief Creates a tensor with the given "shape" and "dtype". \n | |||||
| *@par Inputs: | |||||
| *shape: The shape of the output tensor. \n | |||||
| *@par Attributes: | |||||
| *@li dtype: Optional. The data type of the output tensor. Defaults to "int32". | |||||
| *@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. \n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Empty. | |||||
| */ | |||||
| REG_OP(Empty) | |||||
| .INPUT(shape, TensorType({DT_INT32})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .ATTR(init, Bool, 0) | |||||
| .OP_END_FACTORY_REG(Empty) | |||||
| /** | |||||
| *@brief Returns locations of nonzero / true values in a tensor. \n | |||||
| *@par Inputs: | |||||
| *Including: | |||||
| *x: A Tensor. Must be one of the following types: | |||||
| DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, | |||||
| DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n | |||||
| *@par Outputs: | |||||
| *y: A Tensor of type DT_INT64. \n | |||||
| *@attention Constraints: | |||||
| *Where runs on the Ascend AI CPU, which delivers poor performance.\n | |||||
| *@par Third-party framework compatibility | |||||
| *Compatible with the TensorFlow operator Where. | |||||
| */ | |||||
| REG_OP(Where) | |||||
| .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ | |||||
| DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_INT64})) | |||||
| .OP_END_FACTORY_REG(Where) | |||||
| /** | |||||
| *@brief Change the shape of output according to the attr outShape | |||||
| * | |||||
| *@par Inputs: | |||||
| *x: A Tensor. \n | |||||
| *@par Outputs: | |||||
| *y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n | |||||
| *@par Attributes: | |||||
| *outShape: The shape of output will be inferred according to the attribute | |||||
| */ | |||||
| REG_OP(TransShape) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(outShape,ListInt ,{}) | |||||
| .OP_END_FACTORY_REG(TransShape); | |||||
| /** | |||||
| * @brief sort_v2. | |||||
| * @par Inputs: | |||||
| * @li x: An ND tensor of type float16. | |||||
| * @par Attributes: | |||||
| * @li axis: An optional int. The dimension to sort along. This value defaults to -1. | |||||
| * @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False. | |||||
| * @par Outputs: | |||||
| * @li y: An ND tensor of type float16. | |||||
| * @attention Constraints: | |||||
| * @li Axis should select the last dim. | |||||
| * @li When the sorting data is less than 150K, it is recommended to use this tbe ops, | |||||
| and the descending performance is better than the ascending. | |||||
| * @li The upper limit of data on Ascend910 is 2000K. | |||||
| */ | |||||
| REG_OP(SortV2) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||||
| .ATTR(axis, Int, -1) | |||||
| .ATTR(descending, Bool, false) | |||||
| .OP_END_FACTORY_REG(SortV2) | |||||
| /** | |||||
| * @brief Expand the input tensor to a compatible shape. \n | |||||
| * @par Inputs: | |||||
| * One inputs, including: | |||||
| * @li x: A Tensor. Must be one of the following types: | |||||
| * float16, float32, int32, int8 ,uint8. \n | |||||
| * @li shape: A Tensor to specify the shape that the input tensor expanded to. \n | |||||
| * @par Outputs: | |||||
| * @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||||
| * @par Third-party framework compatibility | |||||
| * Compatible with the ONNX operator Expand. | |||||
| */ | |||||
| REG_OP(Expand) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
| .INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
| .OP_END_FACTORY_REG(Expand) | |||||
| /** | |||||
| * @brief Expand the input tensor to a compatible shape. \n | |||||
| * @par Inputs: | |||||
| * One inputs, including: | |||||
| * @li x: A Tensor. Must be one of the following types: | |||||
| * float16, float32, int32, int8 ,uint8. \n | |||||
| * @par Attributes: | |||||
| * @li shape: A required listInt to specify the shape that the input tensor expanded to. \n | |||||
| * @par Outputs: | |||||
| * @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||||
| * @par Third-party framework compatibility | |||||
| * Compatible with the ONNX operator Expand. | |||||
| */ | |||||
| REG_OP(ExpandD) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
| .REQUIRED_ATTR(shape, ListInt) | |||||
| .OP_END_FACTORY_REG(ExpandD) | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||||
| @@ -0,0 +1,392 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file control_flow_ops.cpp | |||||
| * \brief | |||||
| */ | |||||
| #include "control_flow_ops.h" | |||||
| #include "./util/common_shape_fns.h" | |||||
| #include "./util/error_util.h" | |||||
| #include "util/util.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| graphStatus MergeInferImpl(Operator& op) { | |||||
| TensorDesc td = op.GetOutputDesc("value_index"); | |||||
| TensorDesc td_y = op.GetOutputDesc("y"); | |||||
| td.SetShape(ge::Shape()); | |||||
| td.SetDataType(DT_INT32); | |||||
| auto ret = op.UpdateOutputDesc("value_index", td); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // check N of "x" >= 1 | |||||
| size_t in_num = op.GetInputsSize(); | |||||
| if (in_num < 1) { | |||||
| string reason = "inputs size[" + std::to_string(in_num) + "] must be greater than or equal to 1"; | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "input", reason); | |||||
| return GRAPH_FAILED; | |||||
| } else if (in_num == 2) { | |||||
| // Check is loop_merge, order of InferShape: Enter->Merge->NextIteration | |||||
| // So when processing InferShape on Merge op, shape & datatype of NextIteration op is set as default. | |||||
| // Therefore, shape & datatype of Merge op should be set as the Enter op. | |||||
| auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||||
| auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||||
| bool not_handle_flag0 = (x0_type == DT_FLOAT) && (x0_dims.size() == 0); | |||||
| auto x1_type = op.GetDynamicInputDesc("x", 1).GetDataType(); | |||||
| auto x1_dims = op.GetDynamicInputDesc("x", 1).GetShape().GetDims(); | |||||
| bool not_handle_flag1 = (x1_type == DT_FLOAT) && (x1_dims.size() == 0); | |||||
| if ((x0_type != x1_type) && (not_handle_flag0 || not_handle_flag1)) { | |||||
| if (not_handle_flag0) { | |||||
| td_y.SetShape(ge::Shape(x1_dims)); | |||||
| td_y.SetDataType(x1_type); | |||||
| } else { | |||||
| td_y.SetShape(ge::Shape(x0_dims)); | |||||
| td_y.SetDataType(x0_type); | |||||
| } | |||||
| (void)op.UpdateOutputDesc("y", td_y); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| // check "x" be same type | |||||
| auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||||
| for (size_t i = 1; i < op.GetInputsSize(); i++) { | |||||
| auto xi_type = op.GetDynamicInputDesc("x", i).GetDataType(); | |||||
| if (xi_type != x0_type) { | |||||
| string reason = "x[0]'s dtype[" + std::to_string(x0_type) + "] must be equal to x[" + std::to_string(i) + | |||||
| "]'s dtype[" + std::to_string(xi_type) + "]"; | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| // infer "y" be unknown shape | |||||
| auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||||
| bool x0_unknown = (x0_dims.size() == 1) && (x0_dims[0] == 0); | |||||
| if (x0_unknown) { | |||||
| Shape unknown_shape(ge::UNKNOWN_SHAPE); | |||||
| td_y.SetShape(unknown_shape); | |||||
| td_y.SetDataType(x0_type); | |||||
| (void)op.UpdateOutputDesc("y", td_y); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // find the input with the max size from all inputs, and set it's data type/shape to the output | |||||
| std::map<int64_t, size_t> size_to_index; | |||||
| for (size_t i = 0; i < op.GetInputsSize(); i++) { | |||||
| auto xi_dims = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); | |||||
| bool xi_unknown = (xi_dims.size() == 1) && (xi_dims[0] == 0); | |||||
| if (xi_unknown) { | |||||
| continue; | |||||
| } | |||||
| int64_t size = static_cast<int64_t>(GetSizeByDataType(op.GetDynamicInputDesc("x", i).GetDataType())); | |||||
| if (size < 0) { | |||||
| continue; | |||||
| } | |||||
| if (!xi_dims.empty()) { | |||||
| for (auto& dim : xi_dims) { | |||||
| if (dim <= 0) { | |||||
| size = -1; | |||||
| break; | |||||
| } | |||||
| if (size != 0 && INT64_MAX / size < dim) { | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dim", "the dim size is overflow"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| size *= dim; | |||||
| } | |||||
| if (size < 0) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (size_to_index.count(size) == 0) { | |||||
| size_to_index[size] = i; | |||||
| } | |||||
| } | |||||
| if (size_to_index.empty()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto index = size_to_index.rbegin()->second; | |||||
| td_y.SetShape(ge::Shape(op.GetDynamicInputDesc("x", index).GetShape().GetDims())); | |||||
| td_y.SetDataType(op.GetDynamicInputDesc("x", index).GetDataType()); | |||||
| (void)op.UpdateOutputDesc("y", td_y); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus SwitchInferImpl(Operator& op) { | |||||
| auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
| auto data_desc = op_desc->MutableInputDesc("data"); | |||||
| auto pred_desc = op_desc->MutableInputDesc("pred"); | |||||
| auto output_false_desc = op_desc->MutableOutputDesc("output_false"); | |||||
| auto output_true_desc = op_desc->MutableOutputDesc("output_true"); | |||||
| std::vector<std::pair<int64_t, int64_t>> data_range; | |||||
| data_desc->GetShapeRange(data_range); | |||||
| // check "pred" scalar type be bool | |||||
| auto pred_dims = pred_desc->GetShape().GetDims(); | |||||
| if (pred_dims.size() != 0) { | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "pred dims", "pred should be a scalar"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| DataType pred_type = pred_desc->GetDataType(); | |||||
| if (pred_type != DT_BOOL) { | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "pred should be bool type"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| DataType data_type = data_desc->GetDataType(); | |||||
| auto data_dims = data_desc->GetShape().GetDims(); | |||||
| output_false_desc->SetShapeRange(data_range); | |||||
| output_true_desc->SetShapeRange(data_range); | |||||
| output_false_desc->SetShape(GeShape(data_dims)); | |||||
| output_false_desc->SetOriginShape(GeShape(data_dims)); | |||||
| output_true_desc->SetShape(GeShape(data_dims)); | |||||
| output_true_desc->SetOriginShape(GeShape(data_dims)); | |||||
| output_false_desc->SetDataType(data_type); | |||||
| output_true_desc->SetDataType(data_type); | |||||
| auto context = op.GetInferenceContext(); | |||||
| std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||||
| if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||||
| ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||||
| std::vector<ShapeAndType> grad_handle_shape_and_type; | |||||
| grad_handle_shape_and_type.reserve(1); | |||||
| grad_handle_shape_and_type.emplace_back(shape_and_type); | |||||
| std::vector<std::vector<ShapeAndType>> shapes_and_types(2); | |||||
| shapes_and_types[0] = grad_handle_shape_and_type; | |||||
| shapes_and_types[1] = grad_handle_shape_and_type; | |||||
| context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus EnterInferImpl(Operator& op) { | |||||
| auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
| auto input_desc_x = op_desc->MutableInputDesc("x"); | |||||
| auto output_desc_y = op_desc->MutableOutputDesc("y"); | |||||
| std::vector<std::pair<int64_t, int64_t>> x_range; | |||||
| std::vector<std::pair<int64_t, int64_t>> y_range; | |||||
| input_desc_x->GetShapeRange(x_range); | |||||
| auto input_dims = input_desc_x->MutableShape().GetDims(); | |||||
| DataType input_type = input_desc_x->GetDataType(); | |||||
| output_desc_y->SetShape(ge::GeShape(input_dims)); | |||||
| output_desc_y->SetOriginShape(ge::GeShape(input_dims)); | |||||
| output_desc_y->SetDataType(input_type); | |||||
| if (!x_range.empty()) { | |||||
| output_desc_y->SetShapeRange(x_range); | |||||
| } | |||||
| auto context = op.GetInferenceContext(); | |||||
| std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||||
| if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||||
| ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||||
| std::vector<ShapeAndType> grad_handle_shape_and_type; | |||||
| grad_handle_shape_and_type.reserve(1); | |||||
| grad_handle_shape_and_type.emplace_back(shape_and_type); | |||||
| std::vector<std::vector<ShapeAndType>> shapes_and_types(1); | |||||
| shapes_and_types[0] = grad_handle_shape_and_type; | |||||
| context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus PassThroughInferImpl(Operator& op, const std::string& in_name, const std::string& out_name) { | |||||
| auto input_dims = op.GetInputDesc(in_name).GetShape().GetDims(); | |||||
| DataType input_type = op.GetInputDesc(in_name).GetDataType(); | |||||
| TensorDesc tensordesc_output = op.GetOutputDesc(out_name); | |||||
| tensordesc_output.SetShape(ge::Shape(input_dims)); | |||||
| tensordesc_output.SetDataType(input_type); | |||||
| (void)op.UpdateOutputDesc(out_name, tensordesc_output); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus LoopCondInferImpl(Operator& op) { | |||||
| auto input_dims = op.GetInputDesc("x").GetShape().GetDims(); | |||||
| if (input_dims.size() != 0) { | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "x dims", "x should be a scalar"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| TensorDesc tensordesc_output = op.GetOutputDesc("y"); | |||||
| tensordesc_output.SetShape(ge::Shape(input_dims)); | |||||
| DataType input_type = op.GetInputDesc("x").GetDataType(); | |||||
| if (input_type != DT_BOOL) { | |||||
| GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "x should be bool type"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| tensordesc_output.SetDataType(input_type); | |||||
| (void)op.UpdateOutputDesc("y", tensordesc_output); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| IMPLEMT_INFERFUNC(Merge, MergeInfer) { | |||||
| return MergeInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(Merge, MergeInfer); | |||||
| IMPLEMT_INFERFUNC(RefMerge, RefMergeInfer) { | |||||
| return MergeInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(RefMerge, RefMergeInfer); | |||||
| IMPLEMT_INFERFUNC(Switch, SwitchInfer) { | |||||
| return SwitchInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(Switch, SwitchInfer); | |||||
| IMPLEMT_INFERFUNC(RefSwitch, RefSwitchInfer) { | |||||
| return SwitchInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(RefSwitch, RefSwitchInfer); | |||||
| IMPLEMT_INFERFUNC(SwitchN, SwitchNInfer) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| INFER_FUNC_REG(SwitchN, SwitchNInfer); | |||||
| IMPLEMT_INFERFUNC(Enter, EnterInfer) { | |||||
| return EnterInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(Enter, EnterInfer); | |||||
| IMPLEMT_INFERFUNC(RefEnter, RefEnterInfer) { | |||||
| return PassThroughInferImpl(op, "x", "y"); | |||||
| } | |||||
| INFER_FUNC_REG(RefEnter, RefEnterInfer); | |||||
| IMPLEMT_INFERFUNC(LoopCond, LoopCondInfer) { | |||||
| return LoopCondInferImpl(op); | |||||
| } | |||||
| INFER_FUNC_REG(LoopCond, LoopCondInfer); | |||||
| IMPLEMT_INFERFUNC(NextIteration, NextIterationInfer) { | |||||
| return PassThroughInferImpl(op, "x", "y"); | |||||
| } | |||||
| INFER_FUNC_REG(NextIteration, NextIterationInfer); | |||||
| IMPLEMT_INFERFUNC(RefNextIteration, RefNextIterationInfer) { | |||||
| return PassThroughInferImpl(op, "x", "y"); | |||||
| } | |||||
| INFER_FUNC_REG(RefNextIteration, RefNextIterationInfer); | |||||
| IMPLEMT_INFERFUNC(Exit, ExitInfer) { | |||||
| return PassThroughInferImpl(op, "x", "y"); | |||||
| } | |||||
| INFER_FUNC_REG(Exit, ExitInfer); | |||||
| IMPLEMT_INFERFUNC(RefExit, RefExitInfer) { | |||||
| return PassThroughInferImpl(op, "x", "y"); | |||||
| } | |||||
| INFER_FUNC_REG(RefExit, RefExitInfer); | |||||
| // ----------------MapIndex------------------- | |||||
| IMPLEMT_VERIFIER(MapIndex, MapIndexVerify) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| IMPLEMT_COMMON_INFERFUNC(MapIndexInferShape) { | |||||
| OP_LOGI("MapIndex", "infer shape begin---"); | |||||
| auto x_shape = op.GetInputDesc("x").GetShape().GetDims(); | |||||
| if (x_shape.empty()) { | |||||
| OP_LOGE(op.GetName().c_str(), "x_shape is empty"); | |||||
| OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "x_shape is empty"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int64_t x_length = x_shape[0]; | |||||
| auto data_seq_shape = op.GetInputDesc("data_seq").GetShape().GetDims(); | |||||
| if (data_seq_shape.empty()) { | |||||
| OP_LOGE(op.GetName().c_str(), "data_seq_shape is empty"); | |||||
| OpsOneInputShapeErrReport(op.GetName().c_str(), "data_seq", "data_seq_shape is empty"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int64_t data_seq_length = data_seq_shape[0]; | |||||
| if (x_length > 8 || x_length == 0) { | |||||
| OP_LOGE(op.GetName().c_str(), "the length of x should be less than or equal to 8"); | |||||
| OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "the length of x should be less than or equal to 8 and not 0"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (data_seq_length % x_length != 0) { | |||||
| OP_LOGE(op.GetName().c_str(), "the length of data_seq must be multiple of the length of x"); | |||||
| OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||||
| "the length of data_seq must be multiple of the length of x"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (data_seq_length / x_length > 100) { | |||||
| OP_LOGE(op.GetName().c_str(), "data_seq_length / x_length should be be less than or equal to 100"); | |||||
| OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||||
| "data_seq_length / x_length should be be less than or equal to 100"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto level_index_shape = op.GetInputDesc("level_index").GetShape().GetDims(); | |||||
| if (!level_index_shape.empty()) { | |||||
| int64_t level_index_length = level_index_shape[0]; | |||||
| if (level_index_length != (data_seq_length / x_length)) { | |||||
| OP_LOGE(op.GetName().c_str(), | |||||
| "the length of level_index must be equal to " | |||||
| "the length of data_seq divided by the length of x"); | |||||
| OpsOneInputShapeErrReport(op.GetName().c_str(), "level_index", | |||||
| "the length of level_index must be equal to " | |||||
| "the length of data_seq divided by the length of x"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| TensorDesc y_desc = op.GetOutputDesc("y"); | |||||
| y_desc.SetShape(ge::Shape()); | |||||
| y_desc.SetDataType(ge::DT_INT32); | |||||
| (void)op.UpdateOutputDesc("y", y_desc); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| COMMON_INFER_FUNC_REG(MapIndex, MapIndexInferShape); | |||||
| VERIFY_FUNC_REG(MapIndex, MapIndexVerify); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,407 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file control_flow_ops.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/operator.h" | |||||
| namespace ge { | |||||
| /** | |||||
| *@brief Forwards the value of an available tensor from input "x" to output "y". | |||||
| * Merge waits for at least one of the input tensors to become available. | |||||
| * It is usually combined with Switch to implement branching. | |||||
| * Merge forwards the first tensor to become available to output "y", | |||||
| * and sets "value_index" the index of the tensor in inputs . \n | |||||
| *@par Inputs: | |||||
| *x: The input tensors, one of which will become available. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||||
| *@par Outputs: | |||||
| *@li y: The available tensor. Has the same type as "x". | |||||
| *@li value_index: A scalar of type int32, for the index of the chosen input | |||||
| * tensor . \n | |||||
| *@see Switch() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator Merge. | |||||
| */ | |||||
| REG_OP(Merge) | |||||
| .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(value_index, TensorType({DT_INT32})) | |||||
| .OP_END_FACTORY_REG(Merge) | |||||
| /** | |||||
| *@brief Forwards the value of an available tensor from input "x" to output "y". | |||||
| * Merge waits for at least one of the input tensors to become available. | |||||
| * It is usually combined with Switch to implement branching. | |||||
| * Merge forwards the first tensor to become available to output "y", | |||||
| * and sets "value_index" the index of the tensor in inputs . \n | |||||
| *@par Inputs: | |||||
| *x: The input tensors, one of which will become available. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||||
| *@par Outputs: | |||||
| *@li y: The available tensor. Has the same type as "x". | |||||
| *@li value_index: A scalar of type int32, for the index of the chosen input | |||||
| * tensor . \n | |||||
| *@see Switch() | Merge() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator RefMerge. | |||||
| */ | |||||
| REG_OP(RefMerge) | |||||
| .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(value_index, TensorType({DT_INT32})) | |||||
| .OP_END_FACTORY_REG(RefMerge) | |||||
| /** | |||||
| *@brief Forwards "data" to the output port determined by "pred". | |||||
| * If "pred" is "true", the data input is forwarded to "output_true". | |||||
| * Otherwise, the data is forwarded to "output_false" . \n | |||||
| *@par Inputs: | |||||
| *@li data: The tensor to be forwarded. \ n | |||||
| * Must be one of the following types: float16, float32, float64, | |||||
| * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
| *@li pred: A boolean scalar. The output port that will receive data . \n | |||||
| *@par Outputs: | |||||
| *@li output_false: If "pred" is "false", data will be forwarded to this output. | |||||
| * Has the same type as "data". | |||||
| *@li output_true: If "pred" is "true", data will be forwarded to this output. | |||||
| * Has the same type as "data" . \n | |||||
| *@see Merge() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator Switch. | |||||
| */ | |||||
| REG_OP(Switch) | |||||
| .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .INPUT(pred, TensorType({DT_BOOL})) | |||||
| .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(Switch) | |||||
| /** | |||||
| *@brief Forwards "data" to the output port determined by "pred". | |||||
| * If "pred" is "true", the data input is forwarded to "output_true". | |||||
| * Otherwise, the data is forwarded to "output_false" . \n | |||||
| *@par Inputs: | |||||
| *@li data: The ref tensor to be forwarded. | |||||
| * Must be one of the following types: float16, float32, float64, | |||||
| * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
| *@li pred: A boolean scalar. The output port that will receive data . \n | |||||
| *@par Outputs: | |||||
| *@li output_false: If "pred" is "false", data will be forwarded to this output. | |||||
| * Has the same type as "data". | |||||
| *@li output_true: If "pred" is "true", data will be forwarded to this output. | |||||
| * Has the same type as "data" . \n | |||||
| *@see Merge() | Switch() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator RefSwitch. | |||||
| */ | |||||
| REG_OP(RefSwitch) | |||||
| .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .INPUT(pred, TensorType({DT_BOOL})) | |||||
| .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(RefSwitch) | |||||
| /** | |||||
| *@brief Forwards "data" to the output port determined by "pred_value" . \n | |||||
| *@par Inputs: | |||||
| *@li data: The tensor to be forwarded. \ n | |||||
| * Must be one of the following types: float16, float32, float64, | |||||
| * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
| *@li pred_value: A int64 tensor which determines the output port that will receive data . \n | |||||
| *@par Outputs: | |||||
| *output: The output tensors, one of which will become available. | |||||
| * Has the same type as "data". | |||||
| */ | |||||
| REG_OP(SwitchN) | |||||
| .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .INPUT(pred_value, TensorType({DT_INT64})) | |||||
| .DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(SwitchN) | |||||
| /** | |||||
| *@brief Creates or finds a child frame, and makes "x" available to the child | |||||
| * frame. This op is used together with Exit to create loops in the graph. | |||||
| * The Executor uses the unique "frame_name" to identify frames. | |||||
| * If "is_constant" is "true", output "y" is a constant in the child | |||||
| * frame; otherwise it may be changed in the child frame . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the child frame. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Attributes: | |||||
| *@li frame_name: A required string. The name of the child frame. | |||||
| *@li is_constant: A required bool. If true, the output is constant in | |||||
| * the child frame . \n | |||||
| *@par Outputs: | |||||
| *y: A Tensor. Has the same type as "x" . \n | |||||
| *@see Exit() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator Enter. | |||||
| */ | |||||
| REG_OP(Enter) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .REQUIRED_ATTR(frame_name, String) | |||||
| .REQUIRED_ATTR(is_constant, Bool) | |||||
| .OP_END_FACTORY_REG(Enter) | |||||
| /** | |||||
| *@brief Creates or finds a child frame, and makes "x" available to the child | |||||
| * frame. This op is used together with Exit to create loops in the graph. | |||||
| * The Executor uses the unique "frame_name" to identify frames. | |||||
| * If "is_constant" is "true", output "y" is a constant in the child | |||||
| * frame; otherwise it may be changed in the child frame . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the child frame. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Attributes: | |||||
| *@li frame_name: A required string. The name of the child frame. | |||||
| *@li is_constant: A required bool. If true, the output is constant in | |||||
| * the child frame . \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. Has the same type as "x" . \n | |||||
| *@see Exit() | Enter() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator RefEnter. | |||||
| */ | |||||
| REG_OP(RefEnter) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .REQUIRED_ATTR(frame_name, String) | |||||
| .REQUIRED_ATTR(is_constant, Bool) | |||||
| .OP_END_FACTORY_REG(RefEnter) | |||||
| /** | |||||
| *@brief Forwards the input to the output. This op represents the loop | |||||
| * termination condition . \n | |||||
| *@par Inputs: | |||||
| *x: A boolean scalar. The condition of the Switch op . \n | |||||
| *@par Outputs: | |||||
| *y: The tensor "x" . \n | |||||
| *@see Switch() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator LoopCond. | |||||
| */ | |||||
| REG_OP(LoopCond) | |||||
| .INPUT(x, TensorType({DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(LoopCond) | |||||
| /** | |||||
| *@brief Makes the input available to the next iteration . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the next iteration. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Outputs: | |||||
| *y: A Tensor. Has the same type as "x" . \n | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator NextIteration. | |||||
| */ | |||||
| REG_OP(NextIteration) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(NextIteration) | |||||
| /** | |||||
| *@brief Makes the input available to the next iteration . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the next iteration. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. Has the same type as "x" . \n | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator RefNextIteration. | |||||
| */ | |||||
| REG_OP(RefNextIteration) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(RefNextIteration) | |||||
| /** | |||||
| *@brief Exits the current frame to its parent frame . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the parent frame. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Outputs: | |||||
| *y: A Tensor. Has the same type as "x" . \n | |||||
| *@see Enter() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator Exit. | |||||
| */ | |||||
| REG_OP(Exit) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(Exit) | |||||
| /** | |||||
| *@brief Exits the current frame to its parent frame . \n | |||||
| *@par Inputs: | |||||
| *x: The tensor to be made available to the parent frame. | |||||
| * Must be one of the following types: float16, float32, float64, int8, | |||||
| * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
| *@par Outputs: | |||||
| *y: A tensor. Has the same type as "x" . \n | |||||
| *@see Enter() | Exit() | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator RefExit. | |||||
| */ | |||||
| REG_OP(RefExit) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
| DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
| DT_UINT64, DT_BOOL})) | |||||
| .OP_END_FACTORY_REG(RefExit) | |||||
| /** | |||||
| *@brief Only useful as a placeholder for control edges. | |||||
| * It is similar to a no-op that always produces a live control output | |||||
| * even when some control inputs are dead . \n | |||||
| *@par Third-party framework compatibility | |||||
| *@Compatible with the TensorFlow operator ControlTrigger. | |||||
| */ | |||||
| REG_OP(ControlTrigger) | |||||
| .OP_END_FACTORY_REG(ControlTrigger) | |||||
| /** | |||||
| *@brief Returns index of shape in the map. | |||||
| *@par Inputs: | |||||
| * Three inputs, including: | |||||
| *@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8. | |||||
| *@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried. | |||||
| *@li level_index: One dimensional tensore of type int32, specifying secondary index. \n | |||||
| *@par Outputs: | |||||
| *@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map. | |||||
| *@par Third-party framework compatibility | |||||
| * It is a custom operator. It has no corresponding operator in Caffe. | |||||
| */ | |||||
| REG_OP(MapIndex) | |||||
| .INPUT(x, TensorType({DT_INT32})) | |||||
| .INPUT(data_seq, TensorType({DT_INT32})) | |||||
| .OPTIONAL_INPUT(level_index, TensorType({DT_INT32})) | |||||
| .OUTPUT(y, TensorType({DT_INT32})) | |||||
| .OP_END_FACTORY_REG(MapIndex) | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||||
| @@ -0,0 +1,234 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file array_ops_shape_fns.cpp | |||||
| * \brief | |||||
| */ | |||||
| #include "array_ops_shape_fns.h" | |||||
| #include "graph/types.h" | |||||
| #include "op_log.h" | |||||
| #include "error_util.h" | |||||
| #include "common_shape_fns.h" | |||||
| #include "axis_util.h" | |||||
| namespace ge { | |||||
| static graphStatus PadKnown(Operator& op, const Tensor& paddings_tensor, const int64_t input_dim_num) { | |||||
| TensorDesc paddings_tensor_desc = paddings_tensor.GetTensorDesc(); | |||||
| DataType data_type = paddings_tensor_desc.GetDataType(); | |||||
| std::vector<int64_t> data; | |||||
| // every dim has 2 element | |||||
| int64_t element_num = input_dim_num * 2; | |||||
| data.reserve(element_num); | |||||
| if (data_type == DT_INT32) { | |||||
| const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||||
| CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < element_num, | |||||
| OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||||
| for (int64_t i = 0; i < element_num; ++i) { | |||||
| data.push_back(static_cast<int64_t>(paddings_data[i])); | |||||
| } | |||||
| } else if (data_type == DT_INT64) { | |||||
| const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||||
| CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < element_num, | |||||
| OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||||
| for (int64_t i = 0; i < element_num; ++i) { | |||||
| data.push_back(paddings_data[i]); | |||||
| } | |||||
| } else { | |||||
| string err_msg = ConcatString("paddings data type invalid, ", "should be DT_INT32 or DT_INT64"); | |||||
| InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
| OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto dims = op.GetInputDesc(0).GetShape().GetDims(); | |||||
| std::vector<int64_t> output_dims(input_dim_num, UNKNOWN_DIM); | |||||
| if (dims != UNKNOWN_SHAPE) { | |||||
| output_dims.assign(dims.begin(), dims.end()); | |||||
| } | |||||
| for (size_t i = 0; i < data.size(); i += 2) { | |||||
| if ((data[i] < 0) || (data[i + 1] < 0)) { | |||||
| std::string err_msg = ConcatString("paddings", DebugString(data), " must be non-negative"); | |||||
| InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
| OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus status = Add(output_dims[i / 2], data[i] + data[i + 1], output_dims[i / 2]); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| std::string err_msg = ConcatString("the sum input[0] shape", DebugString(dims), " and input[1] value", | |||||
| DebugString(data), " must be non-negative"); | |||||
| OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| auto output_desc = op.GetOutputDesc("y"); | |||||
| output_desc.SetShape(Shape(output_dims)); | |||||
| return op.UpdateOutputDesc("y", output_desc); | |||||
| } | |||||
| graphStatus PadShapeFn(Operator& op) { | |||||
| Shape paddings; | |||||
| int64_t input_dim_num; | |||||
| graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| status = WithValue(paddings.GetDim(1), 2, input_dim_num, op.GetName().c_str()); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), | |||||
| ConcatString(2, " of dim[1]")); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Shape input; | |||||
| int64_t dim0 = paddings.GetDim(0); | |||||
| if (dim0 != UNKNOWN_DIM) { | |||||
| status = WithRank(op.GetInputDesc(0), dim0, input, op.GetName().c_str()); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } else if (op.GetInputDesc(0).GetShape().GetDim(0) != 0) { | |||||
| status = WithValue(dim0, op.GetInputDesc(0).GetShape().GetDimNum(), input_dim_num, op.GetName().c_str()); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| TensorDesc output_desc = op.GetOutputDesc("y"); | |||||
| Tensor paddings_tensor; | |||||
| status = op.GetInputConstData("paddings", paddings_tensor); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| if (dim0 != UNKNOWN_DIM) { | |||||
| std::vector<int64_t> output_shape(dim0, UNKNOWN_DIM); | |||||
| output_desc.SetShape(Shape(output_shape)); | |||||
| } else { | |||||
| output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||||
| } | |||||
| return op.UpdateOutputDesc("y", output_desc); | |||||
| } | |||||
| input_dim_num = paddings_tensor.GetTensorDesc().GetShape().GetDim(0); | |||||
| status = WithRank(op.GetInputDesc(0), input_dim_num, input, op.GetName().c_str()); | |||||
| if (status == GRAPH_FAILED) { | |||||
| OP_LOGE(op.GetName().c_str(), "WithRank fail"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| status = WithValue(dim0, input_dim_num, dim0, op.GetName().c_str()); | |||||
| if (status == GRAPH_FAILED) { | |||||
| OP_LOGE(op.GetName().c_str(), "WithValue fail"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return PadKnown(op, paddings_tensor, input_dim_num); | |||||
| } | |||||
| static graphStatus CalcPadGradOutDims(const Shape& input_shape, const Tensor& paddings_tensor, | |||||
| std::vector<int64_t>& output_dims, const char* op_name) { | |||||
| graphStatus status; | |||||
| size_t input_rank = input_shape.GetDimNum(); | |||||
| if (output_dims.size() < input_rank) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| DataType padding_type = paddings_tensor.GetTensorDesc().GetDataType(); | |||||
| if (padding_type == DT_INT32) { | |||||
| const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||||
| CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < input_rank, | |||||
| OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||||
| for (size_t i = 0; i < input_rank; ++i) { | |||||
| const int64_t pad0 = static_cast<int64_t>(paddings_data[2 * i]); | |||||
| const int64_t pad1 = static_cast<int64_t>(paddings_data[(2 * i) + 1]); | |||||
| if ((pad0 < 0) || (pad1 < 0)) { | |||||
| OP_LOGE(op_name, "Paddings must be non-negative, pad0= %lld, pad1=%lld.", pad0, pad1); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } else if (padding_type == DT_INT64) { | |||||
| const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||||
| CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < input_rank, | |||||
| OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||||
| for (size_t i = 0; i < input_rank; ++i) { | |||||
| const int64_t pad0 = paddings_data[2 * i]; | |||||
| const int64_t pad1 = paddings_data[(2 * i) + 1]; | |||||
| if ((pad0 < 0) || (pad1 < 0)) { | |||||
| OP_LOGE(op_name, "Paddings must be non-negative, pad0=%lld, pad1=%lld.", pad0, pad1); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus PadGradShapeFn(Operator& op) { | |||||
| Shape paddings; | |||||
| graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int64_t input_rank = paddings.GetDim(0); | |||||
| TensorDesc output_desc = op.GetOutputDesc("y"); | |||||
| output_desc.SetDataType(op.GetInputDesc(0).GetDataType()); | |||||
| if (input_rank == UNKNOWN_DIM) { | |||||
| OP_LOGE(op.GetName().c_str(), "paddings inputShape of 0 dims is unknown, set out shape unknown."); | |||||
| output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||||
| return op.UpdateOutputDesc("y", output_desc); | |||||
| } | |||||
| Shape input_shape; | |||||
| if (WithRank(op.GetInputDesc(0), input_rank, input_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { | |||||
| ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(input_rank)); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Shape check_shape({input_rank, 2}); | |||||
| if (Merge(paddings, check_shape, paddings, op.GetName().c_str())) { | |||||
| string err_msg = ConcatString("merge 1th input shape", DebugString(paddings.GetDims()), " and shape", | |||||
| DebugString(check_shape.GetDims()), " failed"); | |||||
| InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
| OP_LOGE(op.GetName().c_str(), "Input dimension mismatch, inputRank=%lld.", input_rank); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Tensor paddings_tensor; | |||||
| if (op.GetInputConstData("paddings", paddings_tensor) != GRAPH_SUCCESS) { | |||||
| std::vector<int64_t> unknow_dim_vec(input_rank, UNKNOWN_DIM); | |||||
| OP_LOGE(op.GetName().c_str(), "Get paddings input tensor fail, set outPut shape unknown."); | |||||
| output_desc.SetShape(Shape(unknow_dim_vec)); | |||||
| return op.UpdateOutputDesc("y", output_desc); | |||||
| } | |||||
| std::vector<int64_t> output_dims(input_rank); | |||||
| auto result = CalcPadGradOutDims(input_shape, paddings_tensor, output_dims, op.GetName().c_str()); | |||||
| if (result != GRAPH_SUCCESS) { | |||||
| string err_msg = ConcatString("calculate out dims failed,", "please check the validity of input and attribute"); | |||||
| InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
| OP_LOGE(op.GetName().c_str(), "Calculation PadGrad out dimensions failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| output_desc.SetShape(Shape(output_dims)); | |||||
| return op.UpdateOutputDesc("y", output_desc); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file array_ops_shape_fns.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||||
| #include "graph/operator.h" | |||||
| namespace ge { | |||||
| /* * | |||||
| * infer pad op shape | |||||
| * @param op Operator which need to infershape | |||||
| * @return status whether infershape success | |||||
| */ | |||||
| graphStatus PadShapeFn(Operator& op); | |||||
| /* * | |||||
| * infer pad grad op shape | |||||
| * @param op Operator which need to infershape | |||||
| * @return status whether infershape success | |||||
| */ | |||||
| graphStatus PadGradShapeFn(Operator& op); | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file axis_util.cpp | |||||
| * \brief get the axis value | |||||
| */ | |||||
| #include "axis_util.h" | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | |||||
| AxisUtil::AxisUtil() { | |||||
| getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)}, | |||||
| {FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)}, | |||||
| {FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)}, | |||||
| {FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)}, | |||||
| {FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)}, | |||||
| {FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}}; | |||||
| } | |||||
| int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { | |||||
| if (divisor == 0) { | |||||
| return 0; | |||||
| } else { | |||||
| return (dividend + divisor - 1) / divisor; | |||||
| } | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByOriginFormat(const Format& format, const vector<int64_t>& dimVec, const uint32_t& c0, | |||||
| vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
| auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
| if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
| LOG_INFO("Can not get axis value of old format %u!", format); | |||||
| return false; | |||||
| } | |||||
| GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; | |||||
| CHECK_NOTNULL(getAxisFunc); | |||||
| return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); | |||||
| } | |||||
| bool AxisUtil::HasAxisValueFunc(const Format& format) { | |||||
| auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
| if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
| LOG_INFO("Can not get axis value of format %u!", format); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::CheckParams(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
| vector<int64_t>& ndValue) { | |||||
| ndValue = originalDimVec; | |||||
| auto dimSize = originalDimVec.size(); | |||||
| if (dimSize < ge::DIM_DEFAULT_SIZE) { | |||||
| /* Before this funcion, we should call function PadDimensionTo4. */ | |||||
| LOG_INFO("Dimension size %zu is invalid.", dimSize); | |||||
| return false; | |||||
| } | |||||
| if (c0 == 0) { | |||||
| LOG_ERROR("[ERROR]c0 is zero!"); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByND(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
| vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| ndValue = originalDimVec; | |||||
| /* To differentiate the input datatype of int8 and others */ | |||||
| axisValue[AXIS_C0] = c0; | |||||
| if (originalDimVec.size() == NCHW_DIMENSION_NUM) { | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
| axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
| axisValue[AXIS_Co] = c0; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
| vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| /* C0 Must be set for case ND or 2D-NCHW to NZ */ | |||||
| axisValue[AXIS_C0] = c0; | |||||
| CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
| return false); | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
| axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
| axisValue[AXIS_Co] = c0; | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
| vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| /* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
| axisValue[AXIS_C0] = c0; | |||||
| CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
| return false); | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; | |||||
| axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); | |||||
| axisValue[AXIS_Co] = c0; | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
| return false); | |||||
| auto dimSize = originalDimVec.size(); | |||||
| if (dimSize == ge::DIM_DEFAULT_SIZE + 1) { | |||||
| axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; | |||||
| axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; | |||||
| axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; | |||||
| } else { | |||||
| axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
| axisValue[AXIS_C0] = c0; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
| } | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
| vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| /* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
| axisValue[AXIS_C0] = c0; | |||||
| CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
| return false); | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; | |||||
| axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); | |||||
| axisValue[AXIS_Co] = c0; | |||||
| return true; | |||||
| } | |||||
| bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
| /* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
| axisValue[AXIS_C0] = c0; | |||||
| CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
| return false); | |||||
| axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; | |||||
| axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; | |||||
| axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; | |||||
| axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; | |||||
| axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; | |||||
| axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; | |||||
| return true; | |||||
| } | |||||
| }; // namespace ge | |||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file axis_util.h | |||||
| * \brief get the axis value | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||||
| #include <memory.h> | |||||
| #include <functional> | |||||
| #include <vector> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "operator.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "op_log.h" | |||||
| #define LOG_ERROR(format, args...) printf(format, ##args) | |||||
| #define LOG_INFO(format, args...) printf(format, ##args) | |||||
| namespace ge { | |||||
| const uint32_t NCHW_DIMENSION_NUM = 4; | |||||
| const int32_t AXIS_NCHW_DIM_N = 0; | |||||
| const int32_t AXIS_NCHW_DIM_C = 1; | |||||
| const int32_t AXIS_NCHW_DIM_H = 2; | |||||
| const int32_t AXIS_NCHW_DIM_W = 3; | |||||
| const int32_t AXIS_NHWC_DIM_N = 0; | |||||
| const int32_t AXIS_NHWC_DIM_H = 1; | |||||
| const int32_t AXIS_NHWC_DIM_W = 2; | |||||
| const int32_t AXIS_NHWC_DIM_C = 3; | |||||
| const int32_t AXIS_NC1HWC0_DIM_N = 0; | |||||
| const int32_t AXIS_NC1HWC0_DIM_C1 = 1; | |||||
| const int32_t AXIS_NC1HWC0_DIM_C0 = 4; | |||||
| const int32_t AXIS_NC1HWC0_DIM_H = 2; | |||||
| const int32_t AXIS_NC1HWC0_DIM_W = 3; | |||||
| const int32_t AXIS_HWCN_DIM_H = 0; | |||||
| const int32_t AXIS_HWCN_DIM_W = 1; | |||||
| const int32_t AXIS_HWCN_DIM_C = 2; | |||||
| const int32_t AXIS_HWCN_DIM_N = 3; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_H = 1; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_W = 2; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_N = 3; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; | |||||
| const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; | |||||
| #define CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if ((val) == nullptr) { \ | |||||
| LOG_ERROR("[ERROR]Parameter[%s] must not be null.", #val); \ | |||||
| return false; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define CHECK(cond, log_func, return_expr) \ | |||||
| do { \ | |||||
| if (cond) { \ | |||||
| log_func; \ | |||||
| return_expr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| enum AxisValueType { | |||||
| AXIS_N = 0, | |||||
| AXIS_C = 1, | |||||
| AXIS_H = 2, | |||||
| AXIS_W = 3, | |||||
| AXIS_C1 = 4, | |||||
| AXIS_C0 = 5, | |||||
| AXIS_Co = 6, | |||||
| AXIS_D = 7, | |||||
| AXIS_BOTTOM = 8 | |||||
| }; | |||||
| int64_t DivisionCeiling(int64_t dividend, int64_t divisor); | |||||
| /* Axis value is arranged as {N,C,H,W,C1,C0,...} */ | |||||
| /* The first parameter is old shape's dimension, | |||||
| * second is c0 and third is axis value. */ | |||||
| using GetAxisValueInfoByFormat = | |||||
| std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>; | |||||
| using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>; | |||||
| class AxisUtil { | |||||
| public: | |||||
| AxisUtil(); | |||||
| ~AxisUtil(){}; | |||||
| bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| bool HasAxisValueFunc(const ge::Format& format); | |||||
| private: | |||||
| static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
| std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
| /* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
| * formats. */ | |||||
| std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||||
| @@ -0,0 +1,417 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file common_shape_fns.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/tensor.h" | |||||
| #include "graph/operator.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "error_code.h" | |||||
| namespace ge { | |||||
| /** | |||||
| * Check whether Shape's rank is at least rank | |||||
| * @param tensor Input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out Output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Check whether Shape's rank is at least rank | |||||
| * @param tensor Input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out Output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
| /** | |||||
| * Check whether Shape's rank is equal to rank | |||||
| * @param tensor Input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out Output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Check whether Shape's rank is equal to rank | |||||
| * @param tensor Input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out Output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
| /** | |||||
| * Check whether Shape's rank is equal to rank | |||||
| * @param tensor Input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out Output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape); | |||||
| /** | |||||
| * Check whether dim is equal to value | |||||
| * @param dim Input dim | |||||
| * @param value expect val of dim | |||||
| * @param out Output dim | |||||
| * @return status whether Dim is equal to value | |||||
| */ | |||||
| graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name); | |||||
| /** | |||||
| * Merge two dims of Shape | |||||
| * @param dim0 first dim val | |||||
| * @param dim1 second dim val | |||||
| * @param out merged dim val | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out); | |||||
| /** | |||||
| * Merge two shapes | |||||
| * @param s0 first shape val | |||||
| * @param s1 second shape val | |||||
| * @param out merged shape val | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Merge two shapes | |||||
| * @param s0 first Geshape val | |||||
| * @param s1 second Geshape val | |||||
| * @param out merged Geshape val | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name); | |||||
| /** | |||||
| * Replace one dim in a given shape | |||||
| * @param s original shape | |||||
| * @param dim_index_in dim index | |||||
| * @param new_dim new dim value | |||||
| * @param out new shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Replace one dim in a given shape | |||||
| * @param s original shape | |||||
| * @param dim_index_in dim index | |||||
| * @param new_dim new dim value | |||||
| * @param out new shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name); | |||||
| /** | |||||
| * Check if it satisfies 0 <= index < limit | |||||
| * @param index first input | |||||
| * @param limit second input | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| template <typename Ta, typename Tb> | |||||
| bool FastBoundsCheck(const Ta index, const Tb limit); | |||||
| /** | |||||
| * Add two dims | |||||
| * @param dim0 first dim val | |||||
| * @param dim1 second dim val | |||||
| * @param out sum dim val | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out); | |||||
| /** | |||||
| * Subtract two dims | |||||
| * @param dim0 first dim val | |||||
| * @param dim1 second dim val | |||||
| * @param out Subtract dim val | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name); | |||||
| /** | |||||
| * Get SubShape according to start end index and step size stride | |||||
| * @param s input Shape | |||||
| * @param start sub start index | |||||
| * @param end sub end index | |||||
| * @param stride sub step size | |||||
| * @param out sub shape output | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Get SubShape according to start end index and step size stride | |||||
| * @param s input Shape | |||||
| * @param start sub start index | |||||
| * @param end sub end index | |||||
| * @param stride sub step size | |||||
| * @param out sub shape output | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus SubShape(const GeShape& s, size_t start, size_t end, size_t stride, GeShape& out); | |||||
| /** | |||||
| * Get SubShape according to start end index and step size stride | |||||
| * @param s input Shape | |||||
| * @param start sub start index | |||||
| * @param end sub end index | |||||
| * @param stride sub step size | |||||
| * @param out sub shape output | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus SubShape(const GeShape& s, int64_t start, int64_t end, int64_t stride, GeShape& out, const char* op_name); | |||||
| /** | |||||
| * Concatenate two shape | |||||
| * @param s1 first shape | |||||
| * @param s2 second shape | |||||
| * @param out concatenated shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out); | |||||
| /** | |||||
| * Concatenate two shape | |||||
| * @param s1 first shape | |||||
| * @param s2 second shape | |||||
| * @param out concatenated shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out); | |||||
| /** | |||||
| * Gen matrix shape according d1 and d2 | |||||
| * @param dim1 first dim val | |||||
| * @param dim2 first dim val | |||||
| * @param out matrix shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out); | |||||
| /** | |||||
| * Gen vector shape according d | |||||
| * @param dim dim val | |||||
| * @param out vector shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Vector(int64_t dim, Shape& out); | |||||
| /** | |||||
| * Make shape from shape tensor | |||||
| * @param tensor shape tensor | |||||
| * @param out shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Make shape from shape tensor | |||||
| * @param op Operator | |||||
| * @param dst_name const string & | |||||
| * @param out GeShape | |||||
| * @param op_name const char * | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name); | |||||
| /** | |||||
| * Make dim from scalar tensor | |||||
| * @param tensor shape tensor | |||||
| * @param out shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name); | |||||
| /** | |||||
| * Check whether Shape's rank is at most rank | |||||
| * @param tensor input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
| /** | |||||
| * Check whether Shape's rank is at most rank | |||||
| * @param tensor input tensor | |||||
| * @param rank expect val of Shape | |||||
| * @param out output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
| /** | |||||
| * make a empty dim shape | |||||
| * @param out output Shape | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus Scalar(Shape& out); | |||||
| /** | |||||
| * set input_name shape to output_name shape | |||||
| * @param op Operator which need to infershape | |||||
| * @param input_name input name of Operator | |||||
| * @param output_name ouput name of Operator | |||||
| * @return status whether infershape success | |||||
| */ | |||||
| graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name); | |||||
| /** | |||||
| * Devide dim | |||||
| * @param dividend | |||||
| * @param divisor | |||||
| * @param evenlyDivisible if to be divisible | |||||
| * @param out dims | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out, | |||||
| const char* op_name); | |||||
| /** | |||||
| * check shape fully defined or not | |||||
| * @param shape Shape is checked | |||||
| * @return whether shape is fully defined | |||||
| */ | |||||
| bool ShapeFullDefined(const Shape& shape); | |||||
| /** | |||||
| * check shape fully defined or not | |||||
| * @param shape Shape is checked | |||||
| * @return whether shape is fully defined | |||||
| */ | |||||
| bool ShapeFullyDefined(const GeShape& shape); | |||||
| /** | |||||
| * check shape known or not | |||||
| * @param shape Shape is checked | |||||
| * @return whether rank is known | |||||
| */ | |||||
| bool RankKnown(const Shape& shape); | |||||
| /** | |||||
| * check ge_shape known or not | |||||
| * @param shape GeShape is checked | |||||
| * @return whether rank is known | |||||
| */ | |||||
| bool RankKnown(const GeShape& shape); | |||||
| /** | |||||
| * make a unknown shape with rank | |||||
| * @return unknown shape | |||||
| */ | |||||
| Shape UnknownShapeOfRank(int64_t rank); | |||||
| /** | |||||
| * check dim value known or not | |||||
| * @param shape which Shape need check dim value | |||||
| * @param dimIndex the index of dim | |||||
| * @return whether dim value is known | |||||
| */ | |||||
| bool ValueKnown(const Shape& shape, const size_t& dim_index); | |||||
| /** | |||||
| * Validates the 3 component tensors of a sparse tensor | |||||
| * have the proper shapes. | |||||
| * @param sparse indices shape | |||||
| * @param sparse values shape | |||||
| * @param sparse shape | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape, | |||||
| const char* op_name); | |||||
| /** | |||||
| * DecodeWavShapeFn, infereshape funtion of DecodeWav op | |||||
| * @param op Operator | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus DecodeWavShapeFn(Operator& op); | |||||
| /** | |||||
| * EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||||
| * @param op Operator | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus EncodeWavShapeFn(Operator& op); | |||||
| /** | |||||
| * EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||||
| * @param op Operator | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus EncodeWavShapeFn(Operator& op); | |||||
| /** | |||||
| * Infereshape funtion of SparseSegmentReduction op | |||||
| * @param op Operator | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus SparseSegmentReductionShapeFn(Operator& op); | |||||
| /** | |||||
| * Infereshape funtion of SparseSegmentReductionGrad op | |||||
| * @param op Operator | |||||
| * @return status whether Shape's condition Satisfied | |||||
| */ | |||||
| graphStatus SparseSegmentReductionGradShapeFn(Operator& op); | |||||
| /** | |||||
| * Validates variable resource handle | |||||
| * @param op Operator | |||||
| * @param shape_and_type ShapeAndType vector | |||||
| * @return status whether this operation success | |||||
| */ | |||||
| graphStatus ValidateVariableResourceHandle(Operator& op, std::vector<ShapeAndType>& shape_and_type); | |||||
| /** | |||||
| * Fill op_desc with input shape | |||||
| * @param op_desc Operator desc ptr | |||||
| * @param shape input tensor shape | |||||
| * @param shape input tensor datatype | |||||
| */ | |||||
| void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type = DT_FLOAT); | |||||
| /** | |||||
| * InferShapeErrorReport info | |||||
| * @param op_name Operator name | |||||
| * @param op_type Operator type | |||||
| * @param value Operator value | |||||
| * @param reason error reason | |||||
| */ | |||||
| void InferShapeErrorReport(const std::string& op_name, const std::string& op_type, | |||||
| const std::string& value, const std::string& reason); | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file error_code.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||||
| namespace ge { | |||||
| // error code for report purpose. | |||||
| // 30000~34999 for aicpu engine error | |||||
| // and 35000~39999 for infershape error of aicpu op | |||||
| enum ViewErrorCode { | |||||
| INVALID_INFER_SHAPE = 14001, | |||||
| INVALID_INPUT_SHAPE = 35000, | |||||
| INVALID_ATTR_VALUE = 35001, | |||||
| INVALID_ATTR_SIZE = 35002, | |||||
| OTHER_ERROR = 35003, | |||||
| INVALID_CONV_ATTR_VALUE = 50029, | |||||
| INVALID_CONV_SET_ATTR = 50057, | |||||
| INVALID_CONV_SHAPE = 50058, | |||||
| INVALID_MISS_INPUT = 70001, | |||||
| INVALID_INPUT_FORMAT = 70002, | |||||
| INVALID_INPUT_DTYPE = 70003, | |||||
| INVALID_INPUT_TYPE = 70004, | |||||
| INVALID_GET_ATTR = 70005, | |||||
| INVALID_SET_ATTR = 70006, | |||||
| INVALID_OPS_ATTR_VALUE = 70007, | |||||
| FAILED_UPDATE_OP = 70008, | |||||
| INVALID_SHAPE = 70009, | |||||
| INVALID_SHAPE_SIZE = 70010, | |||||
| INVALID_SHAPE_DIM = 70011, | |||||
| INVALID_BROADCAST_SHAPE = 70012, | |||||
| INVALID_TWO_INPUT_DTYPE = 70013, | |||||
| INVALID_AIPP_ERROR = 70014, | |||||
| INVALID_ONE_INPUT_SHAPE = 70015, | |||||
| INVALID_TWO_INPUT_SHAPE = 70016, | |||||
| INVALID_ONE_OUTPUT_SHAPE = 70017, | |||||
| FAILED_GET_COMPILIE_PARAMS = 70018, | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||||
| @@ -0,0 +1,318 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file error_util.cpp | |||||
| * \brief | |||||
| */ | |||||
| #include <map> | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "error_util.h" | |||||
| #include "error_code.h" | |||||
| #include "op_log.h" | |||||
| using namespace std; | |||||
| using namespace ge; | |||||
| namespace ge { | |||||
| inline static std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode) { | |||||
| return "E" + std::to_string(errCode); | |||||
| } | |||||
| void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||||
| const std::string& correct_shape) { | |||||
| map<string, string> err_map; | |||||
| err_map["index"] = std::to_string(index); | |||||
| err_map["opname"] = opname; | |||||
| err_map["wrong_shape"] = wrong_shape; | |||||
| err_map["correct_shape"] = correct_shape; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_SHAPE); | |||||
| (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||||
| const std::string& correct_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["attrname"] = attrName; | |||||
| err_map["opname"] = opname; | |||||
| err_map["wrong_value"] = wrong_value; | |||||
| err_map["correct_value"] = correct_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_VALUE); | |||||
| (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||||
| const std::string& correct_size) { | |||||
| map<string, string> err_map; | |||||
| err_map["attrname"] = attrName; | |||||
| err_map["opname"] = opname; | |||||
| err_map["wrong_size"] = wrong_size; | |||||
| err_map["correct_size"] = correct_size; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_SIZE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg) { | |||||
| map<string, string> err_map; | |||||
| err_map["opname"] = opname; | |||||
| err_map["err_msg"] = err_msg; | |||||
| string report_error_code = GetViewErrorCodeStr(ViewErrorCode::OTHER_ERROR); | |||||
| (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_MISS_INPUT); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_format_list, const std::string& data_format) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["expected_format_list"] = expected_format_list; | |||||
| err_map["format"] = data_format; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_FORMAT); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_data_type_list, const std::string& data_type) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["expected_data_type_list"] = expected_data_type_list; | |||||
| err_map["data_type"] = data_type; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_DTYPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||||
| const std::string& actual_type) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["param_type"] = param_type; | |||||
| err_map["actual_type"] = actual_type; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_TYPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_GET_ATTR); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SET_ATTR); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||||
| const std::string& input_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["excepted_value"] = excepted_value; | |||||
| err_map["input_value"] = input_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_OPS_ATTR_VALUE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_UPDATE_OP); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||||
| const std::string& param_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["rule_desc"] = rule_desc; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["param_value"] = param_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& error_detail) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["error_detail"] = error_detail; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_INPUT_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||||
| const std::string& param_name2, const std::string& error_detail) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name1"] = param_name1; | |||||
| err_map["param_name2"] = param_name2; | |||||
| err_map["error_detail"] = error_detail; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& error_detail) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["error_detail"] = error_detail; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_OUTPUT_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_GET_COMPILIE_PARAMS); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||||
| const std::string& real_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["input_name"] = input_name; | |||||
| err_map["max_value"] = max_value; | |||||
| err_map["real_value"] = real_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_SIZE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||||
| const std::string& min_value, const std::string& real_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["max_value"] = max_value; | |||||
| err_map["min_value"] = min_value; | |||||
| err_map["real_value"] = real_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_DIM); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||||
| const std::string& input2_name, const std::string& input1_shape, | |||||
| const std::string& input2_shape) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["input1_name"] = input1_name; | |||||
| err_map["input2_name"] = input2_name; | |||||
| err_map["input1_shape"] = input1_shape; | |||||
| err_map["input2_shape"] = input2_shape; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_BROADCAST_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_dtype_list, const std::string& dtype) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["expected_dtype_list"] = expected_dtype_list; | |||||
| err_map["dtype"] = dtype; | |||||
| std::string report_error_code = "E50034"; | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||||
| const std::string& input2_name, const std::string& input1_dtype, | |||||
| const std::string& input2_dtype) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["input1_name"] = input1_name; | |||||
| err_map["input2_name"] = input2_name; | |||||
| err_map["input1_dtype"] = input1_dtype; | |||||
| err_map["input2_dtype"] = input2_dtype; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_DTYPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||||
| const std::string& data_W) { | |||||
| map<string, string> err_map; | |||||
| err_map["aipp_output_H"] = aipp_output_H; | |||||
| err_map["aipp_output_W"] = aipp_output_W; | |||||
| err_map["data_H"] = data_H; | |||||
| err_map["data_W"] = data_W; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_AIPP_ERROR); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||||
| const std::string& input_value) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param_name"] = param_name; | |||||
| err_map["expected_value"] = expected_value; | |||||
| err_map["input_value"] = input_value; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_ATTR_VALUE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||||
| const std::string& param2_name) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["param1_name"] = param1_name; | |||||
| err_map["param2_name"] = param2_name; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SET_ATTR); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void OpsConvShapeErrReport(const std::string& op_name, const std::string& description) { | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = op_name; | |||||
| err_map["description"] = description; | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SHAPE); | |||||
| ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
| } | |||||
| void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||||
| const std::string& reason) { | |||||
| std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INFER_SHAPE); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"}, | |||||
| {op_name, op_type, value, reason}); | |||||
| } | |||||
| void CommonRuntimeErrLog(const std::string& opname, const std::string& description){ | |||||
| map<string, string> err_map; | |||||
| err_map["op_name"] = opname; | |||||
| err_map["description"] = description; | |||||
| OP_LOGE(opname.c_str(), description); | |||||
| (void)ErrorManager::GetInstance().ReportErrMessage("E50058", err_map); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,184 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file error_util.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "operator.h" | |||||
| namespace ge { | |||||
| /* | |||||
| * get debug string of vector | |||||
| * param[in] v vector | |||||
| * return vector's debug string | |||||
| */ | |||||
| template <typename T> | |||||
| std::string DebugString(const std::vector<T>& v) { | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| if (v.size() > 0) { | |||||
| for (size_t i = 0; i < v.size() - 1; ++i) { | |||||
| oss << v[i] << ", "; | |||||
| } | |||||
| oss << v[v.size() - 1]; | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| } | |||||
| /* | |||||
| * str cat util function | |||||
| * param[in] params need concat to string | |||||
| * return concatted string | |||||
| */ | |||||
| template <typename T> | |||||
| std::string ConcatString(T arg) { | |||||
| std::ostringstream oss; | |||||
| oss << arg; | |||||
| return oss.str(); | |||||
| } | |||||
| template <typename T, typename... Ts> | |||||
| std::string ConcatString(T arg, Ts... arg_left) { | |||||
| std::ostringstream oss; | |||||
| oss << arg; | |||||
| oss << ConcatString(arg_left...); | |||||
| return oss.str(); | |||||
| } | |||||
| /* | |||||
| * report input shape error of infer shape | |||||
| * param[in] index the index of input | |||||
| * param[in] opname op name | |||||
| * param[in] wrong_shape wrong input shape | |||||
| * param[in] correct_shape correct input shape | |||||
| * return void | |||||
| */ | |||||
| void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||||
| const std::string& correct_shape); | |||||
| /* | |||||
| * report attr value error of infer shape | |||||
| * param[in] attrname the attr name | |||||
| * param[in] opname op name | |||||
| * param[in] wrong_value wrong attr value | |||||
| * param[in] correct_value correct attr value | |||||
| * return void | |||||
| */ | |||||
| void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||||
| const std::string& correct_value); | |||||
| /* | |||||
| * report attr size error of infer shape | |||||
| * param[in] attrname the attr name | |||||
| * param[in] opname op name | |||||
| * param[in] wrong_size wrong attr size | |||||
| * param[in] correct_size correct attr size | |||||
| * return void | |||||
| */ | |||||
| void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||||
| const std::string& correct_size); | |||||
| /* | |||||
| * report common error of infer shape | |||||
| * param[in] opname op name | |||||
| * param[in] err_msg error message | |||||
| * return void | |||||
| */ | |||||
| void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg); | |||||
| void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name); | |||||
| void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_format_list, const std::string& data_format); | |||||
| void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_data_type_list, const std::string& data_type); | |||||
| void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||||
| const std::string& actual_type); | |||||
| void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||||
| void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||||
| void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||||
| const std::string& input_value); | |||||
| void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name); | |||||
| void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||||
| const std::string& param_value); | |||||
| void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& error_detail); | |||||
| void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||||
| const std::string& param_name2, const std::string& error_detail); | |||||
| void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& error_detail); | |||||
| void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name); | |||||
| void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||||
| const std::string& real_value); | |||||
| void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||||
| const std::string& min_value, const std::string& real_value); | |||||
| void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||||
| const std::string& input2_name, const std::string& input1_shape, | |||||
| const std::string& input2_shape); | |||||
| void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
| const std::string& expected_dtype_list, const std::string& dtype); | |||||
| void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||||
| const std::string& input2_name, const std::string& input1_dtype, | |||||
| const std::string& input2_dtype); | |||||
| void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||||
| const std::string& data_W); | |||||
| void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||||
| const std::string& input_value); | |||||
| void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||||
| const std::string& param2_name); | |||||
| void OpsConvShapeErrReport(const std::string& op_name, const std::string& description); | |||||
| void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||||
| const std::string& reason); | |||||
| /* | |||||
| * log common runtime error | |||||
| * param[in] opname op name | |||||
| * param[in] error description | |||||
| * return void | |||||
| */ | |||||
| void CommonRuntimeErrLog(const std::string& opname, const std::string& description); | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file op_common_util.h | |||||
| * \brief common util for op, in this file only original type or class in C++ allowed | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <sstream> | |||||
| template <typename T1, typename T2> | |||||
| std::ostream& operator<<(std::ostream& os, const std::pair<T1, T2>& values) { | |||||
| os << "[" << values.first << ", " << values.second << "]"; | |||||
| return os; | |||||
| } | |||||
| template <typename T> | |||||
| std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) { | |||||
| os << "["; | |||||
| for (const auto& item : values) { | |||||
| os << item << ", "; | |||||
| } | |||||
| os << "]"; | |||||
| return os; | |||||
| } | |||||
| namespace ops { | |||||
| template<typename T> | |||||
| std::string to_string(const std::vector<T> &items) { | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| for (const auto &item: items) { | |||||
| oss << item << ", "; | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| } | |||||
| template<typename T> | |||||
| std::string to_string(const std::set<T> &items) { | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| for (const auto &item: items) { | |||||
| oss << item << ", "; | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| } | |||||
| } // namespace ops | |||||
| #endif //OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file op_log.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef GE_OP_LOG_H | |||||
| #define GE_OP_LOG_H | |||||
| #if !defined( __ANDROID__) && !defined(ANDROID) | |||||
| #include "toolchain/slog.h" | |||||
| #else | |||||
| #include <utils/Log.h> | |||||
| #endif | |||||
| #define OPPROTO_SUBMOD_NAME "OP_PROTO" | |||||
| #if !defined( __ANDROID__) && !defined(ANDROID) | |||||
| #define OP_LOGI(opname, ...) D_OP_LOGI(opname, __VA_ARGS__) | |||||
| #define OP_LOGW(opname, ...) D_OP_LOGW(opname, __VA_ARGS__) | |||||
| #define OP_LOGE(opname, ...) D_OP_LOGE(opname, __VA_ARGS__) | |||||
| #define OP_LOGD(opname, ...) D_OP_LOGD(opname, __VA_ARGS__) | |||||
| #define GE_OP_LOGI(opname, ...) GE_D_OP_LOGI(opname, __VA_ARGS__) | |||||
| #define GE_OP_LOGW(opname, ...) GE_D_OP_LOGW(opname, __VA_ARGS__) | |||||
| #define GE_OP_LOGE(opname, ...) GE_D_OP_LOGE(opname, __VA_ARGS__) | |||||
| #define GE_OP_LOGD(opname, ...) GE_D_OP_LOGD(opname, __VA_ARGS__) | |||||
| #define FUSION_PASS_LOGI(...) D_FUSION_PASS_LOGI(__VA_ARGS__) | |||||
| #define FUSION_PASS_LOGW(...) D_FUSION_PASS_LOGW(__VA_ARGS__) | |||||
| #define FUSION_PASS_LOGE(...) D_FUSION_PASS_LOGE(__VA_ARGS__) | |||||
| #define FUSION_PASS_LOGD(...) D_FUSION_PASS_LOGD(__VA_ARGS__) | |||||
| #else | |||||
| #define OP_LOGI(opname, ...) | |||||
| #define OP_LOGW(opname, ...) | |||||
| #define OP_LOGE(opname, ...) | |||||
| #define OP_LOGD(opname, ...) | |||||
| #define FUSION_PASS_LOGI(...) | |||||
| #define FUSION_PASS_LOGW(...) | |||||
| #define FUSION_PASS_LOGE(...) | |||||
| #define FUSION_PASS_LOGD(...) | |||||
| #endif | |||||
| #if !defined( __ANDROID__) && !defined(ANDROID) | |||||
| #define D_OP_LOGI(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define D_OP_LOGW(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define D_OP_LOGE(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define D_OP_LOGD(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define GE_D_OP_LOGI(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define GE_D_OP_LOGW(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define GE_D_OP_LOGE(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define GE_D_OP_LOGD(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
| #define D_FUSION_PASS_LOGI(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
| #define D_FUSION_PASS_LOGW(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
| #define D_FUSION_PASS_LOGE(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
| #define D_FUSION_PASS_LOGD(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
| #else | |||||
| #define D_OP_LOGI(opname, fmt, ...) | |||||
| #define D_OP_LOGW(opname, fmt, ...) | |||||
| #define D_OP_LOGE(opname, fmt, ...) | |||||
| #define D_OP_LOGD(opname, fmt, ...) | |||||
| #define D_FUSION_PASS_LOGI(fmt, ...) | |||||
| #define D_FUSION_PASS_LOGW(fmt, ...) | |||||
| #define D_FUSION_PASS_LOGE(fmt, ...) | |||||
| #define D_FUSION_PASS_LOGD(fmt, ...) | |||||
| #endif | |||||
| #define OP_CHECK(condition, log_func, do_expr) \ | |||||
| static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \ | |||||
| do { \ | |||||
| if (condition) { \ | |||||
| log_func; \ | |||||
| do_expr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #endif //GE_OP_LOG_H | |||||
| @@ -0,0 +1,258 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file transfer_shape_according_to_format.cpp | |||||
| * \brief set shape according to original format and current format | |||||
| */ | |||||
| #include "transfer_shape_according_to_format.h" | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { | |||||
| getNewShapeFuncMap = { | |||||
| {ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, | |||||
| {ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, | |||||
| {ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, | |||||
| {ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, | |||||
| {ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, | |||||
| {ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, | |||||
| {ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; | |||||
| mapOfDtypeAndC0 = { | |||||
| {ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, | |||||
| {ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, | |||||
| {ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, | |||||
| {ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_C]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newDimVec.push_back(axisValue[AXIS_C]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
| CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_C1]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newDimVec.push_back(axisValue[AXIS_C0]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| } else { | |||||
| CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_C]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| if (ndValue.size() == SIZE_OF_CN) { | |||||
| CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| auto sizeOfOriginalVec = ndValue.size(); | |||||
| std::vector<int64_t> newDimVec = ndValue; | |||||
| /* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
| * sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
| newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
| DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); | |||||
| newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
| DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); | |||||
| newDimVec.push_back(SHAPE_NUMBER_16); | |||||
| newDimVec.push_back(axisValue[AXIS_C0]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| } else { | |||||
| if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
| CHECK(axisValue.size() <= AXIS_C1, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; | |||||
| newDimVec.push_back(hwc1); | |||||
| newDimVec.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); | |||||
| newDimVec.push_back(NI); | |||||
| newDimVec.push_back(axisValue[AXIS_C0]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| } else { | |||||
| CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_C]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newDimVec.push_back(axisValue[AXIS_C]); | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(axisValue.size() <= AXIS_Co, LOG_INFO("AxisValue is not correct!"), return true); | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec; | |||||
| newDimVec.push_back(axisValue[AXIS_C1]); | |||||
| newDimVec.push_back(axisValue[AXIS_H]); | |||||
| newDimVec.push_back(axisValue[AXIS_W]); | |||||
| newDimVec.push_back(axisValue[AXIS_N]); | |||||
| newDimVec.push_back(axisValue[AXIS_Co]); | |||||
| newDimVec.push_back(axisValue[AXIS_C0]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue) { | |||||
| CHECK(ndValue.empty(), LOG_INFO("ndValue is empty!"), return true); | |||||
| CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, | |||||
| LOG_INFO("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); | |||||
| uint32_t sizeOfOriginalVec = ndValue.size(); | |||||
| if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { | |||||
| LOG_INFO("ndValue's dim num is less than 2!"); | |||||
| return true; | |||||
| } | |||||
| /* axisValue is initialized as a size 6 vector. */ | |||||
| std::vector<int64_t> newDimVec = ndValue; | |||||
| /* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
| * sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
| newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
| DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); | |||||
| newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
| DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); | |||||
| newDimVec.push_back(SHAPE_NUMBER_16); | |||||
| newDimVec.push_back(axisValue[AXIS_C0]); | |||||
| newShape = ge::GeShape(newDimVec); | |||||
| return true; | |||||
| } | |||||
| bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { | |||||
| /* The default new shape is old shape */ | |||||
| shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; | |||||
| if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { | |||||
| LOG_ERROR("Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||||
| return false; | |||||
| } | |||||
| if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { | |||||
| LOG_ERROR("currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); | |||||
| return false; | |||||
| } | |||||
| AxisUtil* axisutil_object = new AxisUtil(); | |||||
| if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { | |||||
| delete axisutil_object; | |||||
| return true; | |||||
| } | |||||
| auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); | |||||
| if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { | |||||
| LOG_INFO("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); | |||||
| delete axisutil_object; | |||||
| return true; | |||||
| } | |||||
| LOG_INFO("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||||
| GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; | |||||
| CHECK_NOTNULL(getNewShapeFunc); | |||||
| std::vector<int64_t> axisValue; | |||||
| for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { | |||||
| axisValue.push_back(1); | |||||
| } | |||||
| std::vector<int64_t> ndValue; | |||||
| uint32_t c0; | |||||
| if (mapOfDtypeAndC0.empty()) { | |||||
| c0 = SHAPE_NUMBER_16; | |||||
| } else { | |||||
| auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); | |||||
| if (iterGetC0 == mapOfDtypeAndC0.end()) { | |||||
| LOG_ERROR("Dtype is not support."); | |||||
| delete axisutil_object; | |||||
| return true; | |||||
| } | |||||
| c0 = iterGetC0->second; | |||||
| } | |||||
| // The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 | |||||
| if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { | |||||
| c0 = SHAPE_DIM_VALUE_C04; | |||||
| } | |||||
| bool status = axisutil_object->GetAxisValueByOriginFormat( | |||||
| shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape.GetDims(), c0, axisValue, ndValue); | |||||
| if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { | |||||
| delete axisutil_object; | |||||
| return true; | |||||
| } | |||||
| delete axisutil_object; | |||||
| (*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); | |||||
| if (c != nullptr) { | |||||
| *c = axisValue[AXIS_C]; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| }; // namespace ge | |||||
| @@ -0,0 +1,129 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file transfer_shape_according_to_format.h | |||||
| * \brief set shape according to original format and current format | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
| #include "axis_util.h" | |||||
| #include <memory.h> | |||||
| #include <functional> | |||||
| #include <vector> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "operator.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/tensor.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "op_log.h" | |||||
| #define LOG_ERROR(format, args...) printf(format, ##args) | |||||
| #define LOG_INFO(format, args...) printf(format, ##args) | |||||
| namespace ge { | |||||
| enum OpImplType { | |||||
| EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | |||||
| EN_IMPL_CUSTOM_TIK, // custom tik op | |||||
| EN_IMPL_CUSTOM_TBE, // custom tbe op | |||||
| EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op | |||||
| EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op | |||||
| EN_IMPL_HW_TIK, // Huawei built-in tik op | |||||
| EN_IMPL_HW_TBE, // Huawei built-in tbe op | |||||
| EN_IMPL_RL, // RL op | |||||
| EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op | |||||
| EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op | |||||
| EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op | |||||
| EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op | |||||
| EN_RESERVED // reserved value | |||||
| }; | |||||
| const uint32_t SHAPE_NUMBER_16 = 16; | |||||
| const uint32_t SHAPE_NUMBER_32 = 32; | |||||
| const uint32_t SHAPE_DIM_VALUE_C04 = 4; | |||||
| const uint32_t NI = 16; | |||||
| const uint32_t MINUS_VALUE_ONE = 1; | |||||
| const uint32_t MINUS_VALUE_TWO = 2; | |||||
| const uint32_t SIZE_OF_CN = 2; | |||||
| const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; | |||||
| /* The first parameter is axis value, second is new shape and third is | |||||
| * op implementation type. */ | |||||
| using GetNewShapeByAxisValueAndFormat = | |||||
| std::function<bool(ge::GeShape&, const int64_t&, vector<int64_t>&, vector<int64_t>&)>; | |||||
| using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>; | |||||
| struct ShapeAndFormatInfo { | |||||
| const ge::GeShape& oldShape; | |||||
| ge::GeShape& newShape; | |||||
| const ge::Format& oldFormat; | |||||
| const ge::Format& newFormat; | |||||
| const ge::DataType& currentDataType; | |||||
| const int64_t& opImplType; | |||||
| }; | |||||
| using ShapeAndFormat = struct ShapeAndFormatInfo; | |||||
| class ShapeTransferAccordingToFormat { | |||||
| public: | |||||
| ShapeTransferAccordingToFormat(); | |||||
| ~ShapeTransferAccordingToFormat(){}; | |||||
| ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; | |||||
| ShapeTransferAccordingToFormat& operator=(const ShapeTransferAccordingToFormat&) = delete; | |||||
| bool GetShapeAccordingToFormat(ShapeAndFormat& inputAndOutputInfo, int64_t* c = nullptr); | |||||
| /* ----------Below is the function of getting new shape---------------------- */ | |||||
| static bool GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue); | |||||
| static bool GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue); | |||||
| static bool GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||||
| static bool GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue); | |||||
| static bool GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue); | |||||
| static bool GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
| const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||||
| static bool GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
| const vector<int64_t>& ndValue); | |||||
| private: | |||||
| /* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
| * formats. */ | |||||
| std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap; | |||||
| std::map<ge::DataType, uint32_t> mapOfDtypeAndC0; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
| @@ -0,0 +1,363 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /*! | |||||
| * \file util.h | |||||
| * \brief | |||||
| */ | |||||
| #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||||
| #define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||||
| #include <memory.h> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <algorithm> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "operator.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "transfer_shape_according_to_format.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/tensor.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "op_log.h" | |||||
| #define LOG_ERROR(format, args...) printf(format, ##args) | |||||
| namespace ge { | |||||
| // enum type and string type mapping | |||||
| static const std::map<ge::DataType, std::string> DTYPE_STR_MAP{ | |||||
| {ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"}, | |||||
| {ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, | |||||
| {ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}}; | |||||
| // define the input num of shape | |||||
| const size_t INPUT_NUM0 = 0; | |||||
| const size_t INPUT_NUM1 = 1; | |||||
| const size_t INPUT_NUM2 = 2; | |||||
| const size_t INPUT_NUM3 = 3; | |||||
| const size_t INPUT_NUM4 = 4; | |||||
| const size_t INPUT_NUM5 = 5; | |||||
| const size_t INPUT_NUM6 = 6; | |||||
| const size_t INPUT_NUM7 = 7; | |||||
| const size_t INPUT_NUM8 = 8; | |||||
| const size_t INPUT_NUM9 = 9; | |||||
| // define the dims size of shape | |||||
| const size_t DIM_SIZE0 = 0; | |||||
| const size_t DIM_SIZE1 = 1; | |||||
| const size_t DIM_SIZE2 = 2; | |||||
| const size_t DIM_SIZE3 = 3; | |||||
| const size_t DIM_SIZE4 = 4; | |||||
| const size_t DIM_SIZE5 = 5; | |||||
| const size_t DIM_SIZE6 = 6; | |||||
| const size_t DIM_SIZE7 = 7; | |||||
| const size_t DIM_SIZE8 = 8; | |||||
| // define the index of shape dim | |||||
| const size_t DIM_INDEX0 = 0; | |||||
| const size_t DIM_INDEX1 = 1; | |||||
| const size_t DIM_INDEX2 = 2; | |||||
| const size_t DIM_INDEX3 = 3; | |||||
| const size_t DIM_INDEX4 = 4; | |||||
| const size_t DIM_INDEX5 = 5; | |||||
| const size_t DIM_INDEX6 = 6; | |||||
| const size_t DIM_INDEX7 = 7; | |||||
| const size_t DIM_INDEX8 = 8; | |||||
| /* | |||||
| * get the datatype of input | |||||
| * param[in] dataType input datatype of enum value | |||||
| * param[in] supportList the support range of op | |||||
| * return true :get type success | |||||
| * false:get type failed | |||||
| */ | |||||
| bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList); | |||||
| bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType); | |||||
| /* infer shape of two input and on output with broadcast | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] inputName1 first input name | |||||
| * param[in] inputName2 second input name | |||||
| * param[in] outputName output name | |||||
| * return SUCCESS:infer success | |||||
| * FAILED:infer failed like unsupported broadcast input shape | |||||
| */ | |||||
| bool CheckInputDataType(const Operator& op, const std::string& input_name, | |||||
| const std::vector<ge::DataType>& support_list); | |||||
| /* | |||||
| * check the datatype and shape of input | |||||
| * param[in] op the operator | |||||
| * param[in] inputTensorMap the map of input name and support datatype | |||||
| * param[in] paramType the mode of input param, tensor or scalar | |||||
| * return true | |||||
| * false | |||||
| */ | |||||
| bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap); | |||||
| /* | |||||
| * infer shape of two input and on output with broadcast | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] inputName1 first input name | |||||
| * param[in] inputName2 second input name | |||||
| * param[in] outputName output name | |||||
| * return SUCCESS:infer success | |||||
| * FAILED:infer failed like unsupported broadcast input shape | |||||
| */ | |||||
| bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
| const string& output_name); | |||||
| /* | |||||
| * infer shape of two input and on output with broadcast | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] inputName1 first input name | |||||
| * param[in] inputName2 second input name | |||||
| * param[in] outputName output name | |||||
| * param[in] is_dynamic whether the shape of output is dynamic shape | |||||
| * return SUCCESS:infer success | |||||
| * FAILED:infer failed like unsupported broadcast input shape | |||||
| */ | |||||
| bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
| const string& output_name, bool& is_dynamic); | |||||
| bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2, | |||||
| const string& output_name); | |||||
| bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name, | |||||
| const std::vector<ge::DataType>& supportList); | |||||
| bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2); | |||||
| bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors); | |||||
| bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names); | |||||
| bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value); | |||||
| bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value); | |||||
| bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value); | |||||
| bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value); | |||||
| /** | |||||
| * Get int type const value from tensor data | |||||
| * @param [in] data const tensor data | |||||
| * @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64 | |||||
| * @param [out] const_values const int values | |||||
| * @return true:success, false:failed. | |||||
| */ | |||||
| bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values); | |||||
| bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, | |||||
| std::vector<int64_t>& const_data); | |||||
| bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype, | |||||
| std::vector<int64_t>& const_data); | |||||
| bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data); | |||||
| bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
| const string& output_name); | |||||
| /* | |||||
| * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] inputNumBeg input index begin, [0, N] | |||||
| * param[in] inputNumEnd input index end need to be checked | |||||
| * param[in] supportList, support type of ge::DataType and ge::Format | |||||
| * return true: check pass | |||||
| * false: check failed | |||||
| */ | |||||
| template <typename T> | |||||
| bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd, | |||||
| const std::vector<T>& supportList) { | |||||
| for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) { | |||||
| if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||||
| ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||||
| const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
| if (findDtype == supportList.end()) { | |||||
| return false; | |||||
| } | |||||
| } else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||||
| ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||||
| const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
| if (findDtype == supportList.end()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| /* | |||||
| * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] indexNeedCheck input index need to be checked | |||||
| * param[in] supportList, support type of ge::DataType and ge::Format | |||||
| * return true: check pass | |||||
| * false: check failed | |||||
| */ | |||||
| template <typename T> | |||||
| bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector<std::size_t>& indexNeedCheck, | |||||
| const std::vector<T>& supportList) { | |||||
| for (auto i : indexNeedCheck) { | |||||
| if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||||
| ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||||
| const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
| if (findDtype == supportList.end()) { | |||||
| return false; | |||||
| } | |||||
| } else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||||
| ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||||
| const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
| if (findDtype == supportList.end()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| /* | |||||
| * get const attr | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] attrName list need to be get | |||||
| * param[out] attr vector | |||||
| * return true: get success | |||||
| * false: get failed | |||||
| */ | |||||
| template <typename T> | |||||
| bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, std::vector<T>& attrVec) { | |||||
| T value; | |||||
| for (auto name : attrNameList) { | |||||
| if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) { | |||||
| return false; | |||||
| } | |||||
| attrVec.push_back(value); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| /* | |||||
| * get const attr list | |||||
| * param[in] op op desc supply by ge | |||||
| * param[in] attrName list need to be get | |||||
| * param[out] attr vector | |||||
| * return true: get success | |||||
| * false: get failed | |||||
| */ | |||||
| template <typename T> | |||||
| bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, | |||||
| std::vector<std::vector<T>>& attrListVec) { | |||||
| for (auto name : attrNameList) { | |||||
| std::vector<T> valueList; | |||||
| if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) { | |||||
| return false; | |||||
| } | |||||
| attrListVec.push_back(valueList); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::string to_string(const vector<int64_t>& shape); | |||||
| std::string to_string(const ge::Shape& shape); | |||||
| std::string to_string(const ge::GeShape& shape); | |||||
| std::string to_string(const vector<pair<int64_t, int64_t>>& ranges); | |||||
| class DynamicShapeInfer { | |||||
| public: | |||||
| std::map<std::string, Format> map_format; | |||||
| std::map<std::string, DataType> map_dtype; | |||||
| std::map<std::string, uint32_t> inputs; | |||||
| std::map<std::string, uint32_t> outputs; | |||||
| Operator& op; | |||||
| OpDescPtr& op_desc; | |||||
| DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) { | |||||
| } | |||||
| bool CatchFormatAndShape(); | |||||
| bool UpdateFormatAndShape(); | |||||
| ~DynamicShapeInfer() { | |||||
| UpdateFormatAndShape(); | |||||
| } | |||||
| }; | |||||
| #define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\ | |||||
| do { \ | |||||
| if (!depends_names.empty()) { \ | |||||
| op_desc->SetOpInferDepends(depends_names); \ | |||||
| } \ | |||||
| } while(0) | |||||
| bool IsEmptyTensor(const std::vector<int64_t>& dims); | |||||
| bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||||
| bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec); | |||||
| bool IsUnKnownShape(const std::vector<int64_t>& shape_vec); | |||||
| bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||||
| bool IsUnknownVec(std::vector<int64_t>& shape_vec); | |||||
| bool IsUnknown(const std::vector<int64_t>& shape_vec); | |||||
| void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range); | |||||
| std::string DataTypeToStringDesc(const ge::DataType& dataType); | |||||
| bool OneInOneOutDynamicInfer(const Operator& op, | |||||
| const std::string& input_name, | |||||
| const std::vector<std::string>& output_name_list); | |||||
| bool TwoInOneOutDynamicInferNoBroadcast(Operator& op, | |||||
| const string& input1_name, | |||||
| const string& input2_name, | |||||
| const std::vector<string>& output_name_list); | |||||
| void FixShapeRangeWithDims(const std::vector<int64_t>& dims, | |||||
| std::vector<int64_t>& shape_1, | |||||
| std::vector<int64_t>& shape_2, | |||||
| std::vector<std::pair<int64_t, int64_t>>& range_1, | |||||
| std::vector<std::pair<int64_t, int64_t>>& range_2); | |||||
| bool SetScalarOutputDesc(const string& input, | |||||
| const string& output, | |||||
| OpDescPtr op_desc, | |||||
| GeShape& output_shape); | |||||
| namespace array_ops { | |||||
| bool CheckInt64MulOverflow(int64_t a, int64_t b); | |||||
| void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||||
| int64_t& range_max); | |||||
| void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||||
| std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape); | |||||
| } | |||||
| } // namespace ge | |||||
| #endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||||
| @@ -0,0 +1,17 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph_assertion.h" | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||||
| #define GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||||
| /* | |||||
| * Compare graph node size, node_attr | |||||
| */ | |||||
| #define ASSERT_GRAPH_EQUAL(g1,g2) \ | |||||
| do { \ | |||||
| } while (0) | |||||
| #define ASSERT_GRAPH_CORRECT(g) \ | |||||
| do { \ | |||||
| } while (0) | |||||
| #define ASSERT_GRAPH_SHAPE_CONTINOUS(g) \ | |||||
| do { \ | |||||
| } while (0) | |||||
| #endif // GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph_builder_utils.h" | |||||
| #include "inc/external/graph/operator.h" | |||||
| #include "inc/external/graph/operator_factory.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||||
| DataType data_type, std::vector<int64_t> shape) { | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < in_cnt; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < out_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| return graph_->AddNode(op_desc); | |||||
| } | |||||
| void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
| } | |||||
| void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
| } | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| class ComputeGraphBuilder { | |||||
| public: | |||||
| explicit ComputeGraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
| void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||||
| void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||||
| ComputeGraphPtr GetComputeGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return graph_; | |||||
| } | |||||
| Graph GetGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||||
| } | |||||
| private: | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| #endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| @@ -0,0 +1,17 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tensor_builder_utils.h" | |||||
| @@ -0,0 +1,22 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| #define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| class tensor_builder_utils {}; | |||||
| #endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| @@ -0,0 +1,15 @@ | |||||
| file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
| add_executable(graph_engine_test ${SOURCES}) | |||||
| target_include_directories(graph_engine_test | |||||
| PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ) | |||||
| set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 11) | |||||
| target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) | |||||
| include(CTest) | |||||
| enable_testing() | |||||
| add_test(NAME test COMMAND graph_engine_test) | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <map> | |||||
| #include "external/ge/ge_api.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "framework.h" | |||||
| #include "framework/utils/builder/graph_builder_utils.h" | |||||
| using namespace std; | |||||
| using namespace ge; | |||||
| class FrameworkTest : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| // ge initialize | |||||
| map<AscendString, AscendString> options; | |||||
| auto ret = ge::GEInitialize(options); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(FrameworkTest, test_framework_dummy) { | |||||
| // build graph | |||||
| st::ComputeGraphBuilder graphBuilder("g1"); | |||||
| auto data1 = graphBuilder.AddNode("data1",DATA,1,1); | |||||
| auto data2 = graphBuilder.AddNode("data2",DATA,1,1); | |||||
| auto add = graphBuilder.AddNode("add",ADD,2,1); | |||||
| graphBuilder.AddDataEdge(data1, 0, add,0); | |||||
| graphBuilder.AddDataEdge(data2, 0, add,1); | |||||
| Graph graph = graphBuilder.GetGraph(); | |||||
| // new session & add graph | |||||
| map<AscendString, AscendString> options; | |||||
| Session session(options); | |||||
| auto ret = session.AddGraph(1, graph, options); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| // build input tensor | |||||
| std::vector<InputTensorInfo> inputs; | |||||
| // build_graph through session | |||||
| ret = session.BuildGraph(1, inputs); | |||||
| // TODO check result | |||||
| } | |||||